mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 20:04:35 +00:00
ggml : fix GQA support in ggml_flash_attn_ext
This commit is contained in:
parent
a1c004ef2e
commit
fa7ebcca99
@ -1960,10 +1960,10 @@ kernel void kernel_leaky_relu_f32(
|
||||
}
|
||||
|
||||
kernel void kernel_flash_attn_ext_f16(
|
||||
device const half * q,
|
||||
device const half * k,
|
||||
device const half * v,
|
||||
device const half * mask,
|
||||
device const half * q,
|
||||
device const half * k,
|
||||
device const half * v,
|
||||
device const float * mask,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
23
ggml.c
23
ggml.c
@ -13307,6 +13307,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
// broadcast factors
|
||||
const int64_t rk2 = neq2/nek2;
|
||||
const int64_t rk3 = neq3/nek3;
|
||||
|
||||
const int64_t rv2 = neq2/nev2;
|
||||
const int64_t rv3 = neq3/nev3;
|
||||
|
||||
if (params->type == GGML_TASK_INIT) {
|
||||
return;
|
||||
}
|
||||
@ -13347,8 +13354,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
|
||||
for (int64_t ic = 0; ic < nek1; ++ic) {
|
||||
// k indices
|
||||
const int ik3 = iq3;
|
||||
const int ik2 = iq2 % nek2;
|
||||
const int ik3 = iq3 / rk3;
|
||||
const int ik2 = iq2 / rk2;
|
||||
const int ik1 = ic;
|
||||
|
||||
// S indices
|
||||
@ -13362,8 +13369,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
} else {
|
||||
for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
|
||||
// k indices
|
||||
const int ik3 = iq3;
|
||||
const int ik2 = iq2 % nek2;
|
||||
const int ik3 = iq3 / rk3;
|
||||
const int ik2 = iq2 / rk2;
|
||||
const int ik1 = ic;
|
||||
|
||||
// S indices
|
||||
@ -13452,8 +13459,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
const int i3 = iq3;
|
||||
|
||||
// v indices
|
||||
const int iv2 = iq2 % nev2;
|
||||
const int iv3 = iq3;
|
||||
const int iv2 = iq2 / rv2;
|
||||
const int iv3 = iq3 / rv3;
|
||||
|
||||
ggml_vec_dot_f16(nev0,
|
||||
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
||||
@ -13468,8 +13475,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
const int i3 = iq3;
|
||||
|
||||
// v indices
|
||||
const int iv2 = iq2 % nev2;
|
||||
const int iv3 = iq3;
|
||||
const int iv2 = iq2 / rv2;
|
||||
const int iv3 = iq3 / rv3;
|
||||
|
||||
ggml_vec_dot_f16_unroll(nev0, nbv1,
|
||||
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
||||
|
@ -4220,6 +4220,10 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
struct ggml_tensor * kqv;
|
||||
|
||||
if (supports_flash_attn) {
|
||||
//printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
|
||||
//printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]);
|
||||
//printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]);
|
||||
//printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]);
|
||||
kqv = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale);
|
||||
} else {
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||
|
Loading…
Reference in New Issue
Block a user