From 0e323856a67a274d9b42cf3da2d9095679cbda3c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 09:45:06 +0200 Subject: [PATCH] cont : shmem style --- ggml/src/ggml-metal.metal | 215 +++++++++++++++++++------------------- 1 file changed, 107 insertions(+), 108 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 710274174..bdf49610a 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1630,7 +1630,7 @@ void mul_vec_q_n_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -1753,7 +1753,7 @@ void kernel_mul_mv_q8_0_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -2571,7 +2571,7 @@ kernel void kernel_flash_attn_ext( device const char * v, device const char * mask, device char * dst, - threadgroup half * shared [[threadgroup(0)]], + threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], @@ -2591,17 +2591,17 @@ kernel void kernel_flash_attn_ext( const short TS = nsg*SH; // shared memory size per query in (s_t == float) const short T = D + 2*TS; // shared memory size per query in (half) - threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t - threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation - threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // same as above but in o4_t - threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix + threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t + threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*D); // reuse query data for accumulation + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*D); // same as above but in o4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix - threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory - threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t + threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory + threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t - threadgroup v_t * sv = (threadgroup v_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory - threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t + threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory + threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) o8x8_t lo[D8]; @@ -3063,7 +3063,7 @@ kernel void kernel_flash_attn_ext_vec( device const char * v, device const char * mask, device char * dst, - threadgroup half * shared [[threadgroup(0)]], + threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], @@ -3082,13 +3082,13 @@ kernel void kernel_flash_attn_ext_vec( const short T = D + nsg*SH; // shared memory size per query in (half) - //threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t - threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in q4x4_t - threadgroup s_t * ss = (threadgroup s_t *) (shared + sgitg*SH + Q*D); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + sgitg*SH + Q*D); // same as above but in s4_t - threadgroup half * sm = (threadgroup half *) (shared + sgitg*SH + C + Q*D); // scratch buffer for mask - threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t + threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shmem_f16 + 0*D); // same as above but in q4x4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*D); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*D); // same as above but in s4_t + threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*D); // scratch buffer for mask + threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D + Q*T); // scratch buffer for the results // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) o4x4_t lo[D16/NL]; @@ -3931,7 +3931,7 @@ void kernel_mul_mv_q2_K_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4036,7 +4036,7 @@ void kernel_mul_mv_q3_K_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4199,7 +4199,7 @@ void kernel_mul_mv_q4_K_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4317,7 +4317,7 @@ void kernel_mul_mv_q5_K_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4449,7 +4449,7 @@ void kernel_mul_mv_q6_K_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4542,7 +4542,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4568,15 +4568,15 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const int nb32 = nb * (QK_K / 32); - threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); + threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem); + threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256); { int nval = 4; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i]; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xxs_grid[pos + i]; nval = 2; pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i]; threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -4606,8 +4606,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl( float sum = 0; for (int l = 0; l < 4; ++l) { - const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]); - const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]); + const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; for (int j = 0; j < 8; ++j) { sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } @@ -4637,12 +4637,11 @@ kernel void kernel_mul_mv_iq2_xxs_f32( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values [[threadgroup(0)]], + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -4651,7 +4650,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4677,15 +4676,15 @@ void kernel_mul_mv_iq2_xs_f32_impl( const int nb32 = nb * (QK_K / 32); - threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512); + threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem); + threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 512); { int nval = 8; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i]; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xs_grid[pos + i]; nval = 2; pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i]; threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -4717,15 +4716,15 @@ void kernel_mul_mv_iq2_xs_f32_impl( float sum1 = 0, sum2 = 0; for (int l = 0; l < 2; ++l) { - const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); - const uint8_t signs = shared_signs[(q2[l] >> 9)]; + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); + const uint8_t signs = ssigns[(q2[l] >> 9)]; for (int j = 0; j < 8; ++j) { sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } } for (int l = 2; l < 4; ++l) { - const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); - const uint8_t signs = shared_signs[(q2[l] >> 9)]; + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); + const uint8_t signs = ssigns[(q2[l] >> 9)]; for (int j = 0; j < 8; ++j) { sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } @@ -4756,12 +4755,12 @@ kernel void kernel_mul_mv_iq2_xs_f32( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values [[threadgroup(0)]], + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -4770,7 +4769,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4796,15 +4795,15 @@ void kernel_mul_mv_iq3_xxs_f32_impl( const int nb32 = nb * (QK_K / 32); - threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); + threadgroup uint32_t * svalues = (threadgroup uint32_t *)(shmem); + threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256); { int nval = 4; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i]; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3xxs_grid[pos + i]; nval = 2; pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i]; threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -4834,9 +4833,9 @@ void kernel_mul_mv_iq3_xxs_f32_impl( float2 sum = {0}; for (int l = 0; l < 4; ++l) { - const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]); - const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]); - const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; + const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]); + const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]); + const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; for (int j = 0; j < 4; ++j) { sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); @@ -4868,12 +4867,12 @@ kernel void kernel_mul_mv_iq3_xxs_f32( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values [[threadgroup(0)]], + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -4882,7 +4881,7 @@ void kernel_mul_mv_iq3_s_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4908,11 +4907,11 @@ void kernel_mul_mv_iq3_s_f32_impl( const int nb32 = nb * (QK_K / 32); - threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; + threadgroup uint32_t * svalues = (threadgroup uint32_t *) shmem; { int nval = 8; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i]; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3s_grid[pos + i]; threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -4943,8 +4942,8 @@ void kernel_mul_mv_iq3_s_f32_impl( float2 sum = {0}; for (int l = 0; l < 4; ++l) { - const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values; - const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values; + const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues; + const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues; const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); for (int j = 0; j < 4; ++j) { @@ -4980,12 +4979,12 @@ kernel void kernel_mul_mv_iq3_s_f32( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values [[threadgroup(0)]], + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -4994,7 +4993,7 @@ void kernel_mul_mv_iq2_s_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -5020,11 +5019,11 @@ void kernel_mul_mv_iq2_s_f32_impl( const int nb32 = nb * (QK_K / 32); - //threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + //threadgroup uint64_t * svalues = (threadgroup uint64_t *) shmem; //{ // int nval = 32; // int pos = (32*sgitg + tiisg)*nval; - // for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i]; + // for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2s_grid[pos + i]; // threadgroup_barrier(mem_flags::mem_threadgroup); //} @@ -5056,8 +5055,8 @@ void kernel_mul_mv_iq2_s_f32_impl( float2 sum = {0}; for (int l = 0; l < 2; ++l) { - //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); - //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); + //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); + //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); for (int j = 0; j < 8; ++j) { @@ -5093,12 +5092,12 @@ kernel void kernel_mul_mv_iq2_s_f32( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values [[threadgroup(0)]], + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -5107,7 +5106,7 @@ void kernel_mul_mv_iq1_s_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_value, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -5193,7 +5192,7 @@ void kernel_mul_mv_iq1_m_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_value, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -5288,12 +5287,12 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values_i8, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { - threadgroup float * shared_values = (threadgroup float *)shared_values_i8; + threadgroup float * shmem_f32 = (threadgroup float *) shmem; const int nb = args.ne00/QK4_NL; const int r0 = tgpig.x; const int r1 = tgpig.y; @@ -5312,7 +5311,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( const int ix = tiisg/2; // 0...15 const int it = tiisg%2; // 0 or 1 - shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; @@ -5340,16 +5339,16 @@ void kernel_mul_mv_iq4_nl_f32_impl( aux32[0] = q4[0] | (q4[1] << 16); aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; aux32[0] &= 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; acc1 += yl[0] * qf1; acc2 += yl[1] * qf2; aux32[0] = q4[2] | (q4[3] << 16); aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; aux32[0] &= 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; acc1 += yl[2] * qf1; acc2 += yl[3] * qf2; @@ -5378,12 +5377,12 @@ void kernel_mul_mv_iq4_xs_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values_i8, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { - threadgroup float * shared_values = (threadgroup float *)shared_values_i8; + threadgroup float * shmem_f32 = (threadgroup float *) shmem; const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; @@ -5404,7 +5403,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( const int ib = it/2; const int il = it%2; - shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; @@ -5431,15 +5430,15 @@ void kernel_mul_mv_iq4_xs_f32_impl( aux32[0] = q4[0] & 0x0f0f0f0f; aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; acc1 += yl[0] * qf1; acc2 += yl[1] * qf2; aux32[0] = q4[1] & 0x0f0f0f0f; aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; acc1 += yl[2] * qf1; acc2 += yl[3] * qf2; @@ -5495,12 +5494,12 @@ kernel void kernel_mul_mv_iq4_nl_f32( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values [[threadgroup(0)]], + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_xs_f32")]] @@ -5509,12 +5508,12 @@ kernel void kernel_mul_mv_iq4_xs_f32( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values [[threadgroup(0)]], + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -6066,7 +6065,7 @@ typedef void (kernel_mul_mv2_impl_t)( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg); @@ -6077,11 +6076,11 @@ void mmv_fn( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiitg, - uint tiisg, - uint sgitg) { + threadgroup char * shmem, + uint3 tgpig, + uint tiitg, + uint tiisg, + uint sgitg) { impl_fn(args, src0, src1, dst, tgpig, tiisg); } @@ -6091,12 +6090,12 @@ void mmv_fn( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiitg, - uint tiisg, - uint sgitg) { - impl_fn(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); + threadgroup char * shmem, + uint3 tgpig, + uint tiitg, + uint tiisg, + uint sgitg) { + impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } typedef decltype(mmv_fn>) mul_mv_impl_fn_t; @@ -6108,11 +6107,11 @@ kernel void kernel_mul_mv_id( device const char * src1, device char * dst, device const char * ids, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const int iid1 = tgpig.z/args.nei0; const int idx = tgpig.z%args.nei0; @@ -6157,7 +6156,7 @@ kernel void kernel_mul_mv_id( /* src0 */ src0_cur, /* src1 */ src1_cur, /* dst */ dst_cur, - shared_values, + shmem, tgpig, tiitg, tiisg,