mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 17:21:46 +00:00
remove inner if mask
This commit is contained in:
parent
61d05b57d9
commit
a75cdcca60
@ -2834,7 +2834,6 @@ kernel void kernel_flash_attn_ext(
|
||||
constant float & logit_softcap,
|
||||
threadgroup half * shared [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
@ -2981,7 +2980,7 @@ kernel void kernel_flash_attn_ext(
|
||||
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
|
||||
|
||||
// this is compile-time check, so it does not have runtime overhead
|
||||
if constexpr (is_same<kd4x4_t, k4x4_t>::value) {
|
||||
if (is_same<kd4x4_t, k4x4_t>::value) {
|
||||
// we can read directly from global memory
|
||||
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
|
||||
@ -2996,7 +2995,7 @@ kernel void kernel_flash_attn_ext(
|
||||
for (short ii = 0; ii < D16; ii += 4) {
|
||||
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
|
||||
if constexpr (D16%4 == 0) {
|
||||
if (D16%4 == 0) {
|
||||
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
|
||||
{
|
||||
k4x4_t tmp;
|
||||
@ -3038,15 +3037,10 @@ kernel void kernel_flash_attn_ext(
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (is_same<qk_t, s_t>::value) {
|
||||
// same type - store directly
|
||||
simdgroup_store(mqk, ss + 8*cc, TS, 0, false);
|
||||
} else {
|
||||
// cast qk_t -> s_t
|
||||
s8x8_t mqks(1.0f);
|
||||
simdgroup_multiply(mqks, mqk, mqks);
|
||||
simdgroup_store(mqks, ss + 8*cc, TS, 0, false);
|
||||
}
|
||||
// cast qk_t -> s_t
|
||||
s8x8_t mqks(1.0f);
|
||||
simdgroup_multiply(mqks, mqk, mqks);
|
||||
simdgroup_store(mqks, ss + 8*cc, TS, 0, false);
|
||||
}
|
||||
}
|
||||
|
||||
@ -3062,11 +3056,8 @@ kernel void kernel_flash_attn_ext(
|
||||
s = logit_softcap*precise::tanh(s);
|
||||
}
|
||||
|
||||
if (has_mask) {
|
||||
// mqk = mqk + mask*slope
|
||||
//s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; // TODO: use ne30
|
||||
s += slope*ss[j*TS + C + tiisg];
|
||||
}
|
||||
// mqk = mqk + mask*slope
|
||||
s += slope*ss[j*TS + C + tiisg];
|
||||
|
||||
M[j] = simd_max(max(M[j], s));
|
||||
|
||||
@ -3078,6 +3069,7 @@ kernel void kernel_flash_attn_ext(
|
||||
// the P matrix from the paper (Q rows, C columns)
|
||||
ss[j*TS + tiisg] = vs;
|
||||
|
||||
// create a QxQ diagonal matrix for rescaling the output
|
||||
if (tiisg == j) {
|
||||
ss[j*TS + 2*C + j] = ms;
|
||||
}
|
||||
@ -3101,7 +3093,7 @@ kernel void kernel_flash_attn_ext(
|
||||
s8x8_t ms;
|
||||
simdgroup_load(ms, ss + 8*cc, TS, 0, false);
|
||||
|
||||
if constexpr (is_same<vd4x4_t, v4x4_t>::value) {
|
||||
if (is_same<vd4x4_t, v4x4_t>::value) {
|
||||
// we can read directly from global memory
|
||||
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
#pragma unroll
|
||||
@ -3115,7 +3107,7 @@ kernel void kernel_flash_attn_ext(
|
||||
for (short ii = 0; ii < D16; ii += 4) {
|
||||
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
if constexpr (D16%4 == 0) {
|
||||
if (D16%4 == 0) {
|
||||
// no need for bound checks
|
||||
{
|
||||
v4x4_t tmp;
|
||||
|
Loading…
Reference in New Issue
Block a user