diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index f9bd6faa4..4e585f18e 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -257,7 +257,17 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, - //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F16_F16, @@ -712,7 +722,17 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm); //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); @@ -869,13 +889,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_LEAKY_RELU: return true; case GGML_OP_FLASH_ATTN_EXT: - if (op->src[1]->type != GGML_TYPE_F16) { - return false; - } - if (op->src[2]->type != GGML_TYPE_F16) { - return false; - } - if (op->src[0]->ne[0] == 256) { + if (op->src[1]->type != GGML_TYPE_F16 && op->src[0]->ne[0] % 128 != 0) { return false; } return support_simdgroup_mm; // TODO: over-restricted for vec-kernels @@ -2868,14 +2882,14 @@ static void ggml_metal_encode_node( bool use_vec_kernel = false; - if (ne01 >= 4 || (ne00%128 != 0)) { + if (src1->type == GGML_TYPE_F16 && ne00 < 256 && (ne01 >= 4 || (ne00%128 != 0))) { switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; - //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; default: { GGML_LOG_ERROR("unsupported size: %lld\n", ne00); @@ -2887,8 +2901,40 @@ static void ggml_metal_encode_node( use_vec_kernel = true; switch (ne00) { - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; - //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + case 128: + { + switch (src1->type) { + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } break; + case 256: + { + switch (src1->type) { + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } break; default: { GGML_LOG_ERROR("unsupported size: %lld\n", ne00); diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index ff9d37490..a81c55974 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2723,7 +2723,7 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; } -typedef void (flash_attn_ext_f16_t)( +typedef void (flash_attn_ext_t)( device const char * q, device const char * k, device const char * v, @@ -2800,9 +2800,9 @@ kernel void kernel_flash_attn_ext_f16( ushort sgitg[[simdgroup_index_in_threadgroup]]) { const short nsg = ntg.y; // number of simdgroups - const short iq3 = tgpig[2]; - const short iq2 = tgpig[1]; - const short iq1 = tgpig[0]*Q; + const int iq3 = tgpig[2]; + const int iq2 = tgpig[1]; + const int iq1 = tgpig[0]*Q; const short D4 = D/4; const short D8 = D/8; @@ -2979,13 +2979,14 @@ kernel void kernel_flash_attn_ext_f16( for (short cc = 0; cc < C/8; ++cc) { device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + simdgroup_float8x8 mv; + simdgroup_load(mv, ss + 8*cc, TF, 0, false); + +#pragma unroll for (short i = 0; i < D8; ++i) { simdgroup_half8x8 mk; simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); - simdgroup_float8x8 mv; - simdgroup_load(mv, ss + 8*cc, TF, 0, false); - simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]); } } @@ -3082,15 +3083,16 @@ kernel void kernel_flash_attn_ext_f16( } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; -template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>; -template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; -//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext_f16<64>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext_f16<80>; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext_f16<96>; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext_f16<112>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext_f16<128>; +//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext_f16<256>; -template // head size, queries per threadgroup, cache items per threadgroup -kernel void kernel_flash_attn_ext_vec_f16( +// D - head size, Q - queries per threadgroup, C - cache items per threadgroup +template +kernel void flash_attn_ext_vec( device const char * q, device const char * k, device const char * v, @@ -3128,13 +3130,15 @@ kernel void kernel_flash_attn_ext_vec_f16( ushort sgitg[[simdgroup_index_in_threadgroup]]) { const short nsg = ntg.y; // number of simdgroups - const short iq3 = tgpig[2]; - const short iq2 = tgpig[1]; - const short iq1 = tgpig[0]; + const int iq3 = tgpig[2]; + const int iq2 = tgpig[1]; + const int iq1 = tgpig[0]; - const short D4 = D/4; - const short NW = N_SIMDWIDTH; - const short SH = (C + Q); // shared memory per simdgroup in (half) + const short D4 = D/4; + const short D16 = D/16; + const short NW = N_SIMDWIDTH; + const short NW4 = NW/4; + const short SH = (C + Q); // shared memory per simdgroup in (half) const short T = D + 2*nsg*SH; // shared memory size per query in (half) @@ -3157,7 +3161,7 @@ kernel void kernel_flash_attn_ext_vec_f16( threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - half4 lo[D4/NW]; + half4x4 lo[D16/NW4]; // load heads from Q to shared memory device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); @@ -3171,8 +3175,8 @@ kernel void kernel_flash_attn_ext_vec_f16( } // zero out lo - for (short i = tiisg; i < D4; i += NW) { - lo[i/NW] = 0.0h; + for (short i = 0; i < D16/NW4; i += NW4) { + lo[i] = half4x4(0.0h); } // zero out shared memory SH @@ -3206,15 +3210,18 @@ kernel void kernel_flash_attn_ext_vec_f16( const short iv3 = iq3 / rv3; // load the queries from shared memory into local memory - float4 mq[D4/NW]; + float4x4 mq[D16/NW4]; - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - mq[ii/NW] = (float4) sq4[i]; + for (short ii = 0; ii < D16; ii += NW4) { + short i = ii + tiisg%8; + mq[ii/NW4][0] = (float4) sq4[4*i + 0]; + mq[ii/NW4][1] = (float4) sq4[4*i + 1]; + mq[ii/NW4][2] = (float4) sq4[4*i + 2]; + mq[ii/NW4][3] = (float4) sq4[4*i + 3]; } // pointer to the mask - device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); + device const half * mp = (device const half *) (mask + iq1*nb31); // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns @@ -3226,47 +3233,56 @@ kernel void kernel_flash_attn_ext_vec_f16( // Q*K^T { + // each simdgroup processes 1 query and 4 keys + const short j = tiisg/8; #pragma unroll for (short cc = 0; cc < C/4; ++cc) { - float4 mqk = { 0.0h }; + float mqk = 0.0; - device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); + device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + j)*nb11 + ik2*nb12 + ik3*nb13)); + float4x4 mk; #pragma unroll - for (short ii = 0; ii < D4; ii += NW) { - const short i = ii + tiisg; + for (short ii = 0; ii < D16; ii += NW4) { + const short i = ii + tiisg%8; // 0..7 - float4x4 mk; - mk[0] = (float4) pk4[i + 0*(nb11/8)]; - mk[1] = (float4) pk4[i + 1*(nb11/8)]; - mk[2] = (float4) pk4[i + 2*(nb11/8)]; - mk[3] = (float4) pk4[i + 3*(nb11/8)]; + dequantize_func(pk + i/nl, i%nl, mk); - mqk += (float4) (mq[ii/NW] * mk); + mqk += + dot(mq[ii/NW4][0], mk[0]) + + dot(mq[ii/NW4][1], mk[1]) + + dot(mq[ii/NW4][2], mk[2]) + + dot(mq[ii/NW4][3], mk[3]); } - // reduce the results from the threads in the simdgroup - mqk += simd_shuffle_down(mqk, 16); - mqk += simd_shuffle_down(mqk, 8); + // simdgroup reduce + // [ 0 .. 7] -> [ 0] + // [ 8 .. 15] -> [ 8] + // [16 .. 23] -> [16] + // [24 .. 31] -> [24] + //mqk += simd_shuffle_down(mqk, 16); + //mqk += simd_shuffle_down(mqk, 8); mqk += simd_shuffle_down(mqk, 4); mqk += simd_shuffle_down(mqk, 2); mqk += simd_shuffle_down(mqk, 1); // mqk = mqk*scale + mask*slope - if (tiisg == 0) { + if (tiisg%8 == 0) { mqk *= scale; if (logit_softcap != 0.0f) { mqk = logit_softcap*precise::tanh(mqk); } - mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f; + mqk += (mask != q) ? ((float) mp[ic + 4*cc + j])*slope : (float) 0.0f; - ss4[cc] = mqk; + ss[4*cc + j] = mqk; } } } + simdgroup_barrier(mem_flags::mem_threadgroup); + // online softmax { const short p = tiisg; @@ -3286,29 +3302,32 @@ kernel void kernel_flash_attn_ext_vec_f16( // O = diag(ms)*O #pragma unroll - for (short ii = 0; ii < D4; ii += NW) { - lo[ii/NW] *= ms; + for (short ii = 0; ii < D16; ii += NW4) { + lo[ii/NW4] *= ms; } } + simdgroup_barrier(mem_flags::mem_threadgroup); + // O = O + (Q*K^T)*V { + const short j = tiisg/8; #pragma unroll for (short cc = 0; cc < C/4; ++cc) { - device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23)); + device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + j)*nb21 + iv2*nb22 + iv3*nb23)); + float4x4 mv; + const float4x4 lss(ss[4*cc + j]); #pragma unroll - for (short ii = 0; ii < D4; ii += NW) { - const short i = ii + tiisg; + for (short ii = 0; ii < D16; ii += NW4) { + const short i = ii + tiisg%8; - lo[ii/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; - lo[ii/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; - lo[ii/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; - lo[ii/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; + dequantize_func(pv4 + i/nl, i%nl, mv); + + lo[ii/NW4] += (half4x4)(mv*lss); } } } - } // these are needed for reducing the results from the simdgroups (reuse the ss buffer) @@ -3318,10 +3337,38 @@ kernel void kernel_flash_attn_ext_vec_f16( } } + // simdgroup reduce + // [ 0, 8, 16, 24] -> [ 0] + // [ 1, 9, 17, 25] -> [ 1] + // [ 2, 10, 18, 26] -> [ 2] + // [ 3, 11, 19, 27] -> [ 3] + // [ 4, 12, 20, 28] -> [ 4] + // [ 5, 13, 21, 29] -> [ 5] + // [ 6, 14, 22, 30] -> [ 6] + // [ 7, 15, 23, 31] -> [ 7] + for (short ii = 0; ii < D16; ii += NW4) { + lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0], 16); + lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0], 8); + + lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1], 16); + lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1], 8); + + lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2], 16); + lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2], 8); + + lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3], 16); + lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3], 8); + } + // store results to shared memory - for (short ii = 0; ii < D4; ii += NW) { + for (short ii = 0; ii < D16; ii += NW4) { short i = ii + tiisg; - sr4[i] = lo[ii/NW]; + if (tiisg < 8) { + sr4[4*i + 0] = lo[ii/NW4][0]; + sr4[4*i + 1] = lo[ii/NW4][1]; + sr4[4*i + 2] = lo[ii/NW4][2]; + sr4[4*i + 3] = lo[ii/NW4][3]; + } } threadgroup_barrier(mem_flags::mem_threadgroup); @@ -3370,8 +3417,22 @@ kernel void kernel_flash_attn_ext_vec_f16( } } -template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; -//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; +//template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext_vec_f16<128>; +//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext_vec_f16<256>; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_t flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_t flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_t flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_t flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_t flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_t flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_t flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_t flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_t flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_t flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_t flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_t flash_attn_ext_vec; template kernel void kernel_cpy(