metal : fattn args

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-09 16:09:31 +02:00
parent 996e479780
commit 089404f3a1
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 113 additions and 125 deletions

View File

@ -446,6 +446,30 @@ typedef struct {
float beta_fast; float beta_fast;
float beta_slow; float beta_slow;
} ggml_metal_kargs_rope; } ggml_metal_kargs_rope;
typedef struct {
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3;
uint64_t nb_12_1;
uint64_t nb_12_2;
uint64_t nb_12_3;
uint64_t nb31;
int32_t ne1;
int32_t ne2;
float scale;
float max_bias;
float m0;
float m1;
uint16_t n_head_log2;
float logit_softcap;
} ggml_metal_kargs_flash_attn_ext;
#endif #endif
#endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL

View File

@ -3228,37 +3228,41 @@ static void ggml_metal_encode_node(
} }
} }
ggml_metal_kargs_flash_attn_ext args = {
.ne01 = ne01,
.ne02 = ne02,
.ne03 = ne03,
.nb01 = nb01,
.nb02 = nb02,
.nb03 = nb03,
.ne11 = ne11,
.ne_12_2 = ne12,
.ne_12_3 = ne13,
.nb_12_1 = nb11,
.nb_12_2 = nb12,
.nb_12_3 = nb13,
.nb31 = nb31,
.ne1 = ne1,
.ne2 = ne2,
.scale = scale,
.max_bias = max_bias,
.m0 = m0,
.m1 = m1,
.n_head_log2 = n_head_log2,
.logit_softcap = logit_softcap,
};
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
if (id_src3) { if (id_src3) {
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
} else { } else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
} }
[encoder setBuffer:id_dst offset:offs_dst atIndex:4]; [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; [encoder setBytes:&args length:sizeof(args) atIndex:5];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19];
[encoder setBytes:&scale length:sizeof( float) atIndex:20];
[encoder setBytes:&max_bias length:sizeof( float) atIndex:21];
[encoder setBytes:&m0 length:sizeof(m0) atIndex:22];
[encoder setBytes:&m1 length:sizeof(m1) atIndex:23];
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24];
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25];
if (!use_vec_kernel) { if (!use_vec_kernel) {
// half8x8 kernel // half8x8 kernel

View File

@ -2267,9 +2267,9 @@ kernel void kernel_rope_norm(
device const char * src2, device const char * src2,
device char * dst, device char * dst,
constant ggml_metal_kargs_rope & args, constant ggml_metal_kargs_rope & args,
uint tiitg[[thread_index_in_threadgroup]], ushort tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]], ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) { uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2]; const int i3 = tgpig[2];
const int i2 = tgpig[1]; const int i2 = tgpig[1];
const int i1 = tgpig[0]; const int i1 = tgpig[0];
@ -2320,9 +2320,9 @@ kernel void kernel_rope_neox(
device const char * src2, device const char * src2,
device char * dst, device char * dst,
constant ggml_metal_kargs_rope & args, constant ggml_metal_kargs_rope & args,
uint tiitg[[thread_index_in_threadgroup]], ushort tiitg[[thread_index_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]], ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) { uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2]; const int i3 = tgpig[2];
const int i2 = tgpig[1]; const int i2 = tgpig[1];
const int i1 = tgpig[0]; const int i1 = tgpig[0];
@ -2761,32 +2761,12 @@ kernel void kernel_flash_attn_ext(
device const char * v, device const char * v,
device const char * mask, device const char * mask,
device float * dst, device float * dst,
constant int32_t & ne01, constant ggml_metal_kargs_flash_attn_ext & args,
constant int32_t & ne02,
constant int32_t & ne03,
constant uint32_t & nb01,
constant uint32_t & nb02,
constant uint32_t & nb03,
constant int32_t & ne11,
constant int32_t & ne_12_2, // assume K and V are same shape
constant int32_t & ne_12_3,
constant uint32_t & nb_12_1,
constant uint32_t & nb_12_2,
constant uint32_t & nb_12_3,
constant uint32_t & nb31,
constant int32_t & ne1,
constant int32_t & ne2,
constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
constant uint16_t & n_head_log2,
constant float & logit_softcap,
threadgroup half * shared [[threadgroup(0)]], threadgroup half * shared [[threadgroup(0)]],
ushort3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 ntg[[threads_per_threadgroup]], ushort3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]], ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) { ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short nsg = ntg.y; // number of simdgroups const short nsg = ntg.y; // number of simdgroups
const int iq3 = tgpig[2]; const int iq3 = tgpig[2];
@ -2819,10 +2799,10 @@ kernel void kernel_flash_attn_ext(
// load heads from Q to shared memory // load heads from Q to shared memory
for (short j = sgitg; j < Q; j += nsg) { for (short j = sgitg; j < Q; j += nsg) {
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
for (short i = tiisg; i < D4; i += NW) { for (short i = tiisg; i < D4; i += NW) {
if (iq1 + j < ne01) { if (iq1 + j < args.ne01) {
sq4[j*D4 + i] = (q4_t) q4[i]; sq4[j*D4 + i] = (q4_t) q4[i];
} else { } else {
sq4[j*D4 + i] = (q4_t) 0.0f; sq4[j*D4 + i] = (q4_t) 0.0f;
@ -2855,11 +2835,11 @@ kernel void kernel_flash_attn_ext(
const short ty = tiisg/4; const short ty = tiisg/4;
// broadcast kv // broadcast kv
//const short rk2 = ne02/ne12; //const short rk2 = args.ne02/args.ne12;
//const short rk3 = ne03/ne13; //const short rk3 = args.ne03/args.ne13;
const short ikv2 = iq2/(ne02/ne_12_2); const short ikv2 = iq2/(args.ne02/args.ne_12_2);
const short ikv3 = iq3/(ne03/ne_12_3); const short ikv3 = iq3/(args.ne03/args.ne_12_3);
// load the queries from shared memory into local memory // load the queries from shared memory into local memory
q8x8_t mq[D8]; q8x8_t mq[D8];
@ -2873,20 +2853,20 @@ kernel void kernel_flash_attn_ext(
half slope = 1.0f; half slope = 1.0f;
// ALiBi // ALiBi
if (max_bias > 0.0f) { if (args.max_bias > 0.0f) {
const short h = iq2; const short h = iq2;
const half base = h < n_head_log2 ? m0 : m1; const half base = h < args.n_head_log2 ? args.m0 : args.m1;
const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
slope = pow(base, exph); slope = pow(base, exph);
} }
// loop over the KV cache // loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns // each simdgroup handles blocks of Q rows and C columns
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
const int ic = ic0 + C*sgitg; const int ic = ic0 + C*sgitg;
if (ic >= ne11) { if (ic >= args.ne11) {
break; break;
} }
@ -2896,7 +2876,7 @@ kernel void kernel_flash_attn_ext(
// load the mask in shared memory // load the mask in shared memory
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31); device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
const half m = pm[ic + tiisg]; const half m = pm[ic + tiisg];
@ -2919,18 +2899,18 @@ kernel void kernel_flash_attn_ext(
// this is compile-time check, so it does not have runtime overhead // this is compile-time check, so it does not have runtime overhead
if (is_same<kd4x4_t, k4x4_t>::value) { if (is_same<kd4x4_t, k4x4_t>::value) {
// we can read directly from global memory // we can read directly from global memory
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
#pragma unroll #pragma unroll
for (short i = 0; i < D8; ++i) { for (short i = 0; i < D8; ++i) {
k8x8_t mk; k8x8_t mk;
simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10 simdgroup_load(mk, pk + i*8, args.nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
} }
} else { } else {
for (short ii = 0; ii < D16; ii += 4) { for (short ii = 0; ii < D16; ii += 4) {
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
if (D16%4 == 0) { if (D16%4 == 0) {
// the head is evenly divisible by 4*16 = 64, so no need for bound checks // the head is evenly divisible by 4*16 = 64, so no need for bound checks
@ -2989,10 +2969,10 @@ kernel void kernel_flash_attn_ext(
const half m = M[j]; const half m = M[j];
// scale and apply the logitcap / mask // scale and apply the logitcap / mask
half s = ss[j*TS + tiisg]*scale; half s = ss[j*TS + tiisg]*args.scale;
if (logit_softcap != 0.0f) { if (args.logit_softcap != 0.0f) {
s = logit_softcap*precise::tanh(s); s = args.logit_softcap*precise::tanh(s);
} }
// mqk = mqk + mask*slope // mqk = mqk + mask*slope
@ -3034,17 +3014,17 @@ kernel void kernel_flash_attn_ext(
if (is_same<vd4x4_t, v4x4_t>::value) { if (is_same<vd4x4_t, v4x4_t>::value) {
// we can read directly from global memory // we can read directly from global memory
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
#pragma unroll #pragma unroll
for (short i = 0; i < D8; ++i) { for (short i = 0; i < D8; ++i) {
v8x8_t mv; v8x8_t mv;
simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20 simdgroup_load(mv, pv + i*8, args.nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]); simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
} }
} else { } else {
for (short ii = 0; ii < D16; ii += 4) { for (short ii = 0; ii < D16; ii += 4) {
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
if (D16%4 == 0) { if (D16%4 == 0) {
// no need for bound checks // no need for bound checks
@ -3172,11 +3152,11 @@ kernel void kernel_flash_attn_ext(
// final rescale with 1/S and store to global memory // final rescale with 1/S and store to global memory
if (sgitg == 0) { if (sgitg == 0) {
for (short j = 0; j < Q && iq1 + j < ne01; ++j) { for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
const float S = ss[j*TS + 0]; const float S = ss[j*TS + 0];
for (short i = tiisg; i < D4; i += NW) { for (short i = tiisg; i < D4; i += NW) {
dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
} }
} }
} }
@ -3273,33 +3253,13 @@ kernel void kernel_flash_attn_ext_vec(
device const char * v, device const char * v,
device const char * mask, device const char * mask,
device float * dst, device float * dst,
constant int32_t & ne01, constant ggml_metal_kargs_flash_attn_ext & args,
constant int32_t & ne02,
constant int32_t & ne03,
constant uint32_t & nb01,
constant uint32_t & nb02,
constant uint32_t & nb03,
constant int32_t & ne11,
constant int32_t & ne_12_2, // assume K and V are same shape
constant int32_t & ne_12_3,
constant uint32_t & nb_12_1,
constant uint32_t & nb_12_2,
constant uint32_t & nb_12_3,
constant uint32_t & nb31,
constant int32_t & ne1,
constant int32_t & ne2,
constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
constant uint16_t & n_head_log2,
constant float & logit_softcap,
threadgroup half * shared [[threadgroup(0)]], threadgroup half * shared [[threadgroup(0)]],
ushort3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]], ushort3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]], ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) { ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short nsg = ntg.y; // number of simdgroups const short nsg = ntg.y; // number of simdgroups
const int iq3 = tgpig[2]; const int iq3 = tgpig[2];
@ -3326,10 +3286,10 @@ kernel void kernel_flash_attn_ext_vec(
o4x4_t lo[D16/NL]; o4x4_t lo[D16/NL];
// load heads from Q to shared memory // load heads from Q to shared memory
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
for (short i = tiisg; i < D4; i += NW) { for (short i = tiisg; i < D4; i += NW) {
if (iq1 < ne01) { if (iq1 < args.ne01) {
sq4[i] = (q4_t) q4[i]; sq4[i] = (q4_t) q4[i];
} else { } else {
sq4[i] = (q4_t) 0.0f; sq4[i] = (q4_t) 0.0f;
@ -3357,11 +3317,11 @@ kernel void kernel_flash_attn_ext_vec(
const short ty = tiisg/NL; const short ty = tiisg/NL;
// broadcast kv // broadcast kv
//const short rk2 = ne02/ne12; //const short rk2 = args.ne02/args.ne12;
//const short rk3 = ne03/ne13; //const short rk3 = args.ne03/args.ne13;
const short ikv2 = iq2/(ne02/ne_12_2); const short ikv2 = iq2/(args.ne02/args.ne_12_2);
const short ikv3 = iq3/(ne03/ne_12_3); const short ikv3 = iq3/(args.ne03/args.ne_12_3);
// load the queries from shared memory into local memory // load the queries from shared memory into local memory
q4x4_t mq[D16/NL]; q4x4_t mq[D16/NL];
@ -3373,25 +3333,25 @@ kernel void kernel_flash_attn_ext_vec(
const bool has_mask = mask != q; const bool has_mask = mask != q;
// pointer to the mask // pointer to the mask
device const half * pm = (device const half *) (mask + iq1*nb31); device const half * pm = (device const half *) (mask + iq1*args.nb31);
half slope = 1.0f; half slope = 1.0f;
// ALiBi // ALiBi
if (max_bias > 0.0f) { if (args.max_bias > 0.0f) {
const short h = iq2; const short h = iq2;
const half base = h < n_head_log2 ? m0 : m1; const half base = h < args.n_head_log2 ? args.m0 : args.m1;
const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
slope = pow(base, exph); slope = pow(base, exph);
} }
// loop over the KV cache // loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns // each simdgroup handles blocks of Q rows and C columns
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
const int ic = ic0 + C*sgitg; const int ic = ic0 + C*sgitg;
if (ic >= ne11) { if (ic >= args.ne11) {
break; break;
} }
@ -3405,7 +3365,7 @@ kernel void kernel_flash_attn_ext_vec(
for (short cc = 0; cc < C/4; ++cc) { for (short cc = 0; cc < C/4; ++cc) {
qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 }; qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
#pragma unroll #pragma unroll
for (short ii = 0; ii < D16; ii += NL) { for (short ii = 0; ii < D16; ii += NL) {
@ -3435,10 +3395,10 @@ kernel void kernel_flash_attn_ext_vec(
// mqk = mqk*scale + mask*slope // mqk = mqk*scale + mask*slope
if (tx == 0) { if (tx == 0) {
mqk *= scale; mqk *= args.scale;
if (logit_softcap != 0.0f) { if (args.logit_softcap != 0.0f) {
mqk = logit_softcap*precise::tanh(mqk); mqk = args.logit_softcap*precise::tanh(mqk);
} }
mqk += sm[4*cc + ty]*slope; mqk += sm[4*cc + ty]*slope;
@ -3478,7 +3438,7 @@ kernel void kernel_flash_attn_ext_vec(
{ {
#pragma unroll #pragma unroll
for (short cc = 0; cc < C/4; ++cc) { for (short cc = 0; cc < C/4; ++cc) {
device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
const s4x4_t ms(ss[4*cc + ty]); const s4x4_t ms(ss[4*cc + ty]);
@ -3583,7 +3543,7 @@ kernel void kernel_flash_attn_ext_vec(
const float S = ss[0]; const float S = ss[0];
for (short i = tiisg; i < D16; i += NW) { for (short i = tiisg; i < D16; i += NW) {
dst44[((int64_t)iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = (float4x4) sr4x4[i]/S; dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1)*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
} }
} }
} }