metal : improve decoding speed for batches of 2-16

This commit is contained in:
Georgi Gerganov 2023-10-07 12:59:24 +03:00
parent f1782c68de
commit 99ed03a24a
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -993,6 +993,26 @@ void ggml_metal_graph_compute(
uint gqa = ne12/ne02; uint gqa = ne12/ne02;
GGML_ASSERT(ne03 == ne13); GGML_ASSERT(ne03 == ne13);
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel. the numbers below are measure on M2 Ultra
// not sure if this translates across all chips
int ne11_mm_min = 1;
switch (src0t) {
case GGML_TYPE_F16: ne11_mm_min = 2; break;
case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
case GGML_TYPE_Q5_0: // not tested yet
case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
default: ne11_mm_min = 1; break;
}
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
if (!ggml_is_transposed(src0) && if (!ggml_is_transposed(src0) &&
@ -1000,7 +1020,7 @@ void ggml_metal_graph_compute(
src1t == GGML_TYPE_F32 && src1t == GGML_TYPE_F32 &&
[ctx->device supportsFamily:MTLGPUFamilyApple7] && [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
ne00%32 == 0 && ne00%32 == 0 &&
ne11 > 2) { ne11 > ne11_mm_min) {
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break; case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;