mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-09 02:01:44 +00:00
wip
This commit is contained in:
parent
0f7e8f389d
commit
01c7f11224
@ -2756,11 +2756,24 @@ kernel void kernel_leaky_relu_f32(
|
|||||||
|
|
||||||
// ref: https://arxiv.org/pdf/2307.08691.pdf
|
// ref: https://arxiv.org/pdf/2307.08691.pdf
|
||||||
template<
|
template<
|
||||||
typename block_q,
|
typename q_t,
|
||||||
short nl,
|
typename q4_t,
|
||||||
void (*dequantize_func)(device const block_q *, short, thread half4x4 &),
|
typename q8x8_t,
|
||||||
|
typename k_t,
|
||||||
|
typename k4x4_t,
|
||||||
|
typename k8x8_t,
|
||||||
|
typename v_t,
|
||||||
|
typename v4x4_t,
|
||||||
|
typename v8x8_t,
|
||||||
typename s_t, // attention accumulation types
|
typename s_t, // attention accumulation types
|
||||||
typename s8x8_t,
|
typename s8x8_t,
|
||||||
|
typename o_t,
|
||||||
|
typename o8x8_t,
|
||||||
|
typename block_q,
|
||||||
|
short nl_k,
|
||||||
|
void (*deq_k)(device const block_q *, short, thread k4x4_t &),
|
||||||
|
short nl_v,
|
||||||
|
void (*deq_v)(device const block_q *, short, thread v4x4_t &),
|
||||||
short D, // head size
|
short D, // head size
|
||||||
short Q = 8, // queries per threadgroup
|
short Q = 8, // queries per threadgroup
|
||||||
short KV = 8, // key/value processed per each simdgroup
|
short KV = 8, // key/value processed per each simdgroup
|
||||||
@ -2819,15 +2832,19 @@ kernel void kernel_flash_attn_ext(
|
|||||||
const short TS = T/SF; // shared memory size per query in (s_t)
|
const short TS = T/SF; // shared memory size per query in (s_t)
|
||||||
const short T4 = T/4; // shared memory size per query in (half4)
|
const short T4 = T/4; // shared memory size per query in (half4)
|
||||||
|
|
||||||
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
|
||||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
|
||||||
threadgroup s_t * ss = (threadgroup s_t *) (shared + SF*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation
|
||||||
|
threadgroup s_t * ss = (threadgroup s_t *) (shared + SF*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||||
|
|
||||||
threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K and V in shared memory
|
threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
|
||||||
threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in half4x4
|
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
|
||||||
|
|
||||||
|
threadgroup v_t * sv = (threadgroup v_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
|
||||||
|
threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
|
||||||
|
|
||||||
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||||
simdgroup_half8x8 lo[D8];
|
o8x8_t lo[D8];
|
||||||
|
|
||||||
// load heads from Q to shared memory
|
// load heads from Q to shared memory
|
||||||
for (short j = sgitg; j < Q; j += nsg) {
|
for (short j = sgitg; j < Q; j += nsg) {
|
||||||
@ -2835,7 +2852,7 @@ kernel void kernel_flash_attn_ext(
|
|||||||
|
|
||||||
for (short i = tiisg; i < D4; i += NW) {
|
for (short i = tiisg; i < D4; i += NW) {
|
||||||
if (iq1 + j < ne01) {
|
if (iq1 + j < ne01) {
|
||||||
sq4[j*T4 + i] = (half4) q4[i];
|
sq4[j*T4 + i] = (q4_t) q4[i];
|
||||||
} else {
|
} else {
|
||||||
sq4[j*T4 + i] = 0.0h;
|
sq4[j*T4 + i] = 0.0h;
|
||||||
}
|
}
|
||||||
@ -2844,7 +2861,7 @@ kernel void kernel_flash_attn_ext(
|
|||||||
|
|
||||||
// zero out lo
|
// zero out lo
|
||||||
for (short i = 0; i < D8; ++i) {
|
for (short i = 0; i < D8; ++i) {
|
||||||
lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
|
lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
// zero out shared memory SH
|
// zero out shared memory SH
|
||||||
@ -2883,7 +2900,7 @@ kernel void kernel_flash_attn_ext(
|
|||||||
const short iv3 = iq3/rv3;
|
const short iv3 = iq3/rv3;
|
||||||
|
|
||||||
// load the queries from shared memory into local memory
|
// load the queries from shared memory into local memory
|
||||||
simdgroup_half8x8 mq[D8];
|
q8x8_t mq[D8];
|
||||||
|
|
||||||
for (short i = 0; i < D8; ++i) {
|
for (short i = 0; i < D8; ++i) {
|
||||||
simdgroup_load(mq[i], sq + i*8, T);
|
simdgroup_load(mq[i], sq + i*8, T);
|
||||||
@ -2915,16 +2932,16 @@ kernel void kernel_flash_attn_ext(
|
|||||||
// Q*K^T
|
// Q*K^T
|
||||||
{
|
{
|
||||||
for (short cc = 0; cc < C/8; ++cc) {
|
for (short cc = 0; cc < C/8; ++cc) {
|
||||||
s8x8_t mqk = make_filled_simdgroup_matrix<s_t, 8>(0.0f);
|
s8x8_t mqk = make_filled_simdgroup_matrix<s_t, 8>((s_t) 0.0f);
|
||||||
|
|
||||||
// this is compile-time check, so it does not have runtime overhead
|
// this is compile-time check, so it does not have runtime overhead
|
||||||
if (is_same<block_q, half4x4>::value) {
|
if (is_same<block_q, k4x4_t>::value) {
|
||||||
// we can read directly from global memory
|
// we can read directly from global memory
|
||||||
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
|
|
||||||
for (short i = 0; i < D8; ++i) {
|
for (short i = 0; i < D8; ++i) {
|
||||||
simdgroup_half8x8 mk;
|
k8x8_t mk;
|
||||||
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
|
simdgroup_load(mk, pk + i*8, nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10
|
||||||
|
|
||||||
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
||||||
}
|
}
|
||||||
@ -2934,38 +2951,38 @@ kernel void kernel_flash_attn_ext(
|
|||||||
|
|
||||||
if (D16%4 == 0) {
|
if (D16%4 == 0) {
|
||||||
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
|
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
|
||||||
half4x4 tmp;
|
k4x4_t tmp;
|
||||||
dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
|
deq_k(pk4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
|
||||||
skv4[4*ty + tx] = tmp;
|
sk4x4[4*ty + tx] = tmp;
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (short k = 0; k < 4; ++k) {
|
for (short k = 0; k < 4; ++k) {
|
||||||
simdgroup_half8x8 mk;
|
k8x8_t mk;
|
||||||
|
|
||||||
simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
|
simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
|
||||||
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
|
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
|
||||||
|
|
||||||
simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
|
simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
|
||||||
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
|
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (ii + tx < D16) {
|
if (ii + tx < D16) {
|
||||||
half4x4 tmp;
|
k4x4_t tmp;
|
||||||
dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
|
deq_k(pk4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
|
||||||
skv4[4*ty + tx] = tmp;
|
sk4x4[4*ty + tx] = tmp;
|
||||||
}
|
}
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
for (short k = 0; k < 4 && ii + k < D16; ++k) {
|
for (short k = 0; k < 4 && ii + k < D16; ++k) {
|
||||||
simdgroup_half8x8 mk;
|
k8x8_t mk;
|
||||||
|
|
||||||
simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
|
simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
|
||||||
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
|
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
|
||||||
|
|
||||||
simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
|
simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
|
||||||
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
|
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2995,7 +3012,7 @@ kernel void kernel_flash_attn_ext(
|
|||||||
|
|
||||||
if (mask != q) {
|
if (mask != q) {
|
||||||
// mqk = mqk + mask*slope
|
// mqk = mqk + mask*slope
|
||||||
s += slope*mp[ic + j*nb31/sizeof(half) + tiisg];
|
s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; // TODO: use ne30
|
||||||
}
|
}
|
||||||
|
|
||||||
smax = simd_max(max(smax, s));
|
smax = simd_max(max(smax, s));
|
||||||
@ -3037,13 +3054,13 @@ kernel void kernel_flash_attn_ext(
|
|||||||
s8x8_t ms;
|
s8x8_t ms;
|
||||||
simdgroup_load(ms, ss + 8*cc, TS, 0, false);
|
simdgroup_load(ms, ss + 8*cc, TS, 0, false);
|
||||||
|
|
||||||
if (is_same<block_q, half4x4>::value) {
|
if (is_same<block_q, v4x4_t>::value) {
|
||||||
// we can read directly from global memory
|
// we can read directly from global memory
|
||||||
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (short i = 0; i < D8; ++i) {
|
for (short i = 0; i < D8; ++i) {
|
||||||
simdgroup_half8x8 mv;
|
v8x8_t mv;
|
||||||
simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false);
|
simdgroup_load(mv, pv + i*8, nb21/sizeof(v_t), 0, false); // TODO: use ne20
|
||||||
|
|
||||||
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
|
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
|
||||||
}
|
}
|
||||||
@ -3053,38 +3070,38 @@ kernel void kernel_flash_attn_ext(
|
|||||||
|
|
||||||
if (D16%4 == 0) {
|
if (D16%4 == 0) {
|
||||||
// no need for bound checks
|
// no need for bound checks
|
||||||
half4x4 tmp;
|
v4x4_t tmp;
|
||||||
dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
|
deq_v(pv4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
|
||||||
skv4[4*ty + tx] = tmp;
|
sv4x4[4*ty + tx] = tmp;
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (short k = 0; k < 4; ++k) {
|
for (short k = 0; k < 4; ++k) {
|
||||||
simdgroup_half8x8 mv;
|
v8x8_t mv;
|
||||||
|
|
||||||
simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
|
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
|
||||||
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
|
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
|
||||||
|
|
||||||
simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
|
simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
|
||||||
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
|
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (ii + tx < D16) {
|
if (ii + tx < D16) {
|
||||||
half4x4 tmp;
|
v4x4_t tmp;
|
||||||
dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
|
deq_v(pv4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
|
||||||
skv4[4*ty + tx] = tmp;
|
sv4x4[4*ty + tx] = tmp;
|
||||||
}
|
}
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
for (short k = 0; k < 4 && ii + k < D16; ++k) {
|
for (short k = 0; k < 4 && ii + k < D16; ++k) {
|
||||||
simdgroup_half8x8 mv;
|
v8x8_t mv;
|
||||||
|
|
||||||
simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
|
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
|
||||||
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
|
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
|
||||||
|
|
||||||
simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
|
simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
|
||||||
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
|
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3113,7 +3130,7 @@ kernel void kernel_flash_attn_ext(
|
|||||||
// each simdgroup stores its output to shared memory, reusing sq
|
// each simdgroup stores its output to shared memory, reusing sq
|
||||||
if (sgitg == sg) {
|
if (sgitg == sg) {
|
||||||
for (short i = 0; i < D8; ++i) {
|
for (short i = 0; i < D8; ++i) {
|
||||||
simdgroup_store(lo[i], sq + i*8, T, 0, false);
|
simdgroup_store(lo[i], so + i*8, T, 0, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3146,7 +3163,7 @@ kernel void kernel_flash_attn_ext(
|
|||||||
|
|
||||||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||||
{
|
{
|
||||||
simdgroup_half8x8 t;
|
o8x8_t t;
|
||||||
|
|
||||||
s8x8_t ms0;
|
s8x8_t ms0;
|
||||||
s8x8_t ms1;
|
s8x8_t ms1;
|
||||||
@ -3155,7 +3172,7 @@ kernel void kernel_flash_attn_ext(
|
|||||||
simdgroup_load(ms1, ss + C + sg*SH, TS, 0, false);
|
simdgroup_load(ms1, ss + C + sg*SH, TS, 0, false);
|
||||||
|
|
||||||
for (short i = 0; i < D8; ++i) {
|
for (short i = 0; i < D8; ++i) {
|
||||||
simdgroup_load (t, sq + i*8, T, 0, false);
|
simdgroup_load (t, so + i*8, T, 0, false);
|
||||||
simdgroup_multiply(t, ms1, t);
|
simdgroup_multiply(t, ms1, t);
|
||||||
|
|
||||||
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
|
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
|
||||||
@ -3167,7 +3184,7 @@ kernel void kernel_flash_attn_ext(
|
|||||||
// store result to shared memory (reuse sq)
|
// store result to shared memory (reuse sq)
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
for (short i = 0; i < D8; ++i) {
|
for (short i = 0; i < D8; ++i) {
|
||||||
simdgroup_store(lo[i], sq + i*8, T, 0, false);
|
simdgroup_store(lo[i], so + i*8, T, 0, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3187,55 +3204,75 @@ kernel void kernel_flash_attn_ext(
|
|||||||
|
|
||||||
#if defined(GGML_METAL_FORCE_FATTN_PREC_F16)
|
#if defined(GGML_METAL_FORCE_FATTN_PREC_F16)
|
||||||
#define S_T half
|
#define S_T half
|
||||||
|
#define S4_T half4
|
||||||
|
#define S4x4_T half4x4
|
||||||
#define S8x8_T simdgroup_half8x8
|
#define S8x8_T simdgroup_half8x8
|
||||||
|
|
||||||
|
#define FA_TYPES \
|
||||||
|
half, half4, simdgroup_half8x8, \
|
||||||
|
half, half4x4, simdgroup_half8x8, \
|
||||||
|
half, half4x4, simdgroup_half8x8, \
|
||||||
|
half, simdgroup_half8x8, \
|
||||||
|
half, simdgroup_half8x8
|
||||||
#else
|
#else
|
||||||
#define S_T float
|
#define S_T float
|
||||||
|
#define S4_T float4
|
||||||
|
#define S4x4_T float4x4
|
||||||
#define S8x8_T simdgroup_float8x8
|
#define S8x8_T simdgroup_float8x8
|
||||||
|
|
||||||
|
#define FA_TYPES \
|
||||||
|
half, half4, simdgroup_half8x8, \
|
||||||
|
half, half4x4, simdgroup_half8x8, \
|
||||||
|
half, half4x4, simdgroup_half8x8, \
|
||||||
|
float, simdgroup_float8x8, \
|
||||||
|
half, simdgroup_half8x8
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
typedef decltype(kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 64>) flash_attn_ext_t;
|
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 64>) flash_attn_ext_t;
|
||||||
|
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 64>;
|
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 64>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 80>;
|
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 80>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 96>;
|
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 96>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 112>;
|
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 112>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 128>;
|
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 128>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 256>;
|
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 256>;
|
||||||
|
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 64>;
|
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 64>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 80>;
|
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 80>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 96>;
|
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 96>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 112>;
|
template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 112>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 128>;
|
template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 128>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 256>;
|
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 256>;
|
||||||
|
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 64>;
|
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 64>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 80>;
|
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 80>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 96>;
|
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 96>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 112>;
|
template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 112>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 128>;
|
template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 128>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 256>;
|
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 256>;
|
||||||
|
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 64>;
|
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 64>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 80>;
|
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 80>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 96>;
|
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 96>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 112>;
|
template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 112>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 128>;
|
template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 128>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 256>;
|
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 256>;
|
||||||
|
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 64>;
|
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 64>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 80>;
|
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 80>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 96>;
|
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 96>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 112>;
|
template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 112>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 128>;
|
template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 128>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 256>;
|
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 256>;
|
||||||
|
|
||||||
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 64>;
|
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 64>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 80>;
|
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 80>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 96>;
|
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 96>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 112>;
|
template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 112>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 128>;
|
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 128>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 256>;
|
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 256>;
|
||||||
|
|
||||||
|
#undef FA_TYPES
|
||||||
|
|
||||||
// NOTE: can use half instead of float precision for some extra perf
|
// NOTE: can use half instead of float precision for some extra perf
|
||||||
// however, by default use F32 since the op should be mostly memory bandwidth bound
|
// however, by default use F32 since the op should be mostly memory bandwidth bound
|
||||||
|
Loading…
Reference in New Issue
Block a user