llama : compute BERT graph with F16 K, V

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-03-05 21:22:20 +02:00
parent 6cdabe6526
commit 0ba20ed97a
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -6175,7 +6175,7 @@ struct llm_build_context {
} }
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); struct ggml_tensor * k = ggml_cast(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3), GGML_TYPE_F16);
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
cb(kq, "kq", il); cb(kq, "kq", il);
@ -6183,7 +6183,7 @@ struct llm_build_context {
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias); kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il); cb(kq, "kq_soft_max_ext", il);
struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens))); struct ggml_tensor * v = ggml_cast(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)), GGML_TYPE_F16);
cb(v, "v", il); cb(v, "v", il);
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq); struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq);