From 593bc1aef55e462245764bfa44c7e3bd9bace9b7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 9 Nov 2024 16:39:36 +0200 Subject: [PATCH] metal : cont + avoid potential int overflow [no ci] --- ggml/src/ggml-metal.metal | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index c21ff5b2d..e3631b08d 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2756,11 +2756,11 @@ template< short KV = 8, // key/value processed per each simdgroup short C = 32> // cache items per threadgroup kernel void kernel_flash_attn_ext( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, constant ggml_metal_kargs_flash_attn_ext & args, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], @@ -3156,7 +3156,7 @@ kernel void kernel_flash_attn_ext( const float S = ss[j*TS + 0]; for (short i = tiisg; i < D4; i += NW) { - dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; + dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; } } } @@ -3248,11 +3248,11 @@ template< short Q = 1, // queries per threadgroup short C = 32> // cache items per threadgroup kernel void kernel_flash_attn_ext_vec( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, constant ggml_metal_kargs_flash_attn_ext & args, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], @@ -3543,7 +3543,7 @@ kernel void kernel_flash_attn_ext_vec( const float S = ss[0]; for (short i = tiisg; i < D16; i += NW) { - dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1)*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S; + dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S; } } }