diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 85e94df06..aecd6bc02 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -3251,6 +3251,9 @@ static void ggml_metal_encode_node( GGML_ASSERT(nqptg % 8 == 0); GGML_ASSERT(ncpsg % 32 == 0); + // 2*(2*ncpsg + nqptg)*(nsg) + // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float) + // // 16*32*(nsg) // the shared memory needed for the simdgroups to load the KV cache // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 675cba9b6..8eb3faa86 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2846,14 +2846,14 @@ kernel void kernel_flash_attn_ext( const short NW = N_SIMDWIDTH; const short SH = (2*C + Q); // shared memory per simdgroup in (half) - const short TS = nsg*SH; // shared memory size per query in (s_t) + const short TS = nsg*SH; // shared memory size per query in (s_t == float) const short T = D + 2*TS; // shared memory size per query in (half) threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation - threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // reuse query data for accumulation - threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention and diagonal matrix + threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // same as above but in o4_t + threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t