metal : minor clean-up

This commit is contained in:
Georgi Gerganov 2024-11-07 21:29:22 +02:00
parent 7facc29d69
commit 2fccc8ac2d
No known key found for this signature in database
GPG Key ID: BF970631944C16B7
2 changed files with 20 additions and 17 deletions

View File

@ -3292,12 +3292,12 @@ static void ggml_metal_encode_node(
// ne00 + 2*ncpsg*(nsg) // ne00 + 2*ncpsg*(nsg)
// for each query, we load it as f16 in shared memory (ne00) // for each query, we load it as f16 in shared memory (ne00)
// and store the attention scores (nqptg x ncpsg) as f32 // and store the soft_max values and the mask
// //
// ne00*(nsg) // ne00*(nsg)
// each simdgroup has a full f16 head vector in shared mem to accumulate results // each simdgroup has a full f16 head vector in shared mem to accumulate results
// //
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 4*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16)) #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
int64_t nsgmax = 2; int64_t nsgmax = 2;

View File

@ -3355,14 +3355,14 @@ kernel void kernel_flash_attn_ext_vec(
const short NW4 = NW/4; const short NW4 = NW/4;
const short SH = 2*C; // shared memory per simdgroup const short SH = 2*C; // shared memory per simdgroup
const short T = D + 2*nsg*SH; // shared memory size per query in (half) const short T = D + nsg*SH; // shared memory size per query in (half)
//threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data //threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in q4x4_t threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in q4x4_t
threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention threadgroup s_t * ss = (threadgroup s_t *) (shared + sgitg*SH + Q*D); // scratch buffer for attention
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + 2*sgitg*SH + Q*D); // same as above but in s4_t threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + sgitg*SH + Q*D); // same as above but in s4_t
threadgroup half * sm = (threadgroup half *) (shared + 2*sgitg*SH + SH + Q*D); // scratch buffer for mask threadgroup half * sm = (threadgroup half *) (shared + sgitg*SH + C + Q*D); // scratch buffer for mask
threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
@ -3522,7 +3522,7 @@ kernel void kernel_flash_attn_ext_vec(
for (short cc = 0; cc < C/4; ++cc) { for (short cc = 0; cc < C/4; ++cc) {
device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
const v4x4_t ms(ss[4*cc + ty]); const s4x4_t ms(ss[4*cc + ty]);
#pragma unroll #pragma unroll
for (short ii = 0; ii < D16; ii += NW4) { for (short ii = 0; ii < D16; ii += NW4) {
@ -3531,7 +3531,7 @@ kernel void kernel_flash_attn_ext_vec(
v4x4_t mv; v4x4_t mv;
deq_v(pv4 + i/nl_v, i%nl_v, mv); deq_v(pv4 + i/nl_v, i%nl_v, mv);
lo[ii/NW4] += (o4x4_t)(mv*ms); lo[ii/NW4] += mv*ms;
} }
} }
} }
@ -3616,12 +3616,15 @@ kernel void kernel_flash_attn_ext_vec(
} }
} }
// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
//
#define FA_TYPES \ #define FA_TYPES \
half4, half4x4, \ half4, half4x4, \
half4x4, \ half4x4, \
half4x4, \ half4x4, \
float, \ float, \
float, float4, float4x4, \ half, half4, half4x4, \
half4x4 half4x4
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t; typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;