diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 0466f2acd..a5fbed869 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2756,11 +2756,11 @@ template< short KV = 8, // key/value processed per each simdgroup short C = 32> // cache items per threadgroup kernel void kernel_flash_attn_ext( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, constant ggml_metal_kargs_flash_attn_ext & args, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], @@ -3159,7 +3159,7 @@ kernel void kernel_flash_attn_ext( const float S = ss[j*TS + 0]; for (short i = tiisg; i < D4; i += NW) { - dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; + dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; } } } @@ -3251,11 +3251,11 @@ template< short Q = 1, // queries per threadgroup short C = 32> // cache items per threadgroup kernel void kernel_flash_attn_ext_vec( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, constant ggml_metal_kargs_flash_attn_ext & args, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], @@ -3552,7 +3552,7 @@ kernel void kernel_flash_attn_ext_vec( const float S = ss[0]; for (short i = tiisg; i < D16; i += NW) { - dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1)*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S; + dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S; } } }