This commit is contained in:
Georgi Gerganov 2024-11-06 21:06:56 +02:00
parent 0f7e8f389d
commit 01c7f11224
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -2756,11 +2756,24 @@ kernel void kernel_leaky_relu_f32(
// ref: https://arxiv.org/pdf/2307.08691.pdf
template<
typename block_q,
short nl,
void (*dequantize_func)(device const block_q *, short, thread half4x4 &),
typename q_t,
typename q4_t,
typename q8x8_t,
typename k_t,
typename k4x4_t,
typename k8x8_t,
typename v_t,
typename v4x4_t,
typename v8x8_t,
typename s_t, // attention accumulation types
typename s8x8_t,
typename o_t,
typename o8x8_t,
typename block_q,
short nl_k,
void (*deq_k)(device const block_q *, short, thread k4x4_t &),
short nl_v,
void (*deq_v)(device const block_q *, short, thread v4x4_t &),
short D, // head size
short Q = 8, // queries per threadgroup
short KV = 8, // key/value processed per each simdgroup
@ -2819,15 +2832,19 @@ kernel void kernel_flash_attn_ext(
const short TS = T/SF; // shared memory size per query in (s_t)
const short T4 = T/4; // shared memory size per query in (half4)
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
threadgroup s_t * ss = (threadgroup s_t *) (shared + SF*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation
threadgroup s_t * ss = (threadgroup s_t *) (shared + SF*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K and V in shared memory
threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in half4x4
threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
threadgroup v_t * sv = (threadgroup v_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
simdgroup_half8x8 lo[D8];
o8x8_t lo[D8];
// load heads from Q to shared memory
for (short j = sgitg; j < Q; j += nsg) {
@ -2835,7 +2852,7 @@ kernel void kernel_flash_attn_ext(
for (short i = tiisg; i < D4; i += NW) {
if (iq1 + j < ne01) {
sq4[j*T4 + i] = (half4) q4[i];
sq4[j*T4 + i] = (q4_t) q4[i];
} else {
sq4[j*T4 + i] = 0.0h;
}
@ -2844,7 +2861,7 @@ kernel void kernel_flash_attn_ext(
// zero out lo
for (short i = 0; i < D8; ++i) {
lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f);
}
// zero out shared memory SH
@ -2883,7 +2900,7 @@ kernel void kernel_flash_attn_ext(
const short iv3 = iq3/rv3;
// load the queries from shared memory into local memory
simdgroup_half8x8 mq[D8];
q8x8_t mq[D8];
for (short i = 0; i < D8; ++i) {
simdgroup_load(mq[i], sq + i*8, T);
@ -2915,16 +2932,16 @@ kernel void kernel_flash_attn_ext(
// Q*K^T
{
for (short cc = 0; cc < C/8; ++cc) {
s8x8_t mqk = make_filled_simdgroup_matrix<s_t, 8>(0.0f);
s8x8_t mqk = make_filled_simdgroup_matrix<s_t, 8>((s_t) 0.0f);
// this is compile-time check, so it does not have runtime overhead
if (is_same<block_q, half4x4>::value) {
if (is_same<block_q, k4x4_t>::value) {
// we can read directly from global memory
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
for (short i = 0; i < D8; ++i) {
simdgroup_half8x8 mk;
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
k8x8_t mk;
simdgroup_load(mk, pk + i*8, nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
}
@ -2934,38 +2951,38 @@ kernel void kernel_flash_attn_ext(
if (D16%4 == 0) {
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
half4x4 tmp;
dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
skv4[4*ty + tx] = tmp;
k4x4_t tmp;
deq_k(pk4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
sk4x4[4*ty + tx] = tmp;
simdgroup_barrier(mem_flags::mem_threadgroup);
#pragma unroll
for (short k = 0; k < 4; ++k) {
simdgroup_half8x8 mk;
k8x8_t mk;
simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
}
} else {
if (ii + tx < D16) {
half4x4 tmp;
dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
skv4[4*ty + tx] = tmp;
k4x4_t tmp;
deq_k(pk4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
sk4x4[4*ty + tx] = tmp;
}
simdgroup_barrier(mem_flags::mem_threadgroup);
for (short k = 0; k < 4 && ii + k < D16; ++k) {
simdgroup_half8x8 mk;
k8x8_t mk;
simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
}
}
@ -2995,7 +3012,7 @@ kernel void kernel_flash_attn_ext(
if (mask != q) {
// mqk = mqk + mask*slope
s += slope*mp[ic + j*nb31/sizeof(half) + tiisg];
s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; // TODO: use ne30
}
smax = simd_max(max(smax, s));
@ -3037,13 +3054,13 @@ kernel void kernel_flash_attn_ext(
s8x8_t ms;
simdgroup_load(ms, ss + 8*cc, TS, 0, false);
if (is_same<block_q, half4x4>::value) {
if (is_same<block_q, v4x4_t>::value) {
// we can read directly from global memory
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
#pragma unroll
for (short i = 0; i < D8; ++i) {
simdgroup_half8x8 mv;
simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false);
v8x8_t mv;
simdgroup_load(mv, pv + i*8, nb21/sizeof(v_t), 0, false); // TODO: use ne20
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
}
@ -3053,38 +3070,38 @@ kernel void kernel_flash_attn_ext(
if (D16%4 == 0) {
// no need for bound checks
half4x4 tmp;
dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
skv4[4*ty + tx] = tmp;
v4x4_t tmp;
deq_v(pv4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
sv4x4[4*ty + tx] = tmp;
simdgroup_barrier(mem_flags::mem_threadgroup);
#pragma unroll
for (short k = 0; k < 4; ++k) {
simdgroup_half8x8 mv;
v8x8_t mv;
simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
}
} else {
if (ii + tx < D16) {
half4x4 tmp;
dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
skv4[4*ty + tx] = tmp;
v4x4_t tmp;
deq_v(pv4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
sv4x4[4*ty + tx] = tmp;
}
simdgroup_barrier(mem_flags::mem_threadgroup);
for (short k = 0; k < 4 && ii + k < D16; ++k) {
simdgroup_half8x8 mv;
v8x8_t mv;
simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
}
}
@ -3113,7 +3130,7 @@ kernel void kernel_flash_attn_ext(
// each simdgroup stores its output to shared memory, reusing sq
if (sgitg == sg) {
for (short i = 0; i < D8; ++i) {
simdgroup_store(lo[i], sq + i*8, T, 0, false);
simdgroup_store(lo[i], so + i*8, T, 0, false);
}
}
@ -3146,7 +3163,7 @@ kernel void kernel_flash_attn_ext(
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
{
simdgroup_half8x8 t;
o8x8_t t;
s8x8_t ms0;
s8x8_t ms1;
@ -3155,7 +3172,7 @@ kernel void kernel_flash_attn_ext(
simdgroup_load(ms1, ss + C + sg*SH, TS, 0, false);
for (short i = 0; i < D8; ++i) {
simdgroup_load (t, sq + i*8, T, 0, false);
simdgroup_load (t, so + i*8, T, 0, false);
simdgroup_multiply(t, ms1, t);
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
@ -3167,7 +3184,7 @@ kernel void kernel_flash_attn_ext(
// store result to shared memory (reuse sq)
if (sgitg == 0) {
for (short i = 0; i < D8; ++i) {
simdgroup_store(lo[i], sq + i*8, T, 0, false);
simdgroup_store(lo[i], so + i*8, T, 0, false);
}
}
@ -3187,55 +3204,75 @@ kernel void kernel_flash_attn_ext(
#if defined(GGML_METAL_FORCE_FATTN_PREC_F16)
#define S_T half
#define S4_T half4
#define S4x4_T half4x4
#define S8x8_T simdgroup_half8x8
#define FA_TYPES \
half, half4, simdgroup_half8x8, \
half, half4x4, simdgroup_half8x8, \
half, half4x4, simdgroup_half8x8, \
half, simdgroup_half8x8, \
half, simdgroup_half8x8
#else
#define S_T float
#define S4_T float4
#define S4x4_T float4x4
#define S8x8_T simdgroup_float8x8
#define FA_TYPES \
half, half4, simdgroup_half8x8, \
half, half4x4, simdgroup_half8x8, \
half, half4x4, simdgroup_half8x8, \
float, simdgroup_float8x8, \
half, simdgroup_half8x8
#endif
typedef decltype(kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 64>) flash_attn_ext_t;
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 64>) flash_attn_ext_t;
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 64>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 80>;
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 96>;
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 112>;
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 128>;
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 256>;
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 64>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 80>;
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 96>;
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 112>;
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 128>;
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 256>;
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 64>;
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 80>;
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 96>;
template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 112>;
template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 128>;
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 256>;
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 64>;
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 80>;
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 96>;
template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 112>;
template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 128>;
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 256>;
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 64>;
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 80>;
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 96>;
template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 112>;
template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 128>;
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 256>;
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 64>;
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 80>;
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 96>;
template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 112>;
template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 128>;
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 256>;
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 64>;
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 80>;
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 96>;
template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 112>;
template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 128>;
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 256>;
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 64>;
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 80>;
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 96>;
template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 112>;
template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 128>;
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 256>;
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 64>;
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 80>;
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 96>;
template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 112>;
template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 128>;
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 256>;
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 64>;
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 80>;
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 96>;
template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 112>;
template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 128>;
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 256>;
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 64>;
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 80>;
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 96>;
template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 112>;
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 128>;
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 256>;
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 64>;
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 80>;
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 96>;
template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 112>;
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 128>;
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 256>;
#undef FA_TYPES
// NOTE: can use half instead of float precision for some extra perf
// however, by default use F32 since the op should be mostly memory bandwidth bound