mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-08 09:41:45 +00:00
remove inner if mask
This commit is contained in:
parent
61d05b57d9
commit
a75cdcca60
@ -2834,7 +2834,6 @@ kernel void kernel_flash_attn_ext(
|
|||||||
constant float & logit_softcap,
|
constant float & logit_softcap,
|
||||||
threadgroup half * shared [[threadgroup(0)]],
|
threadgroup half * shared [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
||||||
uint3 ntg[[threads_per_threadgroup]],
|
uint3 ntg[[threads_per_threadgroup]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
@ -2981,7 +2980,7 @@ kernel void kernel_flash_attn_ext(
|
|||||||
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
|
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
|
||||||
|
|
||||||
// this is compile-time check, so it does not have runtime overhead
|
// this is compile-time check, so it does not have runtime overhead
|
||||||
if constexpr (is_same<kd4x4_t, k4x4_t>::value) {
|
if (is_same<kd4x4_t, k4x4_t>::value) {
|
||||||
// we can read directly from global memory
|
// we can read directly from global memory
|
||||||
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
|
|
||||||
@ -2996,7 +2995,7 @@ kernel void kernel_flash_attn_ext(
|
|||||||
for (short ii = 0; ii < D16; ii += 4) {
|
for (short ii = 0; ii < D16; ii += 4) {
|
||||||
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
|
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
|
|
||||||
if constexpr (D16%4 == 0) {
|
if (D16%4 == 0) {
|
||||||
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
|
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
|
||||||
{
|
{
|
||||||
k4x4_t tmp;
|
k4x4_t tmp;
|
||||||
@ -3038,15 +3037,10 @@ kernel void kernel_flash_attn_ext(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (is_same<qk_t, s_t>::value) {
|
// cast qk_t -> s_t
|
||||||
// same type - store directly
|
s8x8_t mqks(1.0f);
|
||||||
simdgroup_store(mqk, ss + 8*cc, TS, 0, false);
|
simdgroup_multiply(mqks, mqk, mqks);
|
||||||
} else {
|
simdgroup_store(mqks, ss + 8*cc, TS, 0, false);
|
||||||
// cast qk_t -> s_t
|
|
||||||
s8x8_t mqks(1.0f);
|
|
||||||
simdgroup_multiply(mqks, mqk, mqks);
|
|
||||||
simdgroup_store(mqks, ss + 8*cc, TS, 0, false);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3062,11 +3056,8 @@ kernel void kernel_flash_attn_ext(
|
|||||||
s = logit_softcap*precise::tanh(s);
|
s = logit_softcap*precise::tanh(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (has_mask) {
|
// mqk = mqk + mask*slope
|
||||||
// mqk = mqk + mask*slope
|
s += slope*ss[j*TS + C + tiisg];
|
||||||
//s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; // TODO: use ne30
|
|
||||||
s += slope*ss[j*TS + C + tiisg];
|
|
||||||
}
|
|
||||||
|
|
||||||
M[j] = simd_max(max(M[j], s));
|
M[j] = simd_max(max(M[j], s));
|
||||||
|
|
||||||
@ -3078,6 +3069,7 @@ kernel void kernel_flash_attn_ext(
|
|||||||
// the P matrix from the paper (Q rows, C columns)
|
// the P matrix from the paper (Q rows, C columns)
|
||||||
ss[j*TS + tiisg] = vs;
|
ss[j*TS + tiisg] = vs;
|
||||||
|
|
||||||
|
// create a QxQ diagonal matrix for rescaling the output
|
||||||
if (tiisg == j) {
|
if (tiisg == j) {
|
||||||
ss[j*TS + 2*C + j] = ms;
|
ss[j*TS + 2*C + j] = ms;
|
||||||
}
|
}
|
||||||
@ -3101,7 +3093,7 @@ kernel void kernel_flash_attn_ext(
|
|||||||
s8x8_t ms;
|
s8x8_t ms;
|
||||||
simdgroup_load(ms, ss + 8*cc, TS, 0, false);
|
simdgroup_load(ms, ss + 8*cc, TS, 0, false);
|
||||||
|
|
||||||
if constexpr (is_same<vd4x4_t, v4x4_t>::value) {
|
if (is_same<vd4x4_t, v4x4_t>::value) {
|
||||||
// we can read directly from global memory
|
// we can read directly from global memory
|
||||||
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -3115,7 +3107,7 @@ kernel void kernel_flash_attn_ext(
|
|||||||
for (short ii = 0; ii < D16; ii += 4) {
|
for (short ii = 0; ii < D16; ii += 4) {
|
||||||
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
|
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
if constexpr (D16%4 == 0) {
|
if (D16%4 == 0) {
|
||||||
// no need for bound checks
|
// no need for bound checks
|
||||||
{
|
{
|
||||||
v4x4_t tmp;
|
v4x4_t tmp;
|
||||||
|
Loading…
Reference in New Issue
Block a user