mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 17:21:46 +00:00
clean-up
This commit is contained in:
parent
022e5e90e9
commit
8f0ef15265
@ -3251,6 +3251,9 @@ static void ggml_metal_encode_node(
|
||||
GGML_ASSERT(nqptg % 8 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
// 2*(2*ncpsg + nqptg)*(nsg)
|
||||
// ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
|
||||
//
|
||||
// 16*32*(nsg)
|
||||
// the shared memory needed for the simdgroups to load the KV cache
|
||||
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
||||
|
@ -2846,14 +2846,14 @@ kernel void kernel_flash_attn_ext(
|
||||
const short NW = N_SIMDWIDTH;
|
||||
const short SH = (2*C + Q); // shared memory per simdgroup in (half)
|
||||
|
||||
const short TS = nsg*SH; // shared memory size per query in (s_t)
|
||||
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
|
||||
const short T = D + 2*TS; // shared memory size per query in (half)
|
||||
|
||||
threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
|
||||
threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation
|
||||
threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // reuse query data for accumulation
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention and diagonal matrix
|
||||
threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // same as above but in o4_t
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
|
||||
|
||||
threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
|
||||
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
|
||||
|
Loading…
Reference in New Issue
Block a user