This commit is contained in:
Georgi Gerganov 2024-11-07 18:20:25 +02:00
parent 4abeb60a1a
commit a6c8dbfa5d
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 70 additions and 124 deletions

View File

@ -1139,7 +1139,7 @@ static void ggml_metal_encode_node(
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
const uint64_t nb23 = src2 ? src2->nb[3] : 0;
const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
const int64_t ne0 = dst ? dst->ne[0] : 0;
const int64_t ne1 = dst ? dst->ne[1] : 0;
@ -3241,18 +3241,15 @@ static void ggml_metal_encode_node(
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
[encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
[encoder setBytes:&scale length:sizeof( float) atIndex:23];
[encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28];
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19];
[encoder setBytes:&scale length:sizeof( float) atIndex:20];
[encoder setBytes:&max_bias length:sizeof( float) atIndex:21];
[encoder setBytes:&m0 length:sizeof(m0) atIndex:22];
[encoder setBytes:&m1 length:sizeof(m1) atIndex:23];
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24];
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25];
if (!use_vec_kernel) {
// half8x8 kernel
@ -3263,21 +3260,11 @@ static void ggml_metal_encode_node(
GGML_ASSERT(nqptg % 8 == 0);
GGML_ASSERT(ncpsg % 32 == 0);
#ifdef GGML_METAL_FORCE_FATTN_PREC_F16
const enum ggml_prec prec = GGML_PREC_DEFAULT;
#else
// TODO: support both precisions
const enum ggml_prec prec = GGML_PREC_F32;
//const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(dst);
#endif
const int64_t nhalfs = prec == GGML_PREC_DEFAULT ? 1 : 2;
// 16*32*(nsg)
// the shared memory needed for the simdgroups to load the KV cache
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
//
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + nhalfs*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
int64_t nsgmax = 2;

View File

@ -2815,14 +2815,11 @@ kernel void kernel_flash_attn_ext(
constant uint32_t & nb02,
constant uint32_t & nb03,
constant int32_t & ne11,
constant int32_t & ne12,
constant int32_t & ne13,
constant uint32_t & nb11,
constant uint32_t & nb12,
constant uint32_t & nb13,
constant uint32_t & nb21,
constant uint32_t & nb22,
constant uint32_t & nb23,
constant int32_t & ne_12_2, // assume K and V are same shape
constant int32_t & ne_12_3,
constant uint32_t & nb_12_1,
constant uint32_t & nb_12_2,
constant uint32_t & nb_12_3,
constant uint32_t & nb31,
constant int32_t & ne1,
constant int32_t & ne2,
@ -2830,13 +2827,13 @@ kernel void kernel_flash_attn_ext(
constant float & max_bias,
constant float & m0,
constant float & m1,
constant uint32_t & n_head_log2,
constant uint16_t & n_head_log2,
constant float & logit_softcap,
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 ntg[[threads_per_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
ushort3 tgpig[[threadgroup_position_in_grid]],
ushort3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short nsg = ntg.y; // number of simdgroups
const int iq3 = tgpig[2];
@ -2849,17 +2846,14 @@ kernel void kernel_flash_attn_ext(
const short NW = N_SIMDWIDTH;
const short SH = (2*C + Q); // shared memory per simdgroup in (half)
const short SF = sizeof(s_t)/sizeof(half);
const short TS = nsg*SH; // shared memory size per query in (s_t)
const short T = D + 2*TS; // shared memory size per query in (half)
const short T = D + SF*nsg*SH; // shared memory size per query in (half)
const short TS = T/SF; // shared memory size per query in (s_t)
const short T4 = T/4; // shared memory size per query in (half4)
threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation
threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // reuse query data for accumulation
threadgroup s_t * ss = (threadgroup s_t *) (shared + SF*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation
threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // reuse query data for accumulation
threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention and diagonal matrix
threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
@ -2876,9 +2870,9 @@ kernel void kernel_flash_attn_ext(
for (short i = tiisg; i < D4; i += NW) {
if (iq1 + j < ne01) {
sq4[j*T4 + i] = (q4_t) q4[i];
sq4[j*D4 + i] = (q4_t) q4[i];
} else {
sq4[j*T4 + i] = (q4_t) 0.0f;
sq4[j*D4 + i] = (q4_t) 0.0f;
}
}
}
@ -2907,29 +2901,18 @@ kernel void kernel_flash_attn_ext(
const short tx = tiisg%4;
const short ty = tiisg/4;
// assume K and V are same shape
const short ne22 = ne12;
const short ne23 = ne13;
// broadcast kv
//const short rk2 = ne02/ne12;
//const short rk3 = ne03/ne13;
// broadcast k
const short rk2 = ne02/ne12;
const short rk3 = ne03/ne13;
const short ik2 = iq2/rk2;
const short ik3 = iq3/rk3;
// broadcast v
const short rv2 = ne02/ne22;
const short rv3 = ne03/ne23;
const short iv2 = iq2/rv2;
const short iv3 = iq3/rv3;
const short ikv2 = iq2/(ne02/ne_12_2);
const short ikv3 = iq3/(ne03/ne_12_3);
// load the queries from shared memory into local memory
q8x8_t mq[D8];
for (short i = 0; i < D8; ++i) {
simdgroup_load(mq[i], sq + i*8, T);
simdgroup_load(mq[i], sq + i*8, D);
}
const bool has_mask = mask != q;
@ -2982,18 +2965,18 @@ kernel void kernel_flash_attn_ext(
// this is compile-time check, so it does not have runtime overhead
if (is_same<kd4x4_t, k4x4_t>::value) {
// we can read directly from global memory
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
#pragma unroll
for (short i = 0; i < D8; ++i) {
k8x8_t mk;
simdgroup_load(mk, pk + i*8, nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10
simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
}
} else {
for (short ii = 0; ii < D16; ii += 4) {
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
if (D16%4 == 0) {
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
@ -3046,7 +3029,7 @@ kernel void kernel_flash_attn_ext(
// online softmax
{
for (short j = 0; j < Q; ++j) {
for (ushort j = 0; j < Q; ++j) {
const half m = M[j];
// scale and apply the logitcap / mask
@ -3095,17 +3078,17 @@ kernel void kernel_flash_attn_ext(
if (is_same<vd4x4_t, v4x4_t>::value) {
// we can read directly from global memory
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
#pragma unroll
for (short i = 0; i < D8; ++i) {
v8x8_t mv;
simdgroup_load(mv, pv + i*8, nb21/sizeof(v_t), 0, false); // TODO: use ne20
simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
}
} else {
for (short ii = 0; ii < D16; ii += 4) {
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
if (D16%4 == 0) {
// no need for bound checks
@ -3162,7 +3145,7 @@ kernel void kernel_flash_attn_ext(
}
// reduce the warps sequentially
for (short sg = 1; sg < nsg; ++sg) {
for (ushort sg = 1; sg < nsg; ++sg) {
half S = { 0.0f };
half M = { -__FLT16_MAX__/2 };
@ -3171,7 +3154,7 @@ kernel void kernel_flash_attn_ext(
// each simdgroup stores its output to shared memory, reusing sq
if (sgitg == sg) {
for (short i = 0; i < D8; ++i) {
simdgroup_store(lo[i], so + i*8, T, 0, false);
simdgroup_store(lo[i], so + i*8, D, 0, false);
}
}
@ -3213,7 +3196,7 @@ kernel void kernel_flash_attn_ext(
for (short i = 0; i < D8; ++i) {
o8x8_t t;
simdgroup_load (t, so + i*8, T, 0, false);
simdgroup_load (t, so + i*8, D, 0, false);
simdgroup_multiply(t, ms1, t);
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
@ -3225,7 +3208,7 @@ kernel void kernel_flash_attn_ext(
// store result to shared memory (reuse sq)
if (sgitg == 0) {
for (short i = 0; i < D8; ++i) {
simdgroup_store(lo[i], so + i*8, T, 0, false);
simdgroup_store(lo[i], so + i*8, D, 0, false);
}
}
@ -3237,21 +3220,12 @@ kernel void kernel_flash_attn_ext(
const half S = ss[j*TS + 0];
for (short i = tiisg; i < D4; i += NW) {
dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*T4 + i]/S;
dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
}
}
}
}
#if defined(GGML_METAL_FORCE_FATTN_PREC_F16)
#define FA_TYPES \
half, half4, simdgroup_half8x8, \
half, half4x4, simdgroup_half8x8, \
half, half4x4, simdgroup_half8x8, \
float, simdgroup_float8x8, \
half, simdgroup_half8x8, \
half, half4, simdgroup_half8x8
#else
#define FA_TYPES \
half, half4, simdgroup_half8x8, \
half, half4x4, simdgroup_half8x8, \
@ -3259,7 +3233,6 @@ kernel void kernel_flash_attn_ext(
float, simdgroup_float8x8, \
float, simdgroup_float8x8, \
half, half4, simdgroup_half8x8
#endif
// TOOD: static_assert
@ -3339,24 +3312,21 @@ kernel void kernel_flash_attn_ext_vec(
device const char * v,
device const char * mask,
device float * dst,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant uint64_t & nb21,
constant uint64_t & nb22,
constant uint64_t & nb23,
constant uint64_t & nb31,
constant int64_t & ne1,
constant int64_t & ne2,
constant int32_t & ne01,
constant int32_t & ne02,
constant int32_t & ne03,
constant uint32_t & nb01,
constant uint32_t & nb02,
constant uint32_t & nb03,
constant int32_t & ne11,
constant int32_t & ne_12_2, // assume K and V are same shape
constant int32_t & ne_12_3,
constant uint32_t & nb_12_1,
constant uint32_t & nb_12_2,
constant uint32_t & nb_12_3,
constant uint32_t & nb31,
constant int32_t & ne1,
constant int32_t & ne2,
constant float & scale,
constant float & max_bias,
constant float & m0,
@ -3426,23 +3396,12 @@ kernel void kernel_flash_attn_ext_vec(
const short tx = tiisg%8;
const short ty = tiisg/8;
// assume K and V are same shape
const short ne22 = ne12;
const short ne23 = ne13;
// broadcast kv
//const short rk2 = ne02/ne12;
//const short rk3 = ne03/ne13;
// broadcast k
const short rk2 = ne02/ne12;
const short rk3 = ne03/ne13;
const short ik2 = iq2/rk2;
const short ik3 = iq3/rk3;
// broadcast v
const short rv2 = ne02/ne22;
const short rv3 = ne03/ne23;
const short iv2 = iq2/rv2;
const short iv3 = iq3/rv3;
const short ikv2 = iq2/(ne02/ne_12_2);
const short ikv3 = iq3/(ne03/ne_12_3);
// load the queries from shared memory into local memory
k4x4_t mq[D16/NW4];
@ -3480,7 +3439,7 @@ kernel void kernel_flash_attn_ext_vec(
for (short cc = 0; cc < C/4; ++cc) {
s_t mqk = 0.0;
device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
#pragma unroll
for (short ii = 0; ii < D16; ii += NW4) {
@ -3554,7 +3513,7 @@ kernel void kernel_flash_attn_ext_vec(
{
#pragma unroll
for (short cc = 0; cc < C/4; ++cc) {
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
const s4x4_t ms(ss[4*cc + ty]);