metal : use F16 precision in FA kernels

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-06 15:33:30 +02:00
parent 25e877309a
commit 7facc29d69
No known key found for this signature in database
GPG Key ID: BF970631944C16B7
7 changed files with 476 additions and 333 deletions

View File

@ -876,6 +876,11 @@ endif # GGML_HIPBLAS
ifdef GGML_METAL ifdef GGML_METAL
MK_CPPFLAGS += -DGGML_USE_METAL MK_CPPFLAGS += -DGGML_USE_METAL
ifdef GGML_METAL_FORCE_FATTN_PREC_F16
MK_CPPFLAGS += -DGGML_METAL_FORCE_FATTN_PREC_F16
endif # GGML_METAL_FORCE_FATTN_PREC_F16
MK_LDFLAGS += -framework Foundation -framework Metal -framework MetalKit MK_LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
OBJ_GGML += ggml/src/ggml-metal.o OBJ_GGML += ggml/src/ggml-metal.o
ifdef GGML_METAL_NDEBUG ifdef GGML_METAL_NDEBUG

View File

@ -256,6 +256,9 @@ static ggml_type ggml_type_from_name(const std::string & s) {
if (s == "f16") { if (s == "f16") {
return GGML_TYPE_F16; return GGML_TYPE_F16;
} }
if (s == "bf16") {
return GGML_TYPE_BF16;
}
if (s == "q8_0") { if (s == "q8_0") {
return GGML_TYPE_Q8_0; return GGML_TYPE_Q8_0;
} }

View File

@ -153,6 +153,7 @@ option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation"
option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF) option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
option(GGML_KOMPUTE "ggml: use Kompute" OFF) option(GGML_KOMPUTE "ggml: use Kompute" OFF)
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT}) option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
option(GGML_METAL_FORCE_FATTN_PREC_F16 "ggml: force F16 accumulators for FA kernels" OFF)
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF) option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF) option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF)
option(GGML_METAL_EMBED_LIBRARY "ggml: embed Metal library" ${GGML_METAL}) option(GGML_METAL_EMBED_LIBRARY "ggml: embed Metal library" ${GGML_METAL})

View File

@ -58,6 +58,10 @@ if (GGML_METAL)
add_compile_definitions(GGML_METAL_NDEBUG) add_compile_definitions(GGML_METAL_NDEBUG)
endif() endif()
if (GGML_METAL_FORCE_FATTN_PREC_F16)
add_compile_definitions(GGML_METAL_FORCE_FATTN_PREC_F16)
endif()
# copy ggml-common.h and ggml-metal.metal to bin directory # copy ggml-common.h and ggml-metal.metal to bin directory
configure_file(ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) configure_file(ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)

View File

@ -269,6 +269,12 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
@ -300,12 +306,14 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
@ -585,6 +593,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \ struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \ id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \ kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
GGML_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
(int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
(int) kernel->pipeline.threadExecutionWidth); \
[metal_function release]; \ [metal_function release]; \
if (error) { \ if (error) { \
GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
@ -777,6 +788,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
@ -808,12 +825,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
@ -1111,7 +1130,7 @@ static void ggml_metal_encode_node(
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
const uint64_t nb21 = src2 ? src2->nb[1] : 0; const uint64_t nb21 = src2 ? src2->nb[1] : 0;
const uint64_t nb22 = src2 ? src2->nb[2] : 0; const uint64_t nb22 = src2 ? src2->nb[2] : 0;
const uint64_t nb23 = src2 ? src2->nb[3] : 0; const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
const int64_t ne0 = dst ? dst->ne[0] : 0; const int64_t ne0 = dst ? dst->ne[0] : 0;
const int64_t ne1 = dst ? dst->ne[1] : 0; const int64_t ne1 = dst ? dst->ne[1] : 0;
@ -3033,6 +3052,23 @@ static void ggml_metal_encode_node(
} }
} }
} break; } break;
case GGML_TYPE_BF16:
{
switch (ne00) {
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
default:
{
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
GGML_LOG_ERROR("add template specialization for this size\n");
GGML_ABORT("add template specialization for this size");
}
}
} break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
{ {
switch (ne00) { switch (ne00) {
@ -3133,6 +3169,7 @@ static void ggml_metal_encode_node(
{ {
switch (src1->type) { switch (src1->type) {
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break; case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break; case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break; case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
@ -3150,6 +3187,7 @@ static void ggml_metal_encode_node(
{ {
switch (src1->type) { switch (src1->type) {
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break; case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break; case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break; case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
@ -3194,18 +3232,15 @@ static void ggml_metal_encode_node(
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17]; [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18]; [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18];
[encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19]; [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19];
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20]; [encoder setBytes:&scale length:sizeof( float) atIndex:20];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21]; [encoder setBytes:&max_bias length:sizeof( float) atIndex:21];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22]; [encoder setBytes:&m0 length:sizeof(m0) atIndex:22];
[encoder setBytes:&scale length:sizeof( float) atIndex:23]; [encoder setBytes:&m1 length:sizeof(m1) atIndex:23];
[encoder setBytes:&max_bias length:sizeof( float) atIndex:24]; [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24];
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25]; [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25];
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28];
if (!use_vec_kernel) { if (!use_vec_kernel) {
// half8x8 kernel // half8x8 kernel
@ -3216,11 +3251,14 @@ static void ggml_metal_encode_node(
GGML_ASSERT(nqptg % 8 == 0); GGML_ASSERT(nqptg % 8 == 0);
GGML_ASSERT(ncpsg % 32 == 0); GGML_ASSERT(ncpsg % 32 == 0);
// 2*(2*ncpsg + nqptg)*(nsg)
// ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
//
// 16*32*(nsg) // 16*32*(nsg)
// the shared memory needed for the simdgroups to load the KV cache // the shared memory needed for the simdgroups to load the KV cache
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
// //
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16)) #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
int64_t nsgmax = 2; int64_t nsgmax = 2;
@ -3256,10 +3294,10 @@ static void ggml_metal_encode_node(
// for each query, we load it as f16 in shared memory (ne00) // for each query, we load it as f16 in shared memory (ne00)
// and store the attention scores (nqptg x ncpsg) as f32 // and store the attention scores (nqptg x ncpsg) as f32
// //
// 2*ne00*(nsg) // ne00*(nsg)
// each simdgroup has a full f32 head vector in shared mem to accumulate results // each simdgroup has a full f16 head vector in shared mem to accumulate results
// //
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16)) #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 4*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
int64_t nsgmax = 2; int64_t nsgmax = 2;

File diff suppressed because it is too large Load Diff

View File

@ -3745,7 +3745,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (int nh : { 32, }) { for (int nh : { 32, }) {
for (int kv : { 512, 1024, }) { for (int kv : { 512, 1024, }) {
for (int nb : { 1, 3, 32, 35, }) { for (int nb : { 1, 3, 32, 35, }) {
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV)); test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
} }
} }