mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 02:14:35 +00:00
llama: Add support for Gemma2ForCausalLM (#8156)
* Inference support for Gemma 2 model family * Update convert-hf-to-gguf.py, constants, and tensor mappings * cleanup * format fix * Fix special token vocab bug * Don't add space prefix * fix deleted lines * Update src/llama.cpp Co-authored-by: slaren <slarengh@gmail.com> * Add model type names * Add control vector * Fix model type identification --------- Co-authored-by: Andrei Betlen <abetlen@gmail.com> Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
parent
a27aa50ab7
commit
e57dc62057
@ -2340,6 +2340,46 @@ class GemmaModel(Model):
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@Model.register("Gemma2ForCausalLM")
|
||||
class Gemma2Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.GEMMA2
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_llama_hf()
|
||||
self.gguf_writer.add_add_space_prefix(False)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
block_count = hparams["num_hidden_layers"]
|
||||
|
||||
self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
|
||||
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
|
||||
self.gguf_writer.add_key_length(hparams["head_dim"])
|
||||
self.gguf_writer.add_value_length(hparams["head_dim"])
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unusem
|
||||
|
||||
# lm_head is not used in llama.cpp, while autoawq will include this tensor in model
|
||||
# To prevent errors, skip loading lm_head.weight.
|
||||
if name == "lm_head.weight":
|
||||
logger.debug(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.")
|
||||
return []
|
||||
|
||||
# ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89
|
||||
if name.endswith("norm.weight"):
|
||||
data_torch = data_torch + 1
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@Model.register("Starcoder2ForCausalLM")
|
||||
class StarCoder2Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.STARCODER2
|
||||
|
@ -150,6 +150,7 @@ class MODEL_ARCH(IntEnum):
|
||||
INTERNLM2 = auto()
|
||||
MINICPM = auto()
|
||||
GEMMA = auto()
|
||||
GEMMA2 = auto()
|
||||
STARCODER2 = auto()
|
||||
MAMBA = auto()
|
||||
XVERSE = auto()
|
||||
@ -180,10 +181,13 @@ class MODEL_TENSOR(IntEnum):
|
||||
ATTN_NORM = auto()
|
||||
ATTN_NORM_2 = auto()
|
||||
ATTN_OUT_NORM = auto()
|
||||
ATTN_POST_NORM = auto()
|
||||
ATTN_ROT_EMBD = auto()
|
||||
FFN_GATE_INP = auto()
|
||||
FFN_GATE_INP_SHEXP = auto()
|
||||
FFN_NORM = auto()
|
||||
FFN_PRE_NORM = auto()
|
||||
FFN_POST_NORM = auto()
|
||||
FFN_GATE = auto()
|
||||
FFN_DOWN = auto()
|
||||
FFN_UP = auto()
|
||||
@ -270,6 +274,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.INTERNLM2: "internlm2",
|
||||
MODEL_ARCH.MINICPM: "minicpm",
|
||||
MODEL_ARCH.GEMMA: "gemma",
|
||||
MODEL_ARCH.GEMMA2: "gemma2",
|
||||
MODEL_ARCH.STARCODER2: "starcoder2",
|
||||
MODEL_ARCH.MAMBA: "mamba",
|
||||
MODEL_ARCH.XVERSE: "xverse",
|
||||
@ -303,9 +308,12 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
|
||||
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
|
||||
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
|
||||
MODEL_TENSOR.ATTN_POST_NORM: "blk.{bid}.post_attention_norm",
|
||||
MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp",
|
||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp",
|
||||
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
|
||||
MODEL_TENSOR.FFN_PRE_NORM: "blk.{bid}.ffn_norm",
|
||||
MODEL_TENSOR.FFN_POST_NORM: "blk.{bid}.post_ffw_norm",
|
||||
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
|
||||
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
|
||||
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
|
||||
@ -751,6 +759,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
],
|
||||
MODEL_ARCH.GEMMA2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_PRE_NORM,
|
||||
MODEL_TENSOR.FFN_POST_NORM,
|
||||
],
|
||||
MODEL_ARCH.STARCODER2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
@ -187,6 +187,10 @@ class TensorNameMap:
|
||||
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_POST_NORM: (
|
||||
"model.layers.{bid}.post_attention_layernorm", # gemma2
|
||||
),
|
||||
|
||||
# Rotary embeddings
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD: (
|
||||
"model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
|
||||
@ -210,6 +214,16 @@ class TensorNameMap:
|
||||
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
MODEL_TENSOR.FFN_PRE_NORM: (
|
||||
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
MODEL_TENSOR.FFN_POST_NORM: (
|
||||
"model.layers.{bid}.post_feedforward_layernorm", # gemma2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP: (
|
||||
"layers.{bid}.feed_forward.gate", # mixtral
|
||||
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
|
||||
|
198
src/llama.cpp
198
src/llama.cpp
@ -217,6 +217,7 @@ enum llm_arch {
|
||||
LLM_ARCH_INTERNLM2,
|
||||
LLM_ARCH_MINICPM,
|
||||
LLM_ARCH_GEMMA,
|
||||
LLM_ARCH_GEMMA2,
|
||||
LLM_ARCH_STARCODER2,
|
||||
LLM_ARCH_MAMBA,
|
||||
LLM_ARCH_XVERSE,
|
||||
@ -257,6 +258,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_INTERNLM2, "internlm2" },
|
||||
{ LLM_ARCH_MINICPM, "minicpm" },
|
||||
{ LLM_ARCH_GEMMA, "gemma" },
|
||||
{ LLM_ARCH_GEMMA2, "gemma2" },
|
||||
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
||||
{ LLM_ARCH_MAMBA, "mamba" },
|
||||
{ LLM_ARCH_XVERSE, "xverse" },
|
||||
@ -478,10 +480,12 @@ enum llm_tensor {
|
||||
LLM_TENSOR_ATTN_NORM,
|
||||
LLM_TENSOR_ATTN_NORM_2,
|
||||
LLM_TENSOR_ATTN_OUT_NORM,
|
||||
LLM_TENSOR_ATTN_POST_NORM,
|
||||
LLM_TENSOR_ATTN_ROT_EMBD,
|
||||
LLM_TENSOR_FFN_GATE_INP,
|
||||
LLM_TENSOR_FFN_GATE_INP_SHEXP,
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
LLM_TENSOR_FFN_POST_NORM,
|
||||
LLM_TENSOR_FFN_GATE,
|
||||
LLM_TENSOR_FFN_DOWN,
|
||||
LLM_TENSOR_FFN_UP,
|
||||
@ -1004,6 +1008,24 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_GEMMA2,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_STARCODER2,
|
||||
{
|
||||
@ -2039,6 +2061,8 @@ enum e_model {
|
||||
MODEL_16x12B,
|
||||
MODEL_10B_128x3_66B,
|
||||
MODEL_57B_A14B,
|
||||
MODEL_9B,
|
||||
MODEL_27B,
|
||||
};
|
||||
|
||||
static const size_t kiB = 1024;
|
||||
@ -2215,6 +2239,7 @@ struct llama_layer {
|
||||
struct ggml_tensor * attn_q_a_norm;
|
||||
struct ggml_tensor * attn_kv_a_norm;
|
||||
struct ggml_tensor * attn_sub_norm;
|
||||
struct ggml_tensor * attn_post_norm;
|
||||
struct ggml_tensor * ffn_sub_norm;
|
||||
|
||||
// attention
|
||||
@ -2238,6 +2263,7 @@ struct llama_layer {
|
||||
// normalization
|
||||
struct ggml_tensor * ffn_norm;
|
||||
struct ggml_tensor * ffn_norm_b;
|
||||
struct ggml_tensor * ffn_post_norm;
|
||||
struct ggml_tensor * layer_out_norm;
|
||||
struct ggml_tensor * layer_out_norm_b;
|
||||
struct ggml_tensor * ffn_norm_exps;
|
||||
@ -4269,6 +4295,8 @@ static const char * llama_model_type_name(e_model type) {
|
||||
case MODEL_16x12B: return "16x12B";
|
||||
case MODEL_10B_128x3_66B: return "10B+128x3.66B";
|
||||
case MODEL_57B_A14B: return "57B.A14B";
|
||||
case MODEL_9B: return "9B";
|
||||
case MODEL_27B: return "27B";
|
||||
default: return "?B";
|
||||
}
|
||||
}
|
||||
@ -4671,6 +4699,16 @@ static void llm_load_hparams(
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA2:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 42: model.type = e_model::MODEL_9B; break;
|
||||
case 46: model.type = e_model::MODEL_27B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
@ -6512,6 +6550,40 @@ static bool llm_load_tensors(
|
||||
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA2:
|
||||
{
|
||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
|
||||
// output
|
||||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
|
||||
|
||||
const int64_t n_ff = hparams.n_ff;
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
|
||||
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * hparams.n_head});
|
||||
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
|
||||
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
|
||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * hparams.n_head, n_embd});
|
||||
layer.attn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd});
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
||||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
|
||||
layer.ffn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd});
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
@ -10923,6 +10995,125 @@ struct llm_build_context {
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_gemma2() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||
|
||||
inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
|
||||
cb(inpL, "inp_scaled", -1);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
|
||||
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
|
||||
cb(Qcur, "Qcur_scaled", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
||||
}
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.layers[il].attn_post_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_post_norm", il);
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
||||
}
|
||||
|
||||
struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
|
||||
cb(sa_out, "sa_out", il);
|
||||
|
||||
cur = llm_build_norm(ctx0, sa_out, hparams,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// feed-forward network
|
||||
{
|
||||
cur = llm_build_ffn(ctx0, cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.layers[il].ffn_post_norm, NULL,
|
||||
LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "ffn_post_norm", -1);
|
||||
|
||||
cur = ggml_add(ctx0, cur, sa_out);
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
// lm_head
|
||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
|
||||
struct ggml_cgraph * build_starcoder2() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||
|
||||
@ -12303,6 +12494,10 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
{
|
||||
result = llm.build_gemma();
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA2:
|
||||
{
|
||||
result = llm.build_gemma2();
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
result = llm.build_starcoder2();
|
||||
@ -17597,6 +17792,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_PHI2:
|
||||
case LLM_ARCH_PHI3:
|
||||
case LLM_ARCH_GEMMA:
|
||||
case LLM_ARCH_GEMMA2:
|
||||
case LLM_ARCH_STARCODER2:
|
||||
case LLM_ARCH_GPTNEOX:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
@ -19486,7 +19682,7 @@ static int32_t llama_chat_apply_template_internal(
|
||||
if (add_ass) {
|
||||
ss << "<s>assistant\n";
|
||||
}
|
||||
} else if (tmpl == "gemma" || tmpl.find("<start_of_turn>") != std::string::npos) {
|
||||
} else if (tmpl == "gemma" || tmpl == "gemma2" || tmpl.find("<start_of_turn>") != std::string::npos) {
|
||||
// google/gemma-7b-it
|
||||
std::string system_prompt = "";
|
||||
for (auto message : chat) {
|
||||
|
Loading…
Reference in New Issue
Block a user