diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 10d59cb9f..aa7acc05c 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -3046,6 +3046,8 @@ static void ggml_metal_encode_node( bool use_vec_kernel = false; + // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0) + // for now avoiding mainly to keep the number of templates/kernels a bit lower if (ne01 >= 4 || (ne00%128 != 0)) { switch (src1->type) { case GGML_TYPE_F16: diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 7e1517414..a60e8d21c 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -3356,8 +3356,8 @@ kernel void kernel_flash_attn_ext_vec( const short D4 = D/4; const short D16 = D/16; const short NW = N_SIMDWIDTH; - const short NL = NW/4; - const short SH = 2*C; // shared memory per simdgroup + const short NL = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0 + const short SH = 2*C; // shared memory per simdgroup const short T = D + nsg*SH; // shared memory size per query in (half) @@ -3448,7 +3448,7 @@ kernel void kernel_flash_attn_ext_vec( // Q*K^T { - // each simdgroup processes 1 query and 4 keys + // each simdgroup processes 1 query and 4 (NW/NL) keys for (short cc = 0; cc < C/4; ++cc) { qk_t mqk = 0.0; @@ -3645,7 +3645,7 @@ kernel void kernel_flash_attn_ext_vec( half, half4, half4x4, \ half4x4 -typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; +typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16)