metal : add poc for normalized Q4_0 and Q4_1

This commit is contained in:
Georgi Gerganov 2023-08-30 18:32:43 +03:00
parent 9ffe54ed10
commit b4e70822f6
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 45 additions and 29 deletions

View File

@ -697,6 +697,9 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_MUL:
{
GGML_ASSERT(ne00 % 4 == 0);
const int64_t nb = ne00/4;
if (ggml_nelements(src1) == ne10) {
// src1 is a row
[encoder setComputePipelineState:ctx->pipeline_mul_row];
@ -706,9 +709,9 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];
const int64_t n = ggml_nelements(dst);
const int64_t n = ggml_nelements(dst)/4;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;

View File

@ -4,17 +4,22 @@ using namespace metal;
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define Q4_0DM (1.0f/8.0f)
#define Q4_0D(x) (((x)*Q4_0DM) / 127.0f)
#define QK4_0 32
#define QR4_0 2
typedef struct {
half d; // delta
int8_t d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
} block_q4_0;
#define Q4_1DM (2.0f/15.0f)
#define Q4_1MM (2.0f )
#define Q4_1D(x) ( (((x) & 0xFF)*Q4_1DM) / 255.0f)
#define Q4_1M(x) (-1.0f + (((x) >> 8)*Q4_1MM) / 255.0f)
#define QK4_1 32
typedef struct {
half d; // delta
half m; // min
uint16_t dm;
uint8_t qs[QK4_1 / 2]; // nibbles / quants
} block_q4_1;
@ -44,9 +49,9 @@ kernel void kernel_add_row(
}
kernel void kernel_mul(
device const float * src0,
device const float * src1,
device float * dst,
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * src1[tpig];
}
@ -54,12 +59,12 @@ kernel void kernel_mul(
// assumption: src1 is a row
// broadcast src1 into src0
kernel void kernel_mul_row(
device const float * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
constant int64_t & nb,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * src1[tpig % ne00];
dst[tpig] = src0[tpig] * src1[tpig % nb];
}
kernel void kernel_scale(
@ -314,14 +319,18 @@ kernel void kernel_rms_norm(
// we assume that the yl's have been multiplied with the appropriate scale factor
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d;
float d = Q4_0D(qb_curr->d);
float2 acc = 0.f;
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
device const uint8_t * qs = ((device const uint8_t *)qb_curr->qs + il);
uint16_t qs16;
for (int i = 0; i < 8; i+=2) {
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
+ yl[i + 9] * (qs[i / 2] & 0xF000);
qs16 = qs[i+1];
qs16 <<= 8;
qs16 |= qs[i];
acc[0] += yl[i + 0] * (qs16 & 0x000F)
+ yl[i + 1] * (qs16 & 0x0F00);
acc[1] += yl[i + 8] * (qs16 & 0x00F0)
+ yl[i + 9] * (qs16 & 0xF000);
}
return d * (sumy * -8.f + acc[0] + acc[1]);
}
@ -331,9 +340,9 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
// we assume that the yl's have been multiplied with the appropriate scale factor
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d;
float m = qb_curr->m;
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
float d = Q4_1D(qb_curr->dm);
float m = Q4_1M(qb_curr->dm);
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
float2 acc = 0.f;
for (int i = 0; i < 8; i+=2) {
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
@ -1686,23 +1695,27 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
template <typename type4x4>
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const half d = il ? (xb->d / 16.h) : xb->d;
device const uint8_t * qs = ((device const uint8_t *)xb->qs);
const half d = il ? (Q4_0D(xb->d) / 16.h) : Q4_0D(xb->d);
const half m = il ? ( -8.h * 16.h) : -8.h;
const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = il ? 0xF000 : 0x0F00;
uint16_t qs16;
for (int i=0;i<8;i++) {
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
qs16 = qs[2*i+1];
qs16 <<= 8;
qs16 |= qs[2*i];
reg[i/2][2*(i%2)] = (((qs16 & mask0) ) + m) * d;
reg[i/2][2*(i%2)+1] = (((qs16 & mask1) >> 8) + m) * d;
}
}
template <typename type4x4>
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
const half d = il ? (xb->d / 16.h) : xb->d;
const half m = xb->m;
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const half d = il ? (Q4_1D(xb->dm) / 16.h) : Q4_1D(xb->dm);
const half m = Q4_1M(xb->dm);
const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = il ? 0xF000 : 0x0F00;