llama : add SEA-LION support (#6448)

* initial commit for sealion support

* add sealion support

* minor fix

* q/k ln and pos_embd only if required

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* minor : clear whitespaces

---------

Co-authored-by: bryan <bryansiow@aisingapore.org>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
bryanSwk 2024-04-04 02:05:10 +08:00 committed by GitHub
parent 9f62c0173d
commit bb43cf7e9d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 65 additions and 4 deletions

View File

@ -118,6 +118,7 @@ Typically finetunes of the base models below are supported as well.
- [x] [Mamba](https://github.com/state-spaces/mamba) - [x] [Mamba](https://github.com/state-spaces/mamba)
- [x] [Xverse](https://huggingface.co/models?search=xverse) - [x] [Xverse](https://huggingface.co/models?search=xverse)
- [x] [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01) - [x] [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01)
- [x] [SEA-LION](https://huggingface.co/models?search=sea-lion)
**Multimodal models:** **Multimodal models:**

View File

@ -510,6 +510,16 @@ class BloomModel(Model):
class MPTModel(Model): class MPTModel(Model):
model_arch = gguf.MODEL_ARCH.MPT model_arch = gguf.MODEL_ARCH.MPT
def set_vocab(self):
try:
self._set_vocab_gpt2()
except:
self._set_vocab_sentencepiece()
self.gguf_writer.add_add_bos_token(False)
self.gguf_writer.add_pad_token_id(3)
self.gguf_writer.add_eos_token_id(1)
self.gguf_writer.add_unk_token_id(0)
def set_gguf_parameters(self): def set_gguf_parameters(self):
block_count = self.hparams["n_layers"] block_count = self.hparams["n_layers"]
self.gguf_writer.add_name(self.dir_model.name) self.gguf_writer.add_name(self.dir_model.name)
@ -523,7 +533,10 @@ class MPTModel(Model):
self.gguf_writer.add_layer_norm_eps(1e-5) self.gguf_writer.add_layer_norm_eps(1e-5)
if self.hparams["attn_config"]["clip_qkv"] is not None: if self.hparams["attn_config"]["clip_qkv"] is not None:
self.gguf_writer.add_clamp_kqv(self.hparams["attn_config"]["clip_qkv"]) self.gguf_writer.add_clamp_kqv(self.hparams["attn_config"]["clip_qkv"])
self.gguf_writer.add_max_alibi_bias(self.hparams["attn_config"]["alibi_bias_max"]) if self.hparams["attn_config"]["alibi"]:
self.gguf_writer.add_max_alibi_bias(self.hparams["attn_config"]["alibi_bias_max"])
else:
self.gguf_writer.add_max_alibi_bias(0.0)
def write_tensors(self): def write_tensors(self):
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers")) block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers"))

View File

@ -367,6 +367,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP, MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_ACT, MODEL_TENSOR.FFN_ACT,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.POS_EMBD,
], ],
MODEL_ARCH.GPTJ: [ MODEL_ARCH.GPTJ: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,

View File

@ -285,11 +285,13 @@ class TensorNameMap:
MODEL_TENSOR.ATTN_Q_NORM: ( MODEL_TENSOR.ATTN_Q_NORM: (
"language_model.encoder.layers.{bid}.self_attention.q_layernorm", "language_model.encoder.layers.{bid}.self_attention.q_layernorm",
"model.layers.{bid}.self_attn.q_layernorm", # persimmon "model.layers.{bid}.self_attn.q_layernorm", # persimmon
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
), ),
MODEL_TENSOR.ATTN_K_NORM: ( MODEL_TENSOR.ATTN_K_NORM: (
"language_model.encoder.layers.{bid}.self_attention.k_layernorm", "language_model.encoder.layers.{bid}.self_attention.k_layernorm",
"model.layers.{bid}.self_attn.k_layernorm", # persimmon "model.layers.{bid}.self_attn.k_layernorm", # persimmon
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
), ),
MODEL_TENSOR.ROPE_FREQS: ( MODEL_TENSOR.ROPE_FREQS: (

View File

@ -594,6 +594,9 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_ACT, "blk.%d.ffn.act" }, { LLM_TENSOR_FFN_ACT, "blk.%d.ffn.act" },
{ LLM_TENSOR_POS_EMBD, "position_embd" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm"},
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm"},
}, },
}, },
{ {
@ -4867,6 +4870,7 @@ static bool llm_load_tensors(
case LLM_ARCH_MPT: case LLM_ARCH_MPT:
{ {
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}, false);
// output // output
{ {
@ -4905,6 +4909,12 @@ static bool llm_load_tensors(
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, false); layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, false);
layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, false);
layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, false);
layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, false);
layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, false);
// AWQ ScaleActivation layer // AWQ ScaleActivation layer
layer.ffn_act = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, false); layer.ffn_act = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, false);
} }
@ -7721,6 +7731,7 @@ struct llm_build_context {
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * pos;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
@ -7731,6 +7742,16 @@ struct llm_build_context {
// positions of the tokens in the KV cache // positions of the tokens in the KV cache
struct ggml_tensor * KQ_pos = build_inp_KQ_pos(); struct ggml_tensor * KQ_pos = build_inp_KQ_pos();
if (model.pos_embd) {
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = build_inp_pos();
pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
cb(pos, "pos_embd", -1);
inpL = ggml_add(ctx0, inpL, pos);
cb(inpL, "inpL", -1);
}
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * attn_norm; struct ggml_tensor * attn_norm;
@ -7765,11 +7786,32 @@ struct llm_build_context {
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il); cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); // Q/K Layernorm
if (model.layers[il].attn_q_norm) {
Qcur = llm_build_norm(ctx0, Qcur, hparams,
model.layers[il].attn_q_norm,
model.layers[il].attn_q_norm_b,
LLM_NORM, cb, il);
cb(Qcur, "Qcur", il);
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, Kcur = llm_build_norm(ctx0, Kcur, hparams,
model.layers[il].attn_k_norm,
model.layers[il].attn_k_norm_b,
LLM_NORM, cb, il);
cb(Kcur, "Kcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
} else {
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
} }
if (il == n_layer - 1) { if (il == n_layer - 1) {