mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-03 23:34:35 +00:00
metal : fattn args
ggml-ci
This commit is contained in:
parent
996e479780
commit
089404f3a1
@ -446,6 +446,30 @@ typedef struct {
|
|||||||
float beta_fast;
|
float beta_fast;
|
||||||
float beta_slow;
|
float beta_slow;
|
||||||
} ggml_metal_kargs_rope;
|
} ggml_metal_kargs_rope;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int32_t ne01;
|
||||||
|
int32_t ne02;
|
||||||
|
int32_t ne03;
|
||||||
|
uint64_t nb01;
|
||||||
|
uint64_t nb02;
|
||||||
|
uint64_t nb03;
|
||||||
|
int32_t ne11;
|
||||||
|
int32_t ne_12_2; // assume K and V are same shape
|
||||||
|
int32_t ne_12_3;
|
||||||
|
uint64_t nb_12_1;
|
||||||
|
uint64_t nb_12_2;
|
||||||
|
uint64_t nb_12_3;
|
||||||
|
uint64_t nb31;
|
||||||
|
int32_t ne1;
|
||||||
|
int32_t ne2;
|
||||||
|
float scale;
|
||||||
|
float max_bias;
|
||||||
|
float m0;
|
||||||
|
float m1;
|
||||||
|
uint16_t n_head_log2;
|
||||||
|
float logit_softcap;
|
||||||
|
} ggml_metal_kargs_flash_attn_ext;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#endif // GGML_COMMON_DECL
|
#endif // GGML_COMMON_DECL
|
||||||
|
@ -3228,37 +3228,41 @@ static void ggml_metal_encode_node(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_metal_kargs_flash_attn_ext args = {
|
||||||
|
.ne01 = ne01,
|
||||||
|
.ne02 = ne02,
|
||||||
|
.ne03 = ne03,
|
||||||
|
.nb01 = nb01,
|
||||||
|
.nb02 = nb02,
|
||||||
|
.nb03 = nb03,
|
||||||
|
.ne11 = ne11,
|
||||||
|
.ne_12_2 = ne12,
|
||||||
|
.ne_12_3 = ne13,
|
||||||
|
.nb_12_1 = nb11,
|
||||||
|
.nb_12_2 = nb12,
|
||||||
|
.nb_12_3 = nb13,
|
||||||
|
.nb31 = nb31,
|
||||||
|
.ne1 = ne1,
|
||||||
|
.ne2 = ne2,
|
||||||
|
.scale = scale,
|
||||||
|
.max_bias = max_bias,
|
||||||
|
.m0 = m0,
|
||||||
|
.m1 = m1,
|
||||||
|
.n_head_log2 = n_head_log2,
|
||||||
|
.logit_softcap = logit_softcap,
|
||||||
|
};
|
||||||
|
|
||||||
[encoder setComputePipelineState:pipeline];
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||||
if (id_src3) {
|
if (id_src3) {
|
||||||
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
||||||
} else {
|
} else {
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
|
||||||
}
|
}
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
||||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
[encoder setBytes:&args length:sizeof(args) atIndex:5];
|
||||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
|
||||||
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
|
||||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
|
||||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
|
||||||
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
|
||||||
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
|
|
||||||
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
|
|
||||||
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
|
|
||||||
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
|
|
||||||
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
|
|
||||||
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
|
|
||||||
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17];
|
|
||||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18];
|
|
||||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19];
|
|
||||||
[encoder setBytes:&scale length:sizeof( float) atIndex:20];
|
|
||||||
[encoder setBytes:&max_bias length:sizeof( float) atIndex:21];
|
|
||||||
[encoder setBytes:&m0 length:sizeof(m0) atIndex:22];
|
|
||||||
[encoder setBytes:&m1 length:sizeof(m1) atIndex:23];
|
|
||||||
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24];
|
|
||||||
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25];
|
|
||||||
|
|
||||||
if (!use_vec_kernel) {
|
if (!use_vec_kernel) {
|
||||||
// half8x8 kernel
|
// half8x8 kernel
|
||||||
|
@ -2267,9 +2267,9 @@ kernel void kernel_rope_norm(
|
|||||||
device const char * src2,
|
device const char * src2,
|
||||||
device char * dst,
|
device char * dst,
|
||||||
constant ggml_metal_kargs_rope & args,
|
constant ggml_metal_kargs_rope & args,
|
||||||
uint tiitg[[thread_index_in_threadgroup]],
|
ushort tiitg[[thread_index_in_threadgroup]],
|
||||||
uint3 tptg [[threads_per_threadgroup]],
|
ushort3 tptg [[threads_per_threadgroup]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||||
const int i3 = tgpig[2];
|
const int i3 = tgpig[2];
|
||||||
const int i2 = tgpig[1];
|
const int i2 = tgpig[1];
|
||||||
const int i1 = tgpig[0];
|
const int i1 = tgpig[0];
|
||||||
@ -2320,9 +2320,9 @@ kernel void kernel_rope_neox(
|
|||||||
device const char * src2,
|
device const char * src2,
|
||||||
device char * dst,
|
device char * dst,
|
||||||
constant ggml_metal_kargs_rope & args,
|
constant ggml_metal_kargs_rope & args,
|
||||||
uint tiitg[[thread_index_in_threadgroup]],
|
ushort tiitg[[thread_index_in_threadgroup]],
|
||||||
uint3 tptg[[threads_per_threadgroup]],
|
ushort3 tptg [[threads_per_threadgroup]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||||
const int i3 = tgpig[2];
|
const int i3 = tgpig[2];
|
||||||
const int i2 = tgpig[1];
|
const int i2 = tgpig[1];
|
||||||
const int i1 = tgpig[0];
|
const int i1 = tgpig[0];
|
||||||
@ -2761,32 +2761,12 @@ kernel void kernel_flash_attn_ext(
|
|||||||
device const char * v,
|
device const char * v,
|
||||||
device const char * mask,
|
device const char * mask,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int32_t & ne01,
|
constant ggml_metal_kargs_flash_attn_ext & args,
|
||||||
constant int32_t & ne02,
|
|
||||||
constant int32_t & ne03,
|
|
||||||
constant uint32_t & nb01,
|
|
||||||
constant uint32_t & nb02,
|
|
||||||
constant uint32_t & nb03,
|
|
||||||
constant int32_t & ne11,
|
|
||||||
constant int32_t & ne_12_2, // assume K and V are same shape
|
|
||||||
constant int32_t & ne_12_3,
|
|
||||||
constant uint32_t & nb_12_1,
|
|
||||||
constant uint32_t & nb_12_2,
|
|
||||||
constant uint32_t & nb_12_3,
|
|
||||||
constant uint32_t & nb31,
|
|
||||||
constant int32_t & ne1,
|
|
||||||
constant int32_t & ne2,
|
|
||||||
constant float & scale,
|
|
||||||
constant float & max_bias,
|
|
||||||
constant float & m0,
|
|
||||||
constant float & m1,
|
|
||||||
constant uint16_t & n_head_log2,
|
|
||||||
constant float & logit_softcap,
|
|
||||||
threadgroup half * shared [[threadgroup(0)]],
|
threadgroup half * shared [[threadgroup(0)]],
|
||||||
ushort3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
ushort3 ntg[[threads_per_threadgroup]],
|
ushort3 ntg[[threads_per_threadgroup]],
|
||||||
ushort tiisg[[thread_index_in_simdgroup]],
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
const short nsg = ntg.y; // number of simdgroups
|
const short nsg = ntg.y; // number of simdgroups
|
||||||
|
|
||||||
const int iq3 = tgpig[2];
|
const int iq3 = tgpig[2];
|
||||||
@ -2819,10 +2799,10 @@ kernel void kernel_flash_attn_ext(
|
|||||||
|
|
||||||
// load heads from Q to shared memory
|
// load heads from Q to shared memory
|
||||||
for (short j = sgitg; j < Q; j += nsg) {
|
for (short j = sgitg; j < Q; j += nsg) {
|
||||||
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
|
||||||
|
|
||||||
for (short i = tiisg; i < D4; i += NW) {
|
for (short i = tiisg; i < D4; i += NW) {
|
||||||
if (iq1 + j < ne01) {
|
if (iq1 + j < args.ne01) {
|
||||||
sq4[j*D4 + i] = (q4_t) q4[i];
|
sq4[j*D4 + i] = (q4_t) q4[i];
|
||||||
} else {
|
} else {
|
||||||
sq4[j*D4 + i] = (q4_t) 0.0f;
|
sq4[j*D4 + i] = (q4_t) 0.0f;
|
||||||
@ -2855,11 +2835,11 @@ kernel void kernel_flash_attn_ext(
|
|||||||
const short ty = tiisg/4;
|
const short ty = tiisg/4;
|
||||||
|
|
||||||
// broadcast kv
|
// broadcast kv
|
||||||
//const short rk2 = ne02/ne12;
|
//const short rk2 = args.ne02/args.ne12;
|
||||||
//const short rk3 = ne03/ne13;
|
//const short rk3 = args.ne03/args.ne13;
|
||||||
|
|
||||||
const short ikv2 = iq2/(ne02/ne_12_2);
|
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
|
||||||
const short ikv3 = iq3/(ne03/ne_12_3);
|
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
|
||||||
|
|
||||||
// load the queries from shared memory into local memory
|
// load the queries from shared memory into local memory
|
||||||
q8x8_t mq[D8];
|
q8x8_t mq[D8];
|
||||||
@ -2873,20 +2853,20 @@ kernel void kernel_flash_attn_ext(
|
|||||||
half slope = 1.0f;
|
half slope = 1.0f;
|
||||||
|
|
||||||
// ALiBi
|
// ALiBi
|
||||||
if (max_bias > 0.0f) {
|
if (args.max_bias > 0.0f) {
|
||||||
const short h = iq2;
|
const short h = iq2;
|
||||||
|
|
||||||
const half base = h < n_head_log2 ? m0 : m1;
|
const half base = h < args.n_head_log2 ? args.m0 : args.m1;
|
||||||
const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
||||||
|
|
||||||
slope = pow(base, exph);
|
slope = pow(base, exph);
|
||||||
}
|
}
|
||||||
|
|
||||||
// loop over the KV cache
|
// loop over the KV cache
|
||||||
// each simdgroup handles blocks of Q rows and C columns
|
// each simdgroup handles blocks of Q rows and C columns
|
||||||
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
|
||||||
const int ic = ic0 + C*sgitg;
|
const int ic = ic0 + C*sgitg;
|
||||||
if (ic >= ne11) {
|
if (ic >= args.ne11) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2896,7 +2876,7 @@ kernel void kernel_flash_attn_ext(
|
|||||||
|
|
||||||
// load the mask in shared memory
|
// load the mask in shared memory
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);
|
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
|
||||||
|
|
||||||
const half m = pm[ic + tiisg];
|
const half m = pm[ic + tiisg];
|
||||||
|
|
||||||
@ -2919,18 +2899,18 @@ kernel void kernel_flash_attn_ext(
|
|||||||
// this is compile-time check, so it does not have runtime overhead
|
// this is compile-time check, so it does not have runtime overhead
|
||||||
if (is_same<kd4x4_t, k4x4_t>::value) {
|
if (is_same<kd4x4_t, k4x4_t>::value) {
|
||||||
// we can read directly from global memory
|
// we can read directly from global memory
|
||||||
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (short i = 0; i < D8; ++i) {
|
for (short i = 0; i < D8; ++i) {
|
||||||
k8x8_t mk;
|
k8x8_t mk;
|
||||||
simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
|
simdgroup_load(mk, pk + i*8, args.nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
|
||||||
|
|
||||||
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (short ii = 0; ii < D16; ii += 4) {
|
for (short ii = 0; ii < D16; ii += 4) {
|
||||||
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
|
||||||
|
|
||||||
if (D16%4 == 0) {
|
if (D16%4 == 0) {
|
||||||
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
|
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
|
||||||
@ -2989,10 +2969,10 @@ kernel void kernel_flash_attn_ext(
|
|||||||
const half m = M[j];
|
const half m = M[j];
|
||||||
|
|
||||||
// scale and apply the logitcap / mask
|
// scale and apply the logitcap / mask
|
||||||
half s = ss[j*TS + tiisg]*scale;
|
half s = ss[j*TS + tiisg]*args.scale;
|
||||||
|
|
||||||
if (logit_softcap != 0.0f) {
|
if (args.logit_softcap != 0.0f) {
|
||||||
s = logit_softcap*precise::tanh(s);
|
s = args.logit_softcap*precise::tanh(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
// mqk = mqk + mask*slope
|
// mqk = mqk + mask*slope
|
||||||
@ -3034,17 +3014,17 @@ kernel void kernel_flash_attn_ext(
|
|||||||
|
|
||||||
if (is_same<vd4x4_t, v4x4_t>::value) {
|
if (is_same<vd4x4_t, v4x4_t>::value) {
|
||||||
// we can read directly from global memory
|
// we can read directly from global memory
|
||||||
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (short i = 0; i < D8; ++i) {
|
for (short i = 0; i < D8; ++i) {
|
||||||
v8x8_t mv;
|
v8x8_t mv;
|
||||||
simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
|
simdgroup_load(mv, pv + i*8, args.nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
|
||||||
|
|
||||||
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
|
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (short ii = 0; ii < D16; ii += 4) {
|
for (short ii = 0; ii < D16; ii += 4) {
|
||||||
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
|
||||||
|
|
||||||
if (D16%4 == 0) {
|
if (D16%4 == 0) {
|
||||||
// no need for bound checks
|
// no need for bound checks
|
||||||
@ -3172,11 +3152,11 @@ kernel void kernel_flash_attn_ext(
|
|||||||
|
|
||||||
// final rescale with 1/S and store to global memory
|
// final rescale with 1/S and store to global memory
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
|
for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
|
||||||
const float S = ss[j*TS + 0];
|
const float S = ss[j*TS + 0];
|
||||||
|
|
||||||
for (short i = tiisg; i < D4; i += NW) {
|
for (short i = tiisg; i < D4; i += NW) {
|
||||||
dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
|
dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3273,33 +3253,13 @@ kernel void kernel_flash_attn_ext_vec(
|
|||||||
device const char * v,
|
device const char * v,
|
||||||
device const char * mask,
|
device const char * mask,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int32_t & ne01,
|
constant ggml_metal_kargs_flash_attn_ext & args,
|
||||||
constant int32_t & ne02,
|
|
||||||
constant int32_t & ne03,
|
|
||||||
constant uint32_t & nb01,
|
|
||||||
constant uint32_t & nb02,
|
|
||||||
constant uint32_t & nb03,
|
|
||||||
constant int32_t & ne11,
|
|
||||||
constant int32_t & ne_12_2, // assume K and V are same shape
|
|
||||||
constant int32_t & ne_12_3,
|
|
||||||
constant uint32_t & nb_12_1,
|
|
||||||
constant uint32_t & nb_12_2,
|
|
||||||
constant uint32_t & nb_12_3,
|
|
||||||
constant uint32_t & nb31,
|
|
||||||
constant int32_t & ne1,
|
|
||||||
constant int32_t & ne2,
|
|
||||||
constant float & scale,
|
|
||||||
constant float & max_bias,
|
|
||||||
constant float & m0,
|
|
||||||
constant float & m1,
|
|
||||||
constant uint16_t & n_head_log2,
|
|
||||||
constant float & logit_softcap,
|
|
||||||
threadgroup half * shared [[threadgroup(0)]],
|
threadgroup half * shared [[threadgroup(0)]],
|
||||||
ushort3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||||
ushort3 ntg[[threads_per_threadgroup]],
|
ushort3 ntg[[threads_per_threadgroup]],
|
||||||
ushort tiisg[[thread_index_in_simdgroup]],
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
const short nsg = ntg.y; // number of simdgroups
|
const short nsg = ntg.y; // number of simdgroups
|
||||||
|
|
||||||
const int iq3 = tgpig[2];
|
const int iq3 = tgpig[2];
|
||||||
@ -3326,10 +3286,10 @@ kernel void kernel_flash_attn_ext_vec(
|
|||||||
o4x4_t lo[D16/NL];
|
o4x4_t lo[D16/NL];
|
||||||
|
|
||||||
// load heads from Q to shared memory
|
// load heads from Q to shared memory
|
||||||
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
|
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
|
||||||
|
|
||||||
for (short i = tiisg; i < D4; i += NW) {
|
for (short i = tiisg; i < D4; i += NW) {
|
||||||
if (iq1 < ne01) {
|
if (iq1 < args.ne01) {
|
||||||
sq4[i] = (q4_t) q4[i];
|
sq4[i] = (q4_t) q4[i];
|
||||||
} else {
|
} else {
|
||||||
sq4[i] = (q4_t) 0.0f;
|
sq4[i] = (q4_t) 0.0f;
|
||||||
@ -3357,11 +3317,11 @@ kernel void kernel_flash_attn_ext_vec(
|
|||||||
const short ty = tiisg/NL;
|
const short ty = tiisg/NL;
|
||||||
|
|
||||||
// broadcast kv
|
// broadcast kv
|
||||||
//const short rk2 = ne02/ne12;
|
//const short rk2 = args.ne02/args.ne12;
|
||||||
//const short rk3 = ne03/ne13;
|
//const short rk3 = args.ne03/args.ne13;
|
||||||
|
|
||||||
const short ikv2 = iq2/(ne02/ne_12_2);
|
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
|
||||||
const short ikv3 = iq3/(ne03/ne_12_3);
|
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
|
||||||
|
|
||||||
// load the queries from shared memory into local memory
|
// load the queries from shared memory into local memory
|
||||||
q4x4_t mq[D16/NL];
|
q4x4_t mq[D16/NL];
|
||||||
@ -3373,25 +3333,25 @@ kernel void kernel_flash_attn_ext_vec(
|
|||||||
const bool has_mask = mask != q;
|
const bool has_mask = mask != q;
|
||||||
|
|
||||||
// pointer to the mask
|
// pointer to the mask
|
||||||
device const half * pm = (device const half *) (mask + iq1*nb31);
|
device const half * pm = (device const half *) (mask + iq1*args.nb31);
|
||||||
|
|
||||||
half slope = 1.0f;
|
half slope = 1.0f;
|
||||||
|
|
||||||
// ALiBi
|
// ALiBi
|
||||||
if (max_bias > 0.0f) {
|
if (args.max_bias > 0.0f) {
|
||||||
const short h = iq2;
|
const short h = iq2;
|
||||||
|
|
||||||
const half base = h < n_head_log2 ? m0 : m1;
|
const half base = h < args.n_head_log2 ? args.m0 : args.m1;
|
||||||
const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
||||||
|
|
||||||
slope = pow(base, exph);
|
slope = pow(base, exph);
|
||||||
}
|
}
|
||||||
|
|
||||||
// loop over the KV cache
|
// loop over the KV cache
|
||||||
// each simdgroup handles blocks of Q rows and C columns
|
// each simdgroup handles blocks of Q rows and C columns
|
||||||
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
|
||||||
const int ic = ic0 + C*sgitg;
|
const int ic = ic0 + C*sgitg;
|
||||||
if (ic >= ne11) {
|
if (ic >= args.ne11) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3405,7 +3365,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|||||||
for (short cc = 0; cc < C/4; ++cc) {
|
for (short cc = 0; cc < C/4; ++cc) {
|
||||||
qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
|
qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
|
||||||
|
|
||||||
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (short ii = 0; ii < D16; ii += NL) {
|
for (short ii = 0; ii < D16; ii += NL) {
|
||||||
@ -3435,10 +3395,10 @@ kernel void kernel_flash_attn_ext_vec(
|
|||||||
|
|
||||||
// mqk = mqk*scale + mask*slope
|
// mqk = mqk*scale + mask*slope
|
||||||
if (tx == 0) {
|
if (tx == 0) {
|
||||||
mqk *= scale;
|
mqk *= args.scale;
|
||||||
|
|
||||||
if (logit_softcap != 0.0f) {
|
if (args.logit_softcap != 0.0f) {
|
||||||
mqk = logit_softcap*precise::tanh(mqk);
|
mqk = args.logit_softcap*precise::tanh(mqk);
|
||||||
}
|
}
|
||||||
|
|
||||||
mqk += sm[4*cc + ty]*slope;
|
mqk += sm[4*cc + ty]*slope;
|
||||||
@ -3478,7 +3438,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|||||||
{
|
{
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (short cc = 0; cc < C/4; ++cc) {
|
for (short cc = 0; cc < C/4; ++cc) {
|
||||||
device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
|
||||||
|
|
||||||
const s4x4_t ms(ss[4*cc + ty]);
|
const s4x4_t ms(ss[4*cc + ty]);
|
||||||
|
|
||||||
@ -3583,7 +3543,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|||||||
const float S = ss[0];
|
const float S = ss[0];
|
||||||
|
|
||||||
for (short i = tiisg; i < D16; i += NW) {
|
for (short i = tiisg; i < D16; i += NW) {
|
||||||
dst44[((int64_t)iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
|
dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1)*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user