From bb821e48548d1ffbe5d5235d87650d5acd1d78fe Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 11:05:10 +0200 Subject: [PATCH] cont : int safety + register optimizations ggml-ci --- ggml/src/ggml-metal.metal | 226 +++++++++++++++++++------------------- 1 file changed, 113 insertions(+), 113 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 561476551..d86458b1a 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1645,8 +1645,8 @@ void mul_vec_q_n_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - //const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -1654,7 +1654,7 @@ void mul_vec_q_n_f32_impl( // pointers to src0 rows device const block_q_type * ax[nr]; for (int row = 0; row < nr; ++row) { - const uint offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); } @@ -1662,10 +1662,10 @@ void mul_vec_q_n_f32_impl( float yl[16]; // src1 vector cache float sumf[nr] = {0.f}; - const int ix = (tiisg/2); - const int il = (tiisg%2)*8; + const short ix = (tiisg/2); + const short il = (tiisg%2)*8; - device const float * yb = y + ix * QK4_0 + il; + device const float * yb = y + ix*QK4_0 + il; // each thread in a SIMD group deals with half a block. for (int ib = ix; ib < nb; ib += nw/2) { @@ -1771,8 +1771,8 @@ void kernel_mul_mv_q8_0_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - //const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -1780,7 +1780,7 @@ void kernel_mul_mv_q8_0_f32_impl( // pointers to src0 rows device const block_q8_0 * ax[nr]; for (int row = 0; row < nr; ++row) { - const uint offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } @@ -1788,21 +1788,21 @@ void kernel_mul_mv_q8_0_f32_impl( float yl[NB_Q8_0]; float sumf[nr] = { 0.f }; - const int ix = tiisg/4; - const int il = tiisg%4; + const short ix = tiisg/4; + const short il = tiisg%4; device const float * yb = y + ix*QK8_0 + il*NB_Q8_0; // each thread in a SIMD group deals with NB_Q8_0 quants at a time for (int ib = ix; ib < nb; ib += nw/4) { - for (int i = 0; i < NB_Q8_0; ++i) { + for (short i = 0; i < NB_Q8_0; ++i) { yl[i] = yb[i]; } for (int row = 0; row < nr; row++) { device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0; float sumq = 0.f; - for (int iq = 0; iq < NB_Q8_0; ++iq) { + for (short iq = 0; iq < NB_Q8_0; ++iq) { sumq += qs[iq] * yl[iq]; } sumf[row] += sumq*ax[row][ib].d; @@ -1811,7 +1811,7 @@ void kernel_mul_mv_q8_0_f32_impl( yb += nw*NB_Q8_0; } - device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); @@ -1844,18 +1844,18 @@ void kernel_mul_mv_impl( device char * dst, uint3 tgpig, ushort tiisg) { - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_MV_T_T; - const int64_t im = tgpig.z; + const int r0 = tgpig.x; + const int rb = tgpig.y*N_MV_T_T; + const int im = tgpig.z; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; device const T0 * x = (device const T0 *) (src0 + offset0); - device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; if (args.ne00 < 128) { for (int row = 0; row < N_MV_T_T; ++row) { @@ -1864,7 +1864,7 @@ void kernel_mul_mv_impl( break; } - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T1 * y = (device const T1 *) (src1 + offset1); @@ -1875,7 +1875,7 @@ void kernel_mul_mv_impl( float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst_f32[r1*args.ne0 + r0] = all_sum; + dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; } } } else { @@ -1886,20 +1886,20 @@ void kernel_mul_mv_impl( break; } - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T1 * y = (device const T1 *) (src1 + offset1); device const T14 * y4 = (device const T14 *) y; float sumf = 0; for (int i = tiisg; i < args.ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); + sumf += dot((T14) x4[i], y4[i]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst_f32[r1*args.ne0 + r0] = all_sum; + dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; } } } @@ -1935,25 +1935,27 @@ template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv< template kernel void kernel_mul_mv_1row( constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]]) { - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T * x = (device const T *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + float sumf = 0; if (args.ne00 < 128) { for (int i = tiisg; i < args.ne00; i += 32) { @@ -1961,21 +1963,21 @@ kernel void kernel_mul_mv_1row( } float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; + dst_f32[r0] = all_sum; } } else { device const T4 * x4 = (device const T4 *) x; device const float4 * y4 = (device const float4 *) y; for (int i = tiisg; i < args.ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); + sumf += dot((float4) x4[i], y4[i]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; + dst_f32[r0] = all_sum; } } } @@ -1991,36 +1993,38 @@ template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kerne template kernel void kernel_mul_mv_l4( constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]]) { const int nrows = args.ne11; - const int64_t r0 = tgpig.x; - const int64_t im = tgpig.z; + const int r0 = tgpig.x; + const int im = tgpig.z; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; device const T4 * x4 = (device const T4 *) (src0 + offset0); + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; + for (int r1 = 0; r1 < nrows; ++r1) { - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const float4 * y4 = (device const float4 *) (src1 + offset1); float sumf = 0; for (int i = tiisg; i < args.ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); + sumf += dot((float4) x4[i], y4[i]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; + dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; } } } @@ -2969,7 +2973,7 @@ kernel void kernel_flash_attn_ext( const float S = ss[j*TS + 0]; for (short i = tiisg; i < D4; i += NW) { - dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; + dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; } } } @@ -3361,7 +3365,7 @@ kernel void kernel_flash_attn_ext_vec( const float S = ss[0]; for (short i = tiisg; i < D16; i += NW) { - dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S; + dst44[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S; } } } @@ -3956,8 +3960,8 @@ void kernel_mul_mv_q2_K_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -4017,7 +4021,7 @@ void kernel_mul_mv_q2_K_f32_impl( y4 += 4 * QK_K; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -4053,17 +4057,17 @@ void kernel_mul_mv_q3_K_f32_impl( const int nb = args.ne00/QK_K; - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0); device const float * yy = (device const float *) (src1 + offset1); @@ -4096,9 +4100,10 @@ void kernel_mul_mv_q3_K_f32_impl( const ushort4 hm = mm[2*ip + il/2]; - const int shift = 2*il; - const float v1 = il == 0 ? 4.f : 64.f; - const float v2 = 4.f * v1; + const short shift = 2*il; + + const float v1 = il == 0 ? 4.f : 64.f; + const float v2 = 4.f * v1; const uint16_t s_shift1 = 4*ip; const uint16_t s_shift2 = s_shift1 + il; @@ -4181,7 +4186,7 @@ void kernel_mul_mv_q3_K_f32_impl( sumf1[row] = simd_sum(sumf); } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; if (tiisg == 0) { for (int row = 0; row < 2; ++row) { @@ -4233,8 +4238,8 @@ void kernel_mul_mv_q4_K_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -4298,7 +4303,7 @@ void kernel_mul_mv_q4_K_f32_impl( y4 += 4 * QK_K; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -4334,8 +4339,8 @@ void kernel_mul_mv_q5_K_f32_impl( const int nb = args.ne00/QK_K; - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; + const int r0 = tgpig.x; + const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; @@ -4343,8 +4348,8 @@ void kernel_mul_mv_q5_K_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0); device const float * yy = (device const float *) (src1 + offset1); @@ -4430,7 +4435,7 @@ void kernel_mul_mv_q5_K_f32_impl( y1 += 4 * QK_K; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < 2; ++row) { const float tot = simd_sum(sumf[row]); @@ -4471,17 +4476,17 @@ void kernel_mul_mv_q6_K_f32_impl( const int nb = args.ne00/QK_K; - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int im = tgpig.z; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; - const int row = 2 * r0 + sgitg; + const int row = 2*r0 + sgitg; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0); device const float * yy = (device const float *) (src1 + offset1); @@ -4501,7 +4506,6 @@ void kernel_mul_mv_q6_K_f32_impl( const int q_offset_h = 32*ip + l0; for (int i = ix; i < nb; i += 2) { - device const uint8_t * q1 = x[i].ql + q_offset_l; device const uint8_t * q2 = q1 + 32; device const uint8_t * qh = x[i].qh + q_offset_h; @@ -4523,7 +4527,7 @@ void kernel_mul_mv_q6_K_f32_impl( } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; const float tot = simd_sum(sumf); if (tiisg == 0) { @@ -4567,8 +4571,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -4631,7 +4635,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -4675,8 +4679,8 @@ void kernel_mul_mv_iq2_xs_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -4749,7 +4753,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -4794,8 +4798,8 @@ void kernel_mul_mv_iq3_xxs_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -4822,7 +4826,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - for (int i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -4836,7 +4839,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const half * dh = &xr->d; for (int row = 0; row < N_DST; row++) { - const float db = dh[0]; const uint32_t aux32 = gas[0] | (gas[1] << 16); const float d = db * (0.5f + (aux32 >> 28)); @@ -4861,7 +4863,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -4906,8 +4908,8 @@ void kernel_mul_mv_iq3_s_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -4973,7 +4975,7 @@ void kernel_mul_mv_iq3_s_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -5018,8 +5020,8 @@ void kernel_mul_mv_iq2_s_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -5086,7 +5088,7 @@ void kernel_mul_mv_iq2_s_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -5131,8 +5133,8 @@ void kernel_mul_mv_iq1_s_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -5186,7 +5188,7 @@ void kernel_mul_mv_iq1_s_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -5217,8 +5219,8 @@ void kernel_mul_mv_iq1_m_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -5281,7 +5283,7 @@ void kernel_mul_mv_iq1_m_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -5312,8 +5314,8 @@ void kernel_mul_mv_iq4_nl_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -5371,7 +5373,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( yb += 16 * QK4_NL; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) { all_sum = simd_sum(sumf[row]); @@ -5402,8 +5404,8 @@ void kernel_mul_mv_iq4_xs_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -5427,25 +5429,23 @@ void kernel_mul_mv_iq4_xs_f32_impl( float4 qf1, qf2; for (int ibl = ix; ibl < nb; ibl += 2) { - device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; for (int row = 0; row < 2; ++row) { - device const block_iq4_xs & xb = x[row*nb + ibl]; device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); float4 acc1 = {0.f}, acc2 = {0.f}; - aux32[0] = q4[0] & 0x0f0f0f0f; + aux32[0] = (q4[0] ) & 0x0f0f0f0f; aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; acc1 += yl[0] * qf1; acc2 += yl[1] * qf2; - aux32[0] = q4[1] & 0x0f0f0f0f; + aux32[0] = (q4[1] ) & 0x0f0f0f0f; aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; @@ -5462,7 +5462,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( yb += 2 * QK_K; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < 2; ++row) { all_sum = simd_sum(sumf[row]); @@ -5665,8 +5665,8 @@ kernel void kernel_mul_mm( const int i12 = im%args.ne12; const int i13 = im/args.ne12; - int offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - short offset1 = il/nl; + uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + short offset1 = il/nl; device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*args.nb01 + offset0) + offset1; device const float * y = (device const float *)(src1 @@ -5791,10 +5791,10 @@ void kernel_mul_mm_id_impl( threadgroup half * sa = (threadgroup half *)(shmem); threadgroup float * sb = (threadgroup float *)(shmem + 4096); - const uint r0 = tgpig.y; - const uint r1 = tgpig.x; + const int r0 = tgpig.y; + const int r1 = tgpig.x; - if (r1 * BLOCK_SIZE_N >= ne1) return; + if (r1*BLOCK_SIZE_N >= ne1) return; // if this block is of 64x32 shape or smaller short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; @@ -5925,7 +5925,7 @@ kernel void kernel_mul_mm_id( threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192); // TODO: parallelize this loop - int64_t _ne1 = 0; + int32_t _ne1 = 0; for (ushort ii1 = 0; ii1 < args.nei1; ii1++) { for (ushort ii0 = 0; ii0 < args.nei0; ii0++) { int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];