metal : add quantized FA (vec) support

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-03 10:27:48 +02:00
parent 05697f670b
commit 6c484f35b0
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 179 additions and 72 deletions

View File

@ -257,7 +257,17 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
GGML_METAL_KERNEL_TYPE_CPY_F16_F16, GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
@ -712,7 +722,17 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm);
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm); //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction);
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
@ -869,13 +889,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
return true; return true;
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
if (op->src[1]->type != GGML_TYPE_F16) { if (op->src[1]->type != GGML_TYPE_F16 && op->src[0]->ne[0] % 128 != 0) {
return false;
}
if (op->src[2]->type != GGML_TYPE_F16) {
return false;
}
if (op->src[0]->ne[0] == 256) {
return false; return false;
} }
return support_simdgroup_mm; // TODO: over-restricted for vec-kernels return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
@ -2868,7 +2882,7 @@ static void ggml_metal_encode_node(
bool use_vec_kernel = false; bool use_vec_kernel = false;
if (ne01 >= 4 || (ne00%128 != 0)) { if (src1->type == GGML_TYPE_F16 && ne00 < 256 && (ne01 >= 4 || (ne00%128 != 0))) {
switch (ne00) { switch (ne00) {
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
@ -2887,8 +2901,40 @@ static void ggml_metal_encode_node(
use_vec_kernel = true; use_vec_kernel = true;
switch (ne00) { switch (ne00) {
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; case 128:
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; {
switch (src1->type) {
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break;
default:
{
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
GGML_LOG_ERROR("add template specialization for this type\n");
GGML_ABORT("add template specialization for this type");
}
}
} break;
case 256:
{
switch (src1->type) {
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break;
default:
{
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
GGML_LOG_ERROR("add template specialization for this type\n");
GGML_ABORT("add template specialization for this type");
}
}
} break;
default: default:
{ {
GGML_LOG_ERROR("unsupported size: %lld\n", ne00); GGML_LOG_ERROR("unsupported size: %lld\n", ne00);

View File

@ -2723,7 +2723,7 @@ kernel void kernel_leaky_relu_f32(
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
} }
typedef void (flash_attn_ext_f16_t)( typedef void (flash_attn_ext_t)(
device const char * q, device const char * q,
device const char * k, device const char * k,
device const char * v, device const char * v,
@ -2800,9 +2800,9 @@ kernel void kernel_flash_attn_ext_f16(
ushort sgitg[[simdgroup_index_in_threadgroup]]) { ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short nsg = ntg.y; // number of simdgroups const short nsg = ntg.y; // number of simdgroups
const short iq3 = tgpig[2]; const int iq3 = tgpig[2];
const short iq2 = tgpig[1]; const int iq2 = tgpig[1];
const short iq1 = tgpig[0]*Q; const int iq1 = tgpig[0]*Q;
const short D4 = D/4; const short D4 = D/4;
const short D8 = D/8; const short D8 = D/8;
@ -2979,13 +2979,14 @@ kernel void kernel_flash_attn_ext_f16(
for (short cc = 0; cc < C/8; ++cc) { for (short cc = 0; cc < C/8; ++cc) {
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
simdgroup_float8x8 mv;
simdgroup_load(mv, ss + 8*cc, TF, 0, false);
#pragma unroll
for (short i = 0; i < D8; ++i) { for (short i = 0; i < D8; ++i) {
simdgroup_half8x8 mk; simdgroup_half8x8 mk;
simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
simdgroup_float8x8 mv;
simdgroup_load(mv, ss + 8*cc, TF, 0, false);
simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]); simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]);
} }
} }
@ -3082,15 +3083,16 @@ kernel void kernel_flash_attn_ext_f16(
} }
} }
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext_f16<64>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext_f16<80>;
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>; template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext_f16<96>;
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>; template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext_f16<112>;
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext_f16<128>;
//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>; //template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext_f16<256>;
template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup // D - head size, Q - queries per threadgroup, C - cache items per threadgroup
kernel void kernel_flash_attn_ext_vec_f16( template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &), int64_t D, int64_t Q = 1, int64_t C = 32>
kernel void flash_attn_ext_vec(
device const char * q, device const char * q,
device const char * k, device const char * k,
device const char * v, device const char * v,
@ -3128,12 +3130,14 @@ kernel void kernel_flash_attn_ext_vec_f16(
ushort sgitg[[simdgroup_index_in_threadgroup]]) { ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short nsg = ntg.y; // number of simdgroups const short nsg = ntg.y; // number of simdgroups
const short iq3 = tgpig[2]; const int iq3 = tgpig[2];
const short iq2 = tgpig[1]; const int iq2 = tgpig[1];
const short iq1 = tgpig[0]; const int iq1 = tgpig[0];
const short D4 = D/4; const short D4 = D/4;
const short D16 = D/16;
const short NW = N_SIMDWIDTH; const short NW = N_SIMDWIDTH;
const short NW4 = NW/4;
const short SH = (C + Q); // shared memory per simdgroup in (half) const short SH = (C + Q); // shared memory per simdgroup in (half)
const short T = D + 2*nsg*SH; // shared memory size per query in (half) const short T = D + 2*nsg*SH; // shared memory size per query in (half)
@ -3157,7 +3161,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
half4 lo[D4/NW]; half4x4 lo[D16/NW4];
// load heads from Q to shared memory // load heads from Q to shared memory
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
@ -3171,8 +3175,8 @@ kernel void kernel_flash_attn_ext_vec_f16(
} }
// zero out lo // zero out lo
for (short i = tiisg; i < D4; i += NW) { for (short i = 0; i < D16/NW4; i += NW4) {
lo[i/NW] = 0.0h; lo[i] = half4x4(0.0h);
} }
// zero out shared memory SH // zero out shared memory SH
@ -3206,15 +3210,18 @@ kernel void kernel_flash_attn_ext_vec_f16(
const short iv3 = iq3 / rv3; const short iv3 = iq3 / rv3;
// load the queries from shared memory into local memory // load the queries from shared memory into local memory
float4 mq[D4/NW]; float4x4 mq[D16/NW4];
for (short ii = 0; ii < D4; ii += NW) { for (short ii = 0; ii < D16; ii += NW4) {
short i = ii + tiisg; short i = ii + tiisg%8;
mq[ii/NW] = (float4) sq4[i]; mq[ii/NW4][0] = (float4) sq4[4*i + 0];
mq[ii/NW4][1] = (float4) sq4[4*i + 1];
mq[ii/NW4][2] = (float4) sq4[4*i + 2];
mq[ii/NW4][3] = (float4) sq4[4*i + 3];
} }
// pointer to the mask // pointer to the mask
device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); device const half * mp = (device const half *) (mask + iq1*nb31);
// loop over the KV cache // loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns // each simdgroup handles blocks of Q rows and C columns
@ -3226,47 +3233,56 @@ kernel void kernel_flash_attn_ext_vec_f16(
// Q*K^T // Q*K^T
{ {
// each simdgroup processes 1 query and 4 keys
const short j = tiisg/8;
#pragma unroll #pragma unroll
for (short cc = 0; cc < C/4; ++cc) { for (short cc = 0; cc < C/4; ++cc) {
float4 mqk = { 0.0h }; float mqk = 0.0;
device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + j)*nb11 + ik2*nb12 + ik3*nb13));
#pragma unroll
for (short ii = 0; ii < D4; ii += NW) {
const short i = ii + tiisg;
float4x4 mk; float4x4 mk;
mk[0] = (float4) pk4[i + 0*(nb11/8)]; #pragma unroll
mk[1] = (float4) pk4[i + 1*(nb11/8)]; for (short ii = 0; ii < D16; ii += NW4) {
mk[2] = (float4) pk4[i + 2*(nb11/8)]; const short i = ii + tiisg%8; // 0..7
mk[3] = (float4) pk4[i + 3*(nb11/8)];
mqk += (float4) (mq[ii/NW] * mk); dequantize_func(pk + i/nl, i%nl, mk);
mqk +=
dot(mq[ii/NW4][0], mk[0]) +
dot(mq[ii/NW4][1], mk[1]) +
dot(mq[ii/NW4][2], mk[2]) +
dot(mq[ii/NW4][3], mk[3]);
} }
// reduce the results from the threads in the simdgroup // simdgroup reduce
mqk += simd_shuffle_down(mqk, 16); // [ 0 .. 7] -> [ 0]
mqk += simd_shuffle_down(mqk, 8); // [ 8 .. 15] -> [ 8]
// [16 .. 23] -> [16]
// [24 .. 31] -> [24]
//mqk += simd_shuffle_down(mqk, 16);
//mqk += simd_shuffle_down(mqk, 8);
mqk += simd_shuffle_down(mqk, 4); mqk += simd_shuffle_down(mqk, 4);
mqk += simd_shuffle_down(mqk, 2); mqk += simd_shuffle_down(mqk, 2);
mqk += simd_shuffle_down(mqk, 1); mqk += simd_shuffle_down(mqk, 1);
// mqk = mqk*scale + mask*slope // mqk = mqk*scale + mask*slope
if (tiisg == 0) { if (tiisg%8 == 0) {
mqk *= scale; mqk *= scale;
if (logit_softcap != 0.0f) { if (logit_softcap != 0.0f) {
mqk = logit_softcap*precise::tanh(mqk); mqk = logit_softcap*precise::tanh(mqk);
} }
mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f; mqk += (mask != q) ? ((float) mp[ic + 4*cc + j])*slope : (float) 0.0f;
ss4[cc] = mqk; ss[4*cc + j] = mqk;
} }
} }
} }
simdgroup_barrier(mem_flags::mem_threadgroup);
// online softmax // online softmax
{ {
const short p = tiisg; const short p = tiisg;
@ -3286,29 +3302,32 @@ kernel void kernel_flash_attn_ext_vec_f16(
// O = diag(ms)*O // O = diag(ms)*O
#pragma unroll #pragma unroll
for (short ii = 0; ii < D4; ii += NW) { for (short ii = 0; ii < D16; ii += NW4) {
lo[ii/NW] *= ms; lo[ii/NW4] *= ms;
} }
} }
simdgroup_barrier(mem_flags::mem_threadgroup);
// O = O + (Q*K^T)*V // O = O + (Q*K^T)*V
{ {
const short j = tiisg/8;
#pragma unroll #pragma unroll
for (short cc = 0; cc < C/4; ++cc) { for (short cc = 0; cc < C/4; ++cc) {
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23)); device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + j)*nb21 + iv2*nb22 + iv3*nb23));
float4x4 mv;
const float4x4 lss(ss[4*cc + j]);
#pragma unroll #pragma unroll
for (short ii = 0; ii < D4; ii += NW) { for (short ii = 0; ii < D16; ii += NW4) {
const short i = ii + tiisg; const short i = ii + tiisg%8;
lo[ii/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; dequantize_func(pv4 + i/nl, i%nl, mv);
lo[ii/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
lo[ii/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
lo[ii/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
}
}
}
lo[ii/NW4] += (half4x4)(mv*lss);
}
}
}
} }
// these are needed for reducing the results from the simdgroups (reuse the ss buffer) // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
@ -3318,10 +3337,38 @@ kernel void kernel_flash_attn_ext_vec_f16(
} }
} }
// simdgroup reduce
// [ 0, 8, 16, 24] -> [ 0]
// [ 1, 9, 17, 25] -> [ 1]
// [ 2, 10, 18, 26] -> [ 2]
// [ 3, 11, 19, 27] -> [ 3]
// [ 4, 12, 20, 28] -> [ 4]
// [ 5, 13, 21, 29] -> [ 5]
// [ 6, 14, 22, 30] -> [ 6]
// [ 7, 15, 23, 31] -> [ 7]
for (short ii = 0; ii < D16; ii += NW4) {
lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0], 16);
lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0], 8);
lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1], 16);
lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1], 8);
lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2], 16);
lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2], 8);
lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3], 16);
lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3], 8);
}
// store results to shared memory // store results to shared memory
for (short ii = 0; ii < D4; ii += NW) { for (short ii = 0; ii < D16; ii += NW4) {
short i = ii + tiisg; short i = ii + tiisg;
sr4[i] = lo[ii/NW]; if (tiisg < 8) {
sr4[4*i + 0] = lo[ii/NW4][0];
sr4[4*i + 1] = lo[ii/NW4][1];
sr4[4*i + 2] = lo[ii/NW4][2];
sr4[4*i + 3] = lo[ii/NW4][3];
}
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
@ -3370,8 +3417,22 @@ kernel void kernel_flash_attn_ext_vec_f16(
} }
} }
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; //template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext_vec_f16<128>;
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; //template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext_vec_f16<256>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_t flash_attn_ext_vec<half4x4, 1, dequantize_f16, 128>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_t flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 128>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_t flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 128>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_t flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 128>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_t flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 128>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_t flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 128>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_t flash_attn_ext_vec<half4x4, 1, dequantize_f16, 256>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_t flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 256>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_t flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 256>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_t flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 256>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_t flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 256>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_t flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 256>;
template<typename T0, typename T1> template<typename T0, typename T1>
kernel void kernel_cpy( kernel void kernel_cpy(