mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 17:21:46 +00:00
wip
This commit is contained in:
parent
0f7e8f389d
commit
01c7f11224
@ -2756,11 +2756,24 @@ kernel void kernel_leaky_relu_f32(
|
||||
|
||||
// ref: https://arxiv.org/pdf/2307.08691.pdf
|
||||
template<
|
||||
typename block_q,
|
||||
short nl,
|
||||
void (*dequantize_func)(device const block_q *, short, thread half4x4 &),
|
||||
typename q_t,
|
||||
typename q4_t,
|
||||
typename q8x8_t,
|
||||
typename k_t,
|
||||
typename k4x4_t,
|
||||
typename k8x8_t,
|
||||
typename v_t,
|
||||
typename v4x4_t,
|
||||
typename v8x8_t,
|
||||
typename s_t, // attention accumulation types
|
||||
typename s8x8_t,
|
||||
typename o_t,
|
||||
typename o8x8_t,
|
||||
typename block_q,
|
||||
short nl_k,
|
||||
void (*deq_k)(device const block_q *, short, thread k4x4_t &),
|
||||
short nl_v,
|
||||
void (*deq_v)(device const block_q *, short, thread v4x4_t &),
|
||||
short D, // head size
|
||||
short Q = 8, // queries per threadgroup
|
||||
short KV = 8, // key/value processed per each simdgroup
|
||||
@ -2819,15 +2832,19 @@ kernel void kernel_flash_attn_ext(
|
||||
const short TS = T/SF; // shared memory size per query in (s_t)
|
||||
const short T4 = T/4; // shared memory size per query in (half4)
|
||||
|
||||
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shared + SF*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||
threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
|
||||
threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shared + SF*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||
|
||||
threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K and V in shared memory
|
||||
threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in half4x4
|
||||
threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
|
||||
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
|
||||
|
||||
threadgroup v_t * sv = (threadgroup v_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
|
||||
threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
|
||||
|
||||
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||
simdgroup_half8x8 lo[D8];
|
||||
o8x8_t lo[D8];
|
||||
|
||||
// load heads from Q to shared memory
|
||||
for (short j = sgitg; j < Q; j += nsg) {
|
||||
@ -2835,7 +2852,7 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
if (iq1 + j < ne01) {
|
||||
sq4[j*T4 + i] = (half4) q4[i];
|
||||
sq4[j*T4 + i] = (q4_t) q4[i];
|
||||
} else {
|
||||
sq4[j*T4 + i] = 0.0h;
|
||||
}
|
||||
@ -2844,7 +2861,7 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
// zero out lo
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
|
||||
lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f);
|
||||
}
|
||||
|
||||
// zero out shared memory SH
|
||||
@ -2883,7 +2900,7 @@ kernel void kernel_flash_attn_ext(
|
||||
const short iv3 = iq3/rv3;
|
||||
|
||||
// load the queries from shared memory into local memory
|
||||
simdgroup_half8x8 mq[D8];
|
||||
q8x8_t mq[D8];
|
||||
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
simdgroup_load(mq[i], sq + i*8, T);
|
||||
@ -2915,16 +2932,16 @@ kernel void kernel_flash_attn_ext(
|
||||
// Q*K^T
|
||||
{
|
||||
for (short cc = 0; cc < C/8; ++cc) {
|
||||
s8x8_t mqk = make_filled_simdgroup_matrix<s_t, 8>(0.0f);
|
||||
s8x8_t mqk = make_filled_simdgroup_matrix<s_t, 8>((s_t) 0.0f);
|
||||
|
||||
// this is compile-time check, so it does not have runtime overhead
|
||||
if (is_same<block_q, half4x4>::value) {
|
||||
if (is_same<block_q, k4x4_t>::value) {
|
||||
// we can read directly from global memory
|
||||
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
simdgroup_half8x8 mk;
|
||||
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
|
||||
k8x8_t mk;
|
||||
simdgroup_load(mk, pk + i*8, nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10
|
||||
|
||||
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
||||
}
|
||||
@ -2934,38 +2951,38 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
if (D16%4 == 0) {
|
||||
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
|
||||
half4x4 tmp;
|
||||
dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
|
||||
skv4[4*ty + tx] = tmp;
|
||||
k4x4_t tmp;
|
||||
deq_k(pk4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
|
||||
sk4x4[4*ty + tx] = tmp;
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
#pragma unroll
|
||||
for (short k = 0; k < 4; ++k) {
|
||||
simdgroup_half8x8 mk;
|
||||
k8x8_t mk;
|
||||
|
||||
simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
|
||||
simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
|
||||
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
|
||||
|
||||
simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
|
||||
simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
|
||||
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
|
||||
}
|
||||
} else {
|
||||
if (ii + tx < D16) {
|
||||
half4x4 tmp;
|
||||
dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
|
||||
skv4[4*ty + tx] = tmp;
|
||||
k4x4_t tmp;
|
||||
deq_k(pk4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
|
||||
sk4x4[4*ty + tx] = tmp;
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (short k = 0; k < 4 && ii + k < D16; ++k) {
|
||||
simdgroup_half8x8 mk;
|
||||
k8x8_t mk;
|
||||
|
||||
simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
|
||||
simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
|
||||
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
|
||||
|
||||
simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
|
||||
simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
|
||||
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
|
||||
}
|
||||
}
|
||||
@ -2995,7 +3012,7 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
if (mask != q) {
|
||||
// mqk = mqk + mask*slope
|
||||
s += slope*mp[ic + j*nb31/sizeof(half) + tiisg];
|
||||
s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; // TODO: use ne30
|
||||
}
|
||||
|
||||
smax = simd_max(max(smax, s));
|
||||
@ -3037,13 +3054,13 @@ kernel void kernel_flash_attn_ext(
|
||||
s8x8_t ms;
|
||||
simdgroup_load(ms, ss + 8*cc, TS, 0, false);
|
||||
|
||||
if (is_same<block_q, half4x4>::value) {
|
||||
if (is_same<block_q, v4x4_t>::value) {
|
||||
// we can read directly from global memory
|
||||
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
#pragma unroll
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
simdgroup_half8x8 mv;
|
||||
simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false);
|
||||
v8x8_t mv;
|
||||
simdgroup_load(mv, pv + i*8, nb21/sizeof(v_t), 0, false); // TODO: use ne20
|
||||
|
||||
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
|
||||
}
|
||||
@ -3053,38 +3070,38 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
if (D16%4 == 0) {
|
||||
// no need for bound checks
|
||||
half4x4 tmp;
|
||||
dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
|
||||
skv4[4*ty + tx] = tmp;
|
||||
v4x4_t tmp;
|
||||
deq_v(pv4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
|
||||
sv4x4[4*ty + tx] = tmp;
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
#pragma unroll
|
||||
for (short k = 0; k < 4; ++k) {
|
||||
simdgroup_half8x8 mv;
|
||||
v8x8_t mv;
|
||||
|
||||
simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
|
||||
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
|
||||
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
|
||||
|
||||
simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
|
||||
simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
|
||||
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
|
||||
}
|
||||
} else {
|
||||
if (ii + tx < D16) {
|
||||
half4x4 tmp;
|
||||
dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
|
||||
skv4[4*ty + tx] = tmp;
|
||||
v4x4_t tmp;
|
||||
deq_v(pv4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
|
||||
sv4x4[4*ty + tx] = tmp;
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (short k = 0; k < 4 && ii + k < D16; ++k) {
|
||||
simdgroup_half8x8 mv;
|
||||
v8x8_t mv;
|
||||
|
||||
simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
|
||||
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
|
||||
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
|
||||
|
||||
simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
|
||||
simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
|
||||
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
|
||||
}
|
||||
}
|
||||
@ -3113,7 +3130,7 @@ kernel void kernel_flash_attn_ext(
|
||||
// each simdgroup stores its output to shared memory, reusing sq
|
||||
if (sgitg == sg) {
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
simdgroup_store(lo[i], sq + i*8, T, 0, false);
|
||||
simdgroup_store(lo[i], so + i*8, T, 0, false);
|
||||
}
|
||||
}
|
||||
|
||||
@ -3146,7 +3163,7 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||
{
|
||||
simdgroup_half8x8 t;
|
||||
o8x8_t t;
|
||||
|
||||
s8x8_t ms0;
|
||||
s8x8_t ms1;
|
||||
@ -3155,7 +3172,7 @@ kernel void kernel_flash_attn_ext(
|
||||
simdgroup_load(ms1, ss + C + sg*SH, TS, 0, false);
|
||||
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
simdgroup_load (t, sq + i*8, T, 0, false);
|
||||
simdgroup_load (t, so + i*8, T, 0, false);
|
||||
simdgroup_multiply(t, ms1, t);
|
||||
|
||||
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
|
||||
@ -3167,7 +3184,7 @@ kernel void kernel_flash_attn_ext(
|
||||
// store result to shared memory (reuse sq)
|
||||
if (sgitg == 0) {
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
simdgroup_store(lo[i], sq + i*8, T, 0, false);
|
||||
simdgroup_store(lo[i], so + i*8, T, 0, false);
|
||||
}
|
||||
}
|
||||
|
||||
@ -3187,55 +3204,75 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
#if defined(GGML_METAL_FORCE_FATTN_PREC_F16)
|
||||
#define S_T half
|
||||
#define S4_T half4
|
||||
#define S4x4_T half4x4
|
||||
#define S8x8_T simdgroup_half8x8
|
||||
|
||||
#define FA_TYPES \
|
||||
half, half4, simdgroup_half8x8, \
|
||||
half, half4x4, simdgroup_half8x8, \
|
||||
half, half4x4, simdgroup_half8x8, \
|
||||
half, simdgroup_half8x8, \
|
||||
half, simdgroup_half8x8
|
||||
#else
|
||||
#define S_T float
|
||||
#define S4_T float4
|
||||
#define S4x4_T float4x4
|
||||
#define S8x8_T simdgroup_float8x8
|
||||
|
||||
#define FA_TYPES \
|
||||
half, half4, simdgroup_half8x8, \
|
||||
half, half4x4, simdgroup_half8x8, \
|
||||
half, half4x4, simdgroup_half8x8, \
|
||||
float, simdgroup_float8x8, \
|
||||
half, simdgroup_half8x8
|
||||
#endif
|
||||
|
||||
typedef decltype(kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 64>) flash_attn_ext_t;
|
||||
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 64>) flash_attn_ext_t;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 256>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 256>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 256>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 256>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 256>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 256>;
|
||||
|
||||
#undef FA_TYPES
|
||||
|
||||
// NOTE: can use half instead of float precision for some extra perf
|
||||
// however, by default use F32 since the op should be mostly memory bandwidth bound
|
||||
|
Loading…
Reference in New Issue
Block a user