diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 7fc4792c4..daa645833 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1645,8 +1645,8 @@ void mul_vec_q_n_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - //const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -1654,7 +1654,7 @@ void mul_vec_q_n_f32_impl( // pointers to src0 rows device const block_q_type * ax[nr]; for (int row = 0; row < nr; ++row) { - const uint offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); } @@ -1662,10 +1662,10 @@ void mul_vec_q_n_f32_impl( float yl[16]; // src1 vector cache float sumf[nr] = {0.f}; - const int ix = (tiisg/2); - const int il = (tiisg%2)*8; + const short ix = (tiisg/2); + const short il = (tiisg%2)*8; - device const float * yb = y + ix * QK4_0 + il; + device const float * yb = y + ix*QK4_0 + il; // each thread in a SIMD group deals with half a block. for (int ib = ix; ib < nb; ib += nw/2) { @@ -1771,8 +1771,8 @@ void kernel_mul_mv_q8_0_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - //const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -1780,7 +1780,7 @@ void kernel_mul_mv_q8_0_f32_impl( // pointers to src0 rows device const block_q8_0 * ax[nr]; for (int row = 0; row < nr; ++row) { - const uint offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } @@ -1788,21 +1788,21 @@ void kernel_mul_mv_q8_0_f32_impl( float yl[NB_Q8_0]; float sumf[nr] = { 0.f }; - const int ix = tiisg/4; - const int il = tiisg%4; + const short ix = tiisg/4; + const short il = tiisg%4; device const float * yb = y + ix*QK8_0 + il*NB_Q8_0; // each thread in a SIMD group deals with NB_Q8_0 quants at a time for (int ib = ix; ib < nb; ib += nw/4) { - for (int i = 0; i < NB_Q8_0; ++i) { + for (short i = 0; i < NB_Q8_0; ++i) { yl[i] = yb[i]; } for (int row = 0; row < nr; row++) { device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0; float sumq = 0.f; - for (int iq = 0; iq < NB_Q8_0; ++iq) { + for (short iq = 0; iq < NB_Q8_0; ++iq) { sumq += qs[iq] * yl[iq]; } sumf[row] += sumq*ax[row][ib].d; @@ -1811,7 +1811,7 @@ void kernel_mul_mv_q8_0_f32_impl( yb += nw*NB_Q8_0; } - device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); @@ -1844,18 +1844,18 @@ void kernel_mul_mv_impl( device char * dst, uint3 tgpig, ushort tiisg) { - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_MV_T_T; - const int64_t im = tgpig.z; + const int r0 = tgpig.x; + const int rb = tgpig.y*N_MV_T_T; + const int im = tgpig.z; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; device const T0 * x = (device const T0 *) (src0 + offset0); - device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; if (args.ne00 < 128) { for (int row = 0; row < N_MV_T_T; ++row) { @@ -1864,7 +1864,7 @@ void kernel_mul_mv_impl( break; } - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T1 * y = (device const T1 *) (src1 + offset1); @@ -1875,7 +1875,7 @@ void kernel_mul_mv_impl( float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst_f32[r1*args.ne0 + r0] = all_sum; + dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; } } } else { @@ -1886,20 +1886,20 @@ void kernel_mul_mv_impl( break; } - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T1 * y = (device const T1 *) (src1 + offset1); device const T14 * y4 = (device const T14 *) y; float sumf = 0; for (int i = tiisg; i < args.ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); + sumf += dot((T14) x4[i], y4[i]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst_f32[r1*args.ne0 + r0] = all_sum; + dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; } } } @@ -1935,25 +1935,27 @@ template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv< template kernel void kernel_mul_mv_1row( constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]]) { - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T * x = (device const T *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + float sumf = 0; if (args.ne00 < 128) { for (int i = tiisg; i < args.ne00; i += 32) { @@ -1961,21 +1963,21 @@ kernel void kernel_mul_mv_1row( } float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; + dst_f32[r0] = all_sum; } } else { device const T4 * x4 = (device const T4 *) x; device const float4 * y4 = (device const float4 *) y; for (int i = tiisg; i < args.ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); + sumf += dot((float4) x4[i], y4[i]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; + dst_f32[r0] = all_sum; } } } @@ -1991,36 +1993,38 @@ template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kerne template kernel void kernel_mul_mv_l4( constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]]) { const int nrows = args.ne11; - const int64_t r0 = tgpig.x; - const int64_t im = tgpig.z; + const int r0 = tgpig.x; + const int im = tgpig.z; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; device const T4 * x4 = (device const T4 *) (src0 + offset0); + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; + for (int r1 = 0; r1 < nrows; ++r1) { - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const float4 * y4 = (device const float4 *) (src1 + offset1); float sumf = 0; for (int i = tiisg; i < args.ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); + sumf += dot((float4) x4[i], y4[i]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; + dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; } } } @@ -2966,7 +2970,7 @@ kernel void kernel_flash_attn_ext( const float S = ss[j*TS + 0]; for (short i = tiisg; i < D4; i += NW) { - dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; + dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; } } } @@ -3352,7 +3356,7 @@ kernel void kernel_flash_attn_ext_vec( const float S = ss[0]; for (short i = tiisg; i < D16; i += NW) { - dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S; + dst44[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S; } } } @@ -3947,8 +3951,8 @@ void kernel_mul_mv_q2_K_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -4008,7 +4012,7 @@ void kernel_mul_mv_q2_K_f32_impl( y4 += 4 * QK_K; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -4044,17 +4048,17 @@ void kernel_mul_mv_q3_K_f32_impl( const int nb = args.ne00/QK_K; - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0); device const float * yy = (device const float *) (src1 + offset1); @@ -4087,9 +4091,10 @@ void kernel_mul_mv_q3_K_f32_impl( const ushort4 hm = mm[2*ip + il/2]; - const int shift = 2*il; - const float v1 = il == 0 ? 4.f : 64.f; - const float v2 = 4.f * v1; + const short shift = 2*il; + + const float v1 = il == 0 ? 4.f : 64.f; + const float v2 = 4.f * v1; const uint16_t s_shift1 = 4*ip; const uint16_t s_shift2 = s_shift1 + il; @@ -4172,7 +4177,7 @@ void kernel_mul_mv_q3_K_f32_impl( sumf1[row] = simd_sum(sumf); } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; if (tiisg == 0) { for (int row = 0; row < 2; ++row) { @@ -4224,8 +4229,8 @@ void kernel_mul_mv_q4_K_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -4289,7 +4294,7 @@ void kernel_mul_mv_q4_K_f32_impl( y4 += 4 * QK_K; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -4325,8 +4330,8 @@ void kernel_mul_mv_q5_K_f32_impl( const int nb = args.ne00/QK_K; - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; + const int r0 = tgpig.x; + const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; @@ -4334,8 +4339,8 @@ void kernel_mul_mv_q5_K_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0); device const float * yy = (device const float *) (src1 + offset1); @@ -4421,7 +4426,7 @@ void kernel_mul_mv_q5_K_f32_impl( y1 += 4 * QK_K; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < 2; ++row) { const float tot = simd_sum(sumf[row]); @@ -4462,17 +4467,17 @@ void kernel_mul_mv_q6_K_f32_impl( const int nb = args.ne00/QK_K; - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int im = tgpig.z; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; - const int row = 2 * r0 + sgitg; + const int row = 2*r0 + sgitg; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0); device const float * yy = (device const float *) (src1 + offset1); @@ -4492,7 +4497,6 @@ void kernel_mul_mv_q6_K_f32_impl( const int q_offset_h = 32*ip + l0; for (int i = ix; i < nb; i += 2) { - device const uint8_t * q1 = x[i].ql + q_offset_l; device const uint8_t * q2 = q1 + 32; device const uint8_t * qh = x[i].qh + q_offset_h; @@ -4514,7 +4518,7 @@ void kernel_mul_mv_q6_K_f32_impl( } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; const float tot = simd_sum(sumf); if (tiisg == 0) { @@ -4558,8 +4562,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -4622,7 +4626,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -4666,8 +4670,8 @@ void kernel_mul_mv_iq2_xs_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -4740,7 +4744,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -4785,8 +4789,8 @@ void kernel_mul_mv_iq3_xxs_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -4813,7 +4817,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - for (int i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -4827,7 +4830,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const half * dh = &xr->d; for (int row = 0; row < N_DST; row++) { - const float db = dh[0]; const uint32_t aux32 = gas[0] | (gas[1] << 16); const float d = db * (0.5f + (aux32 >> 28)); @@ -4852,7 +4854,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -4897,8 +4899,8 @@ void kernel_mul_mv_iq3_s_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -4964,7 +4966,7 @@ void kernel_mul_mv_iq3_s_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -5009,8 +5011,8 @@ void kernel_mul_mv_iq2_s_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -5077,7 +5079,7 @@ void kernel_mul_mv_iq2_s_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -5122,8 +5124,8 @@ void kernel_mul_mv_iq1_s_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -5177,7 +5179,7 @@ void kernel_mul_mv_iq1_s_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -5208,8 +5210,8 @@ void kernel_mul_mv_iq1_m_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -5272,7 +5274,7 @@ void kernel_mul_mv_iq1_m_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -5303,8 +5305,8 @@ void kernel_mul_mv_iq4_nl_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -5362,7 +5364,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( yb += 16 * QK4_NL; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) { all_sum = simd_sum(sumf[row]); @@ -5393,8 +5395,8 @@ void kernel_mul_mv_iq4_xs_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -5418,25 +5420,23 @@ void kernel_mul_mv_iq4_xs_f32_impl( float4 qf1, qf2; for (int ibl = ix; ibl < nb; ibl += 2) { - device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; for (int row = 0; row < 2; ++row) { - device const block_iq4_xs & xb = x[row*nb + ibl]; device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); float4 acc1 = {0.f}, acc2 = {0.f}; - aux32[0] = q4[0] & 0x0f0f0f0f; + aux32[0] = (q4[0] ) & 0x0f0f0f0f; aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; acc1 += yl[0] * qf1; acc2 += yl[1] * qf2; - aux32[0] = q4[1] & 0x0f0f0f0f; + aux32[0] = (q4[1] ) & 0x0f0f0f0f; aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; @@ -5453,7 +5453,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( yb += 2 * QK_K; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < 2; ++row) { all_sum = simd_sum(sumf[row]); @@ -5656,8 +5656,8 @@ kernel void kernel_mul_mm( const int i12 = im%args.ne12; const int i13 = im/args.ne12; - int offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - short offset1 = il/nl; + uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + short offset1 = il/nl; device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*args.nb01 + offset0) + offset1; device const float * y = (device const float *)(src1 @@ -5782,10 +5782,10 @@ void kernel_mul_mm_id_impl( threadgroup half * sa = (threadgroup half *)(shmem); threadgroup float * sb = (threadgroup float *)(shmem + 4096); - const uint r0 = tgpig.y; - const uint r1 = tgpig.x; + const int r0 = tgpig.y; + const int r1 = tgpig.x; - if (r1 * BLOCK_SIZE_N >= ne1) return; + if (r1*BLOCK_SIZE_N >= ne1) return; // if this block is of 64x32 shape or smaller short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; @@ -5916,7 +5916,7 @@ kernel void kernel_mul_mm_id( threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192); // TODO: parallelize this loop - int64_t _ne1 = 0; + int32_t _ne1 = 0; for (ushort ii1 = 0; ii1 < args.nei1; ii1++) { for (ushort ii0 = 0; ii0 < args.nei0; ii0++) { int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];