remove inner if mask

This commit is contained in:
Georgi Gerganov 2024-11-07 16:40:29 +02:00
parent 61d05b57d9
commit a75cdcca60
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -2834,7 +2834,6 @@ kernel void kernel_flash_attn_ext(
constant float & logit_softcap,
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -2981,7 +2980,7 @@ kernel void kernel_flash_attn_ext(
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
// this is compile-time check, so it does not have runtime overhead
if constexpr (is_same<kd4x4_t, k4x4_t>::value) {
if (is_same<kd4x4_t, k4x4_t>::value) {
// we can read directly from global memory
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
@ -2996,7 +2995,7 @@ kernel void kernel_flash_attn_ext(
for (short ii = 0; ii < D16; ii += 4) {
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
if constexpr (D16%4 == 0) {
if (D16%4 == 0) {
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
{
k4x4_t tmp;
@ -3038,15 +3037,10 @@ kernel void kernel_flash_attn_ext(
}
}
if constexpr (is_same<qk_t, s_t>::value) {
// same type - store directly
simdgroup_store(mqk, ss + 8*cc, TS, 0, false);
} else {
// cast qk_t -> s_t
s8x8_t mqks(1.0f);
simdgroup_multiply(mqks, mqk, mqks);
simdgroup_store(mqks, ss + 8*cc, TS, 0, false);
}
// cast qk_t -> s_t
s8x8_t mqks(1.0f);
simdgroup_multiply(mqks, mqk, mqks);
simdgroup_store(mqks, ss + 8*cc, TS, 0, false);
}
}
@ -3062,11 +3056,8 @@ kernel void kernel_flash_attn_ext(
s = logit_softcap*precise::tanh(s);
}
if (has_mask) {
// mqk = mqk + mask*slope
//s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; // TODO: use ne30
s += slope*ss[j*TS + C + tiisg];
}
// mqk = mqk + mask*slope
s += slope*ss[j*TS + C + tiisg];
M[j] = simd_max(max(M[j], s));
@ -3078,6 +3069,7 @@ kernel void kernel_flash_attn_ext(
// the P matrix from the paper (Q rows, C columns)
ss[j*TS + tiisg] = vs;
// create a QxQ diagonal matrix for rescaling the output
if (tiisg == j) {
ss[j*TS + 2*C + j] = ms;
}
@ -3101,7 +3093,7 @@ kernel void kernel_flash_attn_ext(
s8x8_t ms;
simdgroup_load(ms, ss + 8*cc, TS, 0, false);
if constexpr (is_same<vd4x4_t, v4x4_t>::value) {
if (is_same<vd4x4_t, v4x4_t>::value) {
// we can read directly from global memory
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
#pragma unroll
@ -3115,7 +3107,7 @@ kernel void kernel_flash_attn_ext(
for (short ii = 0; ii < D16; ii += 4) {
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
if constexpr (D16%4 == 0) {
if (D16%4 == 0) {
// no need for bound checks
{
v4x4_t tmp;