diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 7ceb32417..24dd523dc 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2898,8 +2898,8 @@ kernel void kernel_flash_attn_ext( threadgroup_barrier(mem_flags::mem_threadgroup); { - float S[Q] = { [0 ... Q-1] = 0.0f }; - float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 }; + half S[Q] = { [0 ... Q-1] = 0.0f }; + half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 }; // thread indices inside the simdgroup // TODO: see if we can utilize quad-group functions for better performance @@ -2934,14 +2934,14 @@ kernel void kernel_flash_attn_ext( const bool has_mask = mask != q; - float slope = 1.0f; + half slope = 1.0f; // ALiBi if (max_bias > 0.0f) { - const uint32_t h = iq2; + const short h = iq2; - const float base = h < n_head_log2 ? m0 : m1; - const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const half base = h < n_head_log2 ? m0 : m1; + const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; slope = pow(base, exph); } @@ -3047,10 +3047,10 @@ kernel void kernel_flash_attn_ext( // online softmax { for (short j = 0; j < Q; ++j) { - const float m = M[j]; + const half m = M[j]; // scale and apply the logitcap / mask - float s = ss[j*TS + tiisg]*scale; + half s = ss[j*TS + tiisg]*scale; if (logit_softcap != 0.0f) { s = logit_softcap*precise::tanh(s); @@ -3061,8 +3061,8 @@ kernel void kernel_flash_attn_ext( M[j] = simd_max(max(M[j], s)); - const float ms = exp(m - M[j]); - const float vs = exp(s - M[j]); + const half ms = exp(m - M[j]); + const half vs = exp(s - M[j]); S[j] = S[j]*ms + simd_sum(vs); @@ -3163,8 +3163,8 @@ kernel void kernel_flash_attn_ext( // reduce the warps sequentially for (short sg = 1; sg < nsg; ++sg) { - float S = { 0.0f }; - float M = { -FLT_MAX/2 }; + half S = { 0.0f }; + half M = { -__FLT16_MAX__/2 }; threadgroup_barrier(mem_flags::mem_threadgroup); @@ -3180,16 +3180,16 @@ kernel void kernel_flash_attn_ext( // the first simdgroup accumulates the results from the other simdgroups if (sgitg == 0) { for (short j = 0; j < Q; ++j) { - const float S0 = ss[j*TS + 0]; - const float S1 = ss[j*TS + sg*SH + 0]; + const half S0 = ss[j*TS + 0]; + const half S1 = ss[j*TS + sg*SH + 0]; - const float M0 = ss[j*TS + 1]; - const float M1 = ss[j*TS + sg*SH + 1]; + const half M0 = ss[j*TS + 1]; + const half M1 = ss[j*TS + sg*SH + 1]; M = max(M0, M1); - const float ms0 = exp(M0 - M); - const float ms1 = exp(M1 - M); + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); S = S0*ms0 + S1*ms1;