mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 21:39:52 +00:00
ggml: implement quantized KV cache for FA (#7372)
This commit is contained in:
parent
1b01f06db0
commit
5ca49cbecd
115
ggml.c
115
ggml.c
@ -15882,9 +15882,10 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
GGML_ASSERT(ne0 == D);
|
||||
GGML_ASSERT(ne2 == N);
|
||||
|
||||
GGML_ASSERT(nbq0 == sizeof(float));
|
||||
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
|
||||
GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
|
||||
// input tensor rows must be contiguous
|
||||
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
|
||||
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
||||
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
||||
|
||||
GGML_ASSERT(neq0 == D);
|
||||
GGML_ASSERT(nek0 == D);
|
||||
@ -15938,6 +15939,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
enum ggml_type const k_vec_dot_type = type_traits[k->type].vec_dot_type;
|
||||
ggml_from_float_t const q_to_vec_dot = type_traits[k_vec_dot_type].from_float;
|
||||
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
|
||||
ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
|
||||
|
||||
// loop over n_batch and n_head
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
// q indices
|
||||
@ -15945,17 +15951,22 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
||||
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
||||
|
||||
const uint32_t h = iq2; // head
|
||||
const uint32_t h = iq2; // head index
|
||||
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
||||
|
||||
float S = 0.0f;
|
||||
float M = -INFINITY;
|
||||
float S = 0.0f; // sum
|
||||
float M = -INFINITY; // maximum KQ value
|
||||
|
||||
float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
|
||||
ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
|
||||
ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
|
||||
float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
|
||||
float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
|
||||
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
|
||||
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
|
||||
|
||||
memset(V16, 0, D*sizeof(ggml_fp16_t));
|
||||
if (v->type == GGML_TYPE_F16) {
|
||||
memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
|
||||
} else {
|
||||
memset(VKQ32, 0, D*sizeof(float));
|
||||
}
|
||||
|
||||
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
|
||||
|
||||
@ -15967,6 +15978,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
const int iv3 = iq3 / rv3;
|
||||
const int iv2 = iq2 / rv2;
|
||||
|
||||
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
|
||||
q_to_vec_dot(pq, Q_q, D);
|
||||
|
||||
// online softmax / attention
|
||||
// loop over n_kv and n_head_kv
|
||||
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
||||
@ -15976,51 +15990,66 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
continue;
|
||||
}
|
||||
|
||||
float s;
|
||||
float s; // KQ value
|
||||
|
||||
// convert Q to F16 in V32
|
||||
{
|
||||
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
|
||||
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
|
||||
|
||||
for (int64_t d = 0; d < D; ++d) {
|
||||
Q16[d] = GGML_FP32_TO_FP16(pq[d]);
|
||||
}
|
||||
}
|
||||
|
||||
ggml_vec_dot_f16(D,
|
||||
&s, 0,
|
||||
(ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
|
||||
Q16, 0, 1);
|
||||
|
||||
s = s*scale + mv;
|
||||
s = s*scale + mv; // scale KQ value and apply mask
|
||||
|
||||
const float Mold = M;
|
||||
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
|
||||
float vs = 1.0f; // post-softmax KQ value, expf(s - M)
|
||||
|
||||
if (s > M) {
|
||||
M = s;
|
||||
ms = expf(Mold - M);
|
||||
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
|
||||
// V = V*expf(Mold - M)
|
||||
ggml_vec_scale_f16(D, V16, ms);
|
||||
if (v->type== GGML_TYPE_F16) {
|
||||
if (s > M) {
|
||||
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
||||
M = s;
|
||||
ms = expf(Mold - M);
|
||||
|
||||
// V = V*expf(Mold - M)
|
||||
ggml_vec_scale_f16(D, VKQ16, ms);
|
||||
} else {
|
||||
// no new maximum, ms == 1.0f, vs != 1.0f
|
||||
vs = expf(s - M);
|
||||
}
|
||||
|
||||
// V += v*expf(s - M)
|
||||
ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
|
||||
} else {
|
||||
vs = expf(s - M);
|
||||
if (s > M) {
|
||||
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
||||
M = s;
|
||||
ms = expf(Mold - M);
|
||||
|
||||
// V = V*expf(Mold - M)
|
||||
ggml_vec_scale_f32(D, VKQ32, ms);
|
||||
} else {
|
||||
// no new maximum, ms == 1.0f, vs != 1.0f
|
||||
vs = expf(s - M);
|
||||
}
|
||||
|
||||
v_to_float(v_data, V32, D);
|
||||
|
||||
// V += v*expf(s - M)
|
||||
ggml_vec_mad_f32(D, VKQ32, V32, vs);
|
||||
}
|
||||
|
||||
const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
S = S*ms + vs; // scale and increment sum with partial sum
|
||||
}
|
||||
|
||||
// V += v*expf(s - M)
|
||||
ggml_vec_mad_f16(D, V16, v16, vs);
|
||||
|
||||
S = S*ms + vs;
|
||||
if (v->type == GGML_TYPE_F16) {
|
||||
for (int64_t d = 0; d < D; ++d) {
|
||||
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
|
||||
}
|
||||
}
|
||||
|
||||
// V /= S
|
||||
for (int64_t d = 0; d < D; ++d) {
|
||||
V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
|
||||
}
|
||||
const float S_inv = 1.0f/S;
|
||||
ggml_vec_scale_f32(D, VKQ32, S_inv);
|
||||
|
||||
// dst indices
|
||||
const int i1 = iq1;
|
||||
@ -16031,7 +16060,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
|
||||
|
||||
// permute(0, 2, 1, 3)
|
||||
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
|
||||
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
|
||||
}
|
||||
}
|
||||
|
||||
@ -19972,7 +20001,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
||||
{
|
||||
const int64_t ne00 = node->src[0]->ne[0]; // D
|
||||
|
||||
cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
|
||||
cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
|
||||
} break;
|
||||
case GGML_OP_FLASH_FF:
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user