mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 19:50:17 +00:00
metal : optimize ggml_mul_mat_id (faster Mixtral PP) (#4725)
* ggml : disable fast-math for Metal (cmake build only) ggml-ci * metal : fix Metal API debug warnings * cmake : add -fno-inline for Metal build (#4545) * metal : fix API debug warnings * metal : fix compile warnings * metal : use uint64_t for strides * cmake : rename option to LLAMA_METAL_SHADER_DEBUG * metal : fix mat-vec Q8_0 kernel for BS > 1 * metal : normalize mat-vec kernel signatures * cmake : respect LLAMA_QKK_64 option * metal : fix mat-vec Q4_K kernel for QK_K == 64 * metal : optimizing ggml_mul_mat_id (wip) * metal : minor fix * metal : opt mul_mm_id
This commit is contained in:
parent
0ef3ca2ac6
commit
f3f62f0d83
31
ggml-metal.m
31
ggml-metal.m
@ -1657,6 +1657,10 @@ void ggml_metal_graph_compute(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (ggml_is_quantized(src0t)) {
|
||||||
|
GGML_ASSERT(ne00 >= nth0*nth1);
|
||||||
|
}
|
||||||
|
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
@ -1715,6 +1719,9 @@ void ggml_metal_graph_compute(
|
|||||||
// TODO: make this more general
|
// TODO: make this more general
|
||||||
GGML_ASSERT(n_as <= 8);
|
GGML_ASSERT(n_as <= 8);
|
||||||
|
|
||||||
|
// max size of the src1ids array in the kernel stack
|
||||||
|
GGML_ASSERT(ne11 <= 512);
|
||||||
|
|
||||||
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
||||||
|
|
||||||
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
||||||
@ -1732,9 +1739,6 @@ void ggml_metal_graph_compute(
|
|||||||
GGML_ASSERT(!ggml_is_transposed(src2));
|
GGML_ASSERT(!ggml_is_transposed(src2));
|
||||||
GGML_ASSERT(!ggml_is_transposed(src1));
|
GGML_ASSERT(!ggml_is_transposed(src1));
|
||||||
|
|
||||||
GGML_ASSERT(ne20 % 32 == 0);
|
|
||||||
// !!!!!!!!! TODO: this assert is probably required but not sure!
|
|
||||||
//GGML_ASSERT(ne20 >= 64);
|
|
||||||
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||||
|
|
||||||
const uint r2 = ne12/ne22;
|
const uint r2 = ne12/ne22;
|
||||||
@ -1742,22 +1746,22 @@ void ggml_metal_graph_compute(
|
|||||||
|
|
||||||
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
||||||
// to the matrix-vector kernel
|
// to the matrix-vector kernel
|
||||||
int ne11_mm_min = 1;
|
int ne11_mm_min = n_as;
|
||||||
|
|
||||||
const int idx = ((int32_t *) dst->op_params)[0];
|
const int idx = ((int32_t *) dst->op_params)[0];
|
||||||
|
|
||||||
// batch size
|
// batch size
|
||||||
GGML_ASSERT(ne01 == ne11);
|
GGML_ASSERT(ne01 == ne11);
|
||||||
|
|
||||||
const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
|
|
||||||
|
|
||||||
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||||
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
||||||
// !!!
|
// !!!
|
||||||
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
|
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
|
||||||
// indirect matrix multiplication
|
// indirect matrix multiplication
|
||||||
// !!!
|
// !!!
|
||||||
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
|
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
||||||
|
ne20 % 32 == 0 && ne20 >= 64 &&
|
||||||
|
ne11 > ne11_mm_min) {
|
||||||
switch (src2->type) {
|
switch (src2->type) {
|
||||||
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
|
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
|
||||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
|
||||||
@ -1787,7 +1791,7 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
|
||||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
|
||||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
||||||
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
||||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
||||||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
|
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
|
||||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
|
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
|
||||||
@ -1805,8 +1809,7 @@ void ggml_metal_graph_compute(
|
|||||||
|
|
||||||
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||||
|
|
||||||
// TODO: processing one row at a time (ne11 -> 1) is not efficient
|
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
|
||||||
} else {
|
} else {
|
||||||
int nth0 = 32;
|
int nth0 = 32;
|
||||||
int nth1 = 1;
|
int nth1 = 1;
|
||||||
@ -1889,11 +1892,17 @@ void ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
||||||
GGML_ASSERT(false && "not implemented");
|
GGML_ASSERT(false && "not implemented");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (ggml_is_quantized(src2t)) {
|
||||||
|
GGML_ASSERT(ne20 >= nth0*nth1);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t _ne1 = 1; // kernels needs a reference in constant memory
|
||||||
|
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
|
205
ggml-metal.metal
205
ggml-metal.metal
@ -846,7 +846,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|||||||
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
||||||
//Note: This is a template, but strictly speaking it only applies to
|
//Note: This is a template, but strictly speaking it only applies to
|
||||||
// quantizations where the block size is 32. It also does not
|
// quantizations where the block size is 32. It also does not
|
||||||
// giard against the number of rows not being divisible by
|
// guard against the number of rows not being divisible by
|
||||||
// N_DST, so this is another explicit assumption of the implementation.
|
// N_DST, so this is another explicit assumption of the implementation.
|
||||||
template<typename block_q_type, int nr, int nsg, int nw>
|
template<typename block_q_type, int nr, int nsg, int nw>
|
||||||
void mul_vec_q_n_f32_impl(
|
void mul_vec_q_n_f32_impl(
|
||||||
@ -3973,6 +3973,131 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
|
||||||
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||||
|
void kernel_mul_mm_id_impl(
|
||||||
|
device const uchar * src0,
|
||||||
|
device const uchar * src1,
|
||||||
|
thread short * src1ids,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant int64_t & ne12,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
int64_t ne1,
|
||||||
|
constant uint & r2,
|
||||||
|
constant uint & r3,
|
||||||
|
threadgroup uchar * shared_memory,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tiitg[[thread_index_in_threadgroup]],
|
||||||
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
|
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
||||||
|
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
||||||
|
|
||||||
|
const uint r0 = tgpig.y;
|
||||||
|
const uint r1 = tgpig.x;
|
||||||
|
const uint im = tgpig.z;
|
||||||
|
|
||||||
|
if (r1 * BLOCK_SIZE_N >= ne1) return;
|
||||||
|
|
||||||
|
// if this block is of 64x32 shape or smaller
|
||||||
|
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
||||||
|
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
||||||
|
|
||||||
|
// a thread shouldn't load data outside of the matrix
|
||||||
|
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
||||||
|
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
||||||
|
|
||||||
|
simdgroup_half8x8 ma[4];
|
||||||
|
simdgroup_float8x8 mb[2];
|
||||||
|
simdgroup_float8x8 c_res[8];
|
||||||
|
for (int i = 0; i < 8; i++){
|
||||||
|
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
||||||
|
}
|
||||||
|
|
||||||
|
short il = (tiitg % THREAD_PER_ROW);
|
||||||
|
|
||||||
|
const uint i12 = im%ne12;
|
||||||
|
const uint i13 = im/ne12;
|
||||||
|
|
||||||
|
uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
|
||||||
|
ushort offset1 = il/nl;
|
||||||
|
|
||||||
|
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
||||||
|
device const float * y = (device const float *)(src1
|
||||||
|
+ nb12 * im
|
||||||
|
+ nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
|
||||||
|
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
||||||
|
|
||||||
|
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
||||||
|
// load data and store to threadgroup memory
|
||||||
|
half4x4 temp_a;
|
||||||
|
dequantize_func(x, il, temp_a);
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
for (int i = 0; i < 16; i++) {
|
||||||
|
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
||||||
|
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
||||||
|
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
||||||
|
}
|
||||||
|
|
||||||
|
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
|
||||||
|
|
||||||
|
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||||
|
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
||||||
|
y += BLOCK_SIZE_K;
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// load matrices from threadgroup memory and conduct outer products
|
||||||
|
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
||||||
|
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
||||||
|
|
||||||
|
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
|
||||||
|
}
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
for (int i = 0; i < 2; i++) {
|
||||||
|
simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
|
||||||
|
}
|
||||||
|
|
||||||
|
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
||||||
|
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
||||||
|
|
||||||
|
for (int i = 0; i < 8; i++){
|
||||||
|
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
|
||||||
|
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
|
||||||
|
if (sgitg == 0) {
|
||||||
|
for (int i = 0; i < n_rows; i++) {
|
||||||
|
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
||||||
|
*(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||||
kernel void kernel_mul_mm(device const uchar * src0,
|
kernel void kernel_mul_mm(device const uchar * src0,
|
||||||
device const uchar * src1,
|
device const uchar * src1,
|
||||||
@ -4019,7 +4144,7 @@ template<typename block_q, short nl, void (*dequantize_func)(device const block_
|
|||||||
kernel void kernel_mul_mm_id(
|
kernel void kernel_mul_mm_id(
|
||||||
device const uchar * ids,
|
device const uchar * ids,
|
||||||
device const uchar * src1,
|
device const uchar * src1,
|
||||||
device uchar * dst,
|
device float * dst,
|
||||||
constant uint64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
@ -4048,18 +4173,28 @@ kernel void kernel_mul_mm_id(
|
|||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiitg[[thread_index_in_threadgroup]],
|
uint tiitg[[thread_index_in_threadgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||||
|
|
||||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
// expert id
|
||||||
|
const int32_t id = tgpig.z/(ne12*ne13);
|
||||||
|
|
||||||
tgpig.z = tgpig.z%(ne12*ne13);
|
tgpig.z = tgpig.z%(ne12*ne13);
|
||||||
|
|
||||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
// row indices of src1 for expert id
|
||||||
|
int64_t _ne1 = 0;
|
||||||
|
short src1ids[512];
|
||||||
|
|
||||||
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||||
src0[id],
|
if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
|
||||||
src1 + bid*nb11,
|
src1ids[_ne1++] = i1;
|
||||||
(device float *) (dst + bid*nb1),
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
||||||
|
src0s[id],
|
||||||
|
src1,
|
||||||
|
src1ids,
|
||||||
|
dst,
|
||||||
ne00,
|
ne00,
|
||||||
ne02,
|
ne02,
|
||||||
nb01,
|
nb01,
|
||||||
@ -4069,7 +4204,7 @@ kernel void kernel_mul_mm_id(
|
|||||||
nb11,
|
nb11,
|
||||||
nb12,
|
nb12,
|
||||||
ne0,
|
ne0,
|
||||||
ne1,
|
_ne1,
|
||||||
r2,
|
r2,
|
||||||
r3,
|
r3,
|
||||||
shared_memory,
|
shared_memory,
|
||||||
@ -4158,7 +4293,7 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|||||||
typedef void (mat_mm_id_t)(
|
typedef void (mat_mm_id_t)(
|
||||||
device const uchar * ids,
|
device const uchar * ids,
|
||||||
device const uchar * src1,
|
device const uchar * src1,
|
||||||
device uchar * dst,
|
device float * dst,
|
||||||
constant uint64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
@ -4207,7 +4342,7 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|||||||
kernel void kernel_mul_mv_id_f32_f32(
|
kernel void kernel_mul_mv_id_f32_f32(
|
||||||
device const char * ids,
|
device const char * ids,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device uchar * dst,
|
device float * dst,
|
||||||
constant uint64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
@ -4251,7 +4386,7 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|||||||
kernel_mul_mv_f32_f32_impl(
|
kernel_mul_mv_f32_f32_impl(
|
||||||
src0[id],
|
src0[id],
|
||||||
src1 + bid*nb11,
|
src1 + bid*nb11,
|
||||||
(device float *) (dst + bid*nb1),
|
dst + bid*ne0,
|
||||||
ne00,
|
ne00,
|
||||||
ne01,
|
ne01,
|
||||||
ne02,
|
ne02,
|
||||||
@ -4276,7 +4411,7 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|||||||
kernel void kernel_mul_mv_id_f16_f32(
|
kernel void kernel_mul_mv_id_f16_f32(
|
||||||
device const char * ids,
|
device const char * ids,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device uchar * dst,
|
device float * dst,
|
||||||
constant uint64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
@ -4320,7 +4455,7 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|||||||
kernel_mul_mv_f16_f32_impl(
|
kernel_mul_mv_f16_f32_impl(
|
||||||
src0[id],
|
src0[id],
|
||||||
src1 + bid*nb11,
|
src1 + bid*nb11,
|
||||||
(device float *) (dst + bid*nb1),
|
dst + bid*ne0,
|
||||||
ne00,
|
ne00,
|
||||||
ne01,
|
ne01,
|
||||||
ne02,
|
ne02,
|
||||||
@ -4345,7 +4480,7 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|||||||
kernel void kernel_mul_mv_id_q8_0_f32(
|
kernel void kernel_mul_mv_id_q8_0_f32(
|
||||||
device const char * ids,
|
device const char * ids,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device uchar * dst,
|
device float * dst,
|
||||||
constant uint64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
@ -4389,7 +4524,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|||||||
kernel_mul_mv_q8_0_f32_impl(
|
kernel_mul_mv_q8_0_f32_impl(
|
||||||
src0[id],
|
src0[id],
|
||||||
(device const float *) (src1 + bid*nb11),
|
(device const float *) (src1 + bid*nb11),
|
||||||
(device float *) ( dst + bid*nb1),
|
dst + bid*ne0,
|
||||||
ne00,
|
ne00,
|
||||||
ne01,
|
ne01,
|
||||||
ne02,
|
ne02,
|
||||||
@ -4408,7 +4543,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|||||||
kernel void kernel_mul_mv_id_q4_0_f32(
|
kernel void kernel_mul_mv_id_q4_0_f32(
|
||||||
device const char * ids,
|
device const char * ids,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device uchar * dst,
|
device float * dst,
|
||||||
constant uint64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
@ -4452,7 +4587,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|||||||
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
||||||
src0[id],
|
src0[id],
|
||||||
(device const float *) (src1 + bid*nb11),
|
(device const float *) (src1 + bid*nb11),
|
||||||
(device float *) ( dst + bid*nb1),
|
dst + bid*ne0,
|
||||||
ne00,
|
ne00,
|
||||||
ne01,
|
ne01,
|
||||||
ne02,
|
ne02,
|
||||||
@ -4471,7 +4606,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|||||||
kernel void kernel_mul_mv_id_q4_1_f32(
|
kernel void kernel_mul_mv_id_q4_1_f32(
|
||||||
device const char * ids,
|
device const char * ids,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device uchar * dst,
|
device float * dst,
|
||||||
constant uint64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
@ -4515,7 +4650,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|||||||
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
||||||
src0[id],
|
src0[id],
|
||||||
(device const float *) (src1 + bid*nb11),
|
(device const float *) (src1 + bid*nb11),
|
||||||
(device float *) ( dst + bid*nb1),
|
dst + bid*ne0,
|
||||||
ne00,
|
ne00,
|
||||||
ne01,
|
ne01,
|
||||||
ne02,
|
ne02,
|
||||||
@ -4534,7 +4669,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|||||||
kernel void kernel_mul_mv_id_q5_0_f32(
|
kernel void kernel_mul_mv_id_q5_0_f32(
|
||||||
device const char * ids,
|
device const char * ids,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device uchar * dst,
|
device float * dst,
|
||||||
constant uint64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
@ -4578,7 +4713,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|||||||
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
||||||
src0[id],
|
src0[id],
|
||||||
(device const float *) (src1 + bid*nb11),
|
(device const float *) (src1 + bid*nb11),
|
||||||
(device float *) ( dst + bid*nb1),
|
dst + bid*ne0,
|
||||||
ne00,
|
ne00,
|
||||||
ne01,
|
ne01,
|
||||||
ne02,
|
ne02,
|
||||||
@ -4597,7 +4732,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|||||||
kernel void kernel_mul_mv_id_q5_1_f32(
|
kernel void kernel_mul_mv_id_q5_1_f32(
|
||||||
device const char * ids,
|
device const char * ids,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device uchar * dst,
|
device float * dst,
|
||||||
constant uint64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
@ -4641,7 +4776,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|||||||
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
||||||
src0[id],
|
src0[id],
|
||||||
(device const float *) (src1 + bid*nb11),
|
(device const float *) (src1 + bid*nb11),
|
||||||
(device float *) ( dst + bid*nb1),
|
dst + bid*ne0,
|
||||||
ne00,
|
ne00,
|
||||||
ne01,
|
ne01,
|
||||||
ne02,
|
ne02,
|
||||||
@ -4660,7 +4795,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|||||||
kernel void kernel_mul_mv_id_q2_K_f32(
|
kernel void kernel_mul_mv_id_q2_K_f32(
|
||||||
device const char * ids,
|
device const char * ids,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device uchar * dst,
|
device float * dst,
|
||||||
constant uint64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
@ -4704,7 +4839,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|||||||
kernel_mul_mv_q2_K_f32_impl(
|
kernel_mul_mv_q2_K_f32_impl(
|
||||||
src0[id],
|
src0[id],
|
||||||
(device const float *) (src1 + bid*nb11),
|
(device const float *) (src1 + bid*nb11),
|
||||||
(device float *) ( dst + bid*nb1),
|
dst + bid*ne0,
|
||||||
ne00,
|
ne00,
|
||||||
ne01,
|
ne01,
|
||||||
ne02,
|
ne02,
|
||||||
@ -4723,7 +4858,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|||||||
kernel void kernel_mul_mv_id_q3_K_f32(
|
kernel void kernel_mul_mv_id_q3_K_f32(
|
||||||
device const char * ids,
|
device const char * ids,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device uchar * dst,
|
device float * dst,
|
||||||
constant uint64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
@ -4767,7 +4902,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|||||||
kernel_mul_mv_q3_K_f32_impl(
|
kernel_mul_mv_q3_K_f32_impl(
|
||||||
src0[id],
|
src0[id],
|
||||||
(device const float *) (src1 + bid*nb11),
|
(device const float *) (src1 + bid*nb11),
|
||||||
(device float *) ( dst + bid*nb1),
|
dst + bid*ne0,
|
||||||
ne00,
|
ne00,
|
||||||
ne01,
|
ne01,
|
||||||
ne02,
|
ne02,
|
||||||
@ -4786,7 +4921,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|||||||
kernel void kernel_mul_mv_id_q4_K_f32(
|
kernel void kernel_mul_mv_id_q4_K_f32(
|
||||||
device const char * ids,
|
device const char * ids,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device uchar * dst,
|
device float * dst,
|
||||||
constant uint64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
@ -4830,7 +4965,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|||||||
kernel_mul_mv_q4_K_f32_impl(
|
kernel_mul_mv_q4_K_f32_impl(
|
||||||
src0[id],
|
src0[id],
|
||||||
(device const float *) (src1 + bid*nb11),
|
(device const float *) (src1 + bid*nb11),
|
||||||
(device float *) ( dst + bid*nb1),
|
dst + bid*ne0,
|
||||||
ne00,
|
ne00,
|
||||||
ne01,
|
ne01,
|
||||||
ne02,
|
ne02,
|
||||||
@ -4849,7 +4984,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|||||||
kernel void kernel_mul_mv_id_q5_K_f32(
|
kernel void kernel_mul_mv_id_q5_K_f32(
|
||||||
device const char * ids,
|
device const char * ids,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device uchar * dst,
|
device float * dst,
|
||||||
constant uint64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
@ -4893,7 +5028,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|||||||
kernel_mul_mv_q5_K_f32_impl(
|
kernel_mul_mv_q5_K_f32_impl(
|
||||||
src0[id],
|
src0[id],
|
||||||
(device const float *) (src1 + bid*nb11),
|
(device const float *) (src1 + bid*nb11),
|
||||||
(device float *) ( dst + bid*nb1),
|
dst + bid*ne0,
|
||||||
ne00,
|
ne00,
|
||||||
ne01,
|
ne01,
|
||||||
ne02,
|
ne02,
|
||||||
@ -4912,7 +5047,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|||||||
kernel void kernel_mul_mv_id_q6_K_f32(
|
kernel void kernel_mul_mv_id_q6_K_f32(
|
||||||
device const char * ids,
|
device const char * ids,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device uchar * dst,
|
device float * dst,
|
||||||
constant uint64_t & nbi1,
|
constant uint64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
@ -4956,7 +5091,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|||||||
kernel_mul_mv_q6_K_f32_impl(
|
kernel_mul_mv_q6_K_f32_impl(
|
||||||
src0[id],
|
src0[id],
|
||||||
(device const float *) (src1 + bid*nb11),
|
(device const float *) (src1 + bid*nb11),
|
||||||
(device float *) ( dst + bid*nb1),
|
dst + bid*ne0,
|
||||||
ne00,
|
ne00,
|
||||||
ne01,
|
ne01,
|
||||||
ne02,
|
ne02,
|
||||||
|
Loading…
Reference in New Issue
Block a user