From eea1f7e532422afe70da6acd1959240886ada682 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 09:57:41 +0200 Subject: [PATCH] cont : thread counters style --- ggml/src/ggml-metal.metal | 281 +++++++++++++++++++------------------- 1 file changed, 141 insertions(+), 140 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index bdf49610a..82e9c937b 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1631,9 +1631,9 @@ void mul_vec_q_n_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK4_0; const int r0 = tgpig.x; @@ -1706,9 +1706,9 @@ kernel void kernel_mul_mv_q4_0_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -1717,9 +1717,9 @@ kernel void kernel_mul_mv_q4_1_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -1728,9 +1728,9 @@ kernel void kernel_mul_mv_q5_0_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -1739,9 +1739,9 @@ kernel void kernel_mul_mv_q5_1_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -1754,9 +1754,9 @@ void kernel_mul_mv_q8_0_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nr = N_DST; const int nsg = N_SIMDGROUP; const int nw = N_SIMDWIDTH; @@ -1766,7 +1766,7 @@ void kernel_mul_mv_q8_0_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr; + const int first_row = (r0*nsg + sgitg)*nr; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -1786,12 +1786,12 @@ void kernel_mul_mv_q8_0_f32_impl( } float yl[NB_Q8_0]; - float sumf[nr]={0.f}; + float sumf[nr] = { 0.f }; const int ix = tiisg/4; const int il = tiisg%4; - device const float * yb = y + ix * QK8_0 + NB_Q8_0*il; + device const float * yb = y + ix*QK8_0 + il*NB_Q8_0; // each thread in a SIMD group deals with NB_Q8_0 quants at a time for (int ib = ix; ib < nb; ib += nw/4) { @@ -1800,7 +1800,7 @@ void kernel_mul_mv_q8_0_f32_impl( } for (int row = 0; row < nr; row++) { - device const int8_t * qs = ax[row][ib].qs + NB_Q8_0*il; + device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0; float sumq = 0.f; for (int iq = 0; iq < NB_Q8_0; ++iq) { sumq += qs[iq] * yl[iq]; @@ -1808,13 +1808,14 @@ void kernel_mul_mv_q8_0_f32_impl( sumf[row] += sumq*ax[row][ib].d; } - yb += NB_Q8_0 * nw; + yb += nw*NB_Q8_0; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0; for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && first_row + row < args.ne01) { dst_f32[first_row + row] = tot; } @@ -1827,9 +1828,9 @@ kernel void kernel_mul_mv_q8_0_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -1841,8 +1842,8 @@ void kernel_mul_mv_impl( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig, - uint tiisg) { + uint3 tgpig, + ushort tiisg) { const int64_t r0 = tgpig.x; const int64_t rb = tgpig.y*N_MV_T_T; const int64_t im = tgpig.z; @@ -1910,8 +1911,8 @@ kernel void kernel_mul_mv( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { kernel_mul_mv_impl( args, src0, @@ -1937,8 +1938,8 @@ kernel void kernel_mul_mv_1row( device const char * src0, device const char * src1, device float * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; @@ -1993,8 +1994,8 @@ kernel void kernel_mul_mv_l4( device const char * src0, device const char * src1, device float * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { const int nrows = args.ne11; const int64_t r0 = tgpig.x; @@ -3932,9 +3933,9 @@ void kernel_mul_mv_q2_K_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; const int r0 = tgpig.x; @@ -4023,9 +4024,9 @@ kernel void kernel_mul_mv_q2_K_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -4037,9 +4038,9 @@ void kernel_mul_mv_q3_K_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; @@ -4186,9 +4187,9 @@ kernel void kernel_mul_mv_q3_K_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -4200,9 +4201,9 @@ void kernel_mul_mv_q4_K_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; @@ -4304,9 +4305,9 @@ kernel void kernel_mul_mv_q4_K_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -4318,9 +4319,9 @@ void kernel_mul_mv_q5_K_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; @@ -4436,9 +4437,9 @@ kernel void kernel_mul_mv_q5_K_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -4450,9 +4451,9 @@ void kernel_mul_mv_q6_K_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const uint8_t kmask1 = 0x03; const uint8_t kmask2 = 0x0C; @@ -4527,9 +4528,9 @@ kernel void kernel_mul_mv_q6_K_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -4543,9 +4544,9 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; const int r0 = tgpig.x; @@ -4638,9 +4639,9 @@ kernel void kernel_mul_mv_iq2_xxs_f32( device const char * src1, device char * dst, threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } @@ -4651,9 +4652,9 @@ void kernel_mul_mv_iq2_xs_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; const int r0 = tgpig.x; @@ -4756,9 +4757,9 @@ kernel void kernel_mul_mv_iq2_xs_f32( device const char * src1, device char * dst, threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } @@ -4770,9 +4771,9 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; const int r0 = tgpig.x; @@ -4868,9 +4869,9 @@ kernel void kernel_mul_mv_iq3_xxs_f32( device const char * src1, device char * dst, threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } @@ -4882,9 +4883,9 @@ void kernel_mul_mv_iq3_s_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; const int r0 = tgpig.x; @@ -4980,9 +4981,9 @@ kernel void kernel_mul_mv_iq3_s_f32( device const char * src1, device char * dst, threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } @@ -4994,9 +4995,9 @@ void kernel_mul_mv_iq2_s_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; const int r0 = tgpig.x; @@ -5093,9 +5094,9 @@ kernel void kernel_mul_mv_iq2_s_f32( device const char * src1, device char * dst, threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } @@ -5107,9 +5108,9 @@ void kernel_mul_mv_iq1_s_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; const int r0 = tgpig.x; @@ -5193,9 +5194,9 @@ void kernel_mul_mv_iq1_m_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; const int r0 = tgpig.x; @@ -5288,9 +5289,9 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { threadgroup float * shmem_f32 = (threadgroup float *) shmem; const int nb = args.ne00/QK4_NL; @@ -5378,9 +5379,9 @@ void kernel_mul_mv_iq4_xs_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { threadgroup float * shmem_f32 = (threadgroup float *) shmem; const int nb = args.ne00/QK_K; @@ -5468,9 +5469,9 @@ kernel void kernel_mul_mv_iq1_s_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -5481,9 +5482,9 @@ kernel void kernel_mul_mv_iq1_m_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -5495,9 +5496,9 @@ kernel void kernel_mul_mv_iq4_nl_f32( device const char * src1, device char * dst, threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } @@ -5509,9 +5510,9 @@ kernel void kernel_mul_mv_iq4_xs_f32( device const char * src1, device char * dst, threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } @@ -5622,13 +5623,13 @@ kernel void kernel_mul_mm( device const char * src0, device const char * src1, device char * dst, - threadgroup char * shared_memory [[threadgroup(0)]], + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiitg[[thread_index_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - threadgroup T * sa = (threadgroup T *)(shared_memory); - threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + threadgroup T * sa = (threadgroup T *)(shmem); + threadgroup float * sb = (threadgroup float *)(shmem + 4096); const int r0 = tgpig.y; const int r1 = tgpig.x; @@ -5723,7 +5724,7 @@ kernel void kernel_mul_mm( } else { // block is smaller than 64x32, we should avoid writing data outside of the matrix threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *) shared_memory) \ + threadgroup float * temp_str = ((threadgroup float *) shmem) \ + 32 * (sgitg&1) + (16 * (sgitg>>1))*BLOCK_SIZE_M; for (short i = 0; i < 8; i++) { simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); @@ -6057,8 +6058,8 @@ typedef void (kernel_mul_mv_impl_t)( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig, - uint tiisg); + uint3 tgpig, + ushort tiisg); typedef void (kernel_mul_mv2_impl_t)( ggml_metal_kargs_mul_mv args, @@ -6066,9 +6067,9 @@ typedef void (kernel_mul_mv2_impl_t)( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg); + uint3 tgpig, + ushort tiisg, + ushort sgitg); template void mmv_fn( @@ -6077,10 +6078,10 @@ void mmv_fn( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiitg, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiitg, + ushort tiisg, + ushort sgitg) { impl_fn(args, src0, src1, dst, tgpig, tiisg); } @@ -6091,10 +6092,10 @@ void mmv_fn( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiitg, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiitg, + ushort tiisg, + ushort sgitg) { impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } @@ -6108,10 +6109,10 @@ kernel void kernel_mul_mv_id( device char * dst, device const char * ids, threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { const int iid1 = tgpig.z/args.nei0; const int idx = tgpig.z%args.nei0;