mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-03 23:34:35 +00:00
metal : fix build and some more comments
This commit is contained in:
parent
5b359bb1e3
commit
bd1198a67a
@ -3046,6 +3046,8 @@ static void ggml_metal_encode_node(
|
|||||||
|
|
||||||
bool use_vec_kernel = false;
|
bool use_vec_kernel = false;
|
||||||
|
|
||||||
|
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
|
||||||
|
// for now avoiding mainly to keep the number of templates/kernels a bit lower
|
||||||
if (ne01 >= 4 || (ne00%128 != 0)) {
|
if (ne01 >= 4 || (ne00%128 != 0)) {
|
||||||
switch (src1->type) {
|
switch (src1->type) {
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
|
@ -3356,8 +3356,8 @@ kernel void kernel_flash_attn_ext_vec(
|
|||||||
const short D4 = D/4;
|
const short D4 = D/4;
|
||||||
const short D16 = D/16;
|
const short D16 = D/16;
|
||||||
const short NW = N_SIMDWIDTH;
|
const short NW = N_SIMDWIDTH;
|
||||||
const short NL = NW/4;
|
const short NL = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
|
||||||
const short SH = 2*C; // shared memory per simdgroup
|
const short SH = 2*C; // shared memory per simdgroup
|
||||||
|
|
||||||
const short T = D + nsg*SH; // shared memory size per query in (half)
|
const short T = D + nsg*SH; // shared memory size per query in (half)
|
||||||
|
|
||||||
@ -3448,7 +3448,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|||||||
|
|
||||||
// Q*K^T
|
// Q*K^T
|
||||||
{
|
{
|
||||||
// each simdgroup processes 1 query and 4 keys
|
// each simdgroup processes 1 query and 4 (NW/NL) keys
|
||||||
for (short cc = 0; cc < C/4; ++cc) {
|
for (short cc = 0; cc < C/4; ++cc) {
|
||||||
qk_t mqk = 0.0;
|
qk_t mqk = 0.0;
|
||||||
|
|
||||||
@ -3645,7 +3645,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|||||||
half, half4, half4x4, \
|
half, half4, half4x4, \
|
||||||
half4x4
|
half4x4
|
||||||
|
|
||||||
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
|
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>) flash_attn_ext_vec_t;
|
||||||
|
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
|
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
|
||||||
#if defined(GGML_METAL_USE_BF16)
|
#if defined(GGML_METAL_USE_BF16)
|
||||||
|
Loading…
Reference in New Issue
Block a user