mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 17:21:46 +00:00
ggml : add ggml_flash_attn_ext_get_prec
This commit is contained in:
parent
76c6e7f105
commit
25e877309a
@ -1746,6 +1746,9 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_prec prec);
|
||||
|
||||
GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
|
||||
const struct ggml_tensor * a);
|
||||
|
||||
// TODO: needs to be adapted to ggml_flash_attn_ext
|
||||
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
||||
struct ggml_context * ctx,
|
||||
|
@ -13,9 +13,9 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
const int32_t precision = KQV->op_params[3];
|
||||
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||
|
||||
if (precision != GGML_PREC_DEFAULT) {
|
||||
if (prec != GGML_PREC_DEFAULT) {
|
||||
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
||||
constexpr int cols_per_block = 16;
|
||||
switch (Q->ne[0]) {
|
||||
@ -301,11 +301,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const int32_t precision = KQV->op_params[3];
|
||||
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||
|
||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||
if (cc >= CC_OFFSET_AMD) {
|
||||
if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
@ -332,7 +332,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
}
|
||||
|
||||
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
|
||||
if (precision == GGML_PREC_DEFAULT) {
|
||||
if (prec == GGML_PREC_DEFAULT) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
return;
|
||||
} else if(Q->ne[0] <= 128) {
|
||||
|
@ -4228,6 +4228,15 @@ void ggml_flash_attn_ext_set_prec(
|
||||
ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
|
||||
}
|
||||
|
||||
enum ggml_prec ggml_flash_attn_ext_get_prec(
|
||||
const struct ggml_tensor * a) {
|
||||
GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
|
||||
const int32_t prec_i32 = ggml_get_op_params_i32(a, 3);
|
||||
|
||||
return (enum ggml_prec) prec_i32;
|
||||
}
|
||||
|
||||
// ggml_flash_attn_back
|
||||
|
||||
struct ggml_tensor * ggml_flash_attn_back(
|
||||
|
Loading…
Reference in New Issue
Block a user