This commit is contained in:
Georgi Gerganov 2024-11-07 20:02:31 +02:00
parent 022e5e90e9
commit 8f0ef15265
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 6 additions and 3 deletions

View File

@ -3251,6 +3251,9 @@ static void ggml_metal_encode_node(
GGML_ASSERT(nqptg % 8 == 0);
GGML_ASSERT(ncpsg % 32 == 0);
// 2*(2*ncpsg + nqptg)*(nsg)
// ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
//
// 16*32*(nsg)
// the shared memory needed for the simdgroups to load the KV cache
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG

View File

@ -2846,14 +2846,14 @@ kernel void kernel_flash_attn_ext(
const short NW = N_SIMDWIDTH;
const short SH = (2*C + Q); // shared memory per simdgroup in (half)
const short TS = nsg*SH; // shared memory size per query in (s_t)
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
const short T = D + 2*TS; // shared memory size per query in (half)
threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation
threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // reuse query data for accumulation
threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention and diagonal matrix
threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // same as above but in o4_t
threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t