From a75cdcca60f58ed97609dce9ae50f8cea47131d1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 7 Nov 2024 16:40:29 +0200 Subject: [PATCH] remove inner if mask --- ggml/src/ggml-metal.metal | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index ebce329a9..bd09f8271 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2834,7 +2834,6 @@ kernel void kernel_flash_attn_ext( constant float & logit_softcap, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -2981,7 +2980,7 @@ kernel void kernel_flash_attn_ext( qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); // this is compile-time check, so it does not have runtime overhead - if constexpr (is_same::value) { + if (is_same::value) { // we can read directly from global memory device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); @@ -2996,7 +2995,7 @@ kernel void kernel_flash_attn_ext( for (short ii = 0; ii < D16; ii += 4) { device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13)); - if constexpr (D16%4 == 0) { + if (D16%4 == 0) { // the head is evenly divisible by 4*16 = 64, so no need for bound checks { k4x4_t tmp; @@ -3038,15 +3037,10 @@ kernel void kernel_flash_attn_ext( } } - if constexpr (is_same::value) { - // same type - store directly - simdgroup_store(mqk, ss + 8*cc, TS, 0, false); - } else { - // cast qk_t -> s_t - s8x8_t mqks(1.0f); - simdgroup_multiply(mqks, mqk, mqks); - simdgroup_store(mqks, ss + 8*cc, TS, 0, false); - } + // cast qk_t -> s_t + s8x8_t mqks(1.0f); + simdgroup_multiply(mqks, mqk, mqks); + simdgroup_store(mqks, ss + 8*cc, TS, 0, false); } } @@ -3062,11 +3056,8 @@ kernel void kernel_flash_attn_ext( s = logit_softcap*precise::tanh(s); } - if (has_mask) { - // mqk = mqk + mask*slope - //s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; // TODO: use ne30 - s += slope*ss[j*TS + C + tiisg]; - } + // mqk = mqk + mask*slope + s += slope*ss[j*TS + C + tiisg]; M[j] = simd_max(max(M[j], s)); @@ -3078,6 +3069,7 @@ kernel void kernel_flash_attn_ext( // the P matrix from the paper (Q rows, C columns) ss[j*TS + tiisg] = vs; + // create a QxQ diagonal matrix for rescaling the output if (tiisg == j) { ss[j*TS + 2*C + j] = ms; } @@ -3101,7 +3093,7 @@ kernel void kernel_flash_attn_ext( s8x8_t ms; simdgroup_load(ms, ss + 8*cc, TS, 0, false); - if constexpr (is_same::value) { + if (is_same::value) { // we can read directly from global memory device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); #pragma unroll @@ -3115,7 +3107,7 @@ kernel void kernel_flash_attn_ext( for (short ii = 0; ii < D16; ii += 4) { device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23)); - if constexpr (D16%4 == 0) { + if (D16%4 == 0) { // no need for bound checks { v4x4_t tmp;