mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-08 09:41:45 +00:00
ggml : add ggml_flash_attn_ext_get_prec
This commit is contained in:
parent
76c6e7f105
commit
25e877309a
@ -1746,6 +1746,9 @@ extern "C" {
|
|||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
enum ggml_prec prec);
|
enum ggml_prec prec);
|
||||||
|
|
||||||
|
GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
|
||||||
|
const struct ggml_tensor * a);
|
||||||
|
|
||||||
// TODO: needs to be adapted to ggml_flash_attn_ext
|
// TODO: needs to be adapted to ggml_flash_attn_ext
|
||||||
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
@ -13,9 +13,9 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
|
|||||||
const ggml_tensor * KQV = dst;
|
const ggml_tensor * KQV = dst;
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
|
||||||
const int32_t precision = KQV->op_params[3];
|
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||||
|
|
||||||
if (precision != GGML_PREC_DEFAULT) {
|
if (prec != GGML_PREC_DEFAULT) {
|
||||||
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
||||||
constexpr int cols_per_block = 16;
|
constexpr int cols_per_block = 16;
|
||||||
switch (Q->ne[0]) {
|
switch (Q->ne[0]) {
|
||||||
@ -301,11 +301,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||||||
|
|
||||||
ggml_cuda_set_device(ctx.device);
|
ggml_cuda_set_device(ctx.device);
|
||||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||||
const int32_t precision = KQV->op_params[3];
|
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||||
|
|
||||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||||
if (cc >= CC_OFFSET_AMD) {
|
if (cc >= CC_OFFSET_AMD) {
|
||||||
if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||||
} else {
|
} else {
|
||||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||||
@ -332,7 +332,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
|
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
|
||||||
if (precision == GGML_PREC_DEFAULT) {
|
if (prec == GGML_PREC_DEFAULT) {
|
||||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||||
return;
|
return;
|
||||||
} else if(Q->ne[0] <= 128) {
|
} else if(Q->ne[0] <= 128) {
|
||||||
|
@ -4228,6 +4228,15 @@ void ggml_flash_attn_ext_set_prec(
|
|||||||
ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
|
ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum ggml_prec ggml_flash_attn_ext_get_prec(
|
||||||
|
const struct ggml_tensor * a) {
|
||||||
|
GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
|
||||||
|
|
||||||
|
const int32_t prec_i32 = ggml_get_op_params_i32(a, 3);
|
||||||
|
|
||||||
|
return (enum ggml_prec) prec_i32;
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_flash_attn_back
|
// ggml_flash_attn_back
|
||||||
|
|
||||||
struct ggml_tensor * ggml_flash_attn_back(
|
struct ggml_tensor * ggml_flash_attn_back(
|
||||||
|
Loading…
Reference in New Issue
Block a user