From 8c1b186cb5eac4abd26b73a8cb2a2248808491e7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 12 Nov 2024 15:30:51 +0200 Subject: [PATCH] metal : minor Q4_0 optimization --- ggml/src/ggml-metal.metal | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index ff6aff28e..a5b631cf9 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1441,18 +1441,18 @@ kernel void kernel_group_norm( inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; - float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; + float acc = -8.0f*sumy; device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2); - for (int i = 0; i < 8; i += 2) { - acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F); - acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0); - acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000); + for (short i = 0; i < 4; ++i) { + acc += yl[2*i + 0] * (qs[i] & 0x000F); + acc += yl[2*i + 1] * (qs[i] & 0x0F00); + acc += yl[2*i + 8] * (qs[i] & 0x00F0); + acc += yl[2*i + 9] * (qs[i] & 0xF000); } - return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]); + return d * acc; } // function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) @@ -1567,29 +1567,28 @@ void mul_vec_q_n_f32_impl( float yl[16]; // src1 vector cache float sumf[nr] = {0.f}; - const short ix = (tiisg/2); - const short il = (tiisg%2)*8; + const short ix = (tiisg%16); + const short il = (tiisg/16)*8; device const float * yb = y + ix*QK4_0 + il; // each thread in a SIMD group deals with half a block. for (int ib = ix; ib < nb; ib += nw/2) { - float sumy[2] = { 0.f, 0.f }; + float sumy = 0.0f; -#pragma unroll +#pragma unroll(4) for (int i = 0; i < 8; i += 2) { - sumy[0] += yb[i + 0] + yb[i + 1]; + sumy += yb[i + 0] + yb[i + 1] + yb[i + 16] + yb[i + 17]; + yl[i + 0] = yb[i + 0]; yl[i + 1] = yb[i + 1]/256.f; - - sumy[1] += yb[i + 16] + yb[i + 17]; yl[i + 8] = yb[i + 16]/16.f; yl[i + 9] = yb[i + 17]/4096.f; } -#pragma unroll - for (int row = 0; row < nr; row++) { - sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); +#pragma unroll(nr) + for (short row = 0; row < nr; row++) { + sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il); } yb += QK4_0 * 16; @@ -1597,7 +1596,7 @@ void mul_vec_q_n_f32_impl( device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; - for (int row = 0; row < nr; ++row) { + for (short row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0 && first_row + row < args.ne01) {