mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 03:44:35 +00:00
metal : improve decoding speed for batches of 2-16
This commit is contained in:
parent
f1782c68de
commit
99ed03a24a
22
ggml-metal.m
22
ggml-metal.m
@ -993,6 +993,26 @@ void ggml_metal_graph_compute(
|
|||||||
uint gqa = ne12/ne02;
|
uint gqa = ne12/ne02;
|
||||||
GGML_ASSERT(ne03 == ne13);
|
GGML_ASSERT(ne03 == ne13);
|
||||||
|
|
||||||
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
||||||
|
// to the matrix-vector kernel. the numbers below are measure on M2 Ultra
|
||||||
|
// not sure if this translates across all chips
|
||||||
|
int ne11_mm_min = 1;
|
||||||
|
|
||||||
|
switch (src0t) {
|
||||||
|
case GGML_TYPE_F16: ne11_mm_min = 2; break;
|
||||||
|
case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
|
||||||
|
case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
|
||||||
|
case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
|
||||||
|
case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
|
||||||
|
case GGML_TYPE_Q5_0: // not tested yet
|
||||||
|
case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
|
||||||
|
case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
|
||||||
|
case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
|
||||||
|
default: ne11_mm_min = 1; break;
|
||||||
|
}
|
||||||
|
|
||||||
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||||
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
||||||
if (!ggml_is_transposed(src0) &&
|
if (!ggml_is_transposed(src0) &&
|
||||||
@ -1000,7 +1020,7 @@ void ggml_metal_graph_compute(
|
|||||||
src1t == GGML_TYPE_F32 &&
|
src1t == GGML_TYPE_F32 &&
|
||||||
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
||||||
ne00%32 == 0 &&
|
ne00%32 == 0 &&
|
||||||
ne11 > 2) {
|
ne11 > ne11_mm_min) {
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
|
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
|
||||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
||||||
|
Loading…
Reference in New Issue
Block a user