cont : pass by reference

This commit is contained in:
Georgi Gerganov 2024-11-10 08:10:22 +02:00
parent c59a13d93f
commit b65e4c1e10
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -1624,12 +1624,12 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
// quantizations where the block size is 32. It also does not
// guard against the number of rows not being divisible by
// N_DST, so this is another explicit assumption of the implementation.
template<typename block_q_type, int nr, int nsg, int nw>
template<typename block_q_type, int nr, int nsg, int nw, typename A>
void mul_vec_q_n_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
ggml_metal_kargs_mul_mv args,
A args,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -1706,7 +1706,7 @@ kernel void kernel_mul_mv_q4_0_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q4_1_f32(
@ -1715,9 +1715,9 @@ kernel void kernel_mul_mv_q4_1_f32(
device float * dst,
constant ggml_metal_kargs_mul_mv & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q5_0_f32(
@ -1728,7 +1728,7 @@ kernel void kernel_mul_mv_q5_0_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q5_1_f32(
@ -1739,16 +1739,17 @@ kernel void kernel_mul_mv_q5_1_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
}
#define NB_Q8_0 8
template<typename A>
void kernel_mul_mv_q8_0_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
ggml_metal_kargs_mul_mv args,
A args,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -1824,17 +1825,17 @@ kernel void kernel_mul_mv_q8_0_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q8_0_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
}
#define N_MV_T_T 4
template<typename T0, typename T04, typename T1, typename T14>
template<typename T0, typename T04, typename T1, typename T14, typename A>
void kernel_mul_mv_impl(
device const char * src0,
device const char * src1,
device float * dst,
ggml_metal_kargs_mul_mv args,
A args,
uint3 tgpig,
uint tiisg) {
const int64_t r0 = tgpig.x;
@ -1904,7 +1905,7 @@ kernel void kernel_mul_mv(
constant ggml_metal_kargs_mul_mv & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_impl<T0, T04, T1, T14>(
kernel_mul_mv_impl<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(
src0,
src1,
dst,
@ -3926,11 +3927,12 @@ kernel void kernel_concat(
}
}
template<typename A>
void kernel_mul_mv_q2_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
ggml_metal_kargs_mul_mv args,
A args,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -4025,14 +4027,15 @@ kernel void kernel_mul_mv_q2_K_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
kernel_mul_mv_q2_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
}
template<typename A>
void kernel_mul_mv_q3_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
ggml_metal_kargs_mul_mv args,
A args,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -4184,14 +4187,15 @@ kernel void kernel_mul_mv_q3_K_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
kernel_mul_mv_q3_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
}
template<typename A>
void kernel_mul_mv_q4_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
ggml_metal_kargs_mul_mv args,
A args,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -4299,14 +4303,15 @@ kernel void kernel_mul_mv_q4_K_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
kernel_mul_mv_q4_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
}
template<typename A>
void kernel_mul_mv_q5_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
ggml_metal_kargs_mul_mv args,
A args,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -4428,14 +4433,15 @@ kernel void kernel_mul_mv_q5_K_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
kernel_mul_mv_q5_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
}
template <typename A>
void kernel_mul_mv_q6_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
ggml_metal_kargs_mul_mv args,
A args,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -4516,16 +4522,17 @@ kernel void kernel_mul_mv_q6_K_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
kernel_mul_mv_q6_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
}
// ======================= "True" 2-bit
template<typename A>
void kernel_mul_mv_iq2_xxs_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
ggml_metal_kargs_mul_mv args,
A args,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -4624,14 +4631,15 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
kernel_mul_mv_iq2_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
}
template<typename A>
void kernel_mul_mv_iq2_xs_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
ggml_metal_kargs_mul_mv args,
A args,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -4740,14 +4748,15 @@ kernel void kernel_mul_mv_iq2_xs_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
kernel_mul_mv_iq2_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
}
template <typename A>
void kernel_mul_mv_iq3_xxs_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
ggml_metal_kargs_mul_mv args,
A args,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -4849,14 +4858,15 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
kernel_mul_mv_iq3_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
}
template<typename A>
void kernel_mul_mv_iq3_s_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
ggml_metal_kargs_mul_mv args,
A args,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -4958,14 +4968,15 @@ kernel void kernel_mul_mv_iq3_s_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
kernel_mul_mv_iq3_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
}
template <typename A>
void kernel_mul_mv_iq2_s_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
ggml_metal_kargs_mul_mv args,
A args,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -5068,14 +5079,15 @@ kernel void kernel_mul_mv_iq2_s_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
kernel_mul_mv_iq2_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
}
template<typename A>
void kernel_mul_mv_iq1_s_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
ggml_metal_kargs_mul_mv args,
A args,
threadgroup int8_t * shared_value,
uint3 tgpig,
uint tiisg,
@ -5154,11 +5166,12 @@ void kernel_mul_mv_iq1_s_f32_impl(
}
}
template <typename A>
void kernel_mul_mv_iq1_m_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
ggml_metal_kargs_mul_mv args,
A args,
threadgroup int8_t * shared_value,
uint3 tgpig,
uint tiisg,
@ -5246,11 +5259,12 @@ void kernel_mul_mv_iq1_m_f32_impl(
}
}
template<typename A>
void kernel_mul_mv_iq4_nl_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
ggml_metal_kargs_mul_mv args,
A args,
threadgroup int8_t * shared_values_i8,
uint3 tgpig,
uint tiisg,
@ -5333,11 +5347,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
}
}
template<typename A>
void kernel_mul_mv_iq4_xs_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
ggml_metal_kargs_mul_mv args,
A args,
threadgroup int8_t * shared_values_i8,
uint3 tgpig,
uint tiisg,
@ -5431,7 +5446,7 @@ kernel void kernel_mul_mv_iq1_s_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
kernel_mul_mv_iq1_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
}
[[host_name("kernel_mul_mv_iq1_m_f32")]]
@ -5444,7 +5459,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
kernel_mul_mv_iq1_m_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
}
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
@ -5458,7 +5473,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
kernel_mul_mv_iq4_nl_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
}
[[host_name("kernel_mul_mv_iq4_xs_f32")]]
@ -5472,7 +5487,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
kernel_mul_mv_iq4_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
@ -6057,7 +6072,7 @@ void mmv_fn(
impl_fn(src0,(const device float *) src1, dst, args, shared_values, tgpig, tiisg, sgitg);
}
typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t;
typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
template<mul_mv_impl_fn_t impl_fn>
kernel void kernel_mul_mv_id(