From b4e70822f6282a6b3c9ae53a282d30c9d5ccf70f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 30 Aug 2023 18:32:43 +0300 Subject: [PATCH] metal : add poc for normalized Q4_0 and Q4_1 --- ggml-metal.m | 7 +++-- ggml-metal.metal | 67 +++++++++++++++++++++++++++++------------------- 2 files changed, 45 insertions(+), 29 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index e929c4b07..1aaff6a93 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -697,6 +697,9 @@ void ggml_metal_graph_compute( } break; case GGML_OP_MUL: { + GGML_ASSERT(ne00 % 4 == 0); + const int64_t nb = ne00/4; + if (ggml_nelements(src1) == ne10) { // src1 is a row [encoder setComputePipelineState:ctx->pipeline_mul_row]; @@ -706,9 +709,9 @@ void ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&nb length:sizeof(nb) atIndex:3]; - const int64_t n = ggml_nelements(dst); + const int64_t n = ggml_nelements(dst)/4; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 82e1a0c7a..bfb32eccd 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -4,17 +4,22 @@ using namespace metal; #define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define Q4_0DM (1.0f/8.0f) +#define Q4_0D(x) (((x)*Q4_0DM) / 127.0f) #define QK4_0 32 #define QR4_0 2 typedef struct { - half d; // delta + int8_t d; // delta uint8_t qs[QK4_0 / 2]; // nibbles / quants } block_q4_0; +#define Q4_1DM (2.0f/15.0f) +#define Q4_1MM (2.0f ) +#define Q4_1D(x) ( (((x) & 0xFF)*Q4_1DM) / 255.0f) +#define Q4_1M(x) (-1.0f + (((x) >> 8)*Q4_1MM) / 255.0f) #define QK4_1 32 typedef struct { - half d; // delta - half m; // min + uint16_t dm; uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; @@ -44,9 +49,9 @@ kernel void kernel_add_row( } kernel void kernel_mul( - device const float * src0, - device const float * src1, - device float * dst, + device const float4 * src0, + device const float4 * src1, + device float4 * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] * src1[tpig]; } @@ -54,12 +59,12 @@ kernel void kernel_mul( // assumption: src1 is a row // broadcast src1 into src0 kernel void kernel_mul_row( - device const float * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + constant int64_t & nb, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src1[tpig % ne00]; + dst[tpig] = src0[tpig] * src1[tpig % nb]; } kernel void kernel_scale( @@ -314,14 +319,18 @@ kernel void kernel_rms_norm( // we assume that the yl's have been multiplied with the appropriate scale factor // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; + float d = Q4_0D(qb_curr->d); float2 acc = 0.f; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); + device const uint8_t * qs = ((device const uint8_t *)qb_curr->qs + il); + uint16_t qs16; for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); + qs16 = qs[i+1]; + qs16 <<= 8; + qs16 |= qs[i]; + acc[0] += yl[i + 0] * (qs16 & 0x000F) + + yl[i + 1] * (qs16 & 0x0F00); + acc[1] += yl[i + 8] * (qs16 & 0x00F0) + + yl[i + 9] * (qs16 & 0xF000); } return d * (sumy * -8.f + acc[0] + acc[1]); } @@ -331,9 +340,9 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre // we assume that the yl's have been multiplied with the appropriate scale factor // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; - float m = qb_curr->m; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + float d = Q4_1D(qb_curr->dm); + float m = Q4_1M(qb_curr->dm); + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); float2 acc = 0.f; for (int i = 0; i < 8; i+=2) { acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) @@ -1686,23 +1695,27 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) template void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 1); - const half d = il ? (xb->d / 16.h) : xb->d; + device const uint8_t * qs = ((device const uint8_t *)xb->qs); + const half d = il ? (Q4_0D(xb->d) / 16.h) : Q4_0D(xb->d); const half m = il ? ( -8.h * 16.h) : -8.h; const ushort mask0 = il ? 0x00F0 : 0x000F; const ushort mask1 = il ? 0xF000 : 0x0F00; + uint16_t qs16; for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d; - reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d; + qs16 = qs[2*i+1]; + qs16 <<= 8; + qs16 |= qs[2*i]; + reg[i/2][2*(i%2)] = (((qs16 & mask0) ) + m) * d; + reg[i/2][2*(i%2)+1] = (((qs16 & mask1) >> 8) + m) * d; } } template void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 2); - const half d = il ? (xb->d / 16.h) : xb->d; - const half m = xb->m; + device const uint16_t * qs = ((device const uint16_t *)xb + 1); + const half d = il ? (Q4_1D(xb->dm) / 16.h) : Q4_1D(xb->dm); + const half m = Q4_1M(xb->dm); const ushort mask0 = il ? 0x00F0 : 0x000F; const ushort mask1 = il ? 0xF000 : 0x0F00;