mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 17:21:46 +00:00
metal : add GGML_METAL_FORCE_FATTN_PREC_F16
ggml-ci
This commit is contained in:
parent
eefc132bb7
commit
0f7e8f389d
5
Makefile
5
Makefile
@ -876,6 +876,11 @@ endif # GGML_HIPBLAS
|
||||
|
||||
ifdef GGML_METAL
|
||||
MK_CPPFLAGS += -DGGML_USE_METAL
|
||||
|
||||
ifdef GGML_METAL_FORCE_FATTN_PREC_F16
|
||||
MK_CPPFLAGS += -DGGML_METAL_FORCE_FATTN_PREC_F16
|
||||
endif # GGML_METAL_FORCE_FATTN_PREC_F16
|
||||
|
||||
MK_LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
|
||||
OBJ_GGML += ggml/src/ggml-metal.o
|
||||
ifdef GGML_METAL_NDEBUG
|
||||
|
@ -153,6 +153,7 @@ option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation"
|
||||
option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
|
||||
option(GGML_KOMPUTE "ggml: use Kompute" OFF)
|
||||
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
|
||||
option(GGML_METAL_FORCE_FATTN_PREC_F16 "ggml: force F16 accumulators for FA kernels" OFF)
|
||||
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
|
||||
option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF)
|
||||
option(GGML_METAL_EMBED_LIBRARY "ggml: embed Metal library" ${GGML_METAL})
|
||||
|
@ -58,6 +58,10 @@ if (GGML_METAL)
|
||||
add_compile_definitions(GGML_METAL_NDEBUG)
|
||||
endif()
|
||||
|
||||
if (GGML_METAL_FORCE_FATTN_PREC_F16)
|
||||
add_compile_definitions(GGML_METAL_FORCE_FATTN_PREC_F16)
|
||||
endif()
|
||||
|
||||
# copy ggml-common.h and ggml-metal.metal to bin directory
|
||||
configure_file(ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
|
||||
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
|
||||
|
@ -12,9 +12,6 @@
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
|
||||
// TODO: for now, always use F32 for flash attention to avoid compiling 2 sets of kernels
|
||||
#define GGML_METAL_FORCE_FATTN_PREC_F32
|
||||
|
||||
// max memory buffers that can be mapped to the device
|
||||
#define GGML_METAL_MAX_BUFFERS 64
|
||||
|
||||
@ -499,9 +496,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
// dictionary of preprocessor macros
|
||||
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
||||
|
||||
// add GGML_METAL_FORCE_FATTN_PREC_F32
|
||||
#if defined(GGML_METAL_FORCE_FATTN_PREC_F32)
|
||||
[prep setObject:@"1" forKey:@"GGML_METAL_FORCE_FATTN_PREC_F32"];
|
||||
#if defined(GGML_METAL_FORCE_FATTN_PREC_F16)
|
||||
[prep setObject:@"1" forKey:@"GGML_METAL_FORCE_FATTN_PREC_F16"];
|
||||
#endif
|
||||
|
||||
MTLCompileOptions * options = [MTLCompileOptions new];
|
||||
@ -554,6 +550,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(GGML_METAL_FORCE_FATTN_PREC_F16)
|
||||
GGML_LOG_INFO("%s: GGML_METAL_FORCE_FATTN_PREC_F16 = yes\n", __func__);
|
||||
#else
|
||||
GGML_LOG_INFO("%s: GGML_METAL_FORCE_FATTN_PREC_F16 = no\n", __func__);
|
||||
#endif
|
||||
GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
|
||||
@ -3224,10 +3225,12 @@ static void ggml_metal_encode_node(
|
||||
GGML_ASSERT(nqptg % 8 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
#ifdef GGML_METAL_FORCE_FATTN_PREC_F32
|
||||
#ifdef GGML_METAL_FORCE_FATTN_PREC_F16
|
||||
const enum ggml_prec prec = GGML_PREC_DEFAULT;
|
||||
#else
|
||||
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(dst);
|
||||
// TODO: support both precisions
|
||||
const enum ggml_prec prec = GGML_PREC_F32;
|
||||
//const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(dst);
|
||||
#endif
|
||||
|
||||
const int64_t nhalfs = prec == GGML_PREC_DEFAULT ? 1 : 2;
|
||||
|
@ -2755,8 +2755,16 @@ kernel void kernel_leaky_relu_f32(
|
||||
}
|
||||
|
||||
// ref: https://arxiv.org/pdf/2307.08691.pdf
|
||||
// D - head size, Q - queries per threadgroup, KV - key/value processed per each simdgroup, C - cache items per threadgroup
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &), short D, short Q = 8, short KV = 8, short C = 32>
|
||||
template<
|
||||
typename block_q,
|
||||
short nl,
|
||||
void (*dequantize_func)(device const block_q *, short, thread half4x4 &),
|
||||
typename s_t, // attention accumulation types
|
||||
typename s8x8_t,
|
||||
short D, // head size
|
||||
short Q = 8, // queries per threadgroup
|
||||
short KV = 8, // key/value processed per each simdgroup
|
||||
short C = 32> // cache items per threadgroup
|
||||
kernel void kernel_flash_attn_ext(
|
||||
device const char * q,
|
||||
device const char * k,
|
||||
@ -2805,13 +2813,15 @@ kernel void kernel_flash_attn_ext(
|
||||
const short NW = N_SIMDWIDTH;
|
||||
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
||||
|
||||
const short T = D + nsg*SH; // shared memory size per query in (half)
|
||||
const short TF = T; // shared memory size per query in (float)
|
||||
const short SF = sizeof(s_t)/sizeof(half);
|
||||
|
||||
const short T = D + SF*nsg*SH; // shared memory size per query in (half)
|
||||
const short TS = T/SF; // shared memory size per query in (s_t)
|
||||
const short T4 = T/4; // shared memory size per query in (half4)
|
||||
|
||||
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shared + SF*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||
|
||||
threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K and V in shared memory
|
||||
threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in half4x4
|
||||
@ -2840,7 +2850,7 @@ kernel void kernel_flash_attn_ext(
|
||||
// zero out shared memory SH
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
for (short i = tiisg; i < SH; i += NW) {
|
||||
ss[j*TF + i] = 0.0h;
|
||||
ss[j*TS + i] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
@ -2905,7 +2915,7 @@ kernel void kernel_flash_attn_ext(
|
||||
// Q*K^T
|
||||
{
|
||||
for (short cc = 0; cc < C/8; ++cc) {
|
||||
simdgroup_half8x8 mqk = make_filled_simdgroup_matrix<half, 8>(0.h);
|
||||
s8x8_t mqk = make_filled_simdgroup_matrix<s_t, 8>(0.0f);
|
||||
|
||||
// this is compile-time check, so it does not have runtime overhead
|
||||
if (is_same<block_q, half4x4>::value) {
|
||||
@ -2962,7 +2972,7 @@ kernel void kernel_flash_attn_ext(
|
||||
}
|
||||
}
|
||||
|
||||
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
||||
simdgroup_store(mqk, ss + 8*cc, TS, 0, false);
|
||||
}
|
||||
}
|
||||
|
||||
@ -2977,7 +2987,7 @@ kernel void kernel_flash_attn_ext(
|
||||
const float m = M[j];
|
||||
|
||||
// scale and apply the logitcap / mask
|
||||
float s = ((float)(ss[j*TF + tiisg]))*scale;
|
||||
float s = ((float)(ss[j*TS + tiisg]))*scale;
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
s = logit_softcap*precise::tanh(s);
|
||||
@ -2997,12 +3007,12 @@ kernel void kernel_flash_attn_ext(
|
||||
S[j] = S[j]*ms[j] + simd_sum(vs);
|
||||
|
||||
// the P matrix from the paper (Q rows, C columns)
|
||||
ss[j*TF + tiisg] = vs;
|
||||
ss[j*TS + tiisg] = vs;
|
||||
}
|
||||
|
||||
// create a QxQ diagonal matrix for rescaling the output
|
||||
if (tiisg < Q) {
|
||||
ss[tiisg*TF + C + tiisg] = ms[tiisg];
|
||||
ss[tiisg*TS + C + tiisg] = ms[tiisg];
|
||||
}
|
||||
}
|
||||
|
||||
@ -3013,8 +3023,8 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
// O = diag(ms)*O
|
||||
{
|
||||
simdgroup_half8x8 mm;
|
||||
simdgroup_load(mm, ss + C, TF, 0, false);
|
||||
s8x8_t mm;
|
||||
simdgroup_load(mm, ss + C, TS, 0, false);
|
||||
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
simdgroup_multiply(lo[i], mm, lo[i]);
|
||||
@ -3024,8 +3034,8 @@ kernel void kernel_flash_attn_ext(
|
||||
// O = O + (Q*K^T)*V
|
||||
{
|
||||
for (short cc = 0; cc < C/8; ++cc) {
|
||||
simdgroup_half8x8 ms;
|
||||
simdgroup_load(ms, ss + 8*cc, TF, 0, false);
|
||||
s8x8_t ms;
|
||||
simdgroup_load(ms, ss + 8*cc, TS, 0, false);
|
||||
|
||||
if (is_same<block_q, half4x4>::value) {
|
||||
// we can read directly from global memory
|
||||
@ -3087,8 +3097,8 @@ kernel void kernel_flash_attn_ext(
|
||||
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
if (tiisg == 0) {
|
||||
ss[j*TF + 0] = S[j];
|
||||
ss[j*TF + 1] = M[j];
|
||||
ss[j*TS + 0] = S[j];
|
||||
ss[j*TS + 1] = M[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3112,11 +3122,11 @@ kernel void kernel_flash_attn_ext(
|
||||
// the first simdgroup accumulates the results from the other simdgroups
|
||||
if (sgitg == 0) {
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
const float S0 = ss[j*TF + 0];
|
||||
const float S1 = ss[j*TF + sg*SH + 0];
|
||||
const float S0 = ss[j*TS + 0];
|
||||
const float S1 = ss[j*TS + sg*SH + 0];
|
||||
|
||||
const float M0 = ss[j*TF + 1];
|
||||
const float M1 = ss[j*TF + sg*SH + 1];
|
||||
const float M0 = ss[j*TS + 1];
|
||||
const float M1 = ss[j*TS + sg*SH + 1];
|
||||
|
||||
M = max(M0, M1);
|
||||
|
||||
@ -3126,22 +3136,23 @@ kernel void kernel_flash_attn_ext(
|
||||
S = S0*ms0 + S1*ms1;
|
||||
|
||||
if (tiisg == 0) {
|
||||
ss[j*TF + 0] = S;
|
||||
ss[j*TF + 1] = M;
|
||||
ss[j*TS + 0] = S;
|
||||
ss[j*TS + 1] = M;
|
||||
|
||||
ss[j*TF + C + j ] = ms0;
|
||||
ss[j*TF + C + j + sg*SH] = ms1;
|
||||
ss[j*TS + C + j ] = ms0;
|
||||
ss[j*TS + C + j + sg*SH] = ms1;
|
||||
}
|
||||
}
|
||||
|
||||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||
{
|
||||
simdgroup_half8x8 t;
|
||||
simdgroup_half8x8 ms0;
|
||||
simdgroup_half8x8 ms1;
|
||||
|
||||
simdgroup_load(ms0, ss + C, TF, 0, false);
|
||||
simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false);
|
||||
s8x8_t ms0;
|
||||
s8x8_t ms1;
|
||||
|
||||
simdgroup_load(ms0, ss + C, TS, 0, false);
|
||||
simdgroup_load(ms1, ss + C + sg*SH, TS, 0, false);
|
||||
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
simdgroup_load (t, sq + i*8, T, 0, false);
|
||||
@ -3165,7 +3176,7 @@ kernel void kernel_flash_attn_ext(
|
||||
// final rescale with 1/S and store to global memory
|
||||
if (sgitg == 0) {
|
||||
for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
|
||||
const float S = ss[j*TF + 0];
|
||||
const float S = ss[j*TS + 0];
|
||||
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
|
||||
@ -3174,49 +3185,57 @@ kernel void kernel_flash_attn_ext(
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t;
|
||||
#if defined(GGML_METAL_FORCE_FATTN_PREC_F16)
|
||||
#define S_T half
|
||||
#define S8x8_T simdgroup_half8x8
|
||||
#else
|
||||
#define S_T float
|
||||
#define S8x8_T simdgroup_float8x8
|
||||
#endif
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 256>;
|
||||
typedef decltype(kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 64>) flash_attn_ext_t;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 256>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 256>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 256>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 256>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 256>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 256>;
|
||||
|
||||
// NOTE: can use half instead of float precision for some extra perf
|
||||
// however, by default use F32 since the op should be mostly memory bandwidth bound
|
||||
|
Loading…
Reference in New Issue
Block a user