mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-10 18:51:45 +00:00
cont : shmem style
This commit is contained in:
parent
a1a201c1a9
commit
cacc4c225f
@ -1630,7 +1630,7 @@ void mul_vec_q_n_f32_impl(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
@ -1753,7 +1753,7 @@ void kernel_mul_mv_q8_0_f32_impl(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
@ -2571,7 +2571,7 @@ kernel void kernel_flash_attn_ext(
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device char * dst,
|
||||
threadgroup half * shared [[threadgroup(0)]],
|
||||
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 ntg[[threads_per_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
@ -2591,17 +2591,17 @@ kernel void kernel_flash_attn_ext(
|
||||
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
|
||||
const short T = D + 2*TS; // shared memory size per query in (half)
|
||||
|
||||
threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
|
||||
threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation
|
||||
threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // same as above but in o4_t
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
|
||||
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t
|
||||
threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*D); // reuse query data for accumulation
|
||||
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*D); // same as above but in o4_t
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
|
||||
|
||||
threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
|
||||
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
|
||||
threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
|
||||
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
|
||||
|
||||
threadgroup v_t * sv = (threadgroup v_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
|
||||
threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
|
||||
threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
|
||||
threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
|
||||
|
||||
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||
o8x8_t lo[D8];
|
||||
@ -3066,7 +3066,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device char * dst,
|
||||
threadgroup half * shared [[threadgroup(0)]],
|
||||
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 ntg[[threads_per_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
@ -3085,13 +3085,13 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
|
||||
const short T = D + nsg*SH; // shared memory size per query in (half)
|
||||
|
||||
//threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
|
||||
threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in q4x4_t
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shared + sgitg*SH + Q*D); // scratch buffer for attention
|
||||
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + sgitg*SH + Q*D); // same as above but in s4_t
|
||||
threadgroup half * sm = (threadgroup half *) (shared + sgitg*SH + C + Q*D); // scratch buffer for mask
|
||||
threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
|
||||
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t
|
||||
threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shmem_f16 + 0*D); // same as above but in q4x4_t
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*D); // scratch buffer for attention
|
||||
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*D); // same as above but in s4_t
|
||||
threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*D); // scratch buffer for mask
|
||||
threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D + Q*T); // scratch buffer for the results
|
||||
|
||||
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||
o4x4_t lo[D16/NL];
|
||||
@ -3940,7 +3940,7 @@ void kernel_mul_mv_q2_K_f32_impl(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
@ -4045,7 +4045,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
@ -4208,7 +4208,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
@ -4326,7 +4326,7 @@ void kernel_mul_mv_q5_K_f32_impl(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
@ -4458,7 +4458,7 @@ void kernel_mul_mv_q6_K_f32_impl(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
@ -4551,7 +4551,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
@ -4577,15 +4577,15 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
||||
|
||||
const int nb32 = nb * (QK_K / 32);
|
||||
|
||||
threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
|
||||
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
|
||||
threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);
|
||||
threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256);
|
||||
{
|
||||
int nval = 4;
|
||||
int pos = (32*sgitg + tiisg)*nval;
|
||||
for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i];
|
||||
for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xxs_grid[pos + i];
|
||||
nval = 2;
|
||||
pos = (32*sgitg + tiisg)*nval;
|
||||
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
|
||||
for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
@ -4615,8 +4615,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
||||
|
||||
float sum = 0;
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
|
||||
const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
|
||||
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);
|
||||
const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||
}
|
||||
@ -4646,12 +4646,11 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq2_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq2_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<typename args_t>
|
||||
@ -4660,7 +4659,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
@ -4686,15 +4685,15 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
||||
|
||||
const int nb32 = nb * (QK_K / 32);
|
||||
|
||||
threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
|
||||
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512);
|
||||
threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);
|
||||
threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 512);
|
||||
{
|
||||
int nval = 8;
|
||||
int pos = (32*sgitg + tiisg)*nval;
|
||||
for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i];
|
||||
for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xs_grid[pos + i];
|
||||
nval = 2;
|
||||
pos = (32*sgitg + tiisg)*nval;
|
||||
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
|
||||
for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
@ -4726,15 +4725,15 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
||||
|
||||
float sum1 = 0, sum2 = 0;
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
|
||||
const uint8_t signs = shared_signs[(q2[l] >> 9)];
|
||||
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
|
||||
const uint8_t signs = ssigns[(q2[l] >> 9)];
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||
}
|
||||
}
|
||||
for (int l = 2; l < 4; ++l) {
|
||||
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
|
||||
const uint8_t signs = shared_signs[(q2[l] >> 9)];
|
||||
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
|
||||
const uint8_t signs = ssigns[(q2[l] >> 9)];
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||
}
|
||||
@ -4765,12 +4764,12 @@ kernel void kernel_mul_mv_iq2_xs_f32(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq2_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq2_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template <typename args_t>
|
||||
@ -4779,7 +4778,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
@ -4805,15 +4804,15 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
||||
|
||||
const int nb32 = nb * (QK_K / 32);
|
||||
|
||||
threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
|
||||
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
|
||||
threadgroup uint32_t * svalues = (threadgroup uint32_t *)(shmem);
|
||||
threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256);
|
||||
{
|
||||
int nval = 4;
|
||||
int pos = (32*sgitg + tiisg)*nval;
|
||||
for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i];
|
||||
for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3xxs_grid[pos + i];
|
||||
nval = 2;
|
||||
pos = (32*sgitg + tiisg)*nval;
|
||||
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
|
||||
for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
@ -4843,9 +4842,9 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
||||
|
||||
float2 sum = {0};
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]);
|
||||
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]);
|
||||
const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
|
||||
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]);
|
||||
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]);
|
||||
const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
||||
sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
||||
@ -4877,12 +4876,12 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq3_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq3_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<typename args_t>
|
||||
@ -4891,7 +4890,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
@ -4917,11 +4916,11 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
||||
|
||||
const int nb32 = nb * (QK_K / 32);
|
||||
|
||||
threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
|
||||
threadgroup uint32_t * svalues = (threadgroup uint32_t *) shmem;
|
||||
{
|
||||
int nval = 8;
|
||||
int pos = (32*sgitg + tiisg)*nval;
|
||||
for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i];
|
||||
for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3s_grid[pos + i];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
@ -4952,8 +4951,8 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
||||
|
||||
float2 sum = {0};
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values;
|
||||
const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values;
|
||||
const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues;
|
||||
const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues;
|
||||
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
|
||||
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
@ -4989,12 +4988,12 @@ kernel void kernel_mul_mv_iq3_s_f32(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq3_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq3_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template <typename args_t>
|
||||
@ -5003,7 +5002,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
@ -5029,11 +5028,11 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
||||
|
||||
const int nb32 = nb * (QK_K / 32);
|
||||
|
||||
//threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
|
||||
//threadgroup uint64_t * svalues = (threadgroup uint64_t *) shmem;
|
||||
//{
|
||||
// int nval = 32;
|
||||
// int pos = (32*sgitg + tiisg)*nval;
|
||||
// for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i];
|
||||
// for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2s_grid[pos + i];
|
||||
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
//}
|
||||
|
||||
@ -5065,8 +5064,8 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
||||
|
||||
float2 sum = {0};
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
//const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
|
||||
//const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
|
||||
//const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
|
||||
//const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
|
||||
constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
|
||||
constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
@ -5102,12 +5101,12 @@ kernel void kernel_mul_mv_iq2_s_f32(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq2_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq2_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<typename args_t>
|
||||
@ -5116,7 +5115,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_value,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
@ -5202,7 +5201,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_value,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
@ -5297,12 +5296,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values_i8,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
|
||||
threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
|
||||
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
||||
const int nb = args.ne00/QK4_NL;
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
@ -5321,7 +5320,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
const int ix = tiisg/2; // 0...15
|
||||
const int it = tiisg%2; // 0 or 1
|
||||
|
||||
shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
|
||||
shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float4 yl[4];
|
||||
@ -5349,16 +5348,16 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
aux32[0] = q4[0] | (q4[1] << 16);
|
||||
aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
|
||||
aux32[0] &= 0x0f0f0f0f;
|
||||
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
||||
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
||||
qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
|
||||
qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
|
||||
acc1 += yl[0] * qf1;
|
||||
acc2 += yl[1] * qf2;
|
||||
|
||||
aux32[0] = q4[2] | (q4[3] << 16);
|
||||
aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
|
||||
aux32[0] &= 0x0f0f0f0f;
|
||||
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
||||
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
||||
qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
|
||||
qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
|
||||
acc1 += yl[2] * qf1;
|
||||
acc2 += yl[3] * qf2;
|
||||
|
||||
@ -5387,12 +5386,12 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values_i8,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
|
||||
threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
|
||||
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
||||
const int nb = args.ne00/QK_K;
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
@ -5413,7 +5412,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
const int ib = it/2;
|
||||
const int il = it%2;
|
||||
|
||||
shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
|
||||
shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float4 yl[4];
|
||||
@ -5440,15 +5439,15 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
|
||||
aux32[0] = q4[0] & 0x0f0f0f0f;
|
||||
aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
|
||||
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
||||
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
||||
qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
|
||||
qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
|
||||
acc1 += yl[0] * qf1;
|
||||
acc2 += yl[1] * qf2;
|
||||
|
||||
aux32[0] = q4[1] & 0x0f0f0f0f;
|
||||
aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
|
||||
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
||||
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
||||
qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
|
||||
qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
|
||||
acc1 += yl[2] * qf1;
|
||||
acc2 += yl[3] * qf2;
|
||||
|
||||
@ -5504,12 +5503,12 @@ kernel void kernel_mul_mv_iq4_nl_f32(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq4_nl_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq4_nl_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_iq4_xs_f32")]]
|
||||
@ -5518,12 +5517,12 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq4_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq4_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
||||
@ -6075,7 +6074,7 @@ typedef void (kernel_mul_mv2_impl_t)(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
uint sgitg);
|
||||
@ -6086,7 +6085,7 @@ void mmv_fn(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiitg,
|
||||
uint tiisg,
|
||||
@ -6100,12 +6099,12 @@ void mmv_fn(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
uint tiitg,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
impl_fn(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
|
||||
impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
|
||||
@ -6117,7 +6116,7 @@ kernel void kernel_mul_mv_id(
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
device const char * ids,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
@ -6166,7 +6165,7 @@ kernel void kernel_mul_mv_id(
|
||||
/* src0 */ src0_cur,
|
||||
/* src1 */ src1_cur,
|
||||
/* dst */ dst_cur,
|
||||
shared_values,
|
||||
shmem,
|
||||
tgpig,
|
||||
tiitg,
|
||||
tiisg,
|
||||
|
Loading…
Reference in New Issue
Block a user