mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 20:14:29 +00:00
llama : add llm_build_kv_store helper
ggml-ci
This commit is contained in:
parent
909d64471b
commit
3e0462594b
362
llama.cpp
362
llama.cpp
@ -3291,6 +3291,44 @@ static void llm_build_k_shift(
|
||||
}
|
||||
}
|
||||
|
||||
static void llm_build_kv_store(
|
||||
const llama_context & lctx,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * graph,
|
||||
struct ggml_tensor * k_cur,
|
||||
struct ggml_tensor * v_cur,
|
||||
int32_t n_tokens,
|
||||
int32_t kv_head,
|
||||
const llm_build_cb & cb,
|
||||
int64_t il) {
|
||||
const auto & model = lctx.model;
|
||||
const auto & kv_self = lctx.kv_self;
|
||||
const auto & cparams = lctx.cparams;
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int64_t n_ctx = cparams.n_ctx;
|
||||
const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
||||
|
||||
// compute the transposed [n_tokens, n_embd] V matrix
|
||||
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_tokens));
|
||||
//struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
|
||||
cb(v_cur_t, "v_cur_t", il);
|
||||
|
||||
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv_self.k, n_tokens*n_embd_gqa,
|
||||
(ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
|
||||
cb(k_cache_view, "k_cache_view", il);
|
||||
|
||||
struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv_self.v, n_tokens, n_embd_gqa,
|
||||
( n_ctx)*ggml_element_size(kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
|
||||
cb(v_cache_view, "v_cache_view", il);
|
||||
|
||||
// important: storing RoPE-ed version of K in the KV cache!
|
||||
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
|
||||
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view));
|
||||
}
|
||||
|
||||
static struct ggml_cgraph * llm_build_llama(
|
||||
llama_context & lctx,
|
||||
const llama_batch & batch,
|
||||
@ -3385,40 +3423,22 @@ static struct ggml_cgraph * llm_build_llama(
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
||||
cb(tmpk, "tmpk", il);
|
||||
|
||||
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||
cb(tmpq, "tmpq", il);
|
||||
|
||||
struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
|
||||
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
// store key and value to memory
|
||||
{
|
||||
// compute the transposed [n_tokens, n_embd] V matrix
|
||||
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
||||
cb(tmpv, "tmpv", il);
|
||||
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens));
|
||||
cb(Vcur, "Vcur", il);
|
||||
Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
|
||||
cb(k, "k", il);
|
||||
Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
|
||||
( n_ctx)*ggml_element_size(kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
|
||||
cb(v, "v", il);
|
||||
|
||||
// important: storing RoPE-ed version of K in the KV cache!
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
|
||||
|
||||
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
cb(Q, "Q", il);
|
||||
@ -3619,53 +3639,31 @@ static struct ggml_cgraph * llm_build_baichaun(
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
||||
cb(tmpk, "tmpk", il);
|
||||
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||
cb(tmpq, "tmpq", il);
|
||||
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
struct ggml_tensor * Kcur;
|
||||
struct ggml_tensor * Qcur;
|
||||
switch (model.type) {
|
||||
case MODEL_7B:
|
||||
Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
|
||||
Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
|
||||
Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
|
||||
Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
|
||||
break;
|
||||
case MODEL_13B:
|
||||
Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, n_tokens);
|
||||
Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, n_tokens);
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd/n_head, n_head, n_tokens);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
// store key and value to memory
|
||||
{
|
||||
// compute the transposed [n_tokens, n_embd] V matrix
|
||||
|
||||
struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
||||
cb(tmpv, "tmpv", il);
|
||||
|
||||
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens));
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
|
||||
cb(k, "k", il);
|
||||
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
|
||||
( n_ctx)*ggml_element_size(kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
|
||||
cb(v, "v", il);
|
||||
|
||||
// important: storing RoPE-ed version of K in the KV cache!
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
|
||||
|
||||
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
cb(Q, "Q", il);
|
||||
@ -3865,14 +3863,14 @@ static struct ggml_cgraph * llm_build_falcon(
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * attn_norm;
|
||||
|
||||
attn_norm = llm_build_norm(ctx0, inpL,
|
||||
model.layers[il].attn_norm,
|
||||
model.layers[il].attn_norm_b,
|
||||
LLM_NORM, norm_eps, cb, il);
|
||||
cb(attn_norm, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
attn_norm = llm_build_norm(ctx0, inpL,
|
||||
model.layers[il].attn_norm,
|
||||
model.layers[il].attn_norm_b,
|
||||
LLM_NORM, norm_eps, cb, il);
|
||||
cb(attn_norm, "attn_norm", il);
|
||||
|
||||
if (model.layers[il].attn_norm_2) {
|
||||
// Falcon-40B
|
||||
cur = llm_build_norm(ctx0, attn_norm,
|
||||
@ -3885,7 +3883,6 @@ static struct ggml_cgraph * llm_build_falcon(
|
||||
}
|
||||
|
||||
// compute QKV
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
|
||||
cb(cur, "wqkv", il);
|
||||
|
||||
@ -3902,52 +3899,35 @@ static struct ggml_cgraph * llm_build_falcon(
|
||||
|
||||
// TODO: these 2 ggml_conts are technically not needed, but we add them until CUDA support for
|
||||
// non-contiguous views is added for the rope operator
|
||||
struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_3d(
|
||||
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_3d(
|
||||
ctx0, cur, n_embd_head, n_head, n_tokens,
|
||||
wsize * n_embd_head,
|
||||
wsize * n_embd_head * (n_head + 2 * n_head_kv),
|
||||
0));
|
||||
cb(tmpq, "tmpq", il);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_3d(
|
||||
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_3d(
|
||||
ctx0, cur, n_embd_head, n_head_kv, n_tokens,
|
||||
wsize * n_embd_head,
|
||||
wsize * n_embd_head * (n_head + 2 * n_head_kv),
|
||||
wsize * n_embd_head * n_head));
|
||||
cb(tmpk, "tmpk", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
struct ggml_tensor * tmpv = ggml_view_3d(
|
||||
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(
|
||||
ctx0, cur, n_embd_head, n_head_kv, n_tokens,
|
||||
wsize * n_embd_head,
|
||||
wsize * n_embd_head * (n_head + 2 * n_head_kv),
|
||||
wsize * n_embd_head * (n_head + n_head_kv));
|
||||
cb(tmpv, "tmpv", il);
|
||||
wsize * n_embd_head * (n_head + n_head_kv)));
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
// using mode = 2 for neox mode
|
||||
struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, tmpq, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale);
|
||||
Qcur = ggml_rope_custom(ctx0, Qcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, tmpk, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale);
|
||||
Kcur = ggml_rope_custom(ctx0, Kcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
{
|
||||
struct ggml_tensor * Vcur = ggml_cont(ctx0, tmpv);
|
||||
cb(Vcur, "Vcur_0", il);
|
||||
|
||||
Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens));
|
||||
cb(Vcur, "Vcur_1", il);
|
||||
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
|
||||
cb(k, "k", il);
|
||||
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
|
||||
( n_ctx)*ggml_element_size(kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
|
||||
cb(v, "v", il);
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
|
||||
|
||||
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
cb(Q, "Q", il);
|
||||
@ -4118,40 +4098,25 @@ static struct ggml_cgraph * llm_build_starcoder(
|
||||
LLM_NORM, norm_eps, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// Self Attention
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
|
||||
cb(cur, "wqkv", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||
cb(cur, "bqkv", il);
|
||||
|
||||
struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
struct ggml_tensor * tmpv = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||
|
||||
cb(tmpq, "tmpq", il);
|
||||
cb(tmpk, "tmpk", il);
|
||||
cb(tmpv, "tmpv", il);
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
struct ggml_tensor * Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens);
|
||||
struct ggml_tensor * Kcur = tmpk;
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
|
||||
{
|
||||
struct ggml_tensor * Vcur = ggml_transpose(ctx0, tmpv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
|
||||
cb(k, "k", il);
|
||||
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
|
||||
( n_ctx)*ggml_element_size(kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
|
||||
cb(v, "v", il);
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
|
||||
|
||||
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
cb(Q, "Q", il);
|
||||
@ -4441,34 +4406,16 @@ static struct ggml_cgraph * llm_build_persimmon(
|
||||
Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 2, 1, 0, 3));
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
{
|
||||
struct ggml_tensor * tmpv = ggml_view_3d(
|
||||
ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens,
|
||||
ggml_element_size(tmpqkv_perm) * n_embd_head,
|
||||
ggml_element_size(tmpqkv_perm) * n_embd_head * n_head,
|
||||
ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens * 2
|
||||
struct ggml_tensor * Vcur = ggml_view_3d(
|
||||
ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens,
|
||||
ggml_element_size(tmpqkv_perm) * n_embd_head,
|
||||
ggml_element_size(tmpqkv_perm) * n_embd_head * n_head,
|
||||
ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens * 2
|
||||
);
|
||||
cb(tmpv, "tmpv", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
// store K, V in cache
|
||||
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens));
|
||||
cb(Vcur, "Vcur", il);
|
||||
llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
|
||||
|
||||
struct ggml_tensor * k = ggml_view_1d(
|
||||
ctx0, kv_self.k, n_tokens*n_embd_gqa,
|
||||
(ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)
|
||||
);
|
||||
cb(k, "k", il);
|
||||
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
|
||||
( n_ctx)*ggml_element_size(kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
|
||||
cb(v, "v", il);
|
||||
|
||||
// important: storing RoPE-ed version of K in the KV cache!
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k,
|
||||
n_embd_head, n_kv, n_head_kv,
|
||||
ggml_element_size(kv_self.k)*n_embd_gqa,
|
||||
@ -4632,40 +4579,22 @@ static struct ggml_cgraph * llm_build_refact(
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K
|
||||
struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
||||
cb(tmpk, "tmpk", il);
|
||||
|
||||
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||
cb(tmpq, "tmpq", il);
|
||||
|
||||
struct ggml_tensor * Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
struct ggml_tensor * Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens);
|
||||
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
// store key and value to memory
|
||||
{
|
||||
// compute the transposed [n_tokens, n_embd] V matrix
|
||||
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
||||
cb(tmpv, "tmpv", il);
|
||||
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens));
|
||||
cb(Vcur, "Vcur", il);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
|
||||
cb(k, "k", il);
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
|
||||
( n_ctx)*ggml_element_size(kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
|
||||
cb(v, "v", il);
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
|
||||
|
||||
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
cb(Q, "Q", il);
|
||||
@ -4852,48 +4781,27 @@ static struct ggml_cgraph * llm_build_bloom(
|
||||
LLM_NORM, norm_eps, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// Self Attention
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
|
||||
cb(cur, "wqkv", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||
cb(cur, "bqkv", il);
|
||||
|
||||
struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
struct ggml_tensor * tmpv = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||
|
||||
cb(tmpq, "tmpq", il);
|
||||
cb(tmpk, "tmpk", il);
|
||||
cb(tmpv, "tmpv", il);
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
struct ggml_tensor * Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens);
|
||||
struct ggml_tensor * Kcur = tmpk;
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
|
||||
// store key and value to memory
|
||||
{
|
||||
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens));
|
||||
cb(Vcur, "Vcur", il);
|
||||
llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
|
||||
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
|
||||
cb(k, "k", il);
|
||||
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
|
||||
( n_ctx)*ggml_element_size(kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
|
||||
cb(v, "v", il);
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctx0,
|
||||
ggml_cpy(ctx0,
|
||||
Qcur,
|
||||
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd_head, n_head, n_tokens)),
|
||||
0, 2, 1, 3);
|
||||
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
cb(Q, "Q", il);
|
||||
|
||||
struct ggml_tensor * K =
|
||||
@ -5075,8 +4983,6 @@ static struct ggml_cgraph * llm_build_mpt(
|
||||
{
|
||||
cur = attn_norm;
|
||||
|
||||
// compute QKV
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
|
||||
cb(cur, "wqkv", il);
|
||||
|
||||
@ -5085,47 +4991,17 @@ static struct ggml_cgraph * llm_build_mpt(
|
||||
cb(cur, "wqkv_clamped", il);
|
||||
}
|
||||
|
||||
const size_t wsize = ggml_type_size(cur->type);
|
||||
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||
|
||||
struct ggml_tensor * Qcur = ggml_view_3d(
|
||||
ctx0, cur, n_embd_head, n_head, n_tokens,
|
||||
wsize * n_embd_head,
|
||||
wsize * n_embd_head * (n_head + 2 * n_head_kv),
|
||||
0);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
struct ggml_tensor * Kcur = ggml_view_3d(
|
||||
ctx0, cur, n_embd_head, n_head_kv, n_tokens,
|
||||
wsize * n_embd_head,
|
||||
wsize * n_embd_head * (n_head + 2 * n_head_kv),
|
||||
wsize * n_embd_head * n_head);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
struct ggml_tensor * tmpv = ggml_view_3d(
|
||||
ctx0, cur, n_embd_head, n_head_kv, n_tokens,
|
||||
wsize * n_embd_head,
|
||||
wsize * n_embd_head * (n_head + 2 * n_head_kv),
|
||||
wsize * n_embd_head * (n_head + n_head_kv));
|
||||
cb(tmpv, "tmpv", il);
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
|
||||
{
|
||||
struct ggml_tensor * Vcur = ggml_cont(ctx0, tmpv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens));
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
|
||||
cb(k, "k", il);
|
||||
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
|
||||
( n_ctx)*ggml_element_size(kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
|
||||
cb(v, "v", il);
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
|
||||
|
||||
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
cb(Q, "Q", il);
|
||||
|
Loading…
Reference in New Issue
Block a user