mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 17:21:46 +00:00
metal : use F16 precision in FA kernel
This commit is contained in:
parent
22a9311a1a
commit
eefc132bb7
@ -12,6 +12,9 @@
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
|
||||
// TODO: for now, always use F32 for flash attention to avoid compiling 2 sets of kernels
|
||||
#define GGML_METAL_FORCE_FATTN_PREC_F32
|
||||
|
||||
// max memory buffers that can be mapped to the device
|
||||
#define GGML_METAL_MAX_BUFFERS 64
|
||||
|
||||
@ -496,6 +499,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
// dictionary of preprocessor macros
|
||||
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
||||
|
||||
// add GGML_METAL_FORCE_FATTN_PREC_F32
|
||||
#if defined(GGML_METAL_FORCE_FATTN_PREC_F32)
|
||||
[prep setObject:@"1" forKey:@"GGML_METAL_FORCE_FATTN_PREC_F32"];
|
||||
#endif
|
||||
|
||||
MTLCompileOptions * options = [MTLCompileOptions new];
|
||||
options.preprocessorMacros = prep;
|
||||
|
||||
@ -3216,11 +3224,19 @@ static void ggml_metal_encode_node(
|
||||
GGML_ASSERT(nqptg % 8 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
#ifdef GGML_METAL_FORCE_FATTN_PREC_F32
|
||||
const enum ggml_prec prec = GGML_PREC_DEFAULT;
|
||||
#else
|
||||
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(dst);
|
||||
#endif
|
||||
|
||||
const int64_t nhalfs = prec == GGML_PREC_DEFAULT ? 1 : 2;
|
||||
|
||||
// 16*32*(nsg)
|
||||
// the shared memory needed for the simdgroups to load the KV cache
|
||||
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
||||
//
|
||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + nhalfs*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
||||
|
||||
int64_t nsgmax = 2;
|
||||
|
||||
|
@ -2805,13 +2805,13 @@ kernel void kernel_flash_attn_ext(
|
||||
const short NW = N_SIMDWIDTH;
|
||||
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
||||
|
||||
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
||||
const short TF = T/2; // shared memory size per query in (float)
|
||||
const short T = D + nsg*SH; // shared memory size per query in (half)
|
||||
const short TF = T; // shared memory size per query in (float)
|
||||
const short T4 = T/4; // shared memory size per query in (half4)
|
||||
|
||||
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||
|
||||
threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K and V in shared memory
|
||||
threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in half4x4
|
||||
@ -2840,7 +2840,7 @@ kernel void kernel_flash_attn_ext(
|
||||
// zero out shared memory SH
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
for (short i = tiisg; i < SH; i += NW) {
|
||||
ss[j*TF + i] = 0.0f;
|
||||
ss[j*TF + i] = 0.0h;
|
||||
}
|
||||
}
|
||||
|
||||
@ -2905,7 +2905,7 @@ kernel void kernel_flash_attn_ext(
|
||||
// Q*K^T
|
||||
{
|
||||
for (short cc = 0; cc < C/8; ++cc) {
|
||||
simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
|
||||
simdgroup_half8x8 mqk = make_filled_simdgroup_matrix<half, 8>(0.h);
|
||||
|
||||
// this is compile-time check, so it does not have runtime overhead
|
||||
if (is_same<block_q, half4x4>::value) {
|
||||
@ -2977,7 +2977,7 @@ kernel void kernel_flash_attn_ext(
|
||||
const float m = M[j];
|
||||
|
||||
// scale and apply the logitcap / mask
|
||||
float s = ss[j*TF + tiisg]*scale;
|
||||
float s = ((float)(ss[j*TF + tiisg]))*scale;
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
s = logit_softcap*precise::tanh(s);
|
||||
@ -3013,7 +3013,7 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
// O = diag(ms)*O
|
||||
{
|
||||
simdgroup_float8x8 mm;
|
||||
simdgroup_half8x8 mm;
|
||||
simdgroup_load(mm, ss + C, TF, 0, false);
|
||||
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
@ -3024,7 +3024,7 @@ kernel void kernel_flash_attn_ext(
|
||||
// O = O + (Q*K^T)*V
|
||||
{
|
||||
for (short cc = 0; cc < C/8; ++cc) {
|
||||
simdgroup_float8x8 ms;
|
||||
simdgroup_half8x8 ms;
|
||||
simdgroup_load(ms, ss + 8*cc, TF, 0, false);
|
||||
|
||||
if (is_same<block_q, half4x4>::value) {
|
||||
@ -3137,8 +3137,8 @@ kernel void kernel_flash_attn_ext(
|
||||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||
{
|
||||
simdgroup_half8x8 t;
|
||||
simdgroup_float8x8 ms0;
|
||||
simdgroup_float8x8 ms1;
|
||||
simdgroup_half8x8 ms0;
|
||||
simdgroup_half8x8 ms1;
|
||||
|
||||
simdgroup_load(ms0, ss + C, TF, 0, false);
|
||||
simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false);
|
||||
@ -3219,6 +3219,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 256>;
|
||||
|
||||
// NOTE: can use half instead of float precision for some extra perf
|
||||
// however, by default use F32 since the op should be mostly memory bandwidth bound
|
||||
// D - head size, Q - queries per threadgroup, C - cache items per threadgroup
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &), short D, short Q = 1, short C = 32>
|
||||
kernel void kernel_flash_attn_ext_vec(
|
||||
|
Loading…
Reference in New Issue
Block a user