This commit is contained in:
Georgi Gerganov 2024-11-06 21:06:56 +02:00
parent 0f7e8f389d
commit 01c7f11224
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -2756,11 +2756,24 @@ kernel void kernel_leaky_relu_f32(
// ref: https://arxiv.org/pdf/2307.08691.pdf // ref: https://arxiv.org/pdf/2307.08691.pdf
template< template<
typename block_q, typename q_t,
short nl, typename q4_t,
void (*dequantize_func)(device const block_q *, short, thread half4x4 &), typename q8x8_t,
typename k_t,
typename k4x4_t,
typename k8x8_t,
typename v_t,
typename v4x4_t,
typename v8x8_t,
typename s_t, // attention accumulation types typename s_t, // attention accumulation types
typename s8x8_t, typename s8x8_t,
typename o_t,
typename o8x8_t,
typename block_q,
short nl_k,
void (*deq_k)(device const block_q *, short, thread k4x4_t &),
short nl_v,
void (*deq_v)(device const block_q *, short, thread v4x4_t &),
short D, // head size short D, // head size
short Q = 8, // queries per threadgroup short Q = 8, // queries per threadgroup
short KV = 8, // key/value processed per each simdgroup short KV = 8, // key/value processed per each simdgroup
@ -2819,15 +2832,19 @@ kernel void kernel_flash_attn_ext(
const short TS = T/SF; // shared memory size per query in (s_t) const short TS = T/SF; // shared memory size per query in (s_t)
const short T4 = T/4; // shared memory size per query in (half4) const short T4 = T/4; // shared memory size per query in (half4)
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation
threadgroup s_t * ss = (threadgroup s_t *) (shared + SF*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix threadgroup s_t * ss = (threadgroup s_t *) (shared + SF*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K and V in shared memory threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in half4x4 threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
threadgroup v_t * sv = (threadgroup v_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
simdgroup_half8x8 lo[D8]; o8x8_t lo[D8];
// load heads from Q to shared memory // load heads from Q to shared memory
for (short j = sgitg; j < Q; j += nsg) { for (short j = sgitg; j < Q; j += nsg) {
@ -2835,7 +2852,7 @@ kernel void kernel_flash_attn_ext(
for (short i = tiisg; i < D4; i += NW) { for (short i = tiisg; i < D4; i += NW) {
if (iq1 + j < ne01) { if (iq1 + j < ne01) {
sq4[j*T4 + i] = (half4) q4[i]; sq4[j*T4 + i] = (q4_t) q4[i];
} else { } else {
sq4[j*T4 + i] = 0.0h; sq4[j*T4 + i] = 0.0h;
} }
@ -2844,7 +2861,7 @@ kernel void kernel_flash_attn_ext(
// zero out lo // zero out lo
for (short i = 0; i < D8; ++i) { for (short i = 0; i < D8; ++i) {
lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h); lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f);
} }
// zero out shared memory SH // zero out shared memory SH
@ -2883,7 +2900,7 @@ kernel void kernel_flash_attn_ext(
const short iv3 = iq3/rv3; const short iv3 = iq3/rv3;
// load the queries from shared memory into local memory // load the queries from shared memory into local memory
simdgroup_half8x8 mq[D8]; q8x8_t mq[D8];
for (short i = 0; i < D8; ++i) { for (short i = 0; i < D8; ++i) {
simdgroup_load(mq[i], sq + i*8, T); simdgroup_load(mq[i], sq + i*8, T);
@ -2915,16 +2932,16 @@ kernel void kernel_flash_attn_ext(
// Q*K^T // Q*K^T
{ {
for (short cc = 0; cc < C/8; ++cc) { for (short cc = 0; cc < C/8; ++cc) {
s8x8_t mqk = make_filled_simdgroup_matrix<s_t, 8>(0.0f); s8x8_t mqk = make_filled_simdgroup_matrix<s_t, 8>((s_t) 0.0f);
// this is compile-time check, so it does not have runtime overhead // this is compile-time check, so it does not have runtime overhead
if (is_same<block_q, half4x4>::value) { if (is_same<block_q, k4x4_t>::value) {
// we can read directly from global memory // we can read directly from global memory
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
for (short i = 0; i < D8; ++i) { for (short i = 0; i < D8; ++i) {
simdgroup_half8x8 mk; k8x8_t mk;
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose simdgroup_load(mk, pk + i*8, nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
} }
@ -2934,38 +2951,38 @@ kernel void kernel_flash_attn_ext(
if (D16%4 == 0) { if (D16%4 == 0) {
// the head is evenly divisible by 4*16 = 64, so no need for bound checks // the head is evenly divisible by 4*16 = 64, so no need for bound checks
half4x4 tmp; k4x4_t tmp;
dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp); deq_k(pk4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
skv4[4*ty + tx] = tmp; sk4x4[4*ty + tx] = tmp;
simdgroup_barrier(mem_flags::mem_threadgroup); simdgroup_barrier(mem_flags::mem_threadgroup);
#pragma unroll #pragma unroll
for (short k = 0; k < 4; ++k) { for (short k = 0; k < 4; ++k) {
simdgroup_half8x8 mk; k8x8_t mk;
simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk); simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk); simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
} }
} else { } else {
if (ii + tx < D16) { if (ii + tx < D16) {
half4x4 tmp; k4x4_t tmp;
dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp); deq_k(pk4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
skv4[4*ty + tx] = tmp; sk4x4[4*ty + tx] = tmp;
} }
simdgroup_barrier(mem_flags::mem_threadgroup); simdgroup_barrier(mem_flags::mem_threadgroup);
for (short k = 0; k < 4 && ii + k < D16; ++k) { for (short k = 0; k < 4 && ii + k < D16; ++k) {
simdgroup_half8x8 mk; k8x8_t mk;
simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk); simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk); simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
} }
} }
@ -2995,7 +3012,7 @@ kernel void kernel_flash_attn_ext(
if (mask != q) { if (mask != q) {
// mqk = mqk + mask*slope // mqk = mqk + mask*slope
s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; // TODO: use ne30
} }
smax = simd_max(max(smax, s)); smax = simd_max(max(smax, s));
@ -3037,13 +3054,13 @@ kernel void kernel_flash_attn_ext(
s8x8_t ms; s8x8_t ms;
simdgroup_load(ms, ss + 8*cc, TS, 0, false); simdgroup_load(ms, ss + 8*cc, TS, 0, false);
if (is_same<block_q, half4x4>::value) { if (is_same<block_q, v4x4_t>::value) {
// we can read directly from global memory // we can read directly from global memory
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
#pragma unroll #pragma unroll
for (short i = 0; i < D8; ++i) { for (short i = 0; i < D8; ++i) {
simdgroup_half8x8 mv; v8x8_t mv;
simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false); simdgroup_load(mv, pv + i*8, nb21/sizeof(v_t), 0, false); // TODO: use ne20
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]); simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
} }
@ -3053,38 +3070,38 @@ kernel void kernel_flash_attn_ext(
if (D16%4 == 0) { if (D16%4 == 0) {
// no need for bound checks // no need for bound checks
half4x4 tmp; v4x4_t tmp;
dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp); deq_v(pv4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
skv4[4*ty + tx] = tmp; sv4x4[4*ty + tx] = tmp;
simdgroup_barrier(mem_flags::mem_threadgroup); simdgroup_barrier(mem_flags::mem_threadgroup);
#pragma unroll #pragma unroll
for (short k = 0; k < 4; ++k) { for (short k = 0; k < 4; ++k) {
simdgroup_half8x8 mv; v8x8_t mv;
simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false); simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]); simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false); simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]); simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
} }
} else { } else {
if (ii + tx < D16) { if (ii + tx < D16) {
half4x4 tmp; v4x4_t tmp;
dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp); deq_v(pv4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
skv4[4*ty + tx] = tmp; sv4x4[4*ty + tx] = tmp;
} }
simdgroup_barrier(mem_flags::mem_threadgroup); simdgroup_barrier(mem_flags::mem_threadgroup);
for (short k = 0; k < 4 && ii + k < D16; ++k) { for (short k = 0; k < 4 && ii + k < D16; ++k) {
simdgroup_half8x8 mv; v8x8_t mv;
simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false); simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]); simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false); simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]); simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
} }
} }
@ -3113,7 +3130,7 @@ kernel void kernel_flash_attn_ext(
// each simdgroup stores its output to shared memory, reusing sq // each simdgroup stores its output to shared memory, reusing sq
if (sgitg == sg) { if (sgitg == sg) {
for (short i = 0; i < D8; ++i) { for (short i = 0; i < D8; ++i) {
simdgroup_store(lo[i], sq + i*8, T, 0, false); simdgroup_store(lo[i], so + i*8, T, 0, false);
} }
} }
@ -3146,7 +3163,7 @@ kernel void kernel_flash_attn_ext(
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
{ {
simdgroup_half8x8 t; o8x8_t t;
s8x8_t ms0; s8x8_t ms0;
s8x8_t ms1; s8x8_t ms1;
@ -3155,7 +3172,7 @@ kernel void kernel_flash_attn_ext(
simdgroup_load(ms1, ss + C + sg*SH, TS, 0, false); simdgroup_load(ms1, ss + C + sg*SH, TS, 0, false);
for (short i = 0; i < D8; ++i) { for (short i = 0; i < D8; ++i) {
simdgroup_load (t, sq + i*8, T, 0, false); simdgroup_load (t, so + i*8, T, 0, false);
simdgroup_multiply(t, ms1, t); simdgroup_multiply(t, ms1, t);
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
@ -3167,7 +3184,7 @@ kernel void kernel_flash_attn_ext(
// store result to shared memory (reuse sq) // store result to shared memory (reuse sq)
if (sgitg == 0) { if (sgitg == 0) {
for (short i = 0; i < D8; ++i) { for (short i = 0; i < D8; ++i) {
simdgroup_store(lo[i], sq + i*8, T, 0, false); simdgroup_store(lo[i], so + i*8, T, 0, false);
} }
} }
@ -3187,55 +3204,75 @@ kernel void kernel_flash_attn_ext(
#if defined(GGML_METAL_FORCE_FATTN_PREC_F16) #if defined(GGML_METAL_FORCE_FATTN_PREC_F16)
#define S_T half #define S_T half
#define S4_T half4
#define S4x4_T half4x4
#define S8x8_T simdgroup_half8x8 #define S8x8_T simdgroup_half8x8
#define FA_TYPES \
half, half4, simdgroup_half8x8, \
half, half4x4, simdgroup_half8x8, \
half, half4x4, simdgroup_half8x8, \
half, simdgroup_half8x8, \
half, simdgroup_half8x8
#else #else
#define S_T float #define S_T float
#define S4_T float4
#define S4x4_T float4x4
#define S8x8_T simdgroup_float8x8 #define S8x8_T simdgroup_float8x8
#define FA_TYPES \
half, half4, simdgroup_half8x8, \
half, half4x4, simdgroup_half8x8, \
half, half4x4, simdgroup_half8x8, \
float, simdgroup_float8x8, \
half, simdgroup_half8x8
#endif #endif
typedef decltype(kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 64>) flash_attn_ext_t; typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 64>) flash_attn_ext_t;
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 64>; template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 64>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 80>; template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 80>;
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 96>; template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 96>;
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 112>; template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 112>;
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 128>; template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 128>;
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 256>; template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 256>;
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 64>; template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 64>;
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 80>; template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 80>;
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 96>; template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 96>;
template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 112>; template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 112>;
template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 128>; template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 128>;
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 256>; template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 256>;
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 64>; template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 64>;
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 80>; template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 80>;
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 96>; template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 96>;
template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 112>; template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 112>;
template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 128>; template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 128>;
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 256>; template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 256>;
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 64>; template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 64>;
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 80>; template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 80>;
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 96>; template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 96>;
template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 112>; template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 112>;
template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 128>; template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 128>;
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 256>; template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 256>;
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 64>; template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 64>;
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 80>; template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 80>;
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 96>; template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 96>;
template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 112>; template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 112>;
template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 128>; template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 128>;
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 256>; template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 256>;
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 64>; template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 64>;
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 80>; template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 80>;
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 96>; template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 96>;
template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 112>; template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 112>;
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 128>; template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 128>;
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 256>; template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 256>;
#undef FA_TYPES
// NOTE: can use half instead of float precision for some extra perf // NOTE: can use half instead of float precision for some extra perf
// however, by default use F32 since the op should be mostly memory bandwidth bound // however, by default use F32 since the op should be mostly memory bandwidth bound