diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index e7873a143..1eddfd0db 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -256,6 +256,9 @@ static ggml_type ggml_type_from_name(const std::string & s) { if (s == "f16") { return GGML_TYPE_F16; } + if (s == "bf16") { + return GGML_TYPE_BF16; + } if (s == "q8_0") { return GGML_TYPE_Q8_0; } diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index df77eb96e..99587fb98 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -269,6 +269,12 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, @@ -300,12 +306,14 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, @@ -594,6 +602,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \ id metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \ kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \ + GGML_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ + (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ + (int) kernel->pipeline.threadExecutionWidth); \ [metal_function release]; \ if (error) { \ GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ @@ -786,6 +797,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm); @@ -817,12 +834,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction); @@ -3042,6 +3061,23 @@ static void ggml_metal_encode_node( } } } break; + case GGML_TYPE_BF16: + { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } break; case GGML_TYPE_Q4_0: { switch (ne00) { @@ -3142,6 +3178,7 @@ static void ggml_metal_encode_node( { switch (src1->type) { case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break; case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break; case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break; case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break; @@ -3159,6 +3196,7 @@ static void ggml_metal_encode_node( { switch (src1->type) { case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break; case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break; case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break; case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index cf36eaab5..512a06d74 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2776,25 +2776,28 @@ kernel void kernel_leaky_relu_f32( // ref: https://arxiv.org/pdf/2307.08691.pdf template< - typename q_t, + typename q_t, // query types in shared memory typename q4_t, typename q8x8_t, - typename k_t, + typename k_t, // key types in shared memory typename k4x4_t, typename k8x8_t, - typename v_t, + typename v_t, // value types in shared memory typename v4x4_t, typename v8x8_t, - typename s_t, // attention accumulation types + typename qk_t, // Q*K types + typename qk8x8_t, + typename s_t, // soft-max types typename s8x8_t, - typename o_t, + typename o_t, // attention accumulation types typename o4_t, typename o8x8_t, - typename block_q, + typename kd4x4_t, // key type in device memory short nl_k, - void (*deq_k)(device const block_q *, short, thread k4x4_t &), + void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &), + typename vd4x4_t, // key type in device memory short nl_v, - void (*deq_v)(device const block_q *, short, thread v4x4_t &), + void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), short D, // head size short Q = 8, // queries per threadgroup short KV = 8, // key/value processed per each simdgroup @@ -2900,6 +2903,8 @@ kernel void kernel_flash_attn_ext( float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 }; // thread indices inside the simdgroup + // TODO: see if we can utilize quad-group functions for better performance + // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3) const short tx = tiisg%4; const short ty = tiisg/4; @@ -2931,6 +2936,8 @@ kernel void kernel_flash_attn_ext( // pointer to the mask device const half * mp = (device const half *) (mask + iq1*nb31); + const bool has_mask = mask != q; + float slope = 1.0f; // ALiBi @@ -2954,13 +2961,14 @@ kernel void kernel_flash_attn_ext( // Q*K^T { for (short cc = 0; cc < C/8; ++cc) { - s8x8_t mqk = make_filled_simdgroup_matrix((s_t) 0.0f); + qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); // this is compile-time check, so it does not have runtime overhead - if (is_same::value) { + if constexpr (is_same::value) { // we can read directly from global memory device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); +#pragma unroll for (short i = 0; i < D8; ++i) { k8x8_t mk; simdgroup_load(mk, pk + i*8, nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10 @@ -2969,13 +2977,15 @@ kernel void kernel_flash_attn_ext( } } else { for (short ii = 0; ii < D16; ii += 4) { - device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13)); + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13)); - if (D16%4 == 0) { + if constexpr (D16%4 == 0) { // the head is evenly divisible by 4*16 = 64, so no need for bound checks - k4x4_t tmp; - deq_k(pk4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); - sk4x4[4*ty + tx] = tmp; + { + k4x4_t tmp; + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + sk4x4[4*ty + tx] = tmp; + } simdgroup_barrier(mem_flags::mem_threadgroup); @@ -2992,7 +3002,7 @@ kernel void kernel_flash_attn_ext( } else { if (ii + tx < D16) { k4x4_t tmp; - deq_k(pk4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); sk4x4[4*ty + tx] = tmp; } @@ -3011,7 +3021,15 @@ kernel void kernel_flash_attn_ext( } } - simdgroup_store(mqk, ss + 8*cc, TS, 0, false); + if constexpr (is_same::value) { + // same type - store directly + simdgroup_store(mqk, ss + 8*cc, TS, 0, false); + } else { + // cast qk_t -> s_t + s8x8_t mqks(1.0f); + simdgroup_multiply(mqks, mqk, mqks); + simdgroup_store(mqks, ss + 8*cc, TS, 0, false); + } } } @@ -3032,7 +3050,7 @@ kernel void kernel_flash_attn_ext( s = logit_softcap*precise::tanh(s); } - if (mask != q) { + if (has_mask) { // mqk = mqk + mask*slope s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; // TODO: use ne30 } @@ -3065,6 +3083,7 @@ kernel void kernel_flash_attn_ext( s8x8_t mm; simdgroup_load(mm, ss + C, TS, 0, false); +#pragma unroll for (short i = 0; i < D8; ++i) { simdgroup_multiply(lo[i], mm, lo[i]); } @@ -3076,7 +3095,7 @@ kernel void kernel_flash_attn_ext( s8x8_t ms; simdgroup_load(ms, ss + 8*cc, TS, 0, false); - if (is_same::value) { + if constexpr (is_same::value) { // we can read directly from global memory device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); #pragma unroll @@ -3088,13 +3107,15 @@ kernel void kernel_flash_attn_ext( } } else { for (short ii = 0; ii < D16; ii += 4) { - device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23)); + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23)); - if (D16%4 == 0) { + if constexpr (D16%4 == 0) { // no need for bound checks - v4x4_t tmp; - deq_v(pv4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp); - sv4x4[4*ty + tx] = tmp; + { + v4x4_t tmp; + deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp); + sv4x4[4*ty + tx] = tmp; + } simdgroup_barrier(mem_flags::mem_threadgroup); @@ -3111,7 +3132,7 @@ kernel void kernel_flash_attn_ext( } else { if (ii + tx < D16) { v4x4_t tmp; - deq_v(pv4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp); + deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp); sv4x4[4*ty + tx] = tmp; } @@ -3225,82 +3246,77 @@ kernel void kernel_flash_attn_ext( } #if defined(GGML_METAL_FORCE_FATTN_PREC_F16) -#define S_T half -#define S4_T half4 -#define S4x4_T half4x4 -#define S8x8_T simdgroup_half8x8 - #define FA_TYPES \ half, half4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \ + float, simdgroup_float8x8, \ half, simdgroup_half8x8, \ half, half4, simdgroup_half8x8 #else -#define S_T float -#define S4_T float4 -#define S4x4_T float4x4 -#define S8x8_T simdgroup_float8x8 - #define FA_TYPES \ half, half4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \ float, simdgroup_float8x8, \ + float, simdgroup_float8x8, \ half, half4, simdgroup_half8x8 #endif // TOOD: static_assert -typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; +typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #undef FA_TYPES -#undef S8x8_T -#undef S4x4_T -#undef S4_T -#undef S_T template< typename q4_t, @@ -3639,10 +3655,6 @@ kernel void kernel_flash_attn_ext_vec( // NOTE: can use half instead of float precision for some extra perf // however, by default use F32 since the op should be mostly memory bandwidth bound -#define S_T float -#define S4_T float4 -#define S4x4_T float4x4 - #define FA_TYPES \ half4, half4x4, \ float4x4, \ @@ -3653,6 +3665,7 @@ kernel void kernel_flash_attn_ext_vec( typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; @@ -3660,6 +3673,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; @@ -3667,9 +3681,6 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #undef FA_TYPES -#undef S4x4_T -#undef S4_T -#undef S_T template kernel void kernel_cpy( diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6cc77edab..c14db8912 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3745,7 +3745,7 @@ static std::vector> make_test_cases_eval() { for (int nh : { 32, }) { for (int kv : { 512, 1024, }) { for (int nb : { 1, 3, 32, 35, }) { - for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { + for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV)); } }