cont : use char ptr

This commit is contained in:
Georgi Gerganov 2024-11-10 09:26:53 +02:00
parent c81640a5fc
commit a1a201c1a9
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -1627,9 +1627,9 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
template<typename block_q_type, int nr, int nsg, int nw, typename args_t>
void mul_vec_q_n_f32_impl(
args_t args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -1648,8 +1648,8 @@ void mul_vec_q_n_f32_impl(
//const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
//device const block_q_type * x = (device const block_q_type *) ((device char *) src0 + offset0);
device const float * y = (device const float *) ((device char *) src1 + offset1);
//device const block_q_type * x = (device const block_q_type *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
// pointers to src0 rows
device const block_q_type * ax[nr];
@ -1690,19 +1690,22 @@ void mul_vec_q_n_f32_impl(
yb += QK4_0 * 16;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
for (int row = 0; row < nr; ++row) {
const float tot = simd_sum(sumf[row]);
if (tiisg == 0 && first_row + row < args.ne01) {
dst[im*args.ne0*args.ne1 + r1*args.ne0 + first_row + row] = tot;
dst_f32[first_row + row] = tot;
}
}
}
kernel void kernel_mul_mv_q4_0_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -1711,9 +1714,9 @@ kernel void kernel_mul_mv_q4_0_f32(
kernel void kernel_mul_mv_q4_1_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -1722,9 +1725,9 @@ kernel void kernel_mul_mv_q4_1_f32(
kernel void kernel_mul_mv_q5_0_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -1733,9 +1736,9 @@ kernel void kernel_mul_mv_q5_0_f32(
kernel void kernel_mul_mv_q5_1_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -1747,9 +1750,9 @@ kernel void kernel_mul_mv_q5_1_f32(
template<typename args_t>
void kernel_mul_mv_q8_0_f32_impl(
args_t args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -1771,8 +1774,8 @@ void kernel_mul_mv_q8_0_f32_impl(
//const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
//device const block_q8_0 * x = (device const block_q8_0 *) ((device char *) src0 + offset0);
device const float * y = (device const float *) ((device char *) src1 + offset1);
//device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
// pointers to src0 rows
device const block_q8_0 * ax[nr];
@ -1808,10 +1811,12 @@ void kernel_mul_mv_q8_0_f32_impl(
yb += NB_Q8_0 * nw;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
for (int row = 0; row < nr; ++row) {
const float tot = simd_sum(sumf[row]);
if (tiisg == 0 && first_row + row < args.ne01) {
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = tot;
dst_f32[first_row + row] = tot;
}
}
}
@ -1819,9 +1824,9 @@ void kernel_mul_mv_q8_0_f32_impl(
[[host_name("kernel_mul_mv_q8_0_f32")]]
kernel void kernel_mul_mv_q8_0_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -1835,7 +1840,7 @@ void kernel_mul_mv_impl(
args_t args,
device const char * src0,
device const char * src1,
device float * dst,
device char * dst,
uint3 tgpig,
uint tiisg) {
const int64_t r0 = tgpig.x;
@ -1849,6 +1854,8 @@ void kernel_mul_mv_impl(
device const T0 * x = (device const T0 *) (src0 + offset0);
device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1;
if (args.ne00 < 128) {
for (int row = 0; row < N_MV_T_T; ++row) {
int r1 = rb + row;
@ -1867,7 +1874,7 @@ void kernel_mul_mv_impl(
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum;
dst_f32[r1*args.ne0 + r0] = all_sum;
}
}
} else {
@ -1891,7 +1898,7 @@ void kernel_mul_mv_impl(
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum;
dst_f32[r1*args.ne0 + r0] = all_sum;
}
}
}
@ -1902,7 +1909,7 @@ kernel void kernel_mul_mv(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
device const char * src1,
device float * dst,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_impl<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(
@ -3930,9 +3937,9 @@ kernel void kernel_concat(
template<typename args_t>
void kernel_mul_mv_q2_K_f32_impl(
args_t args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -3951,8 +3958,8 @@ void kernel_mul_mv_q2_K_f32_impl(
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_q2_K * x = (device const block_q2_K *) ((device char *) src0 + offset0);
device const float * y = (device const float *) ((device char *) src1 + offset1);
device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
@ -4009,10 +4016,12 @@ void kernel_mul_mv_q2_K_f32_impl(
y4 += 4 * QK_K;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum;
dst_f32[first_row + row] = all_sum;
}
}
}
@ -4020,9 +4029,9 @@ void kernel_mul_mv_q2_K_f32_impl(
[[host_name("kernel_mul_mv_q2_K_f32")]]
kernel void kernel_mul_mv_q2_K_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -4033,9 +4042,9 @@ kernel void kernel_mul_mv_q2_K_f32(
template<typename args_t>
void kernel_mul_mv_q3_K_f32_impl(
args_t args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -4055,8 +4064,8 @@ void kernel_mul_mv_q3_K_f32_impl(
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_q3_K * x = (device const block_q3_K *) ((device char *) src0 + offset0);
device const float * yy = (device const float *) ((device char *) src1 + offset1);
device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0);
device const float * yy = (device const float *) (src1 + offset1);
float yl[32];
@ -4170,9 +4179,12 @@ void kernel_mul_mv_q3_K_f32_impl(
const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
sumf1[row] = simd_sum(sumf);
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
if (tiisg == 0) {
for (int row = 0; row < 2; ++row) {
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = sumf1[row];
dst_f32[first_row + row] = sumf1[row];
}
}
}
@ -4180,9 +4192,9 @@ void kernel_mul_mv_q3_K_f32_impl(
[[host_name("kernel_mul_mv_q3_K_f32")]]
kernel void kernel_mul_mv_q3_K_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -4193,9 +4205,9 @@ kernel void kernel_mul_mv_q3_K_f32(
template<typename args_t>
void kernel_mul_mv_q4_K_f32_impl(
args_t args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -4223,8 +4235,8 @@ void kernel_mul_mv_q4_K_f32_impl(
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_q4_K * x = (device const block_q4_K *) ((device char *) src0 + offset0);
device const float * y = (device const float *) ((device char *) src1 + offset1);
device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
float yl[16];
float yh[16];
@ -4285,10 +4297,12 @@ void kernel_mul_mv_q4_K_f32_impl(
y4 += 4 * QK_K;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum;
dst_f32[first_row + row] = all_sum;
}
}
}
@ -4296,9 +4310,9 @@ void kernel_mul_mv_q4_K_f32_impl(
[[host_name("kernel_mul_mv_q4_K_f32")]]
kernel void kernel_mul_mv_q4_K_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -4309,9 +4323,9 @@ kernel void kernel_mul_mv_q4_K_f32(
template<typename args_t>
void kernel_mul_mv_q5_K_f32_impl(
args_t args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -4331,8 +4345,8 @@ void kernel_mul_mv_q5_K_f32_impl(
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_q5_K * x = (device const block_q5_K *) ((device char *) src0 + offset0);
device const float * yy = (device const float *) ((device char *) src1 + offset1);
device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
device const float * yy = (device const float *) (src1 + offset1);
float sumf[2]={0.f};
@ -4415,10 +4429,12 @@ void kernel_mul_mv_q5_K_f32_impl(
y1 += 4 * QK_K;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
for (int row = 0; row < 2; ++row) {
const float tot = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = tot;
dst_f32[first_row + row] = tot;
}
}
}
@ -4426,9 +4442,9 @@ void kernel_mul_mv_q5_K_f32_impl(
[[host_name("kernel_mul_mv_q5_K_f32")]]
kernel void kernel_mul_mv_q5_K_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -4439,9 +4455,9 @@ kernel void kernel_mul_mv_q5_K_f32(
template <typename args_t>
void kernel_mul_mv_q6_K_f32_impl(
args_t args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -4466,8 +4482,8 @@ void kernel_mul_mv_q6_K_f32_impl(
const uint offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_q6_K * x = (device const block_q6_K *) ((device char *) src0 + offset0);
device const float * yy = (device const float *) ((device char *) src1 + offset1);
device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
device const float * yy = (device const float *) (src1 + offset1);
float sumf = 0;
@ -4506,18 +4522,20 @@ void kernel_mul_mv_q6_K_f32_impl(
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
const float tot = simd_sum(sumf);
if (tiisg == 0) {
dst[r1*args.ne0 + im*args.ne0*args.ne1 + row] = tot;
dst_f32[row] = tot;
}
}
[[host_name("kernel_mul_mv_q6_K_f32")]]
kernel void kernel_mul_mv_q6_K_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -4530,9 +4548,9 @@ kernel void kernel_mul_mv_q6_K_f32(
template<typename args_t>
void kernel_mul_mv_iq2_xxs_f32_impl(
args_t args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -4551,8 +4569,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq2_xxs * x = (device const block_iq2_xxs *) ((device char *) src0 + offset0);
device const float * y = (device const float *) ((device char *) src1 + offset1);
device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
@ -4612,10 +4630,12 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
y4 += 32 * 32;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum * 0.25f;
dst_f32[first_row + row] = all_sum * 0.25f;
}
}
}
@ -4623,9 +4643,9 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
[[host_name("kernel_mul_mv_iq2_xxs_f32")]]
kernel void kernel_mul_mv_iq2_xxs_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
@ -4637,9 +4657,9 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
template<typename args_t>
void kernel_mul_mv_iq2_xs_f32_impl(
args_t args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -4658,8 +4678,8 @@ void kernel_mul_mv_iq2_xs_f32_impl(
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq2_xs * x = (device const block_iq2_xs *) ((device char *) src0 + offset0);
device const float * y = (device const float *) ((device char *) src1 + offset1);
device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
@ -4729,10 +4749,12 @@ void kernel_mul_mv_iq2_xs_f32_impl(
y4 += 32 * 32;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum * 0.25f;
dst_f32[first_row + row] = all_sum * 0.25f;
}
}
}
@ -4740,9 +4762,9 @@ void kernel_mul_mv_iq2_xs_f32_impl(
[[host_name("kernel_mul_mv_iq2_xs_f32")]]
kernel void kernel_mul_mv_iq2_xs_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
@ -4754,9 +4776,9 @@ kernel void kernel_mul_mv_iq2_xs_f32(
template <typename args_t>
void kernel_mul_mv_iq3_xxs_f32_impl(
args_t args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -4775,8 +4797,8 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq3_xxs * x = (device const block_iq3_xxs *) ((device char *) src0 + offset0);
device const float * y = (device const float *) ((device char *) src1 + offset1);
device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
@ -4839,10 +4861,12 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
y4 += 32 * 32;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum * 0.5f;
dst_f32[first_row + row] = all_sum * 0.5f;
}
}
}
@ -4850,9 +4874,9 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
[[host_name("kernel_mul_mv_iq3_xxs_f32")]]
kernel void kernel_mul_mv_iq3_xxs_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
@ -4864,9 +4888,9 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
template<typename args_t>
void kernel_mul_mv_iq3_s_f32_impl(
args_t args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -4885,8 +4909,8 @@ void kernel_mul_mv_iq3_s_f32_impl(
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq3_s * x = (device const block_iq3_s *) ((device char *) src0 + offset0);
device const float * y = (device const float *) ((device char *) src1 + offset1);
device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
@ -4949,10 +4973,12 @@ void kernel_mul_mv_iq3_s_f32_impl(
y4 += 32 * 32;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum;
dst_f32[first_row + row] = all_sum;
}
}
}
@ -4960,9 +4986,9 @@ void kernel_mul_mv_iq3_s_f32_impl(
[[host_name("kernel_mul_mv_iq3_s_f32")]]
kernel void kernel_mul_mv_iq3_s_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
@ -4974,9 +5000,9 @@ kernel void kernel_mul_mv_iq3_s_f32(
template <typename args_t>
void kernel_mul_mv_iq2_s_f32_impl(
args_t args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -4995,8 +5021,8 @@ void kernel_mul_mv_iq2_s_f32_impl(
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq2_s * x = (device const block_iq2_s *) ((device char *) src0 + offset0);
device const float * y = (device const float *) ((device char *) src1 + offset1);
device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
@ -5060,10 +5086,12 @@ void kernel_mul_mv_iq2_s_f32_impl(
y4 += 32 * 32;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum * 0.25f;
dst_f32[first_row + row] = all_sum * 0.25f;
}
}
}
@ -5071,9 +5099,9 @@ void kernel_mul_mv_iq2_s_f32_impl(
[[host_name("kernel_mul_mv_iq2_s_f32")]]
kernel void kernel_mul_mv_iq2_s_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
@ -5085,9 +5113,9 @@ kernel void kernel_mul_mv_iq2_s_f32(
template<typename args_t>
void kernel_mul_mv_iq1_s_f32_impl(
args_t args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_value,
uint3 tgpig,
uint tiisg,
@ -5106,8 +5134,8 @@ void kernel_mul_mv_iq1_s_f32_impl(
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq1_s * x = (device const block_iq1_s *) ((device char *) src0 + offset0);
device const float * y = (device const float *) ((device char *) src1 + offset1);
device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
@ -5158,10 +5186,12 @@ void kernel_mul_mv_iq1_s_f32_impl(
y4 += 32 * 32;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum;
dst_f32[first_row + row] = all_sum;
}
}
}
@ -5169,9 +5199,9 @@ void kernel_mul_mv_iq1_s_f32_impl(
template <typename args_t>
void kernel_mul_mv_iq1_m_f32_impl(
args_t args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_value,
uint3 tgpig,
uint tiisg,
@ -5190,8 +5220,8 @@ void kernel_mul_mv_iq1_m_f32_impl(
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq1_m * x = (device const block_iq1_m *) ((device char *) src0 + offset0);
device const float * y = (device const float *) ((device char *) src1 + offset1);
device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
@ -5251,10 +5281,12 @@ void kernel_mul_mv_iq1_m_f32_impl(
y4 += 32 * 32;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum;
dst_f32[first_row + row] = all_sum;
}
}
}
@ -5262,9 +5294,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
template<typename args_t>
void kernel_mul_mv_iq4_nl_f32_impl(
args_t args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values_i8,
uint3 tgpig,
uint tiisg,
@ -5283,8 +5315,8 @@ void kernel_mul_mv_iq4_nl_f32_impl(
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq4_nl * x = (device const block_iq4_nl *) ((device char *) src0 + offset0);
device const float * y = (device const float *) ((device char *) src1 + offset1);
device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
const int ix = tiisg/2; // 0...15
const int it = tiisg%2; // 0 or 1
@ -5339,10 +5371,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
yb += 16 * QK4_NL;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum;
dst_f32[first_row + row] = all_sum;
}
}
}
@ -5350,9 +5384,9 @@ void kernel_mul_mv_iq4_nl_f32_impl(
template<typename args_t>
void kernel_mul_mv_iq4_xs_f32_impl(
args_t args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values_i8,
uint3 tgpig,
uint tiisg,
@ -5371,8 +5405,8 @@ void kernel_mul_mv_iq4_xs_f32_impl(
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq4_xs * x = (device const block_iq4_xs *) ((device char *) src0 + offset0);
device const float * y = (device const float *) ((device char *) src1 + offset1);
device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
const int ix = tiisg/16; // 0 or 1
const int it = tiisg%16; // 0...15
@ -5428,10 +5462,12 @@ void kernel_mul_mv_iq4_xs_f32_impl(
yb += 2 * QK_K;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
for (int row = 0; row < 2; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum;
dst_f32[first_row + row] = all_sum;
}
}
}
@ -5439,9 +5475,9 @@ void kernel_mul_mv_iq4_xs_f32_impl(
[[host_name("kernel_mul_mv_iq1_s_f32")]]
kernel void kernel_mul_mv_iq1_s_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -5452,9 +5488,9 @@ kernel void kernel_mul_mv_iq1_s_f32(
[[host_name("kernel_mul_mv_iq1_m_f32")]]
kernel void kernel_mul_mv_iq1_m_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -5465,9 +5501,9 @@ kernel void kernel_mul_mv_iq1_m_f32(
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
kernel void kernel_mul_mv_iq4_nl_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
@ -5479,9 +5515,9 @@ kernel void kernel_mul_mv_iq4_nl_f32(
[[host_name("kernel_mul_mv_iq4_xs_f32")]]
kernel void kernel_mul_mv_iq4_xs_f32(
constant ggml_metal_kargs_mul_mv & args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
@ -6030,15 +6066,15 @@ typedef void (kernel_mul_mv_impl_t)(
ggml_metal_kargs_mul_mv args,
device const char * src0,
device const char * src1,
device float * dst,
device char * dst,
uint3 tgpig,
uint tiisg);
typedef void (kernel_mul_mv2_impl_t)(
ggml_metal_kargs_mul_mv args,
device const void * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiisg,
@ -6049,7 +6085,7 @@ void mmv_fn(
ggml_metal_kargs_mul_mv args,
device const char * src0,
device const char * src1,
device float * dst,
device char * dst,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiitg,
@ -6063,13 +6099,13 @@ void mmv_fn(
ggml_metal_kargs_mul_mv args,
device const char * src0,
device const char * src1,
device float * dst,
device char * dst,
threadgroup int8_t * shared_values,
uint3 tgpig,
uint tiitg,
uint tiisg,
uint sgitg) {
impl_fn(args, src0,(const device float *) src1, dst, shared_values, tgpig, tiisg, sgitg);
impl_fn(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
}
typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
@ -6079,7 +6115,7 @@ kernel void kernel_mul_mv_id(
constant ggml_metal_kargs_mul_mv_id & args,
device const char * src0s,
device const char * src1,
device float * dst,
device char * dst,
device const char * ids,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
@ -6101,7 +6137,8 @@ kernel void kernel_mul_mv_id(
device const char * src0_cur = src0s + i02*args.nb02;
device const char * src1_cur = src1 + i11*args.nb11 + i12*args.nb12;
device float * dst_cur = dst + i1*args.ne0 + i2*args.ne1*args.ne0;
device char * dst_cur = dst + (i1*args.ne0 + i2*args.ne1*args.ne0)*sizeof(float);
ggml_metal_kargs_mul_mv args0 = {
/*.ne00 =*/ args.ne00,