cont : shmem style

This commit is contained in:
Georgi Gerganov 2024-11-10 09:45:06 +02:00
parent a1a201c1a9
commit cacc4c225f
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -1630,7 +1630,7 @@ void mul_vec_q_n_f32_impl(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
@ -1753,7 +1753,7 @@ void kernel_mul_mv_q8_0_f32_impl(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
@ -2571,7 +2571,7 @@ kernel void kernel_flash_attn_ext(
device const char * v,
device const char * mask,
device char * dst,
threadgroup half * shared [[threadgroup(0)]],
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
@ -2591,17 +2591,17 @@ kernel void kernel_flash_attn_ext(
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
const short T = D + 2*TS; // shared memory size per query in (half)
threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation
threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // same as above but in o4_t
threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t
threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*D); // reuse query data for accumulation
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*D); // same as above but in o4_t
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
threadgroup v_t * sv = (threadgroup v_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
o8x8_t lo[D8];
@ -3066,7 +3066,7 @@ kernel void kernel_flash_attn_ext_vec(
device const char * v,
device const char * mask,
device char * dst,
threadgroup half * shared [[threadgroup(0)]],
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
@ -3085,13 +3085,13 @@ kernel void kernel_flash_attn_ext_vec(
const short T = D + nsg*SH; // shared memory size per query in (half)
//threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in q4x4_t
threadgroup s_t * ss = (threadgroup s_t *) (shared + sgitg*SH + Q*D); // scratch buffer for attention
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + sgitg*SH + Q*D); // same as above but in s4_t
threadgroup half * sm = (threadgroup half *) (shared + sgitg*SH + C + Q*D); // scratch buffer for mask
threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t
threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shmem_f16 + 0*D); // same as above but in q4x4_t
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*D); // scratch buffer for attention
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*D); // same as above but in s4_t
threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*D); // scratch buffer for mask
threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D + Q*T); // scratch buffer for the results
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
o4x4_t lo[D16/NL];
@ -3940,7 +3940,7 @@ void kernel_mul_mv_q2_K_f32_impl(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
@ -4045,7 +4045,7 @@ void kernel_mul_mv_q3_K_f32_impl(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
@ -4208,7 +4208,7 @@ void kernel_mul_mv_q4_K_f32_impl(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
@ -4326,7 +4326,7 @@ void kernel_mul_mv_q5_K_f32_impl(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
@ -4458,7 +4458,7 @@ void kernel_mul_mv_q6_K_f32_impl(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
@ -4551,7 +4551,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
@ -4577,15 +4577,15 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
const int nb32 = nb * (QK_K / 32);
threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);
threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256);
{
int nval = 4;
int pos = (32*sgitg + tiisg)*nval;
for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i];
for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xxs_grid[pos + i];
nval = 2;
pos = (32*sgitg + tiisg)*nval;
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
threadgroup_barrier(mem_flags::mem_threadgroup);
}
@ -4615,8 +4615,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
float sum = 0;
for (int l = 0; l < 4; ++l) {
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);
const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
for (int j = 0; j < 8; ++j) {
sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
@ -4646,12 +4646,11 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values [[threadgroup(0)]],
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq2_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
kernel_mul_mv_iq2_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<typename args_t>
@ -4660,7 +4659,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
@ -4686,15 +4685,15 @@ void kernel_mul_mv_iq2_xs_f32_impl(
const int nb32 = nb * (QK_K / 32);
threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512);
threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);
threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 512);
{
int nval = 8;
int pos = (32*sgitg + tiisg)*nval;
for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i];
for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xs_grid[pos + i];
nval = 2;
pos = (32*sgitg + tiisg)*nval;
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
threadgroup_barrier(mem_flags::mem_threadgroup);
}
@ -4726,15 +4725,15 @@ void kernel_mul_mv_iq2_xs_f32_impl(
float sum1 = 0, sum2 = 0;
for (int l = 0; l < 2; ++l) {
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
const uint8_t signs = shared_signs[(q2[l] >> 9)];
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
const uint8_t signs = ssigns[(q2[l] >> 9)];
for (int j = 0; j < 8; ++j) {
sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
}
for (int l = 2; l < 4; ++l) {
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
const uint8_t signs = shared_signs[(q2[l] >> 9)];
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
const uint8_t signs = ssigns[(q2[l] >> 9)];
for (int j = 0; j < 8; ++j) {
sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
@ -4765,12 +4764,12 @@ kernel void kernel_mul_mv_iq2_xs_f32(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values [[threadgroup(0)]],
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq2_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
kernel_mul_mv_iq2_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template <typename args_t>
@ -4779,7 +4778,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
@ -4805,15 +4804,15 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
const int nb32 = nb * (QK_K / 32);
threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
threadgroup uint32_t * svalues = (threadgroup uint32_t *)(shmem);
threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256);
{
int nval = 4;
int pos = (32*sgitg + tiisg)*nval;
for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i];
for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3xxs_grid[pos + i];
nval = 2;
pos = (32*sgitg + tiisg)*nval;
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
threadgroup_barrier(mem_flags::mem_threadgroup);
}
@ -4843,9 +4842,9 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
float2 sum = {0};
for (int l = 0; l < 4; ++l) {
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]);
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]);
const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]);
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]);
const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
for (int j = 0; j < 4; ++j) {
sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
@ -4877,12 +4876,12 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values [[threadgroup(0)]],
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq3_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
kernel_mul_mv_iq3_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<typename args_t>
@ -4891,7 +4890,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
@ -4917,11 +4916,11 @@ void kernel_mul_mv_iq3_s_f32_impl(
const int nb32 = nb * (QK_K / 32);
threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
threadgroup uint32_t * svalues = (threadgroup uint32_t *) shmem;
{
int nval = 8;
int pos = (32*sgitg + tiisg)*nval;
for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i];
for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3s_grid[pos + i];
threadgroup_barrier(mem_flags::mem_threadgroup);
}
@ -4952,8 +4951,8 @@ void kernel_mul_mv_iq3_s_f32_impl(
float2 sum = {0};
for (int l = 0; l < 4; ++l) {
const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values;
const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values;
const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues;
const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues;
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
for (int j = 0; j < 4; ++j) {
@ -4989,12 +4988,12 @@ kernel void kernel_mul_mv_iq3_s_f32(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values [[threadgroup(0)]],
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq3_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
kernel_mul_mv_iq3_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template <typename args_t>
@ -5003,7 +5002,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
@ -5029,11 +5028,11 @@ void kernel_mul_mv_iq2_s_f32_impl(
const int nb32 = nb * (QK_K / 32);
//threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
//threadgroup uint64_t * svalues = (threadgroup uint64_t *) shmem;
//{
// int nval = 32;
// int pos = (32*sgitg + tiisg)*nval;
// for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i];
// for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2s_grid[pos + i];
// threadgroup_barrier(mem_flags::mem_threadgroup);
//}
@ -5065,8 +5064,8 @@ void kernel_mul_mv_iq2_s_f32_impl(
float2 sum = {0};
for (int l = 0; l < 2; ++l) {
//const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
//const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
//const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
//const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
for (int j = 0; j < 8; ++j) {
@ -5102,12 +5101,12 @@ kernel void kernel_mul_mv_iq2_s_f32(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values [[threadgroup(0)]],
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq2_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
kernel_mul_mv_iq2_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<typename args_t>
@ -5116,7 +5115,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_value,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
@ -5202,7 +5201,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_value,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
@ -5297,12 +5296,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values_i8,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
const int nb = args.ne00/QK4_NL;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
@ -5321,7 +5320,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
const int ix = tiisg/2; // 0...15
const int it = tiisg%2; // 0 or 1
shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4];
@ -5349,16 +5348,16 @@ void kernel_mul_mv_iq4_nl_f32_impl(
aux32[0] = q4[0] | (q4[1] << 16);
aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
aux32[0] &= 0x0f0f0f0f;
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
acc1 += yl[0] * qf1;
acc2 += yl[1] * qf2;
aux32[0] = q4[2] | (q4[3] << 16);
aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
aux32[0] &= 0x0f0f0f0f;
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
acc1 += yl[2] * qf1;
acc2 += yl[3] * qf2;
@ -5387,12 +5386,12 @@ void kernel_mul_mv_iq4_xs_f32_impl(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values_i8,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg) {
threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
@ -5413,7 +5412,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
const int ib = it/2;
const int il = it%2;
shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4];
@ -5440,15 +5439,15 @@ void kernel_mul_mv_iq4_xs_f32_impl(
aux32[0] = q4[0] & 0x0f0f0f0f;
aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
acc1 += yl[0] * qf1;
acc2 += yl[1] * qf2;
aux32[0] = q4[1] & 0x0f0f0f0f;
aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
acc1 += yl[2] * qf1;
acc2 += yl[3] * qf2;
@ -5504,12 +5503,12 @@ kernel void kernel_mul_mv_iq4_nl_f32(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values [[threadgroup(0)]],
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq4_nl_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
kernel_mul_mv_iq4_nl_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
[[host_name("kernel_mul_mv_iq4_xs_f32")]]
@ -5518,12 +5517,12 @@ kernel void kernel_mul_mv_iq4_xs_f32(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values [[threadgroup(0)]],
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq4_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
kernel_mul_mv_iq4_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
@ -6075,7 +6074,7 @@ typedef void (kernel_mul_mv2_impl_t)(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
threadgroup char * shmem,
uint3 tgpig,
uint tiisg,
uint sgitg);
@ -6086,11 +6085,11 @@ void mmv_fn(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiitg,
uint tiisg,
uint sgitg) {
threadgroup char * shmem,
uint3 tgpig,
uint tiitg,
uint tiisg,
uint sgitg) {
impl_fn(args, src0, src1, dst, tgpig, tiisg);
}
@ -6100,12 +6099,12 @@ void mmv_fn(
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiitg,
uint tiisg,
uint sgitg) {
impl_fn(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
threadgroup char * shmem,
uint3 tgpig,
uint tiitg,
uint tiisg,
uint sgitg) {
impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
@ -6117,11 +6116,11 @@ kernel void kernel_mul_mv_id(
device const char * src1,
device char * dst,
device const char * ids,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int iid1 = tgpig.z/args.nei0;
const int idx = tgpig.z%args.nei0;
@ -6166,7 +6165,7 @@ kernel void kernel_mul_mv_id(
/* src0 */ src0_cur,
/* src1 */ src1_cur,
/* dst */ dst_cur,
shared_values,
shmem,
tgpig,
tiitg,
tiisg,