metal : fattn args

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-09 16:09:31 +02:00
parent 996e479780
commit 089404f3a1
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 113 additions and 125 deletions

View File

@ -446,6 +446,30 @@ typedef struct {
float beta_fast;
float beta_slow;
} ggml_metal_kargs_rope;
typedef struct {
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3;
uint64_t nb_12_1;
uint64_t nb_12_2;
uint64_t nb_12_3;
uint64_t nb31;
int32_t ne1;
int32_t ne2;
float scale;
float max_bias;
float m0;
float m1;
uint16_t n_head_log2;
float logit_softcap;
} ggml_metal_kargs_flash_attn_ext;
#endif
#endif // GGML_COMMON_DECL

View File

@ -3228,6 +3228,30 @@ static void ggml_metal_encode_node(
}
}
ggml_metal_kargs_flash_attn_ext args = {
.ne01 = ne01,
.ne02 = ne02,
.ne03 = ne03,
.nb01 = nb01,
.nb02 = nb02,
.nb03 = nb03,
.ne11 = ne11,
.ne_12_2 = ne12,
.ne_12_3 = ne13,
.nb_12_1 = nb11,
.nb_12_2 = nb12,
.nb_12_3 = nb13,
.nb31 = nb31,
.ne1 = ne1,
.ne2 = ne2,
.scale = scale,
.max_bias = max_bias,
.m0 = m0,
.m1 = m1,
.n_head_log2 = n_head_log2,
.logit_softcap = logit_softcap,
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@ -3238,27 +3262,7 @@ static void ggml_metal_encode_node(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19];
[encoder setBytes:&scale length:sizeof( float) atIndex:20];
[encoder setBytes:&max_bias length:sizeof( float) atIndex:21];
[encoder setBytes:&m0 length:sizeof(m0) atIndex:22];
[encoder setBytes:&m1 length:sizeof(m1) atIndex:23];
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24];
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25];
[encoder setBytes:&args length:sizeof(args) atIndex:5];
if (!use_vec_kernel) {
// half8x8 kernel

View File

@ -2267,8 +2267,8 @@ kernel void kernel_rope_norm(
device const char * src2,
device char * dst,
constant ggml_metal_kargs_rope & args,
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
@ -2320,8 +2320,8 @@ kernel void kernel_rope_neox(
device const char * src2,
device char * dst,
constant ggml_metal_kargs_rope & args,
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
@ -2761,29 +2761,9 @@ kernel void kernel_flash_attn_ext(
device const char * v,
device const char * mask,
device float * dst,
constant int32_t & ne01,
constant int32_t & ne02,
constant int32_t & ne03,
constant uint32_t & nb01,
constant uint32_t & nb02,
constant uint32_t & nb03,
constant int32_t & ne11,
constant int32_t & ne_12_2, // assume K and V are same shape
constant int32_t & ne_12_3,
constant uint32_t & nb_12_1,
constant uint32_t & nb_12_2,
constant uint32_t & nb_12_3,
constant uint32_t & nb31,
constant int32_t & ne1,
constant int32_t & ne2,
constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
constant uint16_t & n_head_log2,
constant float & logit_softcap,
constant ggml_metal_kargs_flash_attn_ext & args,
threadgroup half * shared [[threadgroup(0)]],
ushort3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
@ -2819,10 +2799,10 @@ kernel void kernel_flash_attn_ext(
// load heads from Q to shared memory
for (short j = sgitg; j < Q; j += nsg) {
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
for (short i = tiisg; i < D4; i += NW) {
if (iq1 + j < ne01) {
if (iq1 + j < args.ne01) {
sq4[j*D4 + i] = (q4_t) q4[i];
} else {
sq4[j*D4 + i] = (q4_t) 0.0f;
@ -2855,11 +2835,11 @@ kernel void kernel_flash_attn_ext(
const short ty = tiisg/4;
// broadcast kv
//const short rk2 = ne02/ne12;
//const short rk3 = ne03/ne13;
//const short rk2 = args.ne02/args.ne12;
//const short rk3 = args.ne03/args.ne13;
const short ikv2 = iq2/(ne02/ne_12_2);
const short ikv3 = iq3/(ne03/ne_12_3);
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
// load the queries from shared memory into local memory
q8x8_t mq[D8];
@ -2873,20 +2853,20 @@ kernel void kernel_flash_attn_ext(
half slope = 1.0f;
// ALiBi
if (max_bias > 0.0f) {
if (args.max_bias > 0.0f) {
const short h = iq2;
const half base = h < n_head_log2 ? m0 : m1;
const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
const half base = h < args.n_head_log2 ? args.m0 : args.m1;
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
slope = pow(base, exph);
}
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
const int ic = ic0 + C*sgitg;
if (ic >= ne11) {
if (ic >= args.ne11) {
break;
}
@ -2896,7 +2876,7 @@ kernel void kernel_flash_attn_ext(
// load the mask in shared memory
for (short j = 0; j < Q; ++j) {
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
const half m = pm[ic + tiisg];
@ -2919,18 +2899,18 @@ kernel void kernel_flash_attn_ext(
// this is compile-time check, so it does not have runtime overhead
if (is_same<kd4x4_t, k4x4_t>::value) {
// we can read directly from global memory
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
#pragma unroll
for (short i = 0; i < D8; ++i) {
k8x8_t mk;
simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
simdgroup_load(mk, pk + i*8, args.nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
}
} else {
for (short ii = 0; ii < D16; ii += 4) {
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
if (D16%4 == 0) {
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
@ -2989,10 +2969,10 @@ kernel void kernel_flash_attn_ext(
const half m = M[j];
// scale and apply the logitcap / mask
half s = ss[j*TS + tiisg]*scale;
half s = ss[j*TS + tiisg]*args.scale;
if (logit_softcap != 0.0f) {
s = logit_softcap*precise::tanh(s);
if (args.logit_softcap != 0.0f) {
s = args.logit_softcap*precise::tanh(s);
}
// mqk = mqk + mask*slope
@ -3034,17 +3014,17 @@ kernel void kernel_flash_attn_ext(
if (is_same<vd4x4_t, v4x4_t>::value) {
// we can read directly from global memory
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
#pragma unroll
for (short i = 0; i < D8; ++i) {
v8x8_t mv;
simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
simdgroup_load(mv, pv + i*8, args.nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
}
} else {
for (short ii = 0; ii < D16; ii += 4) {
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
if (D16%4 == 0) {
// no need for bound checks
@ -3172,11 +3152,11 @@ kernel void kernel_flash_attn_ext(
// final rescale with 1/S and store to global memory
if (sgitg == 0) {
for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
const float S = ss[j*TS + 0];
for (short i = tiisg; i < D4; i += NW) {
dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
}
}
}
@ -3273,29 +3253,9 @@ kernel void kernel_flash_attn_ext_vec(
device const char * v,
device const char * mask,
device float * dst,
constant int32_t & ne01,
constant int32_t & ne02,
constant int32_t & ne03,
constant uint32_t & nb01,
constant uint32_t & nb02,
constant uint32_t & nb03,
constant int32_t & ne11,
constant int32_t & ne_12_2, // assume K and V are same shape
constant int32_t & ne_12_3,
constant uint32_t & nb_12_1,
constant uint32_t & nb_12_2,
constant uint32_t & nb_12_3,
constant uint32_t & nb31,
constant int32_t & ne1,
constant int32_t & ne2,
constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
constant uint16_t & n_head_log2,
constant float & logit_softcap,
constant ggml_metal_kargs_flash_attn_ext & args,
threadgroup half * shared [[threadgroup(0)]],
ushort3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
@ -3326,10 +3286,10 @@ kernel void kernel_flash_attn_ext_vec(
o4x4_t lo[D16/NL];
// load heads from Q to shared memory
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
for (short i = tiisg; i < D4; i += NW) {
if (iq1 < ne01) {
if (iq1 < args.ne01) {
sq4[i] = (q4_t) q4[i];
} else {
sq4[i] = (q4_t) 0.0f;
@ -3357,11 +3317,11 @@ kernel void kernel_flash_attn_ext_vec(
const short ty = tiisg/NL;
// broadcast kv
//const short rk2 = ne02/ne12;
//const short rk3 = ne03/ne13;
//const short rk2 = args.ne02/args.ne12;
//const short rk3 = args.ne03/args.ne13;
const short ikv2 = iq2/(ne02/ne_12_2);
const short ikv3 = iq3/(ne03/ne_12_3);
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
// load the queries from shared memory into local memory
q4x4_t mq[D16/NL];
@ -3373,25 +3333,25 @@ kernel void kernel_flash_attn_ext_vec(
const bool has_mask = mask != q;
// pointer to the mask
device const half * pm = (device const half *) (mask + iq1*nb31);
device const half * pm = (device const half *) (mask + iq1*args.nb31);
half slope = 1.0f;
// ALiBi
if (max_bias > 0.0f) {
if (args.max_bias > 0.0f) {
const short h = iq2;
const half base = h < n_head_log2 ? m0 : m1;
const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
const half base = h < args.n_head_log2 ? args.m0 : args.m1;
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
slope = pow(base, exph);
}
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
const int ic = ic0 + C*sgitg;
if (ic >= ne11) {
if (ic >= args.ne11) {
break;
}
@ -3405,7 +3365,7 @@ kernel void kernel_flash_attn_ext_vec(
for (short cc = 0; cc < C/4; ++cc) {
qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
#pragma unroll
for (short ii = 0; ii < D16; ii += NL) {
@ -3435,10 +3395,10 @@ kernel void kernel_flash_attn_ext_vec(
// mqk = mqk*scale + mask*slope
if (tx == 0) {
mqk *= scale;
mqk *= args.scale;
if (logit_softcap != 0.0f) {
mqk = logit_softcap*precise::tanh(mqk);
if (args.logit_softcap != 0.0f) {
mqk = args.logit_softcap*precise::tanh(mqk);
}
mqk += sm[4*cc + ty]*slope;
@ -3478,7 +3438,7 @@ kernel void kernel_flash_attn_ext_vec(
{
#pragma unroll
for (short cc = 0; cc < C/4; ++cc) {
device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
const s4x4_t ms(ss[4*cc + ty]);
@ -3583,7 +3543,7 @@ kernel void kernel_flash_attn_ext_vec(
const float S = ss[0];
for (short i = tiisg; i < D16; i += NW) {
dst44[((int64_t)iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1)*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
}
}
}