mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 03:14:35 +00:00
Metal: faster Q4_0 and Q4_1 matrix x vector kernels (#2212)
* 3-5% faster Q4_0 on Metal * 7-25% faster Q4_1 on Metal * Oops, forgot to delete the original Q4_1 kernel --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
parent
32c5411631
commit
27ad57a69b
@ -739,12 +739,8 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
||||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
||||||
|
|
||||||
if (src0t == GGML_TYPE_Q4_0) {
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01 / 8+((ne01 % 8) & 0x01), ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
|
||||||
else if (src0t == GGML_TYPE_Q4_1) {
|
|
||||||
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
||||||
}
|
}
|
||||||
else if (src0t == GGML_TYPE_Q2_K ||
|
else if (src0t == GGML_TYPE_Q2_K ||
|
||||||
src0t == GGML_TYPE_Q3_K ||
|
src0t == GGML_TYPE_Q3_K ||
|
||||||
|
144
ggml-metal.metal
144
ggml-metal.metal
@ -395,9 +395,12 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|||||||
// each thread in a SIMD group deals with 1 block.
|
// each thread in a SIMD group deals with 1 block.
|
||||||
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
|
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
|
||||||
|
|
||||||
|
float sumy = 0;
|
||||||
for (int i = 0; i < QK4_0 / 4; i++) {
|
for (int i = 0; i < QK4_0 / 4; i++) {
|
||||||
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
|
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
|
||||||
|
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
||||||
}
|
}
|
||||||
|
sumy *= (-8.f);
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; row++) {
|
for (int row = 0; row < N_DST; row++) {
|
||||||
// prefetch next x block
|
// prefetch next x block
|
||||||
@ -405,19 +408,30 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|||||||
|
|
||||||
// calculate
|
// calculate
|
||||||
float d = qb_curr.d;
|
float d = qb_curr.d;
|
||||||
float2 acc = {0.0f, 0.0f};
|
float acc = sumy;
|
||||||
for (int i = 0; i < 16; i++) {
|
for (int i = 0; i < 16; i++) {
|
||||||
acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
|
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
|
||||||
acc[1] += yl[i] + yl[i+16];
|
|
||||||
}
|
}
|
||||||
sumf[row] += d * (acc[0] - 8.f*acc[1]);
|
sumf[row] += d * acc;
|
||||||
qb_curr = qb_next;
|
qb_curr = qb_next;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (nb % N_SIMDWIDTH == 0) {
|
||||||
|
for (int row = 0; row < N_DST; ++row) {
|
||||||
|
all_sum = simd_sum(sumf[row]);
|
||||||
|
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
|
||||||
|
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
|
||||||
|
float sumy = 0;
|
||||||
for (int i = 0; i < QK4_0 / 4; i++) {
|
for (int i = 0; i < QK4_0 / 4; i++) {
|
||||||
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
|
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
|
||||||
|
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
||||||
}
|
}
|
||||||
|
sumy *= (-8.f);
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; row++) {
|
for (int row = 0; row < N_DST; row++) {
|
||||||
// prefetch next x block
|
// prefetch next x block
|
||||||
@ -425,13 +439,12 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|||||||
|
|
||||||
// calculate
|
// calculate
|
||||||
float d = qb_curr.d;
|
float d = qb_curr.d;
|
||||||
float2 acc = {0.0f, 0.0f};
|
float acc = sumy;
|
||||||
for (int i = 0; i < 16; i++) {
|
for (int i = 0; i < 16; i++) {
|
||||||
acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
|
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
|
||||||
acc[1] += yl[i] + yl[i+16];
|
|
||||||
}
|
}
|
||||||
if (tiisg < nb % N_SIMDWIDTH) {
|
if (tiisg < nb % N_SIMDWIDTH) {
|
||||||
sumf[row] += d * (acc[0] - 8.f*acc[1]);
|
sumf[row] += d * acc;
|
||||||
}
|
}
|
||||||
qb_curr = qb_next;
|
qb_curr = qb_next;
|
||||||
|
|
||||||
@ -440,6 +453,7 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|||||||
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
|
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mat_q4_1_f32(
|
kernel void kernel_mul_mat_q4_1_f32(
|
||||||
@ -449,65 +463,83 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne10,
|
constant int64_t & ne10,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
threadgroup float * sum [[threadgroup(0)]],
|
constant int64_t & ne01[[buffer(4)]],
|
||||||
uint2 tgpig[[threadgroup_position_in_grid]],
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint2 tpitg[[thread_position_in_threadgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint2 tptg[[threads_per_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
const int nb = ne00/QK4_1;
|
const int nb = ne00/QK4_0;
|
||||||
|
const int r0 = tgpig.x;
|
||||||
const int64_t r0 = tgpig.x;
|
const int r1 = tgpig.y;
|
||||||
const int64_t r1 = tgpig.y;
|
device const block_q4_1 * x = (device const block_q4_1 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
|
||||||
|
|
||||||
device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
|
|
||||||
device const float * y = (device const float *) src1 + r1*ne10;
|
device const float * y = (device const float *) src1 + r1*ne10;
|
||||||
|
block_q4_1 qb_curr, qb_next;
|
||||||
|
float4 y_curr[8]; // src1 vector cache
|
||||||
|
float sumf[N_DST]={0.f}, all_sum;
|
||||||
|
thread float * yl=(thread float *)y_curr;
|
||||||
|
|
||||||
const uint nth = tptg.x*tptg.y;
|
// bootstrap
|
||||||
const uint ith = tptg.y*tpitg.x + tpitg.y;
|
qb_curr = x[tiisg];
|
||||||
|
// each thread in a SIMD group deals with 1 block.
|
||||||
const int ix = tpitg.y/4; // 0 or 1
|
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
|
||||||
const int iy = tpitg.y - 4*ix; // 0...3
|
|
||||||
|
|
||||||
const int first = 4 * iy;
|
|
||||||
|
|
||||||
float sumf = 0;
|
|
||||||
|
|
||||||
for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
|
|
||||||
|
|
||||||
const float d = (float)x[i].d;
|
|
||||||
const float m = (float)x[i].m;
|
|
||||||
|
|
||||||
device const uint8_t * xl = x[i].qs + first;
|
|
||||||
device const float * yl = y + i * QK4_1 + first;
|
|
||||||
|
|
||||||
float2 acc = {0.0f, 0.0f};
|
|
||||||
|
|
||||||
for (int j = 0; j < 4; ++j) {
|
|
||||||
|
|
||||||
acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
|
|
||||||
acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m);
|
|
||||||
|
|
||||||
|
float sumy = 0;
|
||||||
|
for (int i = 0; i < QK4_0 / 4; i++) {
|
||||||
|
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
|
||||||
|
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
||||||
}
|
}
|
||||||
|
|
||||||
sumf += acc[0] + acc[1];
|
for (int row = 0; row < N_DST; row++) {
|
||||||
|
// prefetch next x block
|
||||||
|
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
|
||||||
|
|
||||||
|
// calculate
|
||||||
|
const float d = qb_curr.d;
|
||||||
|
const float m = qb_curr.m;
|
||||||
|
float acc = 0.f;
|
||||||
|
for (int i = 0; i < 16; i++) {
|
||||||
|
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
|
||||||
|
}
|
||||||
|
sumf[row] += d * acc + m * sumy;
|
||||||
|
qb_curr = qb_next;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sum[ith] = sumf;
|
if (nb % N_SIMDWIDTH == 0) {
|
||||||
|
for (int row = 0; row < N_DST; ++row) {
|
||||||
|
all_sum = simd_sum(sumf[row]);
|
||||||
|
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
|
||||||
|
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
|
||||||
//
|
float sumy = 0;
|
||||||
// Accumulate the sum from all threads in the threadgroup
|
for (int i = 0; i < QK4_0 / 4; i++) {
|
||||||
//
|
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
||||||
if (ith%4 == 0) {
|
}
|
||||||
sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
|
|
||||||
|
for (int row = 0; row < N_DST; row++) {
|
||||||
|
// prefetch next x block
|
||||||
|
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
|
||||||
|
|
||||||
|
// calculate
|
||||||
|
const float d = qb_curr.d;
|
||||||
|
const float m = qb_curr.m;
|
||||||
|
float acc = 0.f;
|
||||||
|
for (int i = 0; i < 16; i++) {
|
||||||
|
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
|
||||||
|
}
|
||||||
|
if (tiisg < nb % N_SIMDWIDTH) {
|
||||||
|
sumf[row] += d * acc + m * sumy;
|
||||||
|
}
|
||||||
|
qb_curr = qb_next;
|
||||||
|
|
||||||
|
all_sum = simd_sum(sumf[row]);
|
||||||
|
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
|
||||||
|
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
if (ith%16 == 0) {
|
|
||||||
sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
|
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
if (ith == 0) {
|
|
||||||
for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
|
|
||||||
dst[r1*ne0 + r0] = sum[0];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user