mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-10 02:31:46 +00:00
metal : add quantized FA (non-vec) support
This commit is contained in:
parent
6c484f35b0
commit
e9565ccf9a
@ -255,7 +255,37 @@ enum ggml_metal_kernel_type {
|
|||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
||||||
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
|
||||||
@ -720,7 +750,37 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm);
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, support_simdgroup_reduction);
|
||||||
@ -889,9 +949,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
if (op->src[1]->type != GGML_TYPE_F16 && op->src[0]->ne[0] % 128 != 0) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
||||||
case GGML_OP_SSM_CONV:
|
case GGML_OP_SSM_CONV:
|
||||||
case GGML_OP_SSM_SCAN:
|
case GGML_OP_SSM_SCAN:
|
||||||
@ -2882,20 +2939,116 @@ static void ggml_metal_encode_node(
|
|||||||
|
|
||||||
bool use_vec_kernel = false;
|
bool use_vec_kernel = false;
|
||||||
|
|
||||||
if (src1->type == GGML_TYPE_F16 && ne00 < 256 && (ne01 >= 4 || (ne00%128 != 0))) {
|
if (ne01 >= 4 || (ne00%128 != 0)) {
|
||||||
switch (ne00) {
|
switch (src1->type) {
|
||||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
case GGML_TYPE_F16:
|
||||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
{
|
||||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
switch (ne00) {
|
||||||
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
||||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
||||||
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
||||||
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
||||||
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
||||||
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||||
|
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||||
|
GGML_ABORT("add template specialization for this size");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
{
|
||||||
|
switch (ne00) {
|
||||||
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
|
||||||
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
|
||||||
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
|
||||||
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
|
||||||
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
|
||||||
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||||
|
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||||
|
GGML_ABORT("add template specialization for this size");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q4_1:
|
||||||
|
{
|
||||||
|
switch (ne00) {
|
||||||
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
|
||||||
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
|
||||||
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
|
||||||
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
|
||||||
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
|
||||||
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||||
|
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||||
|
GGML_ABORT("add template specialization for this size");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q5_0:
|
||||||
|
{
|
||||||
|
switch (ne00) {
|
||||||
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
|
||||||
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
|
||||||
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
|
||||||
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
|
||||||
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
|
||||||
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||||
|
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||||
|
GGML_ABORT("add template specialization for this size");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q5_1:
|
||||||
|
{
|
||||||
|
switch (ne00) {
|
||||||
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
|
||||||
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
|
||||||
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
|
||||||
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
|
||||||
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
|
||||||
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||||
|
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||||
|
GGML_ABORT("add template specialization for this size");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q8_0:
|
||||||
|
{
|
||||||
|
switch (ne00) {
|
||||||
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
|
||||||
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
|
||||||
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
|
||||||
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
|
||||||
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
|
||||||
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||||
|
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||||
|
GGML_ABORT("add template specialization for this size");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
||||||
GGML_LOG_ERROR("add template specialization for this size\n");
|
GGML_LOG_ERROR("add template specialization for this type\n");
|
||||||
GGML_ABORT("add template specialization for this size");
|
GGML_ABORT("add template specialization for this type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
use_vec_kernel = true;
|
use_vec_kernel = true;
|
||||||
@ -2982,6 +3135,7 @@ static void ggml_metal_encode_node(
|
|||||||
if (!use_vec_kernel) {
|
if (!use_vec_kernel) {
|
||||||
// half8x8 kernel
|
// half8x8 kernel
|
||||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||||
|
const int64_t nkpsg = 8; // keys per simdgroup
|
||||||
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||||
|
|
||||||
GGML_ASSERT(nqptg <= 32);
|
GGML_ASSERT(nqptg <= 32);
|
||||||
@ -2991,7 +3145,7 @@ static void ggml_metal_encode_node(
|
|||||||
int64_t nsgmax = 2;
|
int64_t nsgmax = 2;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
|
const size_t smem = (nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg)) + 4*16*nkpsg*nsgmax)*(sizeof(float)/2);
|
||||||
if (smem > device.maxThreadgroupMemoryLength) {
|
if (smem > device.maxThreadgroupMemoryLength) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -3002,12 +3156,12 @@ static void ggml_metal_encode_node(
|
|||||||
// simdgroups per threadgroup (a.k.a. warps)
|
// simdgroups per threadgroup (a.k.a. warps)
|
||||||
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
||||||
|
|
||||||
const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + 4*16*nkpsg*nsg)*(sizeof(float)/2);
|
||||||
|
|
||||||
//printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
|
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
|
||||||
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
||||||
|
|
||||||
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 128) atIndex:0];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||||
} else {
|
} else {
|
||||||
|
@ -2761,8 +2761,8 @@ typedef void (flash_attn_ext_t)(
|
|||||||
ushort sgitg[[simdgroup_index_in_threadgroup]]);
|
ushort sgitg[[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
// ref: https://arxiv.org/pdf/2307.08691.pdf
|
// ref: https://arxiv.org/pdf/2307.08691.pdf
|
||||||
template<int64_t D, int64_t Q = 8, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &), short D, short Q = 8, short K = 8, short C = 32> // head size, queries per threadgroup, cache items per threadgroup
|
||||||
kernel void kernel_flash_attn_ext_f16(
|
kernel void kernel_flash_attn_ext(
|
||||||
device const char * q,
|
device const char * q,
|
||||||
device const char * k,
|
device const char * k,
|
||||||
device const char * v,
|
device const char * v,
|
||||||
@ -2804,11 +2804,11 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
const int iq2 = tgpig[1];
|
const int iq2 = tgpig[1];
|
||||||
const int iq1 = tgpig[0]*Q;
|
const int iq1 = tgpig[0]*Q;
|
||||||
|
|
||||||
const short D4 = D/4;
|
const short D4 = D/4;
|
||||||
const short D8 = D/8;
|
const short D8 = D/8;
|
||||||
//const short Q8 = Q/8;
|
const short D16 = D/16;
|
||||||
const short NW = N_SIMDWIDTH;
|
const short NW = N_SIMDWIDTH;
|
||||||
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
||||||
|
|
||||||
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
||||||
const short TF = T/2; // shared memory size per query in (float)
|
const short TF = T/2; // shared memory size per query in (float)
|
||||||
@ -2818,6 +2818,9 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||||
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||||
|
|
||||||
|
threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4*16*K) + Q*T); // scratch buffer to load K and V in shared memory
|
||||||
|
threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*K) + Q*T); // same as above but in half4x4
|
||||||
|
|
||||||
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||||
simdgroup_half8x8 lo[D8];
|
simdgroup_half8x8 lo[D8];
|
||||||
|
|
||||||
@ -2906,13 +2909,60 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
for (short cc = 0; cc < C/8; ++cc) {
|
for (short cc = 0; cc < C/8; ++cc) {
|
||||||
simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
|
simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
|
||||||
|
|
||||||
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
if (is_same<block_q, half4x4>::value) {
|
||||||
|
// we can read directly from global memory
|
||||||
|
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
|
|
||||||
for (short i = 0; i < D8; ++i) {
|
for (short i = 0; i < D8; ++i) {
|
||||||
simdgroup_half8x8 mk;
|
simdgroup_half8x8 mk;
|
||||||
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
|
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
|
||||||
|
|
||||||
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (short ii = 0; ii < D16; ii += 4) {
|
||||||
|
const short i = tiisg%4;
|
||||||
|
const short j = tiisg/4;
|
||||||
|
|
||||||
|
device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + j)*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
|
|
||||||
|
if (D16%4 == 0) {
|
||||||
|
half4x4 tmp;
|
||||||
|
dequantize_func(pk4 + (ii + i)/nl, (ii + i)%nl, tmp);
|
||||||
|
skv4[4*j + i] = tmp;
|
||||||
|
|
||||||
|
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (short k = 0; k < 4; ++k) {
|
||||||
|
simdgroup_half8x8 mk;
|
||||||
|
|
||||||
|
simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
|
||||||
|
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
|
||||||
|
|
||||||
|
simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
|
||||||
|
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (ii + i < D16) {
|
||||||
|
half4x4 tmp;
|
||||||
|
dequantize_func(pk4 + (ii + i)/nl, (ii + i)%nl, tmp);
|
||||||
|
skv4[4*j + i] = tmp;
|
||||||
|
}
|
||||||
|
|
||||||
|
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
for (short k = 0; k < 4 && ii + k < D16; ++k) {
|
||||||
|
simdgroup_half8x8 mk;
|
||||||
|
|
||||||
|
simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
|
||||||
|
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
|
||||||
|
|
||||||
|
simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
|
||||||
|
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
||||||
@ -2977,17 +3027,63 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
// O = O + (Q*K^T)*V
|
// O = O + (Q*K^T)*V
|
||||||
{
|
{
|
||||||
for (short cc = 0; cc < C/8; ++cc) {
|
for (short cc = 0; cc < C/8; ++cc) {
|
||||||
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
simdgroup_float8x8 ms;
|
||||||
|
simdgroup_load(ms, ss + 8*cc, TF, 0, false);
|
||||||
|
|
||||||
simdgroup_float8x8 mv;
|
if (is_same<block_q, half4x4>::value) {
|
||||||
simdgroup_load(mv, ss + 8*cc, TF, 0, false);
|
// we can read directly from global memory
|
||||||
|
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
#pragma unroll
|
||||||
|
for (short i = 0; i < D8; ++i) {
|
||||||
|
simdgroup_half8x8 mv;
|
||||||
|
simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false);
|
||||||
|
|
||||||
|
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (short ii = 0; ii < D16; ii += 4) {
|
||||||
|
const short i = tiisg%4;
|
||||||
|
const short j = tiisg/4;
|
||||||
|
|
||||||
|
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + j)*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
|
if (D16%4 == 0) {
|
||||||
|
half4x4 tmp;
|
||||||
|
dequantize_func(pv4 + (ii + i)/nl, (ii + i)%nl, tmp);
|
||||||
|
skv4[4*j + i] = tmp;
|
||||||
|
|
||||||
|
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (short i = 0; i < D8; ++i) {
|
for (short k = 0; k < 4; ++k) {
|
||||||
simdgroup_half8x8 mk;
|
simdgroup_half8x8 mv;
|
||||||
simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
|
|
||||||
|
|
||||||
simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]);
|
simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
|
||||||
|
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
|
||||||
|
|
||||||
|
simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
|
||||||
|
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (ii + i < D16) {
|
||||||
|
half4x4 tmp;
|
||||||
|
dequantize_func(pv4 + (ii + i)/nl, (ii + i)%nl, tmp);
|
||||||
|
skv4[4*j + i] = tmp;
|
||||||
|
}
|
||||||
|
|
||||||
|
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
for (short k = 0; k < 4 && ii + k < D16; ++k) {
|
||||||
|
simdgroup_half8x8 mv;
|
||||||
|
|
||||||
|
simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
|
||||||
|
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
|
||||||
|
|
||||||
|
simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
|
||||||
|
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3083,15 +3179,50 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext_f16<64>;
|
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 64>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext_f16<80>;
|
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 80>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext_f16<96>;
|
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 96>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext_f16<112>;
|
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 112>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext_f16<128>;
|
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 128>;
|
||||||
//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext_f16<256>;
|
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 256>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 64>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 80>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 96>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 112>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 128>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 256>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 64>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 80>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 96>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 112>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 128>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 256>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 64>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 80>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 96>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 112>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 128>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 256>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 64>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 80>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 96>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 112>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 128>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 256>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 64>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 80>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 96>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 112>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 128>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 256>;
|
||||||
|
|
||||||
// D - head size, Q - queries per threadgroup, C - cache items per threadgroup
|
// D - head size, Q - queries per threadgroup, C - cache items per threadgroup
|
||||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &), int64_t D, int64_t Q = 1, int64_t C = 32>
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &), short D, short Q = 1, short C = 32>
|
||||||
kernel void flash_attn_ext_vec(
|
kernel void flash_attn_ext_vec(
|
||||||
device const char * q,
|
device const char * q,
|
||||||
device const char * k,
|
device const char * k,
|
||||||
@ -3158,7 +3289,7 @@ kernel void flash_attn_ext_vec(
|
|||||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||||
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||||
threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
|
threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
|
||||||
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
|
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + Q*T); // scratch buffer for the results
|
||||||
|
|
||||||
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||||
half4x4 lo[D16/NW4];
|
half4x4 lo[D16/NW4];
|
||||||
|
Loading…
Reference in New Issue
Block a user