mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 12:10:18 +00:00
metal : reduce the kernel launches for ggml_mul_mat_id
This commit is contained in:
parent
7e2006b0c0
commit
8c5b66eeaa
50
ggml-metal.m
50
ggml-metal.m
@ -1495,6 +1495,9 @@ void ggml_metal_graph_compute(
|
|||||||
|
|
||||||
const int idx = ((int32_t *) dst->op_params)[0];
|
const int idx = ((int32_t *) dst->op_params)[0];
|
||||||
|
|
||||||
|
// batch size
|
||||||
|
GGML_ASSERT(ne01 == ne11);
|
||||||
|
|
||||||
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||||
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
||||||
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
||||||
@ -1515,19 +1518,25 @@ void ggml_metal_graph_compute(
|
|||||||
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
||||||
}
|
}
|
||||||
const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
|
const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
|
||||||
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6];
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
||||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
||||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
|
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
|
||||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
|
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
||||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
|
||||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
|
||||||
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
|
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
|
||||||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
|
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
|
||||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
|
||||||
[encoder setBytes:&idx length:sizeof(idx) atIndex:15];
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
|
||||||
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
||||||
|
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
|
||||||
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
||||||
|
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
|
||||||
|
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
|
||||||
|
[encoder setBytes:&idx length:sizeof(idx) atIndex:18];
|
||||||
// TODO: how to make this an array? read Metal docs
|
// TODO: how to make this an array? read Metal docs
|
||||||
for (int j = 0; j < n_as; ++j) {
|
for (int j = 0; j < n_as; ++j) {
|
||||||
struct ggml_tensor * src_cur = dst->src[2 + j];
|
struct ggml_tensor * src_cur = dst->src[2 + j];
|
||||||
@ -1535,18 +1544,19 @@ void ggml_metal_graph_compute(
|
|||||||
size_t offs_src_cur = 0;
|
size_t offs_src_cur = 0;
|
||||||
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
||||||
|
|
||||||
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
|
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
|
||||||
}
|
}
|
||||||
|
|
||||||
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||||
|
|
||||||
for (int64_t i01 = 0; i01 < src0->ne[1]; i01++) {
|
[encoder dispatchThreadgroups:MTLSizeMake( (1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 + i01*nb01 atIndex:0];
|
//[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 + i01*nb11 atIndex:1];
|
//for (int64_t i01 = 0; i01 < src0->ne[1]; i01++) {
|
||||||
[encoder setBuffer:id_dst offset:offs_dst + i01*nb1 atIndex:2];
|
// [encoder setBuffer:id_src0 offset:offs_src0 + i01*nb01 atIndex:0];
|
||||||
|
// [encoder setBuffer:id_src1 offset:offs_src1 + i01*nb11 atIndex:1];
|
||||||
|
// [encoder setBuffer:id_dst offset:offs_dst + i01*nb1 atIndex:2];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
//}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
|
@ -3474,19 +3474,22 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|||||||
|
|
||||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||||
kernel void kernel_mul_mm_id(
|
kernel void kernel_mul_mm_id(
|
||||||
device const int32_t * ids,
|
device const uchar * ids,
|
||||||
device const uchar * src1,
|
device const uchar * src1,
|
||||||
device float * dst,
|
device uchar * dst,
|
||||||
|
constant int64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
constant int64_t & nb01,
|
constant int64_t & nb01,
|
||||||
constant int64_t & nb02,
|
constant int64_t & nb02,
|
||||||
constant int64_t & ne12,
|
constant int64_t & ne12,
|
||||||
|
constant int64_t & ne13,
|
||||||
constant int64_t & nb10,
|
constant int64_t & nb10,
|
||||||
constant int64_t & nb11,
|
constant int64_t & nb11,
|
||||||
constant int64_t & nb12,
|
constant int64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
|
constant int64_t & nb1,
|
||||||
constant uint & r2,
|
constant uint & r2,
|
||||||
constant uint & r3,
|
constant uint & r3,
|
||||||
constant int & idx,
|
constant int & idx,
|
||||||
@ -3504,10 +3507,16 @@ kernel void kernel_mul_mm_id(
|
|||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||||
|
|
||||||
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||||
|
|
||||||
|
tgpig.z = tgpig.z%(ne12*ne13);
|
||||||
|
|
||||||
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||||
|
|
||||||
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
||||||
src0[ids[idx]],
|
src0[id],
|
||||||
src1,
|
src1 + bid*nb11,
|
||||||
dst,
|
(device float *) (dst + bid*nb1),
|
||||||
ne00,
|
ne00,
|
||||||
ne02,
|
ne02,
|
||||||
nb01,
|
nb01,
|
||||||
@ -3589,19 +3598,22 @@ template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|||||||
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
||||||
|
|
||||||
typedef void (mat_mm_id_t)(
|
typedef void (mat_mm_id_t)(
|
||||||
device const int32_t * ids,
|
device const uchar * ids,
|
||||||
device const uchar * src1,
|
device const uchar * src1,
|
||||||
device float * dst,
|
device uchar * dst,
|
||||||
|
constant int64_t & nbi1,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
constant int64_t & nb01,
|
constant int64_t & nb01,
|
||||||
constant int64_t & nb02,
|
constant int64_t & nb02,
|
||||||
constant int64_t & ne12,
|
constant int64_t & ne12,
|
||||||
|
constant int64_t & ne13,
|
||||||
constant int64_t & nb10,
|
constant int64_t & nb10,
|
||||||
constant int64_t & nb11,
|
constant int64_t & nb11,
|
||||||
constant int64_t & nb12,
|
constant int64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
|
constant int64_t & nb1,
|
||||||
constant uint & r2,
|
constant uint & r2,
|
||||||
constant uint & r3,
|
constant uint & r3,
|
||||||
constant int & idx,
|
constant int & idx,
|
||||||
|
Loading…
Reference in New Issue
Block a user