metal : fix comment

This commit is contained in:
Georgi Gerganov 2024-01-25 16:31:39 +02:00
parent 432ad04ffa
commit 40ea8cd1ac
No known key found for this signature in database
GPG Key ID: BF970631944C16B7

View File

@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]);
template<int64_t D, int64_t Q, int64_t C> // head size, heads per threadgroup, queries per threadgroup
template<int64_t D, int64_t Q, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
kernel void kernel_flash_attn_ext_f16(
device const char * q,
device const char * k,
@ -2031,16 +2031,12 @@ kernel void kernel_flash_attn_ext_f16(
uint3 ntg[[threads_per_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const uint nsg = ntg.y; // number of simdgroups
const uint nsg = ntg.y; // number of simdgroups
const int64_t iq3 = tgpig[2];
const int64_t iq2 = tgpig[1];
const int64_t iq1 = tgpig[0]*Q;
if (iq2 >= ne02) {
return;
}
const int64_t D4 = D/4;
const int64_t N4 = N_SIMDWIDTH;
const int64_t L4 = (D4 + N4 - 1)/N4;