cont : int safety + register optimizations

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-10 11:05:10 +02:00
parent c5cf1d74f0
commit bb821e4854
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -1645,8 +1645,8 @@ void mul_vec_q_n_f32_impl(
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
//const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
//const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
//device const block_q_type * x = (device const block_q_type *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
@ -1654,7 +1654,7 @@ void mul_vec_q_n_f32_impl(
// pointers to src0 rows
device const block_q_type * ax[nr];
for (int row = 0; row < nr; ++row) {
const uint offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
}
@ -1662,10 +1662,10 @@ void mul_vec_q_n_f32_impl(
float yl[16]; // src1 vector cache
float sumf[nr] = {0.f};
const int ix = (tiisg/2);
const int il = (tiisg%2)*8;
const short ix = (tiisg/2);
const short il = (tiisg%2)*8;
device const float * yb = y + ix * QK4_0 + il;
device const float * yb = y + ix*QK4_0 + il;
// each thread in a SIMD group deals with half a block.
for (int ib = ix; ib < nb; ib += nw/2) {
@ -1771,8 +1771,8 @@ void kernel_mul_mv_q8_0_f32_impl(
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
//const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
//const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
//device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
@ -1780,7 +1780,7 @@ void kernel_mul_mv_q8_0_f32_impl(
// pointers to src0 rows
device const block_q8_0 * ax[nr];
for (int row = 0; row < nr; ++row) {
const uint offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
}
@ -1788,21 +1788,21 @@ void kernel_mul_mv_q8_0_f32_impl(
float yl[NB_Q8_0];
float sumf[nr] = { 0.f };
const int ix = tiisg/4;
const int il = tiisg%4;
const short ix = tiisg/4;
const short il = tiisg%4;
device const float * yb = y + ix*QK8_0 + il*NB_Q8_0;
// each thread in a SIMD group deals with NB_Q8_0 quants at a time
for (int ib = ix; ib < nb; ib += nw/4) {
for (int i = 0; i < NB_Q8_0; ++i) {
for (short i = 0; i < NB_Q8_0; ++i) {
yl[i] = yb[i];
}
for (int row = 0; row < nr; row++) {
device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
float sumq = 0.f;
for (int iq = 0; iq < NB_Q8_0; ++iq) {
for (short iq = 0; iq < NB_Q8_0; ++iq) {
sumq += qs[iq] * yl[iq];
}
sumf[row] += sumq*ax[row][ib].d;
@ -1811,7 +1811,7 @@ void kernel_mul_mv_q8_0_f32_impl(
yb += nw*NB_Q8_0;
}
device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < nr; ++row) {
const float tot = simd_sum(sumf[row]);
@ -1844,18 +1844,18 @@ void kernel_mul_mv_impl(
device char * dst,
uint3 tgpig,
ushort tiisg) {
const int64_t r0 = tgpig.x;
const int64_t rb = tgpig.y*N_MV_T_T;
const int64_t im = tgpig.z;
const int r0 = tgpig.x;
const int rb = tgpig.y*N_MV_T_T;
const int im = tgpig.z;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
device const T0 * x = (device const T0 *) (src0 + offset0);
device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1;
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
if (args.ne00 < 128) {
for (int row = 0; row < N_MV_T_T; ++row) {
@ -1864,7 +1864,7 @@ void kernel_mul_mv_impl(
break;
}
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const T1 * y = (device const T1 *) (src1 + offset1);
@ -1875,7 +1875,7 @@ void kernel_mul_mv_impl(
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst_f32[r1*args.ne0 + r0] = all_sum;
dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
}
}
} else {
@ -1886,20 +1886,20 @@ void kernel_mul_mv_impl(
break;
}
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const T1 * y = (device const T1 *) (src1 + offset1);
device const T14 * y4 = (device const T14 *) y;
float sumf = 0;
for (int i = tiisg; i < args.ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
sumf += dot((T14) x4[i], y4[i]);
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
dst_f32[r1*args.ne0 + r0] = all_sum;
dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
}
}
}
@ -1935,25 +1935,27 @@ template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<
template<typename T, typename T4>
kernel void kernel_mul_mv_1row(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
device const char * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]]) {
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
const int64_t im = tgpig.z;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const T * x = (device const T *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
float sumf = 0;
if (args.ne00 < 128) {
for (int i = tiisg; i < args.ne00; i += 32) {
@ -1961,21 +1963,21 @@ kernel void kernel_mul_mv_1row(
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum;
dst_f32[r0] = all_sum;
}
} else {
device const T4 * x4 = (device const T4 *) x;
device const float4 * y4 = (device const float4 *) y;
for (int i = tiisg; i < args.ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
sumf += dot((float4) x4[i], y4[i]);
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum;
dst_f32[r0] = all_sum;
}
}
}
@ -1991,36 +1993,38 @@ template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kerne
template<typename T, typename T4>
kernel void kernel_mul_mv_l4(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
device const char * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]]) {
const int nrows = args.ne11;
const int64_t r0 = tgpig.x;
const int64_t im = tgpig.z;
const int r0 = tgpig.x;
const int im = tgpig.z;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
device const T4 * x4 = (device const T4 *) (src0 + offset0);
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
for (int r1 = 0; r1 < nrows; ++r1) {
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const float4 * y4 = (device const float4 *) (src1 + offset1);
float sumf = 0;
for (int i = tiisg; i < args.ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
sumf += dot((float4) x4[i], y4[i]);
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum;
dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
}
}
}
@ -2969,7 +2973,7 @@ kernel void kernel_flash_attn_ext(
const float S = ss[j*TS + 0];
for (short i = tiisg; i < D4; i += NW) {
dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
}
}
}
@ -3361,7 +3365,7 @@ kernel void kernel_flash_attn_ext_vec(
const float S = ss[0];
for (short i = tiisg; i < D16; i += NW) {
dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
dst44[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
}
}
}
@ -3956,8 +3960,8 @@ void kernel_mul_mv_q2_K_f32_impl(
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
@ -4017,7 +4021,7 @@ void kernel_mul_mv_q2_K_f32_impl(
y4 += 4 * QK_K;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
@ -4053,17 +4057,17 @@ void kernel_mul_mv_q3_K_f32_impl(
const int nb = args.ne00/QK_K;
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
const int64_t im = tgpig.z;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0);
device const float * yy = (device const float *) (src1 + offset1);
@ -4096,9 +4100,10 @@ void kernel_mul_mv_q3_K_f32_impl(
const ushort4 hm = mm[2*ip + il/2];
const int shift = 2*il;
const float v1 = il == 0 ? 4.f : 64.f;
const float v2 = 4.f * v1;
const short shift = 2*il;
const float v1 = il == 0 ? 4.f : 64.f;
const float v2 = 4.f * v1;
const uint16_t s_shift1 = 4*ip;
const uint16_t s_shift2 = s_shift1 + il;
@ -4181,7 +4186,7 @@ void kernel_mul_mv_q3_K_f32_impl(
sumf1[row] = simd_sum(sumf);
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
if (tiisg == 0) {
for (int row = 0; row < 2; ++row) {
@ -4233,8 +4238,8 @@ void kernel_mul_mv_q4_K_f32_impl(
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
@ -4298,7 +4303,7 @@ void kernel_mul_mv_q4_K_f32_impl(
y4 += 4 * QK_K;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
@ -4334,8 +4339,8 @@ void kernel_mul_mv_q5_K_f32_impl(
const int nb = args.ne00/QK_K;
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
@ -4343,8 +4348,8 @@ void kernel_mul_mv_q5_K_f32_impl(
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
device const float * yy = (device const float *) (src1 + offset1);
@ -4430,7 +4435,7 @@ void kernel_mul_mv_q5_K_f32_impl(
y1 += 4 * QK_K;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < 2; ++row) {
const float tot = simd_sum(sumf[row]);
@ -4471,17 +4476,17 @@ void kernel_mul_mv_q6_K_f32_impl(
const int nb = args.ne00/QK_K;
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
const int im = tgpig.z;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int row = 2 * r0 + sgitg;
const int row = 2*r0 + sgitg;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
const uint64_t offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
device const float * yy = (device const float *) (src1 + offset1);
@ -4501,7 +4506,6 @@ void kernel_mul_mv_q6_K_f32_impl(
const int q_offset_h = 32*ip + l0;
for (int i = ix; i < nb; i += 2) {
device const uint8_t * q1 = x[i].ql + q_offset_l;
device const uint8_t * q2 = q1 + 32;
device const uint8_t * qh = x[i].qh + q_offset_h;
@ -4523,7 +4527,7 @@ void kernel_mul_mv_q6_K_f32_impl(
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
const float tot = simd_sum(sumf);
if (tiisg == 0) {
@ -4567,8 +4571,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
@ -4631,7 +4635,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
y4 += 32 * 32;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
@ -4675,8 +4679,8 @@ void kernel_mul_mv_iq2_xs_f32_impl(
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
@ -4749,7 +4753,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
y4 += 32 * 32;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
@ -4794,8 +4798,8 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
@ -4822,7 +4826,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
device const float * y4 = y + 32 * ix;
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
for (int i = 0; i < 32; ++i) {
yl[i] = y4[i];
}
@ -4836,7 +4839,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
device const half * dh = &xr->d;
for (int row = 0; row < N_DST; row++) {
const float db = dh[0];
const uint32_t aux32 = gas[0] | (gas[1] << 16);
const float d = db * (0.5f + (aux32 >> 28));
@ -4861,7 +4863,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
y4 += 32 * 32;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
@ -4906,8 +4908,8 @@ void kernel_mul_mv_iq3_s_f32_impl(
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
@ -4973,7 +4975,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
y4 += 32 * 32;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
@ -5018,8 +5020,8 @@ void kernel_mul_mv_iq2_s_f32_impl(
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
@ -5086,7 +5088,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
y4 += 32 * 32;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
@ -5131,8 +5133,8 @@ void kernel_mul_mv_iq1_s_f32_impl(
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
@ -5186,7 +5188,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
y4 += 32 * 32;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
@ -5217,8 +5219,8 @@ void kernel_mul_mv_iq1_m_f32_impl(
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
@ -5281,7 +5283,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
y4 += 32 * 32;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
@ -5312,8 +5314,8 @@ void kernel_mul_mv_iq4_nl_f32_impl(
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
@ -5371,7 +5373,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
yb += 16 * QK4_NL;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
all_sum = simd_sum(sumf[row]);
@ -5402,8 +5404,8 @@ void kernel_mul_mv_iq4_xs_f32_impl(
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
@ -5427,25 +5429,23 @@ void kernel_mul_mv_iq4_xs_f32_impl(
float4 qf1, qf2;
for (int ibl = ix; ibl < nb; ibl += 2) {
device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
for (int row = 0; row < 2; ++row) {
device const block_iq4_xs & xb = x[row*nb + ibl];
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
float4 acc1 = {0.f}, acc2 = {0.f};
aux32[0] = q4[0] & 0x0f0f0f0f;
aux32[0] = (q4[0] ) & 0x0f0f0f0f;
aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
acc1 += yl[0] * qf1;
acc2 += yl[1] * qf2;
aux32[0] = q4[1] & 0x0f0f0f0f;
aux32[0] = (q4[1] ) & 0x0f0f0f0f;
aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
@ -5462,7 +5462,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
yb += 2 * QK_K;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < 2; ++row) {
all_sum = simd_sum(sumf[row]);
@ -5665,8 +5665,8 @@ kernel void kernel_mul_mm(
const int i12 = im%args.ne12;
const int i13 = im/args.ne12;
int offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
short offset1 = il/nl;
uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
short offset1 = il/nl;
device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*args.nb01 + offset0) + offset1;
device const float * y = (device const float *)(src1
@ -5791,10 +5791,10 @@ void kernel_mul_mm_id_impl(
threadgroup half * sa = (threadgroup half *)(shmem);
threadgroup float * sb = (threadgroup float *)(shmem + 4096);
const uint r0 = tgpig.y;
const uint r1 = tgpig.x;
const int r0 = tgpig.y;
const int r1 = tgpig.x;
if (r1 * BLOCK_SIZE_N >= ne1) return;
if (r1*BLOCK_SIZE_N >= ne1) return;
// if this block is of 64x32 shape or smaller
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
@ -5925,7 +5925,7 @@ kernel void kernel_mul_mm_id(
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192);
// TODO: parallelize this loop
int64_t _ne1 = 0;
int32_t _ne1 = 0;
for (ushort ii1 = 0; ii1 < args.nei1; ii1++) {
for (ushort ii0 = 0; ii0 < args.nei0; ii0++) {
int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];