mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-08 09:41:45 +00:00
cont : thread counters style
This commit is contained in:
parent
0e323856a6
commit
eea1f7e532
@ -1632,8 +1632,8 @@ void mul_vec_q_n_f32_impl(
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
const int nb = args.ne00/QK4_0;
|
||||
|
||||
const int r0 = tgpig.x;
|
||||
@ -1707,8 +1707,8 @@ kernel void kernel_mul_mv_q4_0_f32(
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
@ -1718,8 +1718,8 @@ kernel void kernel_mul_mv_q4_1_f32(
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
@ -1729,8 +1729,8 @@ kernel void kernel_mul_mv_q5_0_f32(
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
@ -1740,8 +1740,8 @@ kernel void kernel_mul_mv_q5_1_f32(
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
@ -1755,8 +1755,8 @@ void kernel_mul_mv_q8_0_f32_impl(
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
const int nr = N_DST;
|
||||
const int nsg = N_SIMDGROUP;
|
||||
const int nw = N_SIMDWIDTH;
|
||||
@ -1791,7 +1791,7 @@ void kernel_mul_mv_q8_0_f32_impl(
|
||||
const int ix = tiisg/4;
|
||||
const int il = tiisg%4;
|
||||
|
||||
device const float * yb = y + ix * QK8_0 + NB_Q8_0*il;
|
||||
device const float * yb = y + ix*QK8_0 + il*NB_Q8_0;
|
||||
|
||||
// each thread in a SIMD group deals with NB_Q8_0 quants at a time
|
||||
for (int ib = ix; ib < nb; ib += nw/4) {
|
||||
@ -1800,7 +1800,7 @@ void kernel_mul_mv_q8_0_f32_impl(
|
||||
}
|
||||
|
||||
for (int row = 0; row < nr; row++) {
|
||||
device const int8_t * qs = ax[row][ib].qs + NB_Q8_0*il;
|
||||
device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
|
||||
float sumq = 0.f;
|
||||
for (int iq = 0; iq < NB_Q8_0; ++iq) {
|
||||
sumq += qs[iq] * yl[iq];
|
||||
@ -1808,13 +1808,14 @@ void kernel_mul_mv_q8_0_f32_impl(
|
||||
sumf[row] += sumq*ax[row][ib].d;
|
||||
}
|
||||
|
||||
yb += NB_Q8_0 * nw;
|
||||
yb += nw*NB_Q8_0;
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
||||
device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < nr; ++row) {
|
||||
const float tot = simd_sum(sumf[row]);
|
||||
|
||||
if (tiisg == 0 && first_row + row < args.ne01) {
|
||||
dst_f32[first_row + row] = tot;
|
||||
}
|
||||
@ -1828,8 +1829,8 @@ kernel void kernel_mul_mv_q8_0_f32(
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
@ -1842,7 +1843,7 @@ void kernel_mul_mv_impl(
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig,
|
||||
uint tiisg) {
|
||||
ushort tiisg) {
|
||||
const int64_t r0 = tgpig.x;
|
||||
const int64_t rb = tgpig.y*N_MV_T_T;
|
||||
const int64_t im = tgpig.z;
|
||||
@ -1911,7 +1912,7 @@ kernel void kernel_mul_mv(
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||
ushort tiisg[[thread_index_in_simdgroup]]) {
|
||||
kernel_mul_mv_impl<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(
|
||||
args,
|
||||
src0,
|
||||
@ -1938,7 +1939,7 @@ kernel void kernel_mul_mv_1row(
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||
ushort tiisg[[thread_index_in_simdgroup]]) {
|
||||
|
||||
const int64_t r0 = tgpig.x;
|
||||
const int64_t r1 = tgpig.y;
|
||||
@ -1994,7 +1995,7 @@ kernel void kernel_mul_mv_l4(
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||
ushort tiisg[[thread_index_in_simdgroup]]) {
|
||||
|
||||
const int nrows = args.ne11;
|
||||
const int64_t r0 = tgpig.x;
|
||||
@ -3933,8 +3934,8 @@ void kernel_mul_mv_q2_K_f32_impl(
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
|
||||
const int nb = args.ne00/QK_K;
|
||||
const int r0 = tgpig.x;
|
||||
@ -4024,8 +4025,8 @@ kernel void kernel_mul_mv_q2_K_f32(
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_q2_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
@ -4038,8 +4039,8 @@ void kernel_mul_mv_q3_K_f32_impl(
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
|
||||
const int nb = args.ne00/QK_K;
|
||||
|
||||
@ -4187,8 +4188,8 @@ kernel void kernel_mul_mv_q3_K_f32(
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_q3_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
@ -4201,8 +4202,8 @@ void kernel_mul_mv_q4_K_f32_impl(
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
|
||||
const uint16_t kmask1 = 0x3f3f;
|
||||
const uint16_t kmask2 = 0x0f0f;
|
||||
@ -4305,8 +4306,8 @@ kernel void kernel_mul_mv_q4_K_f32(
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_q4_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
@ -4319,8 +4320,8 @@ void kernel_mul_mv_q5_K_f32_impl(
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
|
||||
const int nb = args.ne00/QK_K;
|
||||
|
||||
@ -4437,8 +4438,8 @@ kernel void kernel_mul_mv_q5_K_f32(
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_q5_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
@ -4451,8 +4452,8 @@ void kernel_mul_mv_q6_K_f32_impl(
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
|
||||
const uint8_t kmask1 = 0x03;
|
||||
const uint8_t kmask2 = 0x0C;
|
||||
@ -4528,8 +4529,8 @@ kernel void kernel_mul_mv_q6_K_f32(
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_q6_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
@ -4544,8 +4545,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
|
||||
const int nb = args.ne00/QK_K;
|
||||
const int r0 = tgpig.x;
|
||||
@ -4639,8 +4640,8 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
kernel_mul_mv_iq2_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
@ -4652,8 +4653,8 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
|
||||
const int nb = args.ne00/QK_K;
|
||||
const int r0 = tgpig.x;
|
||||
@ -4757,8 +4758,8 @@ kernel void kernel_mul_mv_iq2_xs_f32(
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq2_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
@ -4771,8 +4772,8 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
|
||||
const int nb = args.ne00/QK_K;
|
||||
const int r0 = tgpig.x;
|
||||
@ -4869,8 +4870,8 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq3_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
@ -4883,8 +4884,8 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
|
||||
const int nb = args.ne00/QK_K;
|
||||
const int r0 = tgpig.x;
|
||||
@ -4981,8 +4982,8 @@ kernel void kernel_mul_mv_iq3_s_f32(
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq3_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
@ -4995,8 +4996,8 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
|
||||
const int nb = args.ne00/QK_K;
|
||||
const int r0 = tgpig.x;
|
||||
@ -5094,8 +5095,8 @@ kernel void kernel_mul_mv_iq2_s_f32(
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq2_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
@ -5108,8 +5109,8 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
|
||||
const int nb = args.ne00/QK_K;
|
||||
const int r0 = tgpig.x;
|
||||
@ -5194,8 +5195,8 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
|
||||
const int nb = args.ne00/QK_K;
|
||||
const int r0 = tgpig.x;
|
||||
@ -5289,8 +5290,8 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
|
||||
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
||||
const int nb = args.ne00/QK4_NL;
|
||||
@ -5379,8 +5380,8 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
|
||||
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
||||
const int nb = args.ne00/QK_K;
|
||||
@ -5469,8 +5470,8 @@ kernel void kernel_mul_mv_iq1_s_f32(
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq1_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
@ -5482,8 +5483,8 @@ kernel void kernel_mul_mv_iq1_m_f32(
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq1_m_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
@ -5496,8 +5497,8 @@ kernel void kernel_mul_mv_iq4_nl_f32(
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq4_nl_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
@ -5510,8 +5511,8 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq4_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
@ -5622,13 +5623,13 @@ kernel void kernel_mul_mm(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup char * shared_memory [[threadgroup(0)]],
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
threadgroup T * sa = (threadgroup T *)(shared_memory);
|
||||
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
||||
threadgroup T * sa = (threadgroup T *)(shmem);
|
||||
threadgroup float * sb = (threadgroup float *)(shmem + 4096);
|
||||
|
||||
const int r0 = tgpig.y;
|
||||
const int r1 = tgpig.x;
|
||||
@ -5723,7 +5724,7 @@ kernel void kernel_mul_mm(
|
||||
} else {
|
||||
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup float * temp_str = ((threadgroup float *) shared_memory) \
|
||||
threadgroup float * temp_str = ((threadgroup float *) shmem) \
|
||||
+ 32 * (sgitg&1) + (16 * (sgitg>>1))*BLOCK_SIZE_M;
|
||||
for (short i = 0; i < 8; i++) {
|
||||
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
|
||||
@ -6058,7 +6059,7 @@ typedef void (kernel_mul_mv_impl_t)(
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig,
|
||||
uint tiisg);
|
||||
ushort tiisg);
|
||||
|
||||
typedef void (kernel_mul_mv2_impl_t)(
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
@ -6067,8 +6068,8 @@ typedef void (kernel_mul_mv2_impl_t)(
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg);
|
||||
ushort tiisg,
|
||||
ushort sgitg);
|
||||
|
||||
template<kernel_mul_mv_impl_t impl_fn>
|
||||
void mmv_fn(
|
||||
@ -6078,9 +6079,9 @@ void mmv_fn(
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiitg,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
ushort tiitg,
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
impl_fn(args, src0, src1, dst, tgpig, tiisg);
|
||||
}
|
||||
|
||||
@ -6092,9 +6093,9 @@ void mmv_fn(
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiitg,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
ushort tiitg,
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
@ -6109,9 +6110,9 @@ kernel void kernel_mul_mv_id(
|
||||
device const char * ids,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
const int iid1 = tgpig.z/args.nei0;
|
||||
const int idx = tgpig.z%args.nei0;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user