ggml : fix GQA support in ggml_flash_attn_ext

This commit is contained in:
Georgi Gerganov 2024-01-19 20:06:26 +02:00
parent a1c004ef2e
commit fa7ebcca99
3 changed files with 23 additions and 12 deletions

View File

@ -1960,10 +1960,10 @@ kernel void kernel_leaky_relu_f32(
}
kernel void kernel_flash_attn_ext_f16(
device const half * q,
device const half * k,
device const half * v,
device const half * mask,
device const half * q,
device const half * k,
device const half * v,
device const float * mask,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,

23
ggml.c
View File

@ -13307,6 +13307,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
// broadcast factors
const int64_t rk2 = neq2/nek2;
const int64_t rk3 = neq3/nek3;
const int64_t rv2 = neq2/nev2;
const int64_t rv3 = neq3/nev3;
if (params->type == GGML_TASK_INIT) {
return;
}
@ -13347,8 +13354,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
for (int64_t ic = 0; ic < nek1; ++ic) {
// k indices
const int ik3 = iq3;
const int ik2 = iq2 % nek2;
const int ik3 = iq3 / rk3;
const int ik2 = iq2 / rk2;
const int ik1 = ic;
// S indices
@ -13362,8 +13369,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
} else {
for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
// k indices
const int ik3 = iq3;
const int ik2 = iq2 % nek2;
const int ik3 = iq3 / rk3;
const int ik2 = iq2 / rk2;
const int ik1 = ic;
// S indices
@ -13452,8 +13459,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int i3 = iq3;
// v indices
const int iv2 = iq2 % nev2;
const int iv3 = iq3;
const int iv2 = iq2 / rv2;
const int iv3 = iq3 / rv3;
ggml_vec_dot_f16(nev0,
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
@ -13468,8 +13475,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int i3 = iq3;
// v indices
const int iv2 = iq2 % nev2;
const int iv3 = iq3;
const int iv2 = iq2 / rv2;
const int iv3 = iq3 / rv3;
ggml_vec_dot_f16_unroll(nev0, nbv1,
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),

View File

@ -4220,6 +4220,10 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * kqv;
if (supports_flash_attn) {
//printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
//printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]);
//printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]);
//printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]);
kqv = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale);
} else {
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);