mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-03 23:34:35 +00:00
metal : fattn args
ggml-ci
This commit is contained in:
parent
996e479780
commit
089404f3a1
@ -446,6 +446,30 @@ typedef struct {
|
||||
float beta_fast;
|
||||
float beta_slow;
|
||||
} ggml_metal_kargs_rope;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne11;
|
||||
int32_t ne_12_2; // assume K and V are same shape
|
||||
int32_t ne_12_3;
|
||||
uint64_t nb_12_1;
|
||||
uint64_t nb_12_2;
|
||||
uint64_t nb_12_3;
|
||||
uint64_t nb31;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
float m1;
|
||||
uint16_t n_head_log2;
|
||||
float logit_softcap;
|
||||
} ggml_metal_kargs_flash_attn_ext;
|
||||
#endif
|
||||
|
||||
#endif // GGML_COMMON_DECL
|
||||
|
@ -3228,37 +3228,41 @@ static void ggml_metal_encode_node(
|
||||
}
|
||||
}
|
||||
|
||||
ggml_metal_kargs_flash_attn_ext args = {
|
||||
.ne01 = ne01,
|
||||
.ne02 = ne02,
|
||||
.ne03 = ne03,
|
||||
.nb01 = nb01,
|
||||
.nb02 = nb02,
|
||||
.nb03 = nb03,
|
||||
.ne11 = ne11,
|
||||
.ne_12_2 = ne12,
|
||||
.ne_12_3 = ne13,
|
||||
.nb_12_1 = nb11,
|
||||
.nb_12_2 = nb12,
|
||||
.nb_12_3 = nb13,
|
||||
.nb31 = nb31,
|
||||
.ne1 = ne1,
|
||||
.ne2 = ne2,
|
||||
.scale = scale,
|
||||
.max_bias = max_bias,
|
||||
.m0 = m0,
|
||||
.m1 = m1,
|
||||
.n_head_log2 = n_head_log2,
|
||||
.logit_softcap = logit_softcap,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||
if (id_src3) {
|
||||
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
||||
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
||||
} else {
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
|
||||
}
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
||||
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
||||
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
||||
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
|
||||
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
|
||||
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
|
||||
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
|
||||
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
|
||||
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
|
||||
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17];
|
||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18];
|
||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19];
|
||||
[encoder setBytes:&scale length:sizeof( float) atIndex:20];
|
||||
[encoder setBytes:&max_bias length:sizeof( float) atIndex:21];
|
||||
[encoder setBytes:&m0 length:sizeof(m0) atIndex:22];
|
||||
[encoder setBytes:&m1 length:sizeof(m1) atIndex:23];
|
||||
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24];
|
||||
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:5];
|
||||
|
||||
if (!use_vec_kernel) {
|
||||
// half8x8 kernel
|
||||
|
@ -2267,9 +2267,9 @@ kernel void kernel_rope_norm(
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
constant ggml_metal_kargs_rope & args,
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint3 tptg [[threads_per_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 tptg [[threads_per_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
const int i3 = tgpig[2];
|
||||
const int i2 = tgpig[1];
|
||||
const int i1 = tgpig[0];
|
||||
@ -2320,9 +2320,9 @@ kernel void kernel_rope_neox(
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
constant ggml_metal_kargs_rope & args,
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint3 tptg[[threads_per_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 tptg [[threads_per_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
const int i3 = tgpig[2];
|
||||
const int i2 = tgpig[1];
|
||||
const int i1 = tgpig[0];
|
||||
@ -2761,32 +2761,12 @@ kernel void kernel_flash_attn_ext(
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device float * dst,
|
||||
constant int32_t & ne01,
|
||||
constant int32_t & ne02,
|
||||
constant int32_t & ne03,
|
||||
constant uint32_t & nb01,
|
||||
constant uint32_t & nb02,
|
||||
constant uint32_t & nb03,
|
||||
constant int32_t & ne11,
|
||||
constant int32_t & ne_12_2, // assume K and V are same shape
|
||||
constant int32_t & ne_12_3,
|
||||
constant uint32_t & nb_12_1,
|
||||
constant uint32_t & nb_12_2,
|
||||
constant uint32_t & nb_12_3,
|
||||
constant uint32_t & nb31,
|
||||
constant int32_t & ne1,
|
||||
constant int32_t & ne2,
|
||||
constant float & scale,
|
||||
constant float & max_bias,
|
||||
constant float & m0,
|
||||
constant float & m1,
|
||||
constant uint16_t & n_head_log2,
|
||||
constant float & logit_softcap,
|
||||
constant ggml_metal_kargs_flash_attn_ext & args,
|
||||
threadgroup half * shared [[threadgroup(0)]],
|
||||
ushort3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 ntg[[threads_per_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 ntg[[threads_per_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
const short nsg = ntg.y; // number of simdgroups
|
||||
|
||||
const int iq3 = tgpig[2];
|
||||
@ -2819,10 +2799,10 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
// load heads from Q to shared memory
|
||||
for (short j = sgitg; j < Q; j += nsg) {
|
||||
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
||||
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
|
||||
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
if (iq1 + j < ne01) {
|
||||
if (iq1 + j < args.ne01) {
|
||||
sq4[j*D4 + i] = (q4_t) q4[i];
|
||||
} else {
|
||||
sq4[j*D4 + i] = (q4_t) 0.0f;
|
||||
@ -2855,11 +2835,11 @@ kernel void kernel_flash_attn_ext(
|
||||
const short ty = tiisg/4;
|
||||
|
||||
// broadcast kv
|
||||
//const short rk2 = ne02/ne12;
|
||||
//const short rk3 = ne03/ne13;
|
||||
//const short rk2 = args.ne02/args.ne12;
|
||||
//const short rk3 = args.ne03/args.ne13;
|
||||
|
||||
const short ikv2 = iq2/(ne02/ne_12_2);
|
||||
const short ikv3 = iq3/(ne03/ne_12_3);
|
||||
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
|
||||
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
|
||||
|
||||
// load the queries from shared memory into local memory
|
||||
q8x8_t mq[D8];
|
||||
@ -2873,20 +2853,20 @@ kernel void kernel_flash_attn_ext(
|
||||
half slope = 1.0f;
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
if (args.max_bias > 0.0f) {
|
||||
const short h = iq2;
|
||||
|
||||
const half base = h < n_head_log2 ? m0 : m1;
|
||||
const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
const half base = h < args.n_head_log2 ? args.m0 : args.m1;
|
||||
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
||||
|
||||
slope = pow(base, exph);
|
||||
}
|
||||
|
||||
// loop over the KV cache
|
||||
// each simdgroup handles blocks of Q rows and C columns
|
||||
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
||||
for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
|
||||
const int ic = ic0 + C*sgitg;
|
||||
if (ic >= ne11) {
|
||||
if (ic >= args.ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
@ -2896,7 +2876,7 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
// load the mask in shared memory
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);
|
||||
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
|
||||
|
||||
const half m = pm[ic + tiisg];
|
||||
|
||||
@ -2919,18 +2899,18 @@ kernel void kernel_flash_attn_ext(
|
||||
// this is compile-time check, so it does not have runtime overhead
|
||||
if (is_same<kd4x4_t, k4x4_t>::value) {
|
||||
// we can read directly from global memory
|
||||
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
||||
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
|
||||
|
||||
#pragma unroll
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
k8x8_t mk;
|
||||
simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
|
||||
simdgroup_load(mk, pk + i*8, args.nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
|
||||
|
||||
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
||||
}
|
||||
} else {
|
||||
for (short ii = 0; ii < D16; ii += 4) {
|
||||
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
||||
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
|
||||
|
||||
if (D16%4 == 0) {
|
||||
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
|
||||
@ -2989,10 +2969,10 @@ kernel void kernel_flash_attn_ext(
|
||||
const half m = M[j];
|
||||
|
||||
// scale and apply the logitcap / mask
|
||||
half s = ss[j*TS + tiisg]*scale;
|
||||
half s = ss[j*TS + tiisg]*args.scale;
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
s = logit_softcap*precise::tanh(s);
|
||||
if (args.logit_softcap != 0.0f) {
|
||||
s = args.logit_softcap*precise::tanh(s);
|
||||
}
|
||||
|
||||
// mqk = mqk + mask*slope
|
||||
@ -3034,17 +3014,17 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
if (is_same<vd4x4_t, v4x4_t>::value) {
|
||||
// we can read directly from global memory
|
||||
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
||||
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
|
||||
#pragma unroll
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
v8x8_t mv;
|
||||
simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
|
||||
simdgroup_load(mv, pv + i*8, args.nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
|
||||
|
||||
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
|
||||
}
|
||||
} else {
|
||||
for (short ii = 0; ii < D16; ii += 4) {
|
||||
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
||||
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
|
||||
|
||||
if (D16%4 == 0) {
|
||||
// no need for bound checks
|
||||
@ -3172,11 +3152,11 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
// final rescale with 1/S and store to global memory
|
||||
if (sgitg == 0) {
|
||||
for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
|
||||
for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
|
||||
const float S = ss[j*TS + 0];
|
||||
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
|
||||
dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3273,33 +3253,13 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device float * dst,
|
||||
constant int32_t & ne01,
|
||||
constant int32_t & ne02,
|
||||
constant int32_t & ne03,
|
||||
constant uint32_t & nb01,
|
||||
constant uint32_t & nb02,
|
||||
constant uint32_t & nb03,
|
||||
constant int32_t & ne11,
|
||||
constant int32_t & ne_12_2, // assume K and V are same shape
|
||||
constant int32_t & ne_12_3,
|
||||
constant uint32_t & nb_12_1,
|
||||
constant uint32_t & nb_12_2,
|
||||
constant uint32_t & nb_12_3,
|
||||
constant uint32_t & nb31,
|
||||
constant int32_t & ne1,
|
||||
constant int32_t & ne2,
|
||||
constant float & scale,
|
||||
constant float & max_bias,
|
||||
constant float & m0,
|
||||
constant float & m1,
|
||||
constant uint16_t & n_head_log2,
|
||||
constant float & logit_softcap,
|
||||
constant ggml_metal_kargs_flash_attn_ext & args,
|
||||
threadgroup half * shared [[threadgroup(0)]],
|
||||
ushort3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
const short nsg = ntg.y; // number of simdgroups
|
||||
|
||||
const int iq3 = tgpig[2];
|
||||
@ -3326,10 +3286,10 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
o4x4_t lo[D16/NL];
|
||||
|
||||
// load heads from Q to shared memory
|
||||
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
|
||||
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
|
||||
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
if (iq1 < ne01) {
|
||||
if (iq1 < args.ne01) {
|
||||
sq4[i] = (q4_t) q4[i];
|
||||
} else {
|
||||
sq4[i] = (q4_t) 0.0f;
|
||||
@ -3357,11 +3317,11 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
const short ty = tiisg/NL;
|
||||
|
||||
// broadcast kv
|
||||
//const short rk2 = ne02/ne12;
|
||||
//const short rk3 = ne03/ne13;
|
||||
//const short rk2 = args.ne02/args.ne12;
|
||||
//const short rk3 = args.ne03/args.ne13;
|
||||
|
||||
const short ikv2 = iq2/(ne02/ne_12_2);
|
||||
const short ikv3 = iq3/(ne03/ne_12_3);
|
||||
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
|
||||
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
|
||||
|
||||
// load the queries from shared memory into local memory
|
||||
q4x4_t mq[D16/NL];
|
||||
@ -3373,25 +3333,25 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
const bool has_mask = mask != q;
|
||||
|
||||
// pointer to the mask
|
||||
device const half * pm = (device const half *) (mask + iq1*nb31);
|
||||
device const half * pm = (device const half *) (mask + iq1*args.nb31);
|
||||
|
||||
half slope = 1.0f;
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
if (args.max_bias > 0.0f) {
|
||||
const short h = iq2;
|
||||
|
||||
const half base = h < n_head_log2 ? m0 : m1;
|
||||
const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
const half base = h < args.n_head_log2 ? args.m0 : args.m1;
|
||||
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
||||
|
||||
slope = pow(base, exph);
|
||||
}
|
||||
|
||||
// loop over the KV cache
|
||||
// each simdgroup handles blocks of Q rows and C columns
|
||||
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
||||
for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
|
||||
const int ic = ic0 + C*sgitg;
|
||||
if (ic >= ne11) {
|
||||
if (ic >= args.ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
@ -3405,7 +3365,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
for (short cc = 0; cc < C/4; ++cc) {
|
||||
qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
|
||||
|
||||
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
||||
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
|
||||
|
||||
#pragma unroll
|
||||
for (short ii = 0; ii < D16; ii += NL) {
|
||||
@ -3435,10 +3395,10 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
|
||||
// mqk = mqk*scale + mask*slope
|
||||
if (tx == 0) {
|
||||
mqk *= scale;
|
||||
mqk *= args.scale;
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
mqk = logit_softcap*precise::tanh(mqk);
|
||||
if (args.logit_softcap != 0.0f) {
|
||||
mqk = args.logit_softcap*precise::tanh(mqk);
|
||||
}
|
||||
|
||||
mqk += sm[4*cc + ty]*slope;
|
||||
@ -3478,7 +3438,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
{
|
||||
#pragma unroll
|
||||
for (short cc = 0; cc < C/4; ++cc) {
|
||||
device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
||||
device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
|
||||
|
||||
const s4x4_t ms(ss[4*cc + ty]);
|
||||
|
||||
@ -3583,7 +3543,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
const float S = ss[0];
|
||||
|
||||
for (short i = tiisg; i < D16; i += NW) {
|
||||
dst44[((int64_t)iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
|
||||
dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1)*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user