mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 12:24:35 +00:00
metal : fix comment
This commit is contained in:
parent
432ad04ffa
commit
40ea8cd1ac
@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)(
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
template<int64_t D, int64_t Q, int64_t C> // head size, heads per threadgroup, queries per threadgroup
|
||||
template<int64_t D, int64_t Q, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
|
||||
kernel void kernel_flash_attn_ext_f16(
|
||||
device const char * q,
|
||||
device const char * k,
|
||||
@ -2031,16 +2031,12 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
uint3 ntg[[threads_per_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
const uint nsg = ntg.y; // number of simdgroups
|
||||
const uint nsg = ntg.y; // number of simdgroups
|
||||
|
||||
const int64_t iq3 = tgpig[2];
|
||||
const int64_t iq2 = tgpig[1];
|
||||
const int64_t iq1 = tgpig[0]*Q;
|
||||
|
||||
if (iq2 >= ne02) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t D4 = D/4;
|
||||
const int64_t N4 = N_SIMDWIDTH;
|
||||
const int64_t L4 = (D4 + N4 - 1)/N4;
|
||||
|
Loading…
Reference in New Issue
Block a user