ggml : online attention (CPU)

This commit is contained in:
Georgi Gerganov 2024-01-20 12:26:49 +02:00
parent c3cdfffa88
commit a9681febd6
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
6 changed files with 231 additions and 198 deletions

View File

@ -2207,9 +2207,15 @@ static bool ggml_metal_graph_compute(
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20]; [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20];
[encoder setBytes:&scale length:sizeof( float) atIndex:21]; [encoder setBytes:&scale length:sizeof( float) atIndex:21];
const int nwarps = 4;
// each warp needs n_embd_head elements
GGML_ASSERT(nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:nwarps*ne00*sizeof(float) atIndex:0];
const int nth = MIN(1024, ne0); const int nth = MIN(1024, ne0);
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)];
} break; } break;
case GGML_OP_DUP: case GGML_OP_DUP:
case GGML_OP_CPY: case GGML_OP_CPY:

View File

@ -1982,6 +1982,7 @@ kernel void kernel_flash_attn_ext_f16(
constant uint64_t & nb2, constant uint64_t & nb2,
constant uint64_t & nb3, constant uint64_t & nb3,
constant float & scale, constant float & scale,
threadgroup float * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]], uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) { uint3 ntg[[threads_per_threadgroup]]) {

257
ggml.c
View File

@ -817,7 +817,7 @@ do { \
#if defined(__F16C__) #if defined(__F16C__)
// the _mm256_cvt intrinsics require F16C // the _mm256_cvt intrinsics require F16C
#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) #define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
#else #else
static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) { static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
@ -1323,6 +1323,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
#endif #endif
} }
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) {
#if defined(GGML_SIMD)
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
GGML_F16_VEC ax[GGML_F16_ARR];
GGML_F16_VEC ay[GGML_F16_ARR];
for (int i = 0; i < np; i += GGML_F16_STEP) {
for (int j = 0; j < GGML_F16_ARR; j++) {
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
}
}
// leftovers
for (int i = np; i < n; ++i) {
y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v);
}
#else
// scalar
for (int i = 0; i < n; ++i) {
y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v);
}
#endif
}
// xs and vs are byte strides of x and v // xs and vs are byte strides of x and v
inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) { inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
@ -1407,6 +1438,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
#endif #endif
} }
inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
#if defined(GGML_SIMD)
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
GGML_F16_VEC ay[GGML_F16_ARR];
for (int i = 0; i < np; i += GGML_F16_STEP) {
for (int j = 0; j < GGML_F16_ARR; j++) {
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
}
}
// leftovers
for (int i = np; i < n; ++i) {
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
}
#else
// scalar
for (int i = 0; i < n; ++i) {
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
}
#endif
}
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); } inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); }
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
@ -5704,8 +5764,9 @@ struct ggml_tensor * ggml_flash_attn_ext(
is_node = true; is_node = true;
} }
//struct ggml_tensor * result = ggml_dup_tensor(ctx, q); // permute(0, 2, 1, 3)
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne); int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, ne);
float params[] = { scale }; float params[] = { scale };
ggml_set_op_params(result, params, sizeof(params)); ggml_set_op_params(result, params, sizeof(params));
@ -13281,12 +13342,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int64_t D = neq0; const int64_t D = neq0;
const int64_t N = neq1; const int64_t N = neq1;
const int64_t P = nek1 - N; const int64_t P = nek1 - N;
const int64_t M = P + N;
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
GGML_ASSERT(ne0 == D); GGML_ASSERT(ne0 == D);
GGML_ASSERT(ne1 == N); GGML_ASSERT(ne2 == N);
GGML_ASSERT(P >= 0); GGML_ASSERT(P >= 0);
GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t));
@ -13295,11 +13353,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
GGML_ASSERT(neq0 == D); GGML_ASSERT(neq0 == D);
GGML_ASSERT(nek0 == D); GGML_ASSERT(nek0 == D);
GGML_ASSERT(nev1 == D); GGML_ASSERT(nev0 == D);
GGML_ASSERT(neq1 == N); GGML_ASSERT(neq1 == N);
GGML_ASSERT(nek1 == N + P); GGML_ASSERT(nek1 == N + P);
GGML_ASSERT(nev1 == D); GGML_ASSERT(nev0 == D);
// dst cannot be transposed or permuted // dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float)); GGML_ASSERT(nb0 == sizeof(float));
@ -13339,151 +13397,87 @@ static void ggml_compute_forward_flash_attn_ext_f16(
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
// loop over n_batch and n_head
for (int ir = ir0; ir < ir1; ++ir) { for (int ir = ir0; ir < ir1; ++ir) {
// q indices // q indices
const int iq3 = ir/(neq2*neq1); const int iq3 = ir/(neq2*neq1);
const int iq2 = (ir - iq3*neq2*neq1)/neq1; const int iq2 = (ir - iq3*neq2*neq1)/neq1;
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32); float S = 0.0f;
float M = -INFINITY;
for (int i = M; i < Mup; ++i) { float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
S[i] = -INFINITY; ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
}
if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { memset(V16, 0, D*sizeof(ggml_fp16_t));
const float * mp = mask ? (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]) : NULL;
// k indices
const int ik3 = iq3 / rk3;
const int ik2 = iq2 / rk2;
// v indices
const int iv2 = iq2 / rv2;
const int iv3 = iq3 / rv3;
// online softmax / attention
// loop over n_kv and n_head_kv
// ref: https://arxiv.org/pdf/2112.05682.pdf
for (int64_t ic = 0; ic < nek1; ++ic) { for (int64_t ic = 0; ic < nek1; ++ic) {
// k indices const float mv = mp ? mp[ic] : 0.0f;
const int ik3 = iq3 / rk3; if (mv == -INFINITY) {
const int ik2 = iq2 / rk2; continue;
const int ik1 = ic; }
// S indices float s;
const int i1 = ik1;
ggml_vec_dot_f16(neq0, ggml_vec_dot_f16(D,
S + i1, &s,
(ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
}
s = s*scale + mv;
const float Mold = M;
float ms = 1.0f;
float vs = 1.0f;
if (s > M) {
M = s;
ms = expf(Mold - M);
// V = V*expf(Mold - M)
ggml_vec_scale_f16(D, V16, ms);
} else { } else {
for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { vs = expf(s - M);
// k indices
const int ik3 = iq3 / rk3;
const int ik2 = iq2 / rk2;
const int ik1 = ic;
// S indices
const int i1 = ik1;
ggml_vec_dot_f16_unroll(neq0, nbk1,
S + i1,
((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
}
} }
// scale const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
ggml_vec_scale_f32(nek1, S, scale);
if (mask) { // V += v*expf(s - M)
const float * mp = (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]); ggml_vec_mad_f16(D, V16, v16, vs);
ggml_vec_acc_f32(M, S, mp);
S = S*ms + vs;
} }
// softmax // V /= S
// todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero. for (int64_t d = 0; d < D; ++d) {
// dont forget to set their S values to zero V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
{
float max = -INFINITY;
ggml_vec_max_f32(M, &max, S);
ggml_float sum = 0.0;
{
#ifdef GGML_SOFT_MAX_ACCELERATE
max = -max;
vDSP_vsadd(S, 1, &max, S, 1, Mup);
vvexpf(S, S, &Mup);
ggml_vec_sum_f32(Mup, &sum, S);
#else
uint16_t scvt[GGML_SOFT_MAX_UNROLL];
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
float * SS = S + i;
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
if (SS[j] == -INFINITY) {
SS[j] = 0.0f;
} else {
ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
memcpy(&scvt[j], &s, sizeof(uint16_t));
const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
sump[j] += (ggml_float)val;
SS[j] = val;
}
}
} }
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
sum += sump[i];
}
#endif
}
assert(sum > 0.0);
sum = 1.0/sum;
ggml_vec_scale_f32(M, S, sum);
#ifndef NDEBUG
for (int i = 0; i < M; ++i) {
assert(!isnan(S[i]));
assert(!isinf(S[i]));
}
#endif
}
ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
for (int64_t i = 0; i < M; i++) {
S16[i] = GGML_FP32_TO_FP16(S[i]);
}
// todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
for (int64_t ic = 0; ic < nev1; ++ic) {
// dst indices // dst indices
const int i1 = iq1; const int i1 = iq1;
const int i2 = iq2; const int i2 = iq2;
const int i3 = iq3; const int i3 = iq3;
// v indices // original
const int iv2 = iq2 / rv2; //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
const int iv3 = iq3 / rv3;
ggml_vec_dot_f16(nev0, // permute(0, 2, 1, 3)
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
(ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
S16);
}
} else {
for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
// v indices
const int iv2 = iq2 / rv2;
const int iv3 = iq3 / rv3;
ggml_vec_dot_f16_unroll(nev0, nbv1,
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
S16);
}
}
} }
} }
@ -17069,7 +17063,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
} break; } break;
case GGML_OP_FLASH_ATTN: case GGML_OP_FLASH_ATTN:
case GGML_OP_FLASH_ATTN_EXT:
{ {
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
@ -17081,6 +17074,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
} }
} break; } break;
case GGML_OP_FLASH_ATTN_EXT:
{
const int64_t ne00 = node->src[0]->ne[0]; // D
cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
} break;
case GGML_OP_FLASH_FF: case GGML_OP_FLASH_FF:
{ {
if (node->src[1]->type == GGML_TYPE_F32) { if (node->src[1]->type == GGML_TYPE_F32) {

5
ggml.h
View File

@ -1620,6 +1620,11 @@ extern "C" {
struct ggml_tensor * v, struct ggml_tensor * v,
bool masked); bool masked);
// q: [n_embd, n_batch, n_head, 1]
// k: [n_embd, n_kv, n_head_kv, 1]
// v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
// mask: [n_kv, n_batch, 1, 1]
// res: [n_embd, n_head, n_batch, 1] !! permuted !!
GGML_API struct ggml_tensor * ggml_flash_attn_ext( GGML_API struct ggml_tensor * ggml_flash_attn_ext(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * q, struct ggml_tensor * q,

View File

@ -95,6 +95,8 @@
#define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_NODES 8192
#define LLAMA_MAX_EXPERTS 8 #define LLAMA_MAX_EXPERTS 8
#define LLAMA_FLASH_ATTN
// //
// logging // logging
// //
@ -4167,23 +4169,34 @@ static void llm_build_kv_store(
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
// compute the transposed [n_tokens, n_embd] V matrix
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens));
//struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
cb(v_cur_t, "v_cur_t", il);
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa,
(ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head); (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
cb(k_cache_view, "k_cache_view", il); cb(k_cache_view, "k_cache_view", il);
// important: storing RoPE-ed version of K in the KV cache!
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
#if defined(LLAMA_FLASH_ATTN)
// NOTE: the V cache is not transposed when using FLASH attention !!
struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa,
(ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa))*kv_head);
cb(v_cache_view, "v_cache_view", il);
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
GGML_UNUSED(n_ctx);
#else
// compute the transposed [n_tokens, n_embd] V matrix
//struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens));
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
cb(v_cur_t, "v_cur_t", il);
struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
( n_ctx)*ggml_element_size(kv.v_l[il]), ( n_ctx)*ggml_element_size(kv.v_l[il]),
(kv_head)*ggml_element_size(kv.v_l[il])); (kv_head)*ggml_element_size(kv.v_l[il]));
cb(v_cache_view, "v_cache_view", il);
// important: storing RoPE-ed version of K in the KV cache!
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view));
#endif
} }
static struct ggml_tensor * llm_build_norm( static struct ggml_tensor * llm_build_norm(
@ -4343,27 +4356,27 @@ static struct ggml_tensor * llm_build_kqv(
0); 0);
cb(k, "k", il); cb(k, "k", il);
// split cached v into n_head heads struct ggml_tensor * cur;
#if defined(LLAMA_FLASH_ATTN)
// split cached v into n_head heads (not transposed)
struct ggml_tensor * v = struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il], ggml_view_3d(ctx, kv.v_l[il],
n_kv, n_embd_head_v, n_head_kv, n_embd_head_v, n_kv, n_head_kv,
ggml_element_size(kv.v_l[il])*n_ctx, ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa),
ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, ggml_row_size(kv.v_l[il]->type, n_embd_head_k),
0); 0);
cb(v, "v", il); cb(v, "v", il);
// TODO: determine if we can use flash attention cur = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale);
const bool supports_flash_attn = true;
struct ggml_tensor * kqv;
if (supports_flash_attn) {
//printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
//printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]);
//printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]);
//printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]);
kqv = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); //printf("r: %4d %4d %4d %4d\n", kqv->ne[0], kqv->ne[1], kqv->ne[2], kqv->ne[3]);
} else {
cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens);
#else
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il); cb(kq, "kq", il);
@ -4396,15 +4409,24 @@ static struct ggml_tensor * llm_build_kqv(
cb(kq, "kq_soft_max_ext", il); cb(kq, "kq_soft_max_ext", il);
} }
kqv = ggml_mul_mat(ctx, v, kq); // split cached v into n_head heads (transposed)
struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il],
n_kv, n_embd_head_v, n_head_kv,
ggml_element_size(kv.v_l[il])*n_ctx,
ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v,
0);
cb(v, "v", il);
struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
cb(kqv, "kqv", il); cb(kqv, "kqv", il);
}
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
cb(kqv_merged, "kqv_merged", il); cb(kqv_merged, "kqv_merged", il);
struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
cb(cur, "kqv_merged_cont", il); cb(cur, "kqv_merged_cont", il);
#endif
cur = ggml_mul_mat(ctx, wo, cur); cur = ggml_mul_mat(ctx, wo, cur);
if (wo_b) { if (wo_b) {

View File

@ -1390,21 +1390,21 @@ struct test_flash_attn_ext : public test_case {
const int64_t hs; // head size const int64_t hs; // head size
const int64_t nh; // num heads const int64_t nh; // num heads
const int64_t kv; // kv size const int64_t kv; // kv size
const int64_t nt; // tokens const int64_t nb; // batch size
std::string vars() override { std::string vars() override {
return VARS_TO_STR5(typeq, hs, nh, kv, nt); return VARS_TO_STR5(typeq, hs, nh, kv, nb);
} }
test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16,
int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nt = 8) int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
: typeq(typeq), hs(hs), nh(nh), kv(kv), nt(nt) {} : typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nt, nh, 1); ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nb, nh, 1);
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nt, 1, 1); ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1);
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs));
return out; return out;
} }