diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 0d143d2fe..73ede1813 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1746,6 +1746,9 @@ extern "C" { struct ggml_tensor * a, enum ggml_prec prec); + GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec( + const struct ggml_tensor * a); + // TODO: needs to be adapted to ggml_flash_attn_ext GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 83e5589a1..0e7ebbc53 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -13,9 +13,9 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; - const int32_t precision = KQV->op_params[3]; + const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); - if (precision != GGML_PREC_DEFAULT) { + if (prec != GGML_PREC_DEFAULT) { if (Q->ne[1] <= 32 || Q->ne[0] > 128) { constexpr int cols_per_block = 16; switch (Q->ne[0]) { @@ -301,11 +301,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - const int32_t precision = KQV->op_params[3]; + const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); // On AMD the tile kernels perform poorly, use the vec kernel instead: if (cc >= CC_OFFSET_AMD) { - if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { + if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); } else { ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); @@ -332,7 +332,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) { - if (precision == GGML_PREC_DEFAULT) { + if (prec == GGML_PREC_DEFAULT) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); return; } else if(Q->ne[0] <= 128) { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index bc034015f..cd26a361b 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4228,6 +4228,15 @@ void ggml_flash_attn_ext_set_prec( ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second } +enum ggml_prec ggml_flash_attn_ext_get_prec( + const struct ggml_tensor * a) { + GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT); + + const int32_t prec_i32 = ggml_get_op_params_i32(a, 3); + + return (enum ggml_prec) prec_i32; +} + // ggml_flash_attn_back struct ggml_tensor * ggml_flash_attn_back(