mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 17:21:46 +00:00
wip
This commit is contained in:
parent
4abeb60a1a
commit
a6c8dbfa5d
@ -1139,7 +1139,7 @@ static void ggml_metal_encode_node(
|
||||
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
||||
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
||||
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
||||
const uint64_t nb23 = src2 ? src2->nb[3] : 0;
|
||||
const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
|
||||
|
||||
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
||||
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
||||
@ -3241,18 +3241,15 @@ static void ggml_metal_encode_node(
|
||||
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
|
||||
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
|
||||
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
|
||||
[encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
|
||||
[encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
|
||||
[encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
|
||||
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
|
||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
|
||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
|
||||
[encoder setBytes:&scale length:sizeof( float) atIndex:23];
|
||||
[encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
|
||||
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
|
||||
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
|
||||
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
|
||||
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28];
|
||||
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17];
|
||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18];
|
||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19];
|
||||
[encoder setBytes:&scale length:sizeof( float) atIndex:20];
|
||||
[encoder setBytes:&max_bias length:sizeof( float) atIndex:21];
|
||||
[encoder setBytes:&m0 length:sizeof(m0) atIndex:22];
|
||||
[encoder setBytes:&m1 length:sizeof(m1) atIndex:23];
|
||||
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24];
|
||||
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25];
|
||||
|
||||
if (!use_vec_kernel) {
|
||||
// half8x8 kernel
|
||||
@ -3263,21 +3260,11 @@ static void ggml_metal_encode_node(
|
||||
GGML_ASSERT(nqptg % 8 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
#ifdef GGML_METAL_FORCE_FATTN_PREC_F16
|
||||
const enum ggml_prec prec = GGML_PREC_DEFAULT;
|
||||
#else
|
||||
// TODO: support both precisions
|
||||
const enum ggml_prec prec = GGML_PREC_F32;
|
||||
//const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(dst);
|
||||
#endif
|
||||
|
||||
const int64_t nhalfs = prec == GGML_PREC_DEFAULT ? 1 : 2;
|
||||
|
||||
// 16*32*(nsg)
|
||||
// the shared memory needed for the simdgroups to load the KV cache
|
||||
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
||||
//
|
||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + nhalfs*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
||||
|
||||
int64_t nsgmax = 2;
|
||||
|
||||
|
@ -2815,14 +2815,11 @@ kernel void kernel_flash_attn_ext(
|
||||
constant uint32_t & nb02,
|
||||
constant uint32_t & nb03,
|
||||
constant int32_t & ne11,
|
||||
constant int32_t & ne12,
|
||||
constant int32_t & ne13,
|
||||
constant uint32_t & nb11,
|
||||
constant uint32_t & nb12,
|
||||
constant uint32_t & nb13,
|
||||
constant uint32_t & nb21,
|
||||
constant uint32_t & nb22,
|
||||
constant uint32_t & nb23,
|
||||
constant int32_t & ne_12_2, // assume K and V are same shape
|
||||
constant int32_t & ne_12_3,
|
||||
constant uint32_t & nb_12_1,
|
||||
constant uint32_t & nb_12_2,
|
||||
constant uint32_t & nb_12_3,
|
||||
constant uint32_t & nb31,
|
||||
constant int32_t & ne1,
|
||||
constant int32_t & ne2,
|
||||
@ -2830,13 +2827,13 @@ kernel void kernel_flash_attn_ext(
|
||||
constant float & max_bias,
|
||||
constant float & m0,
|
||||
constant float & m1,
|
||||
constant uint32_t & n_head_log2,
|
||||
constant uint16_t & n_head_log2,
|
||||
constant float & logit_softcap,
|
||||
threadgroup half * shared [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 ntg[[threads_per_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
ushort3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 ntg[[threads_per_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
const short nsg = ntg.y; // number of simdgroups
|
||||
|
||||
const int iq3 = tgpig[2];
|
||||
@ -2849,17 +2846,14 @@ kernel void kernel_flash_attn_ext(
|
||||
const short NW = N_SIMDWIDTH;
|
||||
const short SH = (2*C + Q); // shared memory per simdgroup in (half)
|
||||
|
||||
const short SF = sizeof(s_t)/sizeof(half);
|
||||
const short TS = nsg*SH; // shared memory size per query in (s_t)
|
||||
const short T = D + 2*TS; // shared memory size per query in (half)
|
||||
|
||||
const short T = D + SF*nsg*SH; // shared memory size per query in (half)
|
||||
const short TS = T/SF; // shared memory size per query in (s_t)
|
||||
const short T4 = T/4; // shared memory size per query in (half4)
|
||||
|
||||
threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
|
||||
threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation
|
||||
threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // reuse query data for accumulation
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shared + SF*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||
threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
|
||||
threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation
|
||||
threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // reuse query data for accumulation
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention and diagonal matrix
|
||||
|
||||
threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
|
||||
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
|
||||
@ -2876,9 +2870,9 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
if (iq1 + j < ne01) {
|
||||
sq4[j*T4 + i] = (q4_t) q4[i];
|
||||
sq4[j*D4 + i] = (q4_t) q4[i];
|
||||
} else {
|
||||
sq4[j*T4 + i] = (q4_t) 0.0f;
|
||||
sq4[j*D4 + i] = (q4_t) 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2907,29 +2901,18 @@ kernel void kernel_flash_attn_ext(
|
||||
const short tx = tiisg%4;
|
||||
const short ty = tiisg/4;
|
||||
|
||||
// assume K and V are same shape
|
||||
const short ne22 = ne12;
|
||||
const short ne23 = ne13;
|
||||
// broadcast kv
|
||||
//const short rk2 = ne02/ne12;
|
||||
//const short rk3 = ne03/ne13;
|
||||
|
||||
// broadcast k
|
||||
const short rk2 = ne02/ne12;
|
||||
const short rk3 = ne03/ne13;
|
||||
|
||||
const short ik2 = iq2/rk2;
|
||||
const short ik3 = iq3/rk3;
|
||||
|
||||
// broadcast v
|
||||
const short rv2 = ne02/ne22;
|
||||
const short rv3 = ne03/ne23;
|
||||
|
||||
const short iv2 = iq2/rv2;
|
||||
const short iv3 = iq3/rv3;
|
||||
const short ikv2 = iq2/(ne02/ne_12_2);
|
||||
const short ikv3 = iq3/(ne03/ne_12_3);
|
||||
|
||||
// load the queries from shared memory into local memory
|
||||
q8x8_t mq[D8];
|
||||
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
simdgroup_load(mq[i], sq + i*8, T);
|
||||
simdgroup_load(mq[i], sq + i*8, D);
|
||||
}
|
||||
|
||||
const bool has_mask = mask != q;
|
||||
@ -2982,18 +2965,18 @@ kernel void kernel_flash_attn_ext(
|
||||
// this is compile-time check, so it does not have runtime overhead
|
||||
if (is_same<kd4x4_t, k4x4_t>::value) {
|
||||
// we can read directly from global memory
|
||||
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
||||
|
||||
#pragma unroll
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
k8x8_t mk;
|
||||
simdgroup_load(mk, pk + i*8, nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10
|
||||
simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
|
||||
|
||||
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
||||
}
|
||||
} else {
|
||||
for (short ii = 0; ii < D16; ii += 4) {
|
||||
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
||||
|
||||
if (D16%4 == 0) {
|
||||
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
|
||||
@ -3046,7 +3029,7 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
// online softmax
|
||||
{
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
for (ushort j = 0; j < Q; ++j) {
|
||||
const half m = M[j];
|
||||
|
||||
// scale and apply the logitcap / mask
|
||||
@ -3095,17 +3078,17 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
if (is_same<vd4x4_t, v4x4_t>::value) {
|
||||
// we can read directly from global memory
|
||||
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
||||
#pragma unroll
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
v8x8_t mv;
|
||||
simdgroup_load(mv, pv + i*8, nb21/sizeof(v_t), 0, false); // TODO: use ne20
|
||||
simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
|
||||
|
||||
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
|
||||
}
|
||||
} else {
|
||||
for (short ii = 0; ii < D16; ii += 4) {
|
||||
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
||||
|
||||
if (D16%4 == 0) {
|
||||
// no need for bound checks
|
||||
@ -3162,7 +3145,7 @@ kernel void kernel_flash_attn_ext(
|
||||
}
|
||||
|
||||
// reduce the warps sequentially
|
||||
for (short sg = 1; sg < nsg; ++sg) {
|
||||
for (ushort sg = 1; sg < nsg; ++sg) {
|
||||
half S = { 0.0f };
|
||||
half M = { -__FLT16_MAX__/2 };
|
||||
|
||||
@ -3171,7 +3154,7 @@ kernel void kernel_flash_attn_ext(
|
||||
// each simdgroup stores its output to shared memory, reusing sq
|
||||
if (sgitg == sg) {
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
simdgroup_store(lo[i], so + i*8, T, 0, false);
|
||||
simdgroup_store(lo[i], so + i*8, D, 0, false);
|
||||
}
|
||||
}
|
||||
|
||||
@ -3213,7 +3196,7 @@ kernel void kernel_flash_attn_ext(
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
o8x8_t t;
|
||||
|
||||
simdgroup_load (t, so + i*8, T, 0, false);
|
||||
simdgroup_load (t, so + i*8, D, 0, false);
|
||||
simdgroup_multiply(t, ms1, t);
|
||||
|
||||
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
|
||||
@ -3225,7 +3208,7 @@ kernel void kernel_flash_attn_ext(
|
||||
// store result to shared memory (reuse sq)
|
||||
if (sgitg == 0) {
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
simdgroup_store(lo[i], so + i*8, T, 0, false);
|
||||
simdgroup_store(lo[i], so + i*8, D, 0, false);
|
||||
}
|
||||
}
|
||||
|
||||
@ -3237,21 +3220,12 @@ kernel void kernel_flash_attn_ext(
|
||||
const half S = ss[j*TS + 0];
|
||||
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*T4 + i]/S;
|
||||
dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(GGML_METAL_FORCE_FATTN_PREC_F16)
|
||||
#define FA_TYPES \
|
||||
half, half4, simdgroup_half8x8, \
|
||||
half, half4x4, simdgroup_half8x8, \
|
||||
half, half4x4, simdgroup_half8x8, \
|
||||
float, simdgroup_float8x8, \
|
||||
half, simdgroup_half8x8, \
|
||||
half, half4, simdgroup_half8x8
|
||||
#else
|
||||
#define FA_TYPES \
|
||||
half, half4, simdgroup_half8x8, \
|
||||
half, half4x4, simdgroup_half8x8, \
|
||||
@ -3259,7 +3233,6 @@ kernel void kernel_flash_attn_ext(
|
||||
float, simdgroup_float8x8, \
|
||||
float, simdgroup_float8x8, \
|
||||
half, half4, simdgroup_half8x8
|
||||
#endif
|
||||
|
||||
// TOOD: static_assert
|
||||
|
||||
@ -3339,24 +3312,21 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device float * dst,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & ne13,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant uint64_t & nb13,
|
||||
constant uint64_t & nb21,
|
||||
constant uint64_t & nb22,
|
||||
constant uint64_t & nb23,
|
||||
constant uint64_t & nb31,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int32_t & ne01,
|
||||
constant int32_t & ne02,
|
||||
constant int32_t & ne03,
|
||||
constant uint32_t & nb01,
|
||||
constant uint32_t & nb02,
|
||||
constant uint32_t & nb03,
|
||||
constant int32_t & ne11,
|
||||
constant int32_t & ne_12_2, // assume K and V are same shape
|
||||
constant int32_t & ne_12_3,
|
||||
constant uint32_t & nb_12_1,
|
||||
constant uint32_t & nb_12_2,
|
||||
constant uint32_t & nb_12_3,
|
||||
constant uint32_t & nb31,
|
||||
constant int32_t & ne1,
|
||||
constant int32_t & ne2,
|
||||
constant float & scale,
|
||||
constant float & max_bias,
|
||||
constant float & m0,
|
||||
@ -3426,23 +3396,12 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
const short tx = tiisg%8;
|
||||
const short ty = tiisg/8;
|
||||
|
||||
// assume K and V are same shape
|
||||
const short ne22 = ne12;
|
||||
const short ne23 = ne13;
|
||||
// broadcast kv
|
||||
//const short rk2 = ne02/ne12;
|
||||
//const short rk3 = ne03/ne13;
|
||||
|
||||
// broadcast k
|
||||
const short rk2 = ne02/ne12;
|
||||
const short rk3 = ne03/ne13;
|
||||
|
||||
const short ik2 = iq2/rk2;
|
||||
const short ik3 = iq3/rk3;
|
||||
|
||||
// broadcast v
|
||||
const short rv2 = ne02/ne22;
|
||||
const short rv3 = ne03/ne23;
|
||||
|
||||
const short iv2 = iq2/rv2;
|
||||
const short iv3 = iq3/rv3;
|
||||
const short ikv2 = iq2/(ne02/ne_12_2);
|
||||
const short ikv3 = iq3/(ne03/ne_12_3);
|
||||
|
||||
// load the queries from shared memory into local memory
|
||||
k4x4_t mq[D16/NW4];
|
||||
@ -3480,7 +3439,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
for (short cc = 0; cc < C/4; ++cc) {
|
||||
s_t mqk = 0.0;
|
||||
|
||||
device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
||||
|
||||
#pragma unroll
|
||||
for (short ii = 0; ii < D16; ii += NW4) {
|
||||
@ -3554,7 +3513,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
{
|
||||
#pragma unroll
|
||||
for (short cc = 0; cc < C/4; ++cc) {
|
||||
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
||||
|
||||
const s4x4_t ms(ss[4*cc + ty]);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user