cont : args is first argument

This commit is contained in:
Georgi Gerganov 2024-11-10 08:47:30 +02:00
parent b65e4c1e10
commit c81640a5fc
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 117 additions and 117 deletions

View File

@ -1977,10 +1977,10 @@ static void ggml_metal_encode_node(
}; };
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBytes:&args length:sizeof(args) atIndex:3]; [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
[encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
@ -2181,10 +2181,10 @@ static void ggml_metal_encode_node(
}; };
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBytes:&args length:sizeof(args) atIndex:3]; [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
@ -2499,11 +2499,11 @@ static void ggml_metal_encode_node(
}; };
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
[encoder setBytes:&args length:sizeof(args) atIndex:4]; [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
const int64_t _ne1 = 1; const int64_t _ne1 = 1;
const int tgz = dst_rows; const int tgz = dst_rows;
@ -2748,15 +2748,15 @@ static void ggml_metal_encode_node(
}; };
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
if (id_src2 != nil) { if (id_src2 != nil) {
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
} else { } else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
} }
[encoder setBuffer:id_dst offset:offs_dst atIndex:3]; [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
[encoder setBytes:&args length:sizeof(args) atIndex:4];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;
@ -3266,16 +3266,16 @@ static void ggml_metal_encode_node(
}; };
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
if (id_src3) { if (id_src3) {
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4];
} else { } else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
} }
[encoder setBuffer:id_dst offset:offs_dst atIndex:4]; [encoder setBuffer:id_dst offset:offs_dst atIndex:5];
[encoder setBytes:&args length:sizeof(args) atIndex:5];
if (!use_vec_kernel) { if (!use_vec_kernel) {
// half8x8 kernel // half8x8 kernel

View File

@ -1624,12 +1624,12 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
// quantizations where the block size is 32. It also does not // quantizations where the block size is 32. It also does not
// guard against the number of rows not being divisible by // guard against the number of rows not being divisible by
// N_DST, so this is another explicit assumption of the implementation. // N_DST, so this is another explicit assumption of the implementation.
template<typename block_q_type, int nr, int nsg, int nw, typename A> template<typename block_q_type, int nr, int nsg, int nw, typename args_t>
void mul_vec_q_n_f32_impl( void mul_vec_q_n_f32_impl(
args_t args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
A args,
threadgroup int8_t * shared_values, threadgroup int8_t * shared_values,
uint3 tgpig, uint3 tgpig,
uint tiisg, uint tiisg,
@ -1699,57 +1699,57 @@ void mul_vec_q_n_f32_impl(
} }
kernel void kernel_mul_mv_q4_0_f32( kernel void kernel_mul_mv_q4_0_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant ggml_metal_kargs_mul_mv & args,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
} }
kernel void kernel_mul_mv_q4_1_f32( kernel void kernel_mul_mv_q4_1_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant ggml_metal_kargs_mul_mv & args,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
} }
kernel void kernel_mul_mv_q5_0_f32( kernel void kernel_mul_mv_q5_0_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant ggml_metal_kargs_mul_mv & args,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
} }
kernel void kernel_mul_mv_q5_1_f32( kernel void kernel_mul_mv_q5_1_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant ggml_metal_kargs_mul_mv & args,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
} }
#define NB_Q8_0 8 #define NB_Q8_0 8
template<typename A> template<typename args_t>
void kernel_mul_mv_q8_0_f32_impl( void kernel_mul_mv_q8_0_f32_impl(
args_t args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
A args,
threadgroup int8_t * shared_values, threadgroup int8_t * shared_values,
uint3 tgpig, uint3 tgpig,
uint tiisg, uint tiisg,
@ -1818,24 +1818,24 @@ void kernel_mul_mv_q8_0_f32_impl(
[[host_name("kernel_mul_mv_q8_0_f32")]] [[host_name("kernel_mul_mv_q8_0_f32")]]
kernel void kernel_mul_mv_q8_0_f32( kernel void kernel_mul_mv_q8_0_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant ggml_metal_kargs_mul_mv & args,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
} }
#define N_MV_T_T 4 #define N_MV_T_T 4
template<typename T0, typename T04, typename T1, typename T14, typename A> template<typename T0, typename T04, typename T1, typename T14, typename args_t>
void kernel_mul_mv_impl( void kernel_mul_mv_impl(
args_t args,
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device float * dst, device float * dst,
A args,
uint3 tgpig, uint3 tgpig,
uint tiisg) { uint tiisg) {
const int64_t r0 = tgpig.x; const int64_t r0 = tgpig.x;
@ -1899,17 +1899,17 @@ void kernel_mul_mv_impl(
template<typename T0, typename T04, typename T1, typename T14> template<typename T0, typename T04, typename T1, typename T14>
kernel void kernel_mul_mv( kernel void kernel_mul_mv(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device float * dst, device float * dst,
constant ggml_metal_kargs_mul_mv & args,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) { uint tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_impl<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>( kernel_mul_mv_impl<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(
args,
src0, src0,
src1, src1,
dst, dst,
args,
tgpig, tgpig,
tiisg); tiisg);
} }
@ -1926,10 +1926,10 @@ template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<
template<typename T, typename T4> template<typename T, typename T4>
kernel void kernel_mul_mv_1row( kernel void kernel_mul_mv_1row(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device float * dst, device float * dst,
constant ggml_metal_kargs_mul_mv & args,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) { uint tiisg[[thread_index_in_simdgroup]]) {
@ -1982,10 +1982,10 @@ template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kerne
// Assumes row size (ne00) is a multiple of 4 // Assumes row size (ne00) is a multiple of 4
template<typename T, typename T4> template<typename T, typename T4>
kernel void kernel_mul_mv_l4( kernel void kernel_mul_mv_l4(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device float * dst, device float * dst,
constant ggml_metal_kargs_mul_mv & args,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) { uint tiisg[[thread_index_in_simdgroup]]) {
@ -2064,11 +2064,11 @@ static void rope_yarn_corr_dims(
template<typename T> template<typename T>
kernel void kernel_rope_norm( kernel void kernel_rope_norm(
constant ggml_metal_kargs_rope & args,
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device const char * src2, device const char * src2,
device char * dst, device char * dst,
constant ggml_metal_kargs_rope & args,
ushort tiitg[[thread_index_in_threadgroup]], ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]], ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) { uint3 tgpig[[threadgroup_position_in_grid]]) {
@ -2117,11 +2117,11 @@ kernel void kernel_rope_norm(
template<typename T> template<typename T>
kernel void kernel_rope_neox( kernel void kernel_rope_neox(
constant ggml_metal_kargs_rope & args,
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device const char * src2, device const char * src2,
device char * dst, device char * dst,
constant ggml_metal_kargs_rope & args,
ushort tiitg[[thread_index_in_threadgroup]], ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]], ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) { uint3 tgpig[[threadgroup_position_in_grid]]) {
@ -2558,13 +2558,13 @@ template<
short KV = 8, // key/value processed per each simdgroup short KV = 8, // key/value processed per each simdgroup
short C = 32> // cache items per threadgroup short C = 32> // cache items per threadgroup
kernel void kernel_flash_attn_ext( kernel void kernel_flash_attn_ext(
constant ggml_metal_kargs_flash_attn_ext & args,
device const char * q, device const char * q,
device const char * k, device const char * k,
device const char * v, device const char * v,
device const char * mask, device const char * mask,
device char * dst, device char * dst,
constant ggml_metal_kargs_flash_attn_ext & args, threadgroup half * shared [[threadgroup(0)]],
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 ntg[[threads_per_threadgroup]], ushort3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]], ushort tiisg[[thread_index_in_simdgroup]],
@ -3053,13 +3053,13 @@ template<
short Q = 1, // queries per threadgroup short Q = 1, // queries per threadgroup
short C = 32> // cache items per threadgroup short C = 32> // cache items per threadgroup
kernel void kernel_flash_attn_ext_vec( kernel void kernel_flash_attn_ext_vec(
constant ggml_metal_kargs_flash_attn_ext & args,
device const char * q, device const char * q,
device const char * k, device const char * k,
device const char * v, device const char * v,
device const char * mask, device const char * mask,
device char * dst, device char * dst,
constant ggml_metal_kargs_flash_attn_ext & args, threadgroup half * shared [[threadgroup(0)]],
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 ntg[[threads_per_threadgroup]], ushort3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]], ushort tiisg[[thread_index_in_simdgroup]],
@ -3927,12 +3927,12 @@ kernel void kernel_concat(
} }
} }
template<typename A> template<typename args_t>
void kernel_mul_mv_q2_K_f32_impl( void kernel_mul_mv_q2_K_f32_impl(
args_t args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
A args,
threadgroup int8_t * shared_values, threadgroup int8_t * shared_values,
uint3 tgpig, uint3 tgpig,
uint tiisg, uint tiisg,
@ -4019,23 +4019,23 @@ void kernel_mul_mv_q2_K_f32_impl(
[[host_name("kernel_mul_mv_q2_K_f32")]] [[host_name("kernel_mul_mv_q2_K_f32")]]
kernel void kernel_mul_mv_q2_K_f32( kernel void kernel_mul_mv_q2_K_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant ggml_metal_kargs_mul_mv & args,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q2_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); kernel_mul_mv_q2_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
} }
template<typename A> template<typename args_t>
void kernel_mul_mv_q3_K_f32_impl( void kernel_mul_mv_q3_K_f32_impl(
args_t args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
A args,
threadgroup int8_t * shared_values, threadgroup int8_t * shared_values,
uint3 tgpig, uint3 tgpig,
uint tiisg, uint tiisg,
@ -4179,23 +4179,23 @@ void kernel_mul_mv_q3_K_f32_impl(
[[host_name("kernel_mul_mv_q3_K_f32")]] [[host_name("kernel_mul_mv_q3_K_f32")]]
kernel void kernel_mul_mv_q3_K_f32( kernel void kernel_mul_mv_q3_K_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant ggml_metal_kargs_mul_mv & args,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q3_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); kernel_mul_mv_q3_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
} }
template<typename A> template<typename args_t>
void kernel_mul_mv_q4_K_f32_impl( void kernel_mul_mv_q4_K_f32_impl(
args_t args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
A args,
threadgroup int8_t * shared_values, threadgroup int8_t * shared_values,
uint3 tgpig, uint3 tgpig,
uint tiisg, uint tiisg,
@ -4295,23 +4295,23 @@ void kernel_mul_mv_q4_K_f32_impl(
[[host_name("kernel_mul_mv_q4_K_f32")]] [[host_name("kernel_mul_mv_q4_K_f32")]]
kernel void kernel_mul_mv_q4_K_f32( kernel void kernel_mul_mv_q4_K_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant ggml_metal_kargs_mul_mv & args,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q4_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); kernel_mul_mv_q4_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
} }
template<typename A> template<typename args_t>
void kernel_mul_mv_q5_K_f32_impl( void kernel_mul_mv_q5_K_f32_impl(
args_t args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
A args,
threadgroup int8_t * shared_values, threadgroup int8_t * shared_values,
uint3 tgpig, uint3 tgpig,
uint tiisg, uint tiisg,
@ -4425,23 +4425,23 @@ void kernel_mul_mv_q5_K_f32_impl(
[[host_name("kernel_mul_mv_q5_K_f32")]] [[host_name("kernel_mul_mv_q5_K_f32")]]
kernel void kernel_mul_mv_q5_K_f32( kernel void kernel_mul_mv_q5_K_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant ggml_metal_kargs_mul_mv & args,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q5_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); kernel_mul_mv_q5_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
} }
template <typename A> template <typename args_t>
void kernel_mul_mv_q6_K_f32_impl( void kernel_mul_mv_q6_K_f32_impl(
args_t args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
A args,
threadgroup int8_t * shared_values, threadgroup int8_t * shared_values,
uint3 tgpig, uint3 tgpig,
uint tiisg, uint tiisg,
@ -4514,25 +4514,25 @@ void kernel_mul_mv_q6_K_f32_impl(
[[host_name("kernel_mul_mv_q6_K_f32")]] [[host_name("kernel_mul_mv_q6_K_f32")]]
kernel void kernel_mul_mv_q6_K_f32( kernel void kernel_mul_mv_q6_K_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant ggml_metal_kargs_mul_mv & args,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q6_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); kernel_mul_mv_q6_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
} }
// ======================= "True" 2-bit // ======================= "True" 2-bit
template<typename A> template<typename args_t>
void kernel_mul_mv_iq2_xxs_f32_impl( void kernel_mul_mv_iq2_xxs_f32_impl(
args_t args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
A args,
threadgroup int8_t * shared_values, threadgroup int8_t * shared_values,
uint3 tgpig, uint3 tgpig,
uint tiisg, uint tiisg,
@ -4622,24 +4622,24 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
[[host_name("kernel_mul_mv_iq2_xxs_f32")]] [[host_name("kernel_mul_mv_iq2_xxs_f32")]]
kernel void kernel_mul_mv_iq2_xxs_f32( kernel void kernel_mul_mv_iq2_xxs_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant ggml_metal_kargs_mul_mv & args,
threadgroup int8_t * shared_values [[threadgroup(0)]], threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq2_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); kernel_mul_mv_iq2_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
} }
template<typename A> template<typename args_t>
void kernel_mul_mv_iq2_xs_f32_impl( void kernel_mul_mv_iq2_xs_f32_impl(
args_t args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
A args,
threadgroup int8_t * shared_values, threadgroup int8_t * shared_values,
uint3 tgpig, uint3 tgpig,
uint tiisg, uint tiisg,
@ -4739,24 +4739,24 @@ void kernel_mul_mv_iq2_xs_f32_impl(
[[host_name("kernel_mul_mv_iq2_xs_f32")]] [[host_name("kernel_mul_mv_iq2_xs_f32")]]
kernel void kernel_mul_mv_iq2_xs_f32( kernel void kernel_mul_mv_iq2_xs_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant ggml_metal_kargs_mul_mv & args,
threadgroup int8_t * shared_values [[threadgroup(0)]], threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq2_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); kernel_mul_mv_iq2_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
} }
template <typename A> template <typename args_t>
void kernel_mul_mv_iq3_xxs_f32_impl( void kernel_mul_mv_iq3_xxs_f32_impl(
args_t args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
A args,
threadgroup int8_t * shared_values, threadgroup int8_t * shared_values,
uint3 tgpig, uint3 tgpig,
uint tiisg, uint tiisg,
@ -4849,24 +4849,24 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
[[host_name("kernel_mul_mv_iq3_xxs_f32")]] [[host_name("kernel_mul_mv_iq3_xxs_f32")]]
kernel void kernel_mul_mv_iq3_xxs_f32( kernel void kernel_mul_mv_iq3_xxs_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant ggml_metal_kargs_mul_mv & args,
threadgroup int8_t * shared_values [[threadgroup(0)]], threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq3_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); kernel_mul_mv_iq3_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
} }
template<typename A> template<typename args_t>
void kernel_mul_mv_iq3_s_f32_impl( void kernel_mul_mv_iq3_s_f32_impl(
args_t args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
A args,
threadgroup int8_t * shared_values, threadgroup int8_t * shared_values,
uint3 tgpig, uint3 tgpig,
uint tiisg, uint tiisg,
@ -4959,24 +4959,24 @@ void kernel_mul_mv_iq3_s_f32_impl(
[[host_name("kernel_mul_mv_iq3_s_f32")]] [[host_name("kernel_mul_mv_iq3_s_f32")]]
kernel void kernel_mul_mv_iq3_s_f32( kernel void kernel_mul_mv_iq3_s_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant ggml_metal_kargs_mul_mv & args,
threadgroup int8_t * shared_values [[threadgroup(0)]], threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq3_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); kernel_mul_mv_iq3_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
} }
template <typename A> template <typename args_t>
void kernel_mul_mv_iq2_s_f32_impl( void kernel_mul_mv_iq2_s_f32_impl(
args_t args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
A args,
threadgroup int8_t * shared_values, threadgroup int8_t * shared_values,
uint3 tgpig, uint3 tgpig,
uint tiisg, uint tiisg,
@ -5070,24 +5070,24 @@ void kernel_mul_mv_iq2_s_f32_impl(
[[host_name("kernel_mul_mv_iq2_s_f32")]] [[host_name("kernel_mul_mv_iq2_s_f32")]]
kernel void kernel_mul_mv_iq2_s_f32( kernel void kernel_mul_mv_iq2_s_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant ggml_metal_kargs_mul_mv & args,
threadgroup int8_t * shared_values [[threadgroup(0)]], threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq2_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); kernel_mul_mv_iq2_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
} }
template<typename A> template<typename args_t>
void kernel_mul_mv_iq1_s_f32_impl( void kernel_mul_mv_iq1_s_f32_impl(
args_t args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
A args,
threadgroup int8_t * shared_value, threadgroup int8_t * shared_value,
uint3 tgpig, uint3 tgpig,
uint tiisg, uint tiisg,
@ -5166,12 +5166,12 @@ void kernel_mul_mv_iq1_s_f32_impl(
} }
} }
template <typename A> template <typename args_t>
void kernel_mul_mv_iq1_m_f32_impl( void kernel_mul_mv_iq1_m_f32_impl(
args_t args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
A args,
threadgroup int8_t * shared_value, threadgroup int8_t * shared_value,
uint3 tgpig, uint3 tgpig,
uint tiisg, uint tiisg,
@ -5259,12 +5259,12 @@ void kernel_mul_mv_iq1_m_f32_impl(
} }
} }
template<typename A> template<typename args_t>
void kernel_mul_mv_iq4_nl_f32_impl( void kernel_mul_mv_iq4_nl_f32_impl(
args_t args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
A args,
threadgroup int8_t * shared_values_i8, threadgroup int8_t * shared_values_i8,
uint3 tgpig, uint3 tgpig,
uint tiisg, uint tiisg,
@ -5347,12 +5347,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
} }
} }
template<typename A> template<typename args_t>
void kernel_mul_mv_iq4_xs_f32_impl( void kernel_mul_mv_iq4_xs_f32_impl(
args_t args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
A args,
threadgroup int8_t * shared_values_i8, threadgroup int8_t * shared_values_i8,
uint3 tgpig, uint3 tgpig,
uint tiisg, uint tiisg,
@ -5438,56 +5438,56 @@ void kernel_mul_mv_iq4_xs_f32_impl(
[[host_name("kernel_mul_mv_iq1_s_f32")]] [[host_name("kernel_mul_mv_iq1_s_f32")]]
kernel void kernel_mul_mv_iq1_s_f32( kernel void kernel_mul_mv_iq1_s_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant ggml_metal_kargs_mul_mv & args,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq1_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); kernel_mul_mv_iq1_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
} }
[[host_name("kernel_mul_mv_iq1_m_f32")]] [[host_name("kernel_mul_mv_iq1_m_f32")]]
kernel void kernel_mul_mv_iq1_m_f32( kernel void kernel_mul_mv_iq1_m_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant ggml_metal_kargs_mul_mv & args,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq1_m_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); kernel_mul_mv_iq1_m_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
} }
[[host_name("kernel_mul_mv_iq4_nl_f32")]] [[host_name("kernel_mul_mv_iq4_nl_f32")]]
kernel void kernel_mul_mv_iq4_nl_f32( kernel void kernel_mul_mv_iq4_nl_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant ggml_metal_kargs_mul_mv & args,
threadgroup int8_t * shared_values [[threadgroup(0)]], threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq4_nl_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); kernel_mul_mv_iq4_nl_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
} }
[[host_name("kernel_mul_mv_iq4_xs_f32")]] [[host_name("kernel_mul_mv_iq4_xs_f32")]]
kernel void kernel_mul_mv_iq4_xs_f32( kernel void kernel_mul_mv_iq4_xs_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant ggml_metal_kargs_mul_mv & args,
threadgroup int8_t * shared_values [[threadgroup(0)]], threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq4_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); kernel_mul_mv_iq4_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
} }
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)> template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
@ -5592,10 +5592,10 @@ kernel void kernel_get_rows_i32(
// each block_q contains 16*nl weights // each block_q contains 16*nl weights
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)> template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
kernel void kernel_mul_mm( kernel void kernel_mul_mm(
constant ggml_metal_kargs_mul_mm & args,
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device char * dst, device char * dst,
constant ggml_metal_kargs_mul_mm & args,
threadgroup char * shared_memory [[threadgroup(0)]], threadgroup char * shared_memory [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]], ushort tiitg[[thread_index_in_threadgroup]],
@ -6027,18 +6027,18 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel
// //
typedef void (kernel_mul_mv_impl_t)( typedef void (kernel_mul_mv_impl_t)(
ggml_metal_kargs_mul_mv args,
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device float * dst, device float * dst,
ggml_metal_kargs_mul_mv args,
uint3 tgpig, uint3 tgpig,
uint tiisg); uint tiisg);
typedef void (kernel_mul_mv2_impl_t)( typedef void (kernel_mul_mv2_impl_t)(
ggml_metal_kargs_mul_mv args,
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
device float * dst, device float * dst,
ggml_metal_kargs_mul_mv args,
threadgroup int8_t * shared_values, threadgroup int8_t * shared_values,
uint3 tgpig, uint3 tgpig,
uint tiisg, uint tiisg,
@ -6046,41 +6046,41 @@ typedef void (kernel_mul_mv2_impl_t)(
template<kernel_mul_mv_impl_t impl_fn> template<kernel_mul_mv_impl_t impl_fn>
void mmv_fn( void mmv_fn(
ggml_metal_kargs_mul_mv args,
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device float * dst, device float * dst,
ggml_metal_kargs_mul_mv args,
threadgroup int8_t * shared_values, threadgroup int8_t * shared_values,
uint3 tgpig, uint3 tgpig,
uint tiitg, uint tiitg,
uint tiisg, uint tiisg,
uint sgitg) { uint sgitg) {
impl_fn(src0, src1, dst, args, tgpig, tiisg); impl_fn(args, src0, src1, dst, tgpig, tiisg);
} }
template<kernel_mul_mv2_impl_t impl_fn> template<kernel_mul_mv2_impl_t impl_fn>
void mmv_fn( void mmv_fn(
ggml_metal_kargs_mul_mv args,
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device float * dst, device float * dst,
ggml_metal_kargs_mul_mv args,
threadgroup int8_t * shared_values, threadgroup int8_t * shared_values,
uint3 tgpig, uint3 tgpig,
uint tiitg, uint tiitg,
uint tiisg, uint tiisg,
uint sgitg) { uint sgitg) {
impl_fn(src0,(const device float *) src1, dst, args, shared_values, tgpig, tiisg, sgitg); impl_fn(args, src0,(const device float *) src1, dst, shared_values, tgpig, tiisg, sgitg);
} }
typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t; typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
template<mul_mv_impl_fn_t impl_fn> template<mul_mv_impl_fn_t impl_fn>
kernel void kernel_mul_mv_id( kernel void kernel_mul_mv_id(
constant ggml_metal_kargs_mul_mv_id & args,
device const char * src0s, device const char * src0s,
device const char * src1, device const char * src1,
device float * dst, device float * dst,
device const char * ids, device const char * ids,
constant ggml_metal_kargs_mul_mv_id & args,
threadgroup int8_t * shared_values [[threadgroup(0)]], threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]], uint tiitg[[thread_index_in_threadgroup]],
@ -6125,10 +6125,10 @@ kernel void kernel_mul_mv_id(
}; };
impl_fn( impl_fn(
args0,
/* src0 */ src0_cur, /* src0 */ src0_cur,
/* src1 */ src1_cur, /* src1 */ src1_cur,
/* dst */ dst_cur, /* dst */ dst_cur,
args0,
shared_values, shared_values,
tgpig, tgpig,
tiitg, tiitg,