metal : add poc for normalized Q4_0 and Q4_1

This commit is contained in:
Georgi Gerganov 2023-08-30 18:32:43 +03:00
parent 9ffe54ed10
commit b4e70822f6
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 45 additions and 29 deletions

View File

@ -697,6 +697,9 @@ void ggml_metal_graph_compute(
} break; } break;
case GGML_OP_MUL: case GGML_OP_MUL:
{ {
GGML_ASSERT(ne00 % 4 == 0);
const int64_t nb = ne00/4;
if (ggml_nelements(src1) == ne10) { if (ggml_nelements(src1) == ne10) {
// src1 is a row // src1 is a row
[encoder setComputePipelineState:ctx->pipeline_mul_row]; [encoder setComputePipelineState:ctx->pipeline_mul_row];
@ -706,9 +709,9 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
const int64_t n = ggml_nelements(dst); const int64_t n = ggml_nelements(dst)/4;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;

View File

@ -4,17 +4,22 @@ using namespace metal;
#define MAX(x, y) ((x) > (y) ? (x) : (y)) #define MAX(x, y) ((x) > (y) ? (x) : (y))
#define Q4_0DM (1.0f/8.0f)
#define Q4_0D(x) (((x)*Q4_0DM) / 127.0f)
#define QK4_0 32 #define QK4_0 32
#define QR4_0 2 #define QR4_0 2
typedef struct { typedef struct {
half d; // delta int8_t d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants uint8_t qs[QK4_0 / 2]; // nibbles / quants
} block_q4_0; } block_q4_0;
#define Q4_1DM (2.0f/15.0f)
#define Q4_1MM (2.0f )
#define Q4_1D(x) ( (((x) & 0xFF)*Q4_1DM) / 255.0f)
#define Q4_1M(x) (-1.0f + (((x) >> 8)*Q4_1MM) / 255.0f)
#define QK4_1 32 #define QK4_1 32
typedef struct { typedef struct {
half d; // delta uint16_t dm;
half m; // min
uint8_t qs[QK4_1 / 2]; // nibbles / quants uint8_t qs[QK4_1 / 2]; // nibbles / quants
} block_q4_1; } block_q4_1;
@ -44,9 +49,9 @@ kernel void kernel_add_row(
} }
kernel void kernel_mul( kernel void kernel_mul(
device const float * src0, device const float4 * src0,
device const float * src1, device const float4 * src1,
device float * dst, device float4 * dst,
uint tpig[[thread_position_in_grid]]) { uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * src1[tpig]; dst[tpig] = src0[tpig] * src1[tpig];
} }
@ -54,12 +59,12 @@ kernel void kernel_mul(
// assumption: src1 is a row // assumption: src1 is a row
// broadcast src1 into src0 // broadcast src1 into src0
kernel void kernel_mul_row( kernel void kernel_mul_row(
device const float * src0, device const float4 * src0,
device const float * src1, device const float4 * src1,
device float * dst, device float4 * dst,
constant int64_t & ne00, constant int64_t & nb,
uint tpig[[thread_position_in_grid]]) { uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * src1[tpig % ne00]; dst[tpig] = src0[tpig] * src1[tpig % nb];
} }
kernel void kernel_scale( kernel void kernel_scale(
@ -314,14 +319,18 @@ kernel void kernel_rms_norm(
// we assume that the yl's have been multiplied with the appropriate scale factor // we assume that the yl's have been multiplied with the appropriate scale factor
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d; float d = Q4_0D(qb_curr->d);
float2 acc = 0.f; float2 acc = 0.f;
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); device const uint8_t * qs = ((device const uint8_t *)qb_curr->qs + il);
uint16_t qs16;
for (int i = 0; i < 8; i+=2) { for (int i = 0; i < 8; i+=2) {
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) qs16 = qs[i+1];
+ yl[i + 1] * (qs[i / 2] & 0x0F00); qs16 <<= 8;
acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) qs16 |= qs[i];
+ yl[i + 9] * (qs[i / 2] & 0xF000); acc[0] += yl[i + 0] * (qs16 & 0x000F)
+ yl[i + 1] * (qs16 & 0x0F00);
acc[1] += yl[i + 8] * (qs16 & 0x00F0)
+ yl[i + 9] * (qs16 & 0xF000);
} }
return d * (sumy * -8.f + acc[0] + acc[1]); return d * (sumy * -8.f + acc[0] + acc[1]);
} }
@ -331,9 +340,9 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
// we assume that the yl's have been multiplied with the appropriate scale factor // we assume that the yl's have been multiplied with the appropriate scale factor
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d; float d = Q4_1D(qb_curr->dm);
float m = qb_curr->m; float m = Q4_1M(qb_curr->dm);
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
float2 acc = 0.f; float2 acc = 0.f;
for (int i = 0; i < 8; i+=2) { for (int i = 0; i < 8; i+=2) {
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
@ -1686,23 +1695,27 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
template <typename type4x4> template <typename type4x4>
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1); device const uint8_t * qs = ((device const uint8_t *)xb->qs);
const half d = il ? (xb->d / 16.h) : xb->d; const half d = il ? (Q4_0D(xb->d) / 16.h) : Q4_0D(xb->d);
const half m = il ? ( -8.h * 16.h) : -8.h; const half m = il ? ( -8.h * 16.h) : -8.h;
const ushort mask0 = il ? 0x00F0 : 0x000F; const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = il ? 0xF000 : 0x0F00; const ushort mask1 = il ? 0xF000 : 0x0F00;
uint16_t qs16;
for (int i=0;i<8;i++) { for (int i=0;i<8;i++) {
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d; qs16 = qs[2*i+1];
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d; qs16 <<= 8;
qs16 |= qs[2*i];
reg[i/2][2*(i%2)] = (((qs16 & mask0) ) + m) * d;
reg[i/2][2*(i%2)+1] = (((qs16 & mask1) >> 8) + m) * d;
} }
} }
template <typename type4x4> template <typename type4x4>
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 2); device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const half d = il ? (xb->d / 16.h) : xb->d; const half d = il ? (Q4_1D(xb->dm) / 16.h) : Q4_1D(xb->dm);
const half m = xb->m; const half m = Q4_1M(xb->dm);
const ushort mask0 = il ? 0x00F0 : 0x000F; const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = il ? 0xF000 : 0x0F00; const ushort mask1 = il ? 0xF000 : 0x0F00;