metal : fix build and some more comments

This commit is contained in:
Georgi Gerganov 2024-11-09 10:09:50 +02:00
parent 5b359bb1e3
commit bd1198a67a
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 6 additions and 4 deletions

View File

@ -3046,6 +3046,8 @@ static void ggml_metal_encode_node(
bool use_vec_kernel = false;
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
// for now avoiding mainly to keep the number of templates/kernels a bit lower
if (ne01 >= 4 || (ne00%128 != 0)) {
switch (src1->type) {
case GGML_TYPE_F16:

View File

@ -3356,7 +3356,7 @@ kernel void kernel_flash_attn_ext_vec(
const short D4 = D/4;
const short D16 = D/16;
const short NW = N_SIMDWIDTH;
const short NL = NW/4;
const short NL = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
const short SH = 2*C; // shared memory per simdgroup
const short T = D + nsg*SH; // shared memory size per query in (half)
@ -3448,7 +3448,7 @@ kernel void kernel_flash_attn_ext_vec(
// Q*K^T
{
// each simdgroup processes 1 query and 4 keys
// each simdgroup processes 1 query and 4 (NW/NL) keys
for (short cc = 0; cc < C/4; ++cc) {
qk_t mqk = 0.0;
@ -3645,7 +3645,7 @@ kernel void kernel_flash_attn_ext_vec(
half, half4, half4x4, \
half4x4
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>) flash_attn_ext_vec_t;
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
#if defined(GGML_METAL_USE_BF16)