mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 11:24:35 +00:00
llama : add llm_build_k_shift helper
ggml-ci
This commit is contained in:
parent
dbf836bb64
commit
38728a0be0
130
llama.cpp
130
llama.cpp
@ -3230,6 +3230,65 @@ static struct ggml_tensor * llm_build_ffn(
|
||||
return cur;
|
||||
}
|
||||
|
||||
enum llm_rope_type {
|
||||
LLM_ROPE,
|
||||
LLM_ROPE_NEOX,
|
||||
LLM_ROPE_GLM,
|
||||
};
|
||||
|
||||
// Persimmon: n_rot = n_embd_head/2
|
||||
// Other: n_rot = n_embd_head
|
||||
static void llm_build_k_shift(
|
||||
const llama_context & lctx,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * graph,
|
||||
int64_t n_rot,
|
||||
llm_rope_type type,
|
||||
const llm_build_cb & cb) {
|
||||
const auto & model = lctx.model;
|
||||
const auto & kv_self = lctx.kv_self;
|
||||
const auto & cparams = lctx.cparams;
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int64_t n_head = hparams.n_head;
|
||||
const int64_t n_layer = hparams.n_layer;
|
||||
const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
||||
const int64_t n_embd_head = hparams.n_embd_head();
|
||||
|
||||
const int64_t n_ctx = lctx.cparams.n_ctx;
|
||||
|
||||
const float freq_base = cparams.rope_freq_base;
|
||||
const float freq_scale = cparams.rope_freq_scale;
|
||||
|
||||
GGML_ASSERT(n_embd_head % n_rot == 0);
|
||||
|
||||
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_ctx);
|
||||
cb(K_shift, "K_shift", -1);
|
||||
|
||||
int rope_type = 0;
|
||||
|
||||
switch (type) {
|
||||
case LLM_ROPE: rope_type = 0; break;
|
||||
case LLM_ROPE_NEOX: rope_type = 2; break;
|
||||
case LLM_ROPE_GLM: rope_type = 4; break;
|
||||
};
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * tmp =
|
||||
// we rotate only the first n_rot dimensions
|
||||
ggml_rope_custom_inplace(ctx,
|
||||
ggml_view_3d(ctx, kv_self.k,
|
||||
n_rot, n_head, n_ctx,
|
||||
ggml_element_size(kv_self.k)*n_embd_head,
|
||||
ggml_element_size(kv_self.k)*n_embd_gqa,
|
||||
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
|
||||
K_shift, n_rot, rope_type, 0, freq_base, freq_scale);
|
||||
cb(tmp, "K_shifted", il);
|
||||
ggml_build_forward_expand(graph, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
static struct ggml_cgraph * llm_build_llama(
|
||||
llama_context & lctx,
|
||||
const llama_batch & batch,
|
||||
@ -3308,21 +3367,7 @@ static struct ggml_cgraph * llm_build_llama(
|
||||
|
||||
// shift the entire K-cache if needed
|
||||
if (do_rope_shift) {
|
||||
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
|
||||
cb(K_shift, "K_shift", -1);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * tmp =
|
||||
ggml_rope_custom_inplace(ctx0,
|
||||
ggml_view_3d(ctx0, kv_self.k,
|
||||
n_embd_head, n_head_kv, n_ctx,
|
||||
ggml_element_size(kv_self.k)*n_embd_head,
|
||||
ggml_element_size(kv_self.k)*n_embd_gqa,
|
||||
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
|
||||
K_shift, n_embd_head, 0, 0, freq_base, freq_scale);
|
||||
cb(tmp, "K_shifted", il);
|
||||
ggml_build_forward_expand(gf, tmp);
|
||||
}
|
||||
llm_build_k_shift(lctx, ctx0, gf, n_embd_head, LLM_ROPE, cb);
|
||||
}
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
@ -3557,21 +3602,7 @@ static struct ggml_cgraph * llm_build_baichaun(
|
||||
|
||||
// shift the entire K-cache if needed
|
||||
if (do_rope_shift) {
|
||||
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
|
||||
cb(K_shift, "K_shift", -1);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * tmp =
|
||||
ggml_rope_custom_inplace(ctx0,
|
||||
ggml_view_3d(ctx0, kv_self.k,
|
||||
n_embd_head, n_head_kv, n_ctx,
|
||||
ggml_element_size(kv_self.k)*n_embd_head,
|
||||
ggml_element_size(kv_self.k)*n_embd_gqa,
|
||||
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
|
||||
K_shift, n_embd_head, 0, 0, freq_base, freq_scale);
|
||||
cb(tmp, "K_shifted", il);
|
||||
ggml_build_forward_expand(gf, tmp);
|
||||
}
|
||||
llm_build_k_shift(lctx, ctx0, gf, n_embd_head, LLM_ROPE, cb);
|
||||
}
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
@ -3830,21 +3861,7 @@ static struct ggml_cgraph * llm_build_falcon(
|
||||
|
||||
// shift the entire K-cache if needed
|
||||
if (do_rope_shift) {
|
||||
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
|
||||
cb(K_shift, "K_shift", -1);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * tmp =
|
||||
ggml_rope_custom_inplace(ctx0,
|
||||
ggml_view_3d(ctx0, kv_self.k,
|
||||
n_embd_head, n_head_kv, n_ctx,
|
||||
ggml_element_size(kv_self.k)*n_embd_head,
|
||||
ggml_element_size(kv_self.k)*n_embd_gqa,
|
||||
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
|
||||
K_shift, n_embd_head, 2, 0, freq_base, freq_scale);
|
||||
cb(tmp, "K_shifted", il);
|
||||
ggml_build_forward_expand(gf, tmp);
|
||||
}
|
||||
llm_build_k_shift(lctx, ctx0, gf, n_embd_head, LLM_ROPE_NEOX, cb);
|
||||
}
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
@ -4243,6 +4260,7 @@ static struct ggml_cgraph * llm_build_persimmon(
|
||||
GGML_ASSERT(!!kv_self.ctx);
|
||||
|
||||
const auto & cparams = lctx.cparams;
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_layer = hparams.n_layer;
|
||||
const int64_t n_ctx = cparams.n_ctx;
|
||||
@ -4250,7 +4268,7 @@ static struct ggml_cgraph * llm_build_persimmon(
|
||||
const int64_t n_head = hparams.n_head;
|
||||
const int64_t n_embd_head = hparams.n_embd_head();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
||||
const size_t n_rot = n_embd_head / 2;
|
||||
const int64_t n_rot = n_embd_head / 2;
|
||||
|
||||
const float freq_base = cparams.rope_freq_base;
|
||||
const float freq_scale = cparams.rope_freq_scale;
|
||||
@ -4297,23 +4315,7 @@ static struct ggml_cgraph * llm_build_persimmon(
|
||||
cb(KQ_mask, "KQ_mask", -1);
|
||||
|
||||
if (do_rope_shift) {
|
||||
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
|
||||
cb(K_shift, "K_shift", -1);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * tmp =
|
||||
// we rotate only the first n_rot dimensions.
|
||||
ggml_rope_custom_inplace(ctx0,
|
||||
ggml_view_3d(ctx0, kv_self.k,
|
||||
n_rot, n_head, n_ctx,
|
||||
ggml_element_size(kv_self.k)*n_embd_gqa,
|
||||
ggml_element_size(kv_self.k)*n_embd_head,
|
||||
ggml_element_size(kv_self.k)*(n_embd_head*n_ctx*il)
|
||||
),
|
||||
K_shift, n_rot, 2, 0, freq_base, freq_scale);
|
||||
cb(tmp, "K_shifted", il);
|
||||
ggml_build_forward_expand(gf, tmp);
|
||||
}
|
||||
llm_build_k_shift(lctx, ctx0, gf, n_rot, LLM_ROPE_NEOX, cb);
|
||||
}
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
@ -5534,7 +5536,7 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
const bool do_offload = true;
|
||||
#else
|
||||
const bool do_offload = false;
|
||||
const bool do_offload = true; // TODO: set to false after finishing refactoring
|
||||
#endif
|
||||
|
||||
if (!do_offload) {
|
||||
|
Loading…
Reference in New Issue
Block a user