cont : thread counters style

This commit is contained in:
Georgi Gerganov 2024-11-10 09:57:41 +02:00
parent cacc4c225f
commit 15a7105967
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -1631,9 +1631,9 @@ void mul_vec_q_n_f32_impl(
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const int nb = args.ne00/QK4_0;
const int r0 = tgpig.x;
@ -1706,9 +1706,9 @@ kernel void kernel_mul_mv_q4_0_f32(
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
@ -1717,9 +1717,9 @@ kernel void kernel_mul_mv_q4_1_f32(
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
@ -1728,9 +1728,9 @@ kernel void kernel_mul_mv_q5_0_f32(
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
@ -1739,9 +1739,9 @@ kernel void kernel_mul_mv_q5_1_f32(
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
@ -1754,9 +1754,9 @@ void kernel_mul_mv_q8_0_f32_impl(
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const int nr = N_DST;
const int nsg = N_SIMDGROUP;
const int nw = N_SIMDWIDTH;
@ -1766,7 +1766,7 @@ void kernel_mul_mv_q8_0_f32_impl(
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr;
const int first_row = (r0*nsg + sgitg)*nr;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
@ -1786,12 +1786,12 @@ void kernel_mul_mv_q8_0_f32_impl(
}
float yl[NB_Q8_0];
float sumf[nr]={0.f};
float sumf[nr] = { 0.f };
const int ix = tiisg/4;
const int il = tiisg%4;
device const float * yb = y + ix * QK8_0 + NB_Q8_0*il;
device const float * yb = y + ix*QK8_0 + il*NB_Q8_0;
// each thread in a SIMD group deals with NB_Q8_0 quants at a time
for (int ib = ix; ib < nb; ib += nw/4) {
@ -1800,7 +1800,7 @@ void kernel_mul_mv_q8_0_f32_impl(
}
for (int row = 0; row < nr; row++) {
device const int8_t * qs = ax[row][ib].qs + NB_Q8_0*il;
device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
float sumq = 0.f;
for (int iq = 0; iq < NB_Q8_0; ++iq) {
sumq += qs[iq] * yl[iq];
@ -1808,13 +1808,14 @@ void kernel_mul_mv_q8_0_f32_impl(
sumf[row] += sumq*ax[row][ib].d;
}
yb += NB_Q8_0 * nw;
yb += nw*NB_Q8_0;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
for (int row = 0; row < nr; ++row) {
const float tot = simd_sum(sumf[row]);
if (tiisg == 0 && first_row + row < args.ne01) {
dst_f32[first_row + row] = tot;
}
@ -1827,9 +1828,9 @@ kernel void kernel_mul_mv_q8_0_f32(
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
@ -1841,8 +1842,8 @@ void kernel_mul_mv_impl(
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig,
uint tiisg) {
uint3 tgpig,
ushort tiisg) {
const int64_t r0 = tgpig.x;
const int64_t rb = tgpig.y*N_MV_T_T;
const int64_t im = tgpig.z;
@ -1910,8 +1911,8 @@ kernel void kernel_mul_mv(
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_impl<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(
args,
src0,
@ -1937,8 +1938,8 @@ kernel void kernel_mul_mv_1row(
device const char * src0,
device const char * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]]) {
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
@ -1993,8 +1994,8 @@ kernel void kernel_mul_mv_l4(
device const char * src0,
device const char * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]]) {
const int nrows = args.ne11;
const int64_t r0 = tgpig.x;
@ -3941,9 +3942,9 @@ void kernel_mul_mv_q2_K_f32_impl(
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
@ -4032,9 +4033,9 @@ kernel void kernel_mul_mv_q2_K_f32(
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q2_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
@ -4046,9 +4047,9 @@ void kernel_mul_mv_q3_K_f32_impl(
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const int nb = args.ne00/QK_K;
@ -4195,9 +4196,9 @@ kernel void kernel_mul_mv_q3_K_f32(
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q3_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
@ -4209,9 +4210,9 @@ void kernel_mul_mv_q4_K_f32_impl(
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
@ -4313,9 +4314,9 @@ kernel void kernel_mul_mv_q4_K_f32(
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q4_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
@ -4327,9 +4328,9 @@ void kernel_mul_mv_q5_K_f32_impl(
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const int nb = args.ne00/QK_K;
@ -4445,9 +4446,9 @@ kernel void kernel_mul_mv_q5_K_f32(
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q5_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
@ -4459,9 +4460,9 @@ void kernel_mul_mv_q6_K_f32_impl(
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const uint8_t kmask1 = 0x03;
const uint8_t kmask2 = 0x0C;
@ -4536,9 +4537,9 @@ kernel void kernel_mul_mv_q6_K_f32(
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q6_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
@ -4552,9 +4553,9 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
@ -4647,9 +4648,9 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq2_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
@ -4660,9 +4661,9 @@ void kernel_mul_mv_iq2_xs_f32_impl(
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
@ -4765,9 +4766,9 @@ kernel void kernel_mul_mv_iq2_xs_f32(
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq2_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
@ -4779,9 +4780,9 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
@ -4877,9 +4878,9 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq3_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
@ -4891,9 +4892,9 @@ void kernel_mul_mv_iq3_s_f32_impl(
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
@ -4989,9 +4990,9 @@ kernel void kernel_mul_mv_iq3_s_f32(
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq3_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
@ -5003,9 +5004,9 @@ void kernel_mul_mv_iq2_s_f32_impl(
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
@ -5102,9 +5103,9 @@ kernel void kernel_mul_mv_iq2_s_f32(
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq2_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
@ -5116,9 +5117,9 @@ void kernel_mul_mv_iq1_s_f32_impl(
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
@ -5202,9 +5203,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
@ -5297,9 +5298,9 @@ void kernel_mul_mv_iq4_nl_f32_impl(
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
const int nb = args.ne00/QK4_NL;
@ -5387,9 +5388,9 @@ void kernel_mul_mv_iq4_xs_f32_impl(
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
const int nb = args.ne00/QK_K;
@ -5477,9 +5478,9 @@ kernel void kernel_mul_mv_iq1_s_f32(
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq1_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
@ -5490,9 +5491,9 @@ kernel void kernel_mul_mv_iq1_m_f32(
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq1_m_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
@ -5504,9 +5505,9 @@ kernel void kernel_mul_mv_iq4_nl_f32(
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq4_nl_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
@ -5518,9 +5519,9 @@ kernel void kernel_mul_mv_iq4_xs_f32(
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq4_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
@ -5631,13 +5632,13 @@ kernel void kernel_mul_mm(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shared_memory [[threadgroup(0)]],
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup T * sa = (threadgroup T *)(shared_memory);
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
threadgroup T * sa = (threadgroup T *)(shmem);
threadgroup float * sb = (threadgroup float *)(shmem + 4096);
const int r0 = tgpig.y;
const int r1 = tgpig.x;
@ -5732,7 +5733,7 @@ kernel void kernel_mul_mm(
} else {
// block is smaller than 64x32, we should avoid writing data outside of the matrix
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float * temp_str = ((threadgroup float *) shared_memory) \
threadgroup float * temp_str = ((threadgroup float *) shmem) \
+ 32 * (sgitg&1) + (16 * (sgitg>>1))*BLOCK_SIZE_M;
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
@ -6066,8 +6067,8 @@ typedef void (kernel_mul_mv_impl_t)(
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig,
uint tiisg);
uint3 tgpig,
ushort tiisg);
typedef void (kernel_mul_mv2_impl_t)(
ggml_metal_kargs_mul_mv args,
@ -6075,9 +6076,9 @@ typedef void (kernel_mul_mv2_impl_t)(
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg);
uint3 tgpig,
ushort tiisg,
ushort sgitg);
template<kernel_mul_mv_impl_t impl_fn>
void mmv_fn(
@ -6086,10 +6087,10 @@ void mmv_fn(
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
uint tiitg,
uint tiisg,
uint sgitg) {
uint3 tgpig,
ushort tiitg,
ushort tiisg,
ushort sgitg) {
impl_fn(args, src0, src1, dst, tgpig, tiisg);
}
@ -6100,10 +6101,10 @@ void mmv_fn(
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
uint tiitg,
uint tiisg,
uint sgitg) {
uint3 tgpig,
ushort tiitg,
ushort tiisg,
ushort sgitg) {
impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
@ -6117,10 +6118,10 @@ kernel void kernel_mul_mv_id(
device char * dst,
device const char * ids,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const int iid1 = tgpig.z/args.nei0;
const int idx = tgpig.z%args.nei0;