From aa931d0375798700ea2dd441f810d1987350bb8b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 26 Aug 2024 13:02:36 +0300 Subject: [PATCH] metal : fix fa kernel --- ggml/src/ggml-metal.metal | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index aba0b9a03..afa29caa3 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2144,19 +2144,26 @@ kernel void kernel_flash_attn_ext_f16( const short tx = tiisg%4; const short ty = tiisg/4; - // mqk = mqk*scale - ss[8*cc + ty*TF + 2*tx + 0] *= scale; - ss[8*cc + ty*TF + 2*tx + 1] *= scale; - - if (logit_softcap != 0.0f) { - ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]); - ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]); - } - - if (mask != q) { - // mqk = mqk + mask*slope - ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0]; - ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1]; + if (logit_softcap == 0.0f) { + if (mask != q) { + // mqk = mqk*scale + mask*slope + ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0]; + ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1]; + } else { + // mqk = mqk*scale + ss[8*cc + ty*TF + 2*tx + 0] *= scale; + ss[8*cc + ty*TF + 2*tx + 1] *= scale; + } + } else { + if (mask != q) { + // mqk = ls*tanh(mqk*scale) + mask*slope + ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(scale*ss[8*cc + ty*TF + 2*tx + 0]) + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0]; + ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(scale*ss[8*cc + ty*TF + 2*tx + 1]) + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1]; + } else { + // mqk = ls*tanh(mqk*scale) + ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(scale*ss[8*cc + ty*TF + 2*tx + 0]); + ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(scale*ss[8*cc + ty*TF + 2*tx + 1]); + } } } }