diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 922d0d67a..d3497233e 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1977,10 +1977,10 @@ static void ggml_metal_encode_node( }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&args length:sizeof(args) atIndex:3]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; @@ -2181,10 +2181,10 @@ static void ggml_metal_encode_node( }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&args length:sizeof(args) atIndex:3]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || @@ -2499,11 +2499,11 @@ static void ggml_metal_encode_node( }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&args length:sizeof(args) atIndex:4]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; const int64_t _ne1 = 1; const int tgz = dst_rows; @@ -2748,15 +2748,15 @@ static void ggml_metal_encode_node( }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; if (id_src2 != nil) { - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBytes:&args length:sizeof(args) atIndex:4]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; @@ -3266,16 +3266,16 @@ static void ggml_metal_encode_node( }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; if (id_src3) { - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4]; } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - [encoder setBytes:&args length:sizeof(args) atIndex:5]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:5]; if (!use_vec_kernel) { // half8x8 kernel diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 5a669ef45..5e8d0f4b5 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1624,12 +1624,12 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre // quantizations where the block size is 32. It also does not // guard against the number of rows not being divisible by // N_DST, so this is another explicit assumption of the implementation. -template +template void mul_vec_q_n_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -1699,57 +1699,57 @@ void mul_vec_q_n_f32_impl( } kernel void kernel_mul_mv_q4_0_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q4_1_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_0_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_1_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } #define NB_Q8_0 8 -template +template void kernel_mul_mv_q8_0_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -1818,24 +1818,24 @@ void kernel_mul_mv_q8_0_f32_impl( [[host_name("kernel_mul_mv_q8_0_f32")]] kernel void kernel_mul_mv_q8_0_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } #define N_MV_T_T 4 -template +template void kernel_mul_mv_impl( + args_t args, device const char * src0, device const char * src1, device float * dst, - A args, uint3 tgpig, uint tiisg) { const int64_t r0 = tgpig.x; @@ -1899,17 +1899,17 @@ void kernel_mul_mv_impl( template kernel void kernel_mul_mv( + constant ggml_metal_kargs_mul_mv & args, device const char * src0, device const char * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { kernel_mul_mv_impl( + args, src0, src1, dst, - args, tgpig, tiisg); } @@ -1926,10 +1926,10 @@ template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv< template kernel void kernel_mul_mv_1row( + constant ggml_metal_kargs_mul_mv & args, device const char * src0, device const char * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { @@ -1982,10 +1982,10 @@ template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kerne // Assumes row size (ne00) is a multiple of 4 template kernel void kernel_mul_mv_l4( + constant ggml_metal_kargs_mul_mv & args, device const char * src0, device const char * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { @@ -2064,11 +2064,11 @@ static void rope_yarn_corr_dims( template kernel void kernel_rope_norm( + constant ggml_metal_kargs_rope & args, device const char * src0, device const char * src1, device const char * src2, device char * dst, - constant ggml_metal_kargs_rope & args, ushort tiitg[[thread_index_in_threadgroup]], ushort3 tptg [[threads_per_threadgroup]], uint3 tgpig[[threadgroup_position_in_grid]]) { @@ -2117,11 +2117,11 @@ kernel void kernel_rope_norm( template kernel void kernel_rope_neox( + constant ggml_metal_kargs_rope & args, device const char * src0, device const char * src1, device const char * src2, device char * dst, - constant ggml_metal_kargs_rope & args, ushort tiitg[[thread_index_in_threadgroup]], ushort3 tptg [[threads_per_threadgroup]], uint3 tgpig[[threadgroup_position_in_grid]]) { @@ -2558,13 +2558,13 @@ template< short KV = 8, // key/value processed per each simdgroup short C = 32> // cache items per threadgroup kernel void kernel_flash_attn_ext( + constant ggml_metal_kargs_flash_attn_ext & args, device const char * q, device const char * k, device const char * v, device const char * mask, device char * dst, - constant ggml_metal_kargs_flash_attn_ext & args, - threadgroup half * shared [[threadgroup(0)]], + threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], @@ -3050,13 +3050,13 @@ template< short Q = 1, // queries per threadgroup short C = 32> // cache items per threadgroup kernel void kernel_flash_attn_ext_vec( + constant ggml_metal_kargs_flash_attn_ext & args, device const char * q, device const char * k, device const char * v, device const char * mask, device char * dst, - constant ggml_metal_kargs_flash_attn_ext & args, - threadgroup half * shared [[threadgroup(0)]], + threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], @@ -3918,12 +3918,12 @@ kernel void kernel_concat( } } -template +template void kernel_mul_mv_q2_K_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4010,23 +4010,23 @@ void kernel_mul_mv_q2_K_f32_impl( [[host_name("kernel_mul_mv_q2_K_f32")]] kernel void kernel_mul_mv_q2_K_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q3_K_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4170,23 +4170,23 @@ void kernel_mul_mv_q3_K_f32_impl( [[host_name("kernel_mul_mv_q3_K_f32")]] kernel void kernel_mul_mv_q3_K_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q4_K_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4286,23 +4286,23 @@ void kernel_mul_mv_q4_K_f32_impl( [[host_name("kernel_mul_mv_q4_K_f32")]] kernel void kernel_mul_mv_q4_K_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q5_K_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4416,23 +4416,23 @@ void kernel_mul_mv_q5_K_f32_impl( [[host_name("kernel_mul_mv_q5_K_f32")]] kernel void kernel_mul_mv_q5_K_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q6_K_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4505,25 +4505,25 @@ void kernel_mul_mv_q6_K_f32_impl( [[host_name("kernel_mul_mv_q6_K_f32")]] kernel void kernel_mul_mv_q6_K_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } // ======================= "True" 2-bit -template +template void kernel_mul_mv_iq2_xxs_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4613,24 +4613,24 @@ void kernel_mul_mv_iq2_xxs_f32_impl( [[host_name("kernel_mul_mv_iq2_xxs_f32")]] kernel void kernel_mul_mv_iq2_xxs_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq2_xs_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4730,24 +4730,24 @@ void kernel_mul_mv_iq2_xs_f32_impl( [[host_name("kernel_mul_mv_iq2_xs_f32")]] kernel void kernel_mul_mv_iq2_xs_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq3_xxs_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4840,24 +4840,24 @@ void kernel_mul_mv_iq3_xxs_f32_impl( [[host_name("kernel_mul_mv_iq3_xxs_f32")]] kernel void kernel_mul_mv_iq3_xxs_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq3_s_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4950,24 +4950,24 @@ void kernel_mul_mv_iq3_s_f32_impl( [[host_name("kernel_mul_mv_iq3_s_f32")]] kernel void kernel_mul_mv_iq3_s_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq2_s_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -5061,24 +5061,24 @@ void kernel_mul_mv_iq2_s_f32_impl( [[host_name("kernel_mul_mv_iq2_s_f32")]] kernel void kernel_mul_mv_iq2_s_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq1_s_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_value, uint3 tgpig, uint tiisg, @@ -5157,12 +5157,12 @@ void kernel_mul_mv_iq1_s_f32_impl( } } -template +template void kernel_mul_mv_iq1_m_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_value, uint3 tgpig, uint tiisg, @@ -5250,12 +5250,12 @@ void kernel_mul_mv_iq1_m_f32_impl( } } -template +template void kernel_mul_mv_iq4_nl_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values_i8, uint3 tgpig, uint tiisg, @@ -5338,12 +5338,12 @@ void kernel_mul_mv_iq4_nl_f32_impl( } } -template +template void kernel_mul_mv_iq4_xs_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values_i8, uint3 tgpig, uint tiisg, @@ -5429,56 +5429,56 @@ void kernel_mul_mv_iq4_xs_f32_impl( [[host_name("kernel_mul_mv_iq1_s_f32")]] kernel void kernel_mul_mv_iq1_s_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq1_m_f32")]] kernel void kernel_mul_mv_iq1_m_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_nl_f32")]] kernel void kernel_mul_mv_iq4_nl_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_xs_f32")]] kernel void kernel_mul_mv_iq4_xs_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); } template @@ -5583,10 +5583,10 @@ kernel void kernel_get_rows_i32( // each block_q contains 16*nl weights template kernel void kernel_mul_mm( + constant ggml_metal_kargs_mul_mm & args, device const char * src0, device const char * src1, device char * dst, - constant ggml_metal_kargs_mul_mm & args, threadgroup char * shared_memory [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiitg[[thread_index_in_threadgroup]], @@ -6018,18 +6018,18 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel // typedef void (kernel_mul_mv_impl_t)( + ggml_metal_kargs_mul_mv args, device const char * src0, device const char * src1, device float * dst, - ggml_metal_kargs_mul_mv args, uint3 tgpig, uint tiisg); typedef void (kernel_mul_mv2_impl_t)( + ggml_metal_kargs_mul_mv args, device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -6037,41 +6037,41 @@ typedef void (kernel_mul_mv2_impl_t)( template void mmv_fn( + ggml_metal_kargs_mul_mv args, device const char * src0, device const char * src1, device float * dst, - ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiitg, uint tiisg, uint sgitg) { - impl_fn(src0, src1, dst, args, tgpig, tiisg); + impl_fn(args, src0, src1, dst, tgpig, tiisg); } template void mmv_fn( + ggml_metal_kargs_mul_mv args, device const char * src0, device const char * src1, device float * dst, - ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiitg, uint tiisg, uint sgitg) { - impl_fn(src0,(const device float *) src1, dst, args, shared_values, tgpig, tiisg, sgitg); + impl_fn(args, src0,(const device float *) src1, dst, shared_values, tgpig, tiisg, sgitg); } typedef decltype(mmv_fn>) mul_mv_impl_fn_t; template kernel void kernel_mul_mv_id( + constant ggml_metal_kargs_mul_mv_id & args, device const char * src0s, device const char * src1, device float * dst, device const char * ids, - constant ggml_metal_kargs_mul_mv_id & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], @@ -6116,10 +6116,10 @@ kernel void kernel_mul_mv_id( }; impl_fn( + args0, /* src0 */ src0_cur, /* src1 */ src1_cur, /* dst */ dst_cur, - args0, shared_values, tgpig, tiitg,