mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 12:10:18 +00:00
metal : add poc for normalized Q4_0 and Q4_1
This commit is contained in:
parent
9ffe54ed10
commit
b4e70822f6
@ -697,6 +697,9 @@ void ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
{
|
{
|
||||||
|
GGML_ASSERT(ne00 % 4 == 0);
|
||||||
|
const int64_t nb = ne00/4;
|
||||||
|
|
||||||
if (ggml_nelements(src1) == ne10) {
|
if (ggml_nelements(src1) == ne10) {
|
||||||
// src1 is a row
|
// src1 is a row
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
||||||
@ -706,9 +709,9 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];
|
||||||
|
|
||||||
const int64_t n = ggml_nelements(dst);
|
const int64_t n = ggml_nelements(dst)/4;
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
|
@ -4,17 +4,22 @@ using namespace metal;
|
|||||||
|
|
||||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||||
|
|
||||||
|
#define Q4_0DM (1.0f/8.0f)
|
||||||
|
#define Q4_0D(x) (((x)*Q4_0DM) / 127.0f)
|
||||||
#define QK4_0 32
|
#define QK4_0 32
|
||||||
#define QR4_0 2
|
#define QR4_0 2
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d; // delta
|
int8_t d; // delta
|
||||||
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
||||||
} block_q4_0;
|
} block_q4_0;
|
||||||
|
|
||||||
|
#define Q4_1DM (2.0f/15.0f)
|
||||||
|
#define Q4_1MM (2.0f )
|
||||||
|
#define Q4_1D(x) ( (((x) & 0xFF)*Q4_1DM) / 255.0f)
|
||||||
|
#define Q4_1M(x) (-1.0f + (((x) >> 8)*Q4_1MM) / 255.0f)
|
||||||
#define QK4_1 32
|
#define QK4_1 32
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d; // delta
|
uint16_t dm;
|
||||||
half m; // min
|
|
||||||
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
||||||
} block_q4_1;
|
} block_q4_1;
|
||||||
|
|
||||||
@ -44,9 +49,9 @@ kernel void kernel_add_row(
|
|||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_mul(
|
kernel void kernel_mul(
|
||||||
device const float * src0,
|
device const float4 * src0,
|
||||||
device const float * src1,
|
device const float4 * src1,
|
||||||
device float * dst,
|
device float4 * dst,
|
||||||
uint tpig[[thread_position_in_grid]]) {
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
dst[tpig] = src0[tpig] * src1[tpig];
|
dst[tpig] = src0[tpig] * src1[tpig];
|
||||||
}
|
}
|
||||||
@ -54,12 +59,12 @@ kernel void kernel_mul(
|
|||||||
// assumption: src1 is a row
|
// assumption: src1 is a row
|
||||||
// broadcast src1 into src0
|
// broadcast src1 into src0
|
||||||
kernel void kernel_mul_row(
|
kernel void kernel_mul_row(
|
||||||
device const float * src0,
|
device const float4 * src0,
|
||||||
device const float * src1,
|
device const float4 * src1,
|
||||||
device float * dst,
|
device float4 * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & nb,
|
||||||
uint tpig[[thread_position_in_grid]]) {
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
dst[tpig] = src0[tpig] * src1[tpig % ne00];
|
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_scale(
|
kernel void kernel_scale(
|
||||||
@ -314,14 +319,18 @@ kernel void kernel_rms_norm(
|
|||||||
// we assume that the yl's have been multiplied with the appropriate scale factor
|
// we assume that the yl's have been multiplied with the appropriate scale factor
|
||||||
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
||||||
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
|
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
|
||||||
float d = qb_curr->d;
|
float d = Q4_0D(qb_curr->d);
|
||||||
float2 acc = 0.f;
|
float2 acc = 0.f;
|
||||||
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
|
device const uint8_t * qs = ((device const uint8_t *)qb_curr->qs + il);
|
||||||
|
uint16_t qs16;
|
||||||
for (int i = 0; i < 8; i+=2) {
|
for (int i = 0; i < 8; i+=2) {
|
||||||
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
qs16 = qs[i+1];
|
||||||
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
qs16 <<= 8;
|
||||||
acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
|
qs16 |= qs[i];
|
||||||
+ yl[i + 9] * (qs[i / 2] & 0xF000);
|
acc[0] += yl[i + 0] * (qs16 & 0x000F)
|
||||||
|
+ yl[i + 1] * (qs16 & 0x0F00);
|
||||||
|
acc[1] += yl[i + 8] * (qs16 & 0x00F0)
|
||||||
|
+ yl[i + 9] * (qs16 & 0xF000);
|
||||||
}
|
}
|
||||||
return d * (sumy * -8.f + acc[0] + acc[1]);
|
return d * (sumy * -8.f + acc[0] + acc[1]);
|
||||||
}
|
}
|
||||||
@ -331,9 +340,9 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
|
|||||||
// we assume that the yl's have been multiplied with the appropriate scale factor
|
// we assume that the yl's have been multiplied with the appropriate scale factor
|
||||||
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
||||||
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
|
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
|
||||||
float d = qb_curr->d;
|
float d = Q4_1D(qb_curr->dm);
|
||||||
float m = qb_curr->m;
|
float m = Q4_1M(qb_curr->dm);
|
||||||
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
|
||||||
float2 acc = 0.f;
|
float2 acc = 0.f;
|
||||||
for (int i = 0; i < 8; i+=2) {
|
for (int i = 0; i < 8; i+=2) {
|
||||||
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
||||||
@ -1686,23 +1695,27 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
|
|||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
||||||
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
device const uint8_t * qs = ((device const uint8_t *)xb->qs);
|
||||||
const half d = il ? (xb->d / 16.h) : xb->d;
|
const half d = il ? (Q4_0D(xb->d) / 16.h) : Q4_0D(xb->d);
|
||||||
const half m = il ? ( -8.h * 16.h) : -8.h;
|
const half m = il ? ( -8.h * 16.h) : -8.h;
|
||||||
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
||||||
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
||||||
|
|
||||||
|
uint16_t qs16;
|
||||||
for (int i=0;i<8;i++) {
|
for (int i=0;i<8;i++) {
|
||||||
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
|
qs16 = qs[2*i+1];
|
||||||
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
|
qs16 <<= 8;
|
||||||
|
qs16 |= qs[2*i];
|
||||||
|
reg[i/2][2*(i%2)] = (((qs16 & mask0) ) + m) * d;
|
||||||
|
reg[i/2][2*(i%2)+1] = (((qs16 & mask1) >> 8) + m) * d;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
||||||
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
||||||
const half d = il ? (xb->d / 16.h) : xb->d;
|
const half d = il ? (Q4_1D(xb->dm) / 16.h) : Q4_1D(xb->dm);
|
||||||
const half m = xb->m;
|
const half m = Q4_1M(xb->dm);
|
||||||
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
||||||
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user