mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 04:44:34 +00:00
metal : fix comment
This commit is contained in:
parent
432ad04ffa
commit
40ea8cd1ac
@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)(
|
|||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]);
|
uint sgitg[[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
template<int64_t D, int64_t Q, int64_t C> // head size, heads per threadgroup, queries per threadgroup
|
template<int64_t D, int64_t Q, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
|
||||||
kernel void kernel_flash_attn_ext_f16(
|
kernel void kernel_flash_attn_ext_f16(
|
||||||
device const char * q,
|
device const char * q,
|
||||||
device const char * k,
|
device const char * k,
|
||||||
@ -2031,16 +2031,12 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
uint3 ntg[[threads_per_threadgroup]],
|
uint3 ntg[[threads_per_threadgroup]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
const uint nsg = ntg.y; // number of simdgroups
|
const uint nsg = ntg.y; // number of simdgroups
|
||||||
|
|
||||||
const int64_t iq3 = tgpig[2];
|
const int64_t iq3 = tgpig[2];
|
||||||
const int64_t iq2 = tgpig[1];
|
const int64_t iq2 = tgpig[1];
|
||||||
const int64_t iq1 = tgpig[0]*Q;
|
const int64_t iq1 = tgpig[0]*Q;
|
||||||
|
|
||||||
if (iq2 >= ne02) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int64_t D4 = D/4;
|
const int64_t D4 = D/4;
|
||||||
const int64_t N4 = N_SIMDWIDTH;
|
const int64_t N4 = N_SIMDWIDTH;
|
||||||
const int64_t L4 = (D4 + N4 - 1)/N4;
|
const int64_t L4 = (D4 + N4 - 1)/N4;
|
||||||
|
Loading…
Reference in New Issue
Block a user