From 40ea8cd1aca61294e1987bcb1051317827f1b145 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 16:31:39 +0200 Subject: [PATCH] metal : fix comment --- ggml-metal.metal | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index ae8f5caea..9ab9e16c3 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); -template // head size, heads per threadgroup, queries per threadgroup +template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2031,16 +2031,12 @@ kernel void kernel_flash_attn_ext_f16( uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const uint nsg = ntg.y; // number of simdgroups + const uint nsg = ntg.y; // number of simdgroups const int64_t iq3 = tgpig[2]; const int64_t iq2 = tgpig[1]; const int64_t iq1 = tgpig[0]*Q; - if (iq2 >= ne02) { - return; - } - const int64_t D4 = D/4; const int64_t N4 = N_SIMDWIDTH; const int64_t L4 = (D4 + N4 - 1)/N4;