mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
llama : add grok-1 support (#6204)
* Add support for Grok model architecture * Revert convert-hf-to-gguf to default options * Fixed f_norm_rms_eps bug * Fix whitespaces * llama : fix grok rope type * llama : minor --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
21cad01b6e
commit
476b0251b2
@ -93,31 +93,42 @@ class Model(ABC):
|
|||||||
|
|
||||||
if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
|
if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
|
||||||
self.gguf_writer.add_context_length(n_ctx)
|
self.gguf_writer.add_context_length(n_ctx)
|
||||||
|
print(f"gguf: context length = {n_ctx}")
|
||||||
|
|
||||||
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
||||||
self.gguf_writer.add_embedding_length(n_embd)
|
self.gguf_writer.add_embedding_length(n_embd)
|
||||||
|
print(f"gguf: embedding length = {n_embd}")
|
||||||
|
|
||||||
if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None:
|
if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None:
|
||||||
self.gguf_writer.add_feed_forward_length(n_ff)
|
self.gguf_writer.add_feed_forward_length(n_ff)
|
||||||
|
print(f"gguf: feed forward length = {n_ff}")
|
||||||
|
|
||||||
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
||||||
self.gguf_writer.add_head_count(n_head)
|
self.gguf_writer.add_head_count(n_head)
|
||||||
|
print(f"gguf: head count = {n_head}")
|
||||||
|
|
||||||
if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
|
if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
|
||||||
self.gguf_writer.add_head_count_kv(n_head_kv)
|
self.gguf_writer.add_head_count_kv(n_head_kv)
|
||||||
|
print(f"gguf: key-value head count = {n_head_kv}")
|
||||||
|
|
||||||
if (rope_theta := self.hparams.get("rope_theta")) is not None:
|
if (rope_theta := self.hparams.get("rope_theta")) is not None:
|
||||||
self.gguf_writer.add_rope_freq_base(rope_theta)
|
self.gguf_writer.add_rope_freq_base(rope_theta)
|
||||||
|
print(f"gguf: rope theta = {rope_theta}")
|
||||||
if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
|
if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
|
||||||
self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
|
self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
|
||||||
|
print(f"gguf: rms norm epsilon = {f_rms_eps}")
|
||||||
if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
|
if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
|
||||||
self.gguf_writer.add_layer_norm_eps(f_norm_eps)
|
self.gguf_writer.add_layer_norm_eps(f_norm_eps)
|
||||||
|
print(f"gguf: layer norm epsilon = {f_norm_eps}")
|
||||||
if (n_experts := self.hparams.get("num_local_experts")) is not None:
|
if (n_experts := self.hparams.get("num_local_experts")) is not None:
|
||||||
self.gguf_writer.add_expert_count(n_experts)
|
self.gguf_writer.add_expert_count(n_experts)
|
||||||
|
print(f"gguf: expert count = {n_experts}")
|
||||||
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
|
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
|
||||||
self.gguf_writer.add_expert_used_count(n_experts_used)
|
self.gguf_writer.add_expert_used_count(n_experts_used)
|
||||||
|
print(f"gguf: experts used count = {n_experts_used}")
|
||||||
|
|
||||||
self.gguf_writer.add_file_type(self.ftype)
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
|
print(f"gguf: file type = {self.ftype}")
|
||||||
|
|
||||||
def write_tensors(self):
|
def write_tensors(self):
|
||||||
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
|
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
|
||||||
@ -1051,6 +1062,21 @@ class MixtralModel(Model):
|
|||||||
self._set_vocab_sentencepiece()
|
self._set_vocab_sentencepiece()
|
||||||
|
|
||||||
|
|
||||||
|
@Model.register("GrokForCausalLM")
|
||||||
|
class GrokModel(Model):
|
||||||
|
model_arch = gguf.MODEL_ARCH.GROK
|
||||||
|
|
||||||
|
def set_vocab(self):
|
||||||
|
self._set_vocab_sentencepiece()
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
super().set_gguf_parameters()
|
||||||
|
self.gguf_writer.add_name("Grok")
|
||||||
|
|
||||||
|
|
||||||
@Model.register("MiniCPMForCausalLM")
|
@Model.register("MiniCPMForCausalLM")
|
||||||
class MiniCPMModel(Model):
|
class MiniCPMModel(Model):
|
||||||
model_arch = gguf.MODEL_ARCH.MINICPM
|
model_arch = gguf.MODEL_ARCH.MINICPM
|
||||||
|
@ -100,6 +100,7 @@ class MODEL_ARCH(IntEnum):
|
|||||||
LLAMA = auto()
|
LLAMA = auto()
|
||||||
FALCON = auto()
|
FALCON = auto()
|
||||||
BAICHUAN = auto()
|
BAICHUAN = auto()
|
||||||
|
GROK = auto()
|
||||||
GPT2 = auto()
|
GPT2 = auto()
|
||||||
GPTJ = auto()
|
GPTJ = auto()
|
||||||
GPTNEOX = auto()
|
GPTNEOX = auto()
|
||||||
@ -167,6 +168,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||||||
MODEL_ARCH.LLAMA: "llama",
|
MODEL_ARCH.LLAMA: "llama",
|
||||||
MODEL_ARCH.FALCON: "falcon",
|
MODEL_ARCH.FALCON: "falcon",
|
||||||
MODEL_ARCH.BAICHUAN: "baichuan",
|
MODEL_ARCH.BAICHUAN: "baichuan",
|
||||||
|
MODEL_ARCH.GROK: "grok",
|
||||||
MODEL_ARCH.GPT2: "gpt2",
|
MODEL_ARCH.GPT2: "gpt2",
|
||||||
MODEL_ARCH.GPTJ: "gptj",
|
MODEL_ARCH.GPTJ: "gptj",
|
||||||
MODEL_ARCH.GPTNEOX: "gptneox",
|
MODEL_ARCH.GPTNEOX: "gptneox",
|
||||||
@ -251,6 +253,28 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||||
MODEL_TENSOR.FFN_UP_EXP,
|
MODEL_TENSOR.FFN_UP_EXP,
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.GROK: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ROPE_FREQS,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_Q,
|
||||||
|
MODEL_TENSOR.ATTN_K,
|
||||||
|
MODEL_TENSOR.ATTN_V,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||||
|
MODEL_TENSOR.ATTN_OUT_NORM,
|
||||||
|
MODEL_TENSOR.FFN_GATE_INP,
|
||||||
|
MODEL_TENSOR.FFN_NORM,
|
||||||
|
MODEL_TENSOR.FFN_GATE,
|
||||||
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
|
MODEL_TENSOR.FFN_UP,
|
||||||
|
MODEL_TENSOR.FFN_GATE_EXP,
|
||||||
|
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||||
|
MODEL_TENSOR.FFN_UP_EXP,
|
||||||
|
MODEL_TENSOR.LAYER_OUT_NORM,
|
||||||
|
],
|
||||||
MODEL_ARCH.GPTNEOX: [
|
MODEL_ARCH.GPTNEOX: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
MODEL_TENSOR.OUTPUT_NORM,
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
@ -23,6 +23,7 @@ class TensorNameMap:
|
|||||||
"model.embedding", # mamba-qbert
|
"model.embedding", # mamba-qbert
|
||||||
"backbone.embedding", # mamba
|
"backbone.embedding", # mamba
|
||||||
"backbone.embeddings", # mamba-hf
|
"backbone.embeddings", # mamba-hf
|
||||||
|
"transformer.in_out_embed", # Grok
|
||||||
),
|
),
|
||||||
|
|
||||||
# Token type embeddings
|
# Token type embeddings
|
||||||
@ -66,6 +67,7 @@ class TensorNameMap:
|
|||||||
"lm_head.ln", # phi2
|
"lm_head.ln", # phi2
|
||||||
"model.norm_f", # mamba-qbert
|
"model.norm_f", # mamba-qbert
|
||||||
"backbone.norm_f", # mamba
|
"backbone.norm_f", # mamba
|
||||||
|
"transformer.rms_norm", # Grok
|
||||||
),
|
),
|
||||||
|
|
||||||
# Rope frequencies
|
# Rope frequencies
|
||||||
@ -93,6 +95,7 @@ class TensorNameMap:
|
|||||||
"model.layers.{bid}.attention_norm", # internlm2
|
"model.layers.{bid}.attention_norm", # internlm2
|
||||||
"model.layers.{bid}.norm", # mamba-qbert
|
"model.layers.{bid}.norm", # mamba-qbert
|
||||||
"backbone.layers.{bid}.norm", # mamba
|
"backbone.layers.{bid}.norm", # mamba
|
||||||
|
"transformer.decoder_layer.{bid}.rms_norm", # Grok
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention norm 2
|
# Attention norm 2
|
||||||
@ -121,7 +124,8 @@ class TensorNameMap:
|
|||||||
"encoder.layer.{bid}.attention.self.query", # bert
|
"encoder.layer.{bid}.attention.self.query", # bert
|
||||||
"transformer.h.{bid}.attn.q_proj", # gpt-j
|
"transformer.h.{bid}.attn.q_proj", # gpt-j
|
||||||
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
|
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
|
||||||
"model.layers.{bid}.attention.wq" # internlm2
|
"model.layers.{bid}.attention.wq", # internlm2
|
||||||
|
"transformer.decoder_layer.{bid}.multi_head_attention.query" # Grok
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention key
|
# Attention key
|
||||||
@ -131,7 +135,8 @@ class TensorNameMap:
|
|||||||
"encoder.layer.{bid}.attention.self.key", # bert
|
"encoder.layer.{bid}.attention.self.key", # bert
|
||||||
"transformer.h.{bid}.attn.k_proj", # gpt-j
|
"transformer.h.{bid}.attn.k_proj", # gpt-j
|
||||||
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
|
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
|
||||||
"model.layers.{bid}.attention.wk" # internlm2
|
"model.layers.{bid}.attention.wk", # internlm2
|
||||||
|
"transformer.decoder_layer.{bid}.multi_head_attention.key" # Grok
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention value
|
# Attention value
|
||||||
@ -141,7 +146,8 @@ class TensorNameMap:
|
|||||||
"encoder.layer.{bid}.attention.self.value", # bert
|
"encoder.layer.{bid}.attention.self.value", # bert
|
||||||
"transformer.h.{bid}.attn.v_proj", # gpt-j
|
"transformer.h.{bid}.attn.v_proj", # gpt-j
|
||||||
"model.layers.layers.{bid}.self_attn.v_proj", # plamo
|
"model.layers.layers.{bid}.self_attn.v_proj", # plamo
|
||||||
"model.layers.{bid}.attention.wv" # internlm2
|
"model.layers.{bid}.attention.wv", # internlm2
|
||||||
|
"transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention output
|
# Attention output
|
||||||
@ -162,12 +168,14 @@ class TensorNameMap:
|
|||||||
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
|
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
|
||||||
"model.layers.{bid}.attention.wo", # internlm2
|
"model.layers.{bid}.attention.wo", # internlm2
|
||||||
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
|
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
|
||||||
|
"transformer.decoder_layer.{bid}.multi_head_attention.linear"# Grok
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention output norm
|
# Attention output norm
|
||||||
MODEL_TENSOR.ATTN_OUT_NORM: (
|
MODEL_TENSOR.ATTN_OUT_NORM: (
|
||||||
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
|
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
|
||||||
"encoder.layers.{bid}.norm1", # nomic-bert
|
"encoder.layers.{bid}.norm1", # nomic-bert
|
||||||
|
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok
|
||||||
),
|
),
|
||||||
|
|
||||||
# Rotary embeddings
|
# Rotary embeddings
|
||||||
@ -190,11 +198,13 @@ class TensorNameMap:
|
|||||||
"model.layers.{bid}.ln2", # yi
|
"model.layers.{bid}.ln2", # yi
|
||||||
"h.{bid}.ln_2", # gpt2
|
"h.{bid}.ln_2", # gpt2
|
||||||
"model.layers.{bid}.ffn_norm", # internlm2
|
"model.layers.{bid}.ffn_norm", # internlm2
|
||||||
|
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_GATE_INP: (
|
MODEL_TENSOR.FFN_GATE_INP: (
|
||||||
"layers.{bid}.feed_forward.gate", # mixtral
|
"layers.{bid}.feed_forward.gate", # mixtral
|
||||||
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
|
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
|
||||||
|
"transformer.decoder_layer.{bid}.router" # Grok
|
||||||
),
|
),
|
||||||
|
|
||||||
# Feed-forward up
|
# Feed-forward up
|
||||||
@ -223,6 +233,7 @@ class TensorNameMap:
|
|||||||
MODEL_TENSOR.FFN_UP_EXP: (
|
MODEL_TENSOR.FFN_UP_EXP: (
|
||||||
"layers.{bid}.feed_forward.experts.{xid}.w3", # mixtral
|
"layers.{bid}.feed_forward.experts.{xid}.w3", # mixtral
|
||||||
"model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", # mixtral
|
"model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", # mixtral
|
||||||
|
"transformer.decoder_layer.{bid}.moe.{xid}.linear_v", # Grok
|
||||||
),
|
),
|
||||||
|
|
||||||
# AWQ-activation gate
|
# AWQ-activation gate
|
||||||
@ -243,6 +254,7 @@ class TensorNameMap:
|
|||||||
MODEL_TENSOR.FFN_GATE_EXP: (
|
MODEL_TENSOR.FFN_GATE_EXP: (
|
||||||
"layers.{bid}.feed_forward.experts.{xid}.w1", # mixtral
|
"layers.{bid}.feed_forward.experts.{xid}.w1", # mixtral
|
||||||
"model.layers.{bid}.block_sparse_moe.experts.{xid}.w1", # mixtral
|
"model.layers.{bid}.block_sparse_moe.experts.{xid}.w1", # mixtral
|
||||||
|
"transformer.decoder_layer.{bid}.moe.{xid}.linear" # Grok
|
||||||
),
|
),
|
||||||
|
|
||||||
# Feed-forward down
|
# Feed-forward down
|
||||||
@ -270,6 +282,8 @@ class TensorNameMap:
|
|||||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||||
"layers.{bid}.feed_forward.experts.{xid}.w2", # mixtral
|
"layers.{bid}.feed_forward.experts.{xid}.w2", # mixtral
|
||||||
"model.layers.{bid}.block_sparse_moe.experts.{xid}.w2", # mixtral
|
"model.layers.{bid}.block_sparse_moe.experts.{xid}.w2", # mixtral
|
||||||
|
"transformer.decoder_layer.{bid}.moe.{xid}.linear_1", # Grok
|
||||||
|
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.ATTN_Q_NORM: (
|
MODEL_TENSOR.ATTN_Q_NORM: (
|
||||||
@ -289,6 +303,7 @@ class TensorNameMap:
|
|||||||
MODEL_TENSOR.LAYER_OUT_NORM: (
|
MODEL_TENSOR.LAYER_OUT_NORM: (
|
||||||
"encoder.layer.{bid}.output.LayerNorm", # bert
|
"encoder.layer.{bid}.output.LayerNorm", # bert
|
||||||
"encoder.layers.{bid}.norm2", # nomic-bert
|
"encoder.layers.{bid}.norm2", # nomic-bert
|
||||||
|
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_IN: (
|
MODEL_TENSOR.SSM_IN: (
|
||||||
|
299
llama.cpp
299
llama.cpp
@ -195,6 +195,7 @@ enum llm_arch {
|
|||||||
LLM_ARCH_LLAMA,
|
LLM_ARCH_LLAMA,
|
||||||
LLM_ARCH_FALCON,
|
LLM_ARCH_FALCON,
|
||||||
LLM_ARCH_BAICHUAN,
|
LLM_ARCH_BAICHUAN,
|
||||||
|
LLM_ARCH_GROK,
|
||||||
LLM_ARCH_GPT2,
|
LLM_ARCH_GPT2,
|
||||||
LLM_ARCH_GPTJ,
|
LLM_ARCH_GPTJ,
|
||||||
LLM_ARCH_GPTNEOX,
|
LLM_ARCH_GPTNEOX,
|
||||||
@ -224,6 +225,7 @@ enum llm_arch {
|
|||||||
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||||
{ LLM_ARCH_LLAMA, "llama" },
|
{ LLM_ARCH_LLAMA, "llama" },
|
||||||
{ LLM_ARCH_FALCON, "falcon" },
|
{ LLM_ARCH_FALCON, "falcon" },
|
||||||
|
{ LLM_ARCH_GROK, "grok" },
|
||||||
{ LLM_ARCH_GPT2, "gpt2" },
|
{ LLM_ARCH_GPT2, "gpt2" },
|
||||||
{ LLM_ARCH_GPTJ, "gptj" },
|
{ LLM_ARCH_GPTJ, "gptj" },
|
||||||
{ LLM_ARCH_GPTNEOX, "gptneox" },
|
{ LLM_ARCH_GPTNEOX, "gptneox" },
|
||||||
@ -494,6 +496,28 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
|||||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_GROK,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||||
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
|
||||||
|
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
|
||||||
|
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_GPT2,
|
LLM_ARCH_GPT2,
|
||||||
{
|
{
|
||||||
@ -1635,6 +1659,7 @@ enum e_model {
|
|||||||
MODEL_40B,
|
MODEL_40B,
|
||||||
MODEL_65B,
|
MODEL_65B,
|
||||||
MODEL_70B,
|
MODEL_70B,
|
||||||
|
MODEL_314B,
|
||||||
MODEL_SMALL,
|
MODEL_SMALL,
|
||||||
MODEL_MEDIUM,
|
MODEL_MEDIUM,
|
||||||
MODEL_LARGE,
|
MODEL_LARGE,
|
||||||
@ -3419,6 +3444,7 @@ static const char * llama_model_type_name(e_model type) {
|
|||||||
case MODEL_40B: return "40B";
|
case MODEL_40B: return "40B";
|
||||||
case MODEL_65B: return "65B";
|
case MODEL_65B: return "65B";
|
||||||
case MODEL_70B: return "70B";
|
case MODEL_70B: return "70B";
|
||||||
|
case MODEL_314B: return "314B";
|
||||||
case MODEL_SMALL: return "0.1B";
|
case MODEL_SMALL: return "0.1B";
|
||||||
case MODEL_MEDIUM: return "0.4B";
|
case MODEL_MEDIUM: return "0.4B";
|
||||||
case MODEL_LARGE: return "0.8B";
|
case MODEL_LARGE: return "0.8B";
|
||||||
@ -3557,6 +3583,15 @@ static void llm_load_hparams(
|
|||||||
default: model.type = e_model::MODEL_UNKNOWN;
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_GROK:
|
||||||
|
{
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
|
||||||
|
switch (hparams.n_layer) {
|
||||||
|
case 64: model.type = e_model::MODEL_314B; break;
|
||||||
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case LLM_ARCH_FALCON:
|
case LLM_ARCH_FALCON:
|
||||||
{
|
{
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||||
@ -4394,6 +4429,54 @@ static bool llm_load_tensors(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_GROK:
|
||||||
|
{
|
||||||
|
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||||
|
|
||||||
|
// output
|
||||||
|
{
|
||||||
|
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||||
|
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false);
|
||||||
|
// if output is NULL, init from the input tok embed
|
||||||
|
if (model.output == NULL) {
|
||||||
|
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||||
|
ml.n_created--; // artificial tensor
|
||||||
|
ml.size_data += ggml_nbytes(model.output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < n_layer; ++i) {
|
||||||
|
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||||
|
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||||
|
|
||||||
|
auto & layer = model.layers[i];
|
||||||
|
|
||||||
|
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||||
|
|
||||||
|
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
|
||||||
|
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
|
||||||
|
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
|
||||||
|
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
|
||||||
|
|
||||||
|
layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd});
|
||||||
|
|
||||||
|
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||||
|
|
||||||
|
layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd});
|
||||||
|
|
||||||
|
GGML_ASSERT(hparams.n_expert > 0);
|
||||||
|
GGML_ASSERT(hparams.n_expert_used > 0);
|
||||||
|
|
||||||
|
// MoE branch
|
||||||
|
for (uint32_t x = 0; x < hparams.n_expert; ++x) {
|
||||||
|
layer.ffn_gate_exp[x] = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), {n_embd, n_ff});
|
||||||
|
layer.ffn_down_exp[x] = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), { n_ff, n_embd});
|
||||||
|
layer.ffn_up_exp[x] = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, x), {n_embd, n_ff});
|
||||||
|
}
|
||||||
|
|
||||||
|
layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case LLM_ARCH_BAICHUAN:
|
case LLM_ARCH_BAICHUAN:
|
||||||
{
|
{
|
||||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||||
@ -5621,6 +5704,20 @@ static struct ggml_tensor * llm_build_kqv(
|
|||||||
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (model.arch == LLM_ARCH_GROK) {
|
||||||
|
// need to do the following:
|
||||||
|
// multiply by attn_output_multiplyer of 0.08838834764831845
|
||||||
|
// and then :
|
||||||
|
// kq = 30 * tanh(kq / 30)
|
||||||
|
// before the softmax below
|
||||||
|
|
||||||
|
//try from phi2
|
||||||
|
//ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||||
|
|
||||||
|
kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
|
||||||
|
kq = ggml_scale(ctx, kq, 30);
|
||||||
|
}
|
||||||
|
|
||||||
#if defined(GGML_USE_KOMPUTE)
|
#if defined(GGML_USE_KOMPUTE)
|
||||||
#pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Kompute")
|
#pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Kompute")
|
||||||
#pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024")
|
#pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024")
|
||||||
@ -6395,6 +6492,203 @@ struct llm_build_context {
|
|||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_cgraph * build_grok() {
|
||||||
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||||
|
|
||||||
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||||
|
|
||||||
|
struct ggml_tensor * cur;
|
||||||
|
struct ggml_tensor * inpL;
|
||||||
|
|
||||||
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||||
|
|
||||||
|
// multiply by embedding_multiplier_scale of 78.38367176906169
|
||||||
|
inpL = ggml_scale(ctx0, inpL, 78.38367176906169f);
|
||||||
|
|
||||||
|
// inp_pos - contains the positions
|
||||||
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||||
|
|
||||||
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||||
|
|
||||||
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
struct ggml_tensor * inpSA = inpL;
|
||||||
|
|
||||||
|
// norm
|
||||||
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||||
|
model.layers[il].attn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, il);
|
||||||
|
cb(cur, "attn_norm", il);
|
||||||
|
|
||||||
|
|
||||||
|
// self-attention
|
||||||
|
{
|
||||||
|
// compute Q and K and RoPE them
|
||||||
|
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
if (model.layers[il].bq) {
|
||||||
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
if (model.layers[il].bk) {
|
||||||
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
if (model.layers[il].bv) {
|
||||||
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
Qcur = ggml_rope_custom(
|
||||||
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
|
||||||
|
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
|
Kcur = ggml_rope_custom(
|
||||||
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
|
||||||
|
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
|
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||||
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
|
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Grok
|
||||||
|
// if attn_out_norm is present then apply it before adding the input
|
||||||
|
if (model.layers[il].attn_out_norm) {
|
||||||
|
cur = llm_build_norm(ctx0, cur, hparams,
|
||||||
|
model.layers[il].attn_out_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, il);
|
||||||
|
cb(cur, "attn_out_norm", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||||
|
cb(ffn_inp, "ffn_inp", il);
|
||||||
|
|
||||||
|
// feed-forward network
|
||||||
|
// MoE branch
|
||||||
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||||
|
model.layers[il].ffn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, il);
|
||||||
|
cb(cur, "ffn_norm", il);
|
||||||
|
|
||||||
|
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
|
||||||
|
cb(logits, "ffn_moe_logits", il);
|
||||||
|
|
||||||
|
ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
|
||||||
|
cb(probs, "ffn_moe_probs", il);
|
||||||
|
|
||||||
|
// select experts
|
||||||
|
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
|
||||||
|
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
||||||
|
|
||||||
|
ggml_tensor * weights = ggml_get_rows(ctx0,
|
||||||
|
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
|
||||||
|
cb(weights, "ffn_moe_weights", il);
|
||||||
|
|
||||||
|
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
|
||||||
|
|
||||||
|
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
|
||||||
|
cb(weights_sum, "ffn_moe_weights_sum", il);
|
||||||
|
|
||||||
|
weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
|
||||||
|
cb(weights, "ffn_moe_weights_norm", il);
|
||||||
|
|
||||||
|
// compute expert outputs
|
||||||
|
ggml_tensor * moe_out = nullptr;
|
||||||
|
|
||||||
|
for (int i = 0; i < n_expert_used; ++i) {
|
||||||
|
ggml_tensor * cur_expert;
|
||||||
|
|
||||||
|
ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exp, n_expert, selected_experts, i, cur);
|
||||||
|
cb(cur_up, "ffn_moe_up", il);
|
||||||
|
|
||||||
|
ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exp, n_expert, selected_experts, i, cur);
|
||||||
|
cb(cur_gate, "ffn_moe_gate", il);
|
||||||
|
|
||||||
|
//GeLU
|
||||||
|
cur_gate = ggml_gelu(ctx0, cur_gate);
|
||||||
|
cb(cur_gate, "ffn_moe_gelu", il);
|
||||||
|
|
||||||
|
cur_expert = ggml_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd]
|
||||||
|
cb(cur_expert, "ffn_moe_gate_par", il);
|
||||||
|
|
||||||
|
cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exp, n_expert, selected_experts, i, cur_expert); // [n_tokens, n_embd]
|
||||||
|
cb(cur_expert, "ffn_moe_down", il);
|
||||||
|
|
||||||
|
cur_expert = ggml_mul(ctx0, cur_expert,
|
||||||
|
ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
|
||||||
|
cb(cur_expert, "ffn_moe_weighted", il);
|
||||||
|
|
||||||
|
if (i == 0) {
|
||||||
|
moe_out = cur_expert;
|
||||||
|
} else {
|
||||||
|
moe_out = ggml_add(ctx0, moe_out, cur_expert);
|
||||||
|
cb(moe_out, "ffn_moe_out", il);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = moe_out;
|
||||||
|
|
||||||
|
// Grok
|
||||||
|
// if layer_out_norm is present then apply it before adding the input
|
||||||
|
// Idea: maybe ffn_out_norm is a better name
|
||||||
|
if (model.layers[il].layer_out_norm) {
|
||||||
|
cur = llm_build_norm(ctx0, cur, hparams,
|
||||||
|
model.layers[il].layer_out_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, il);
|
||||||
|
cb(cur, "layer_out_norm", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
|
||||||
|
ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
|
||||||
|
if (layer_dir != nullptr) {
|
||||||
|
cur = ggml_add(ctx0, cur, layer_dir);
|
||||||
|
}
|
||||||
|
cb(cur, "l_out", il);
|
||||||
|
|
||||||
|
// input for next layer
|
||||||
|
inpL = cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = inpL;
|
||||||
|
|
||||||
|
cur = llm_build_norm(ctx0, cur, hparams,
|
||||||
|
model.output_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, -1);
|
||||||
|
cb(cur, "result_norm", -1);
|
||||||
|
|
||||||
|
// lm_head
|
||||||
|
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||||
|
|
||||||
|
// Grok
|
||||||
|
// multiply logits by output_multiplier_scale of 0.5773502691896257
|
||||||
|
|
||||||
|
cur = ggml_scale(ctx0, cur, 0.5773502691896257f);
|
||||||
|
|
||||||
|
cb(cur, "result_output", -1);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_cgraph * build_starcoder() {
|
struct ggml_cgraph * build_starcoder() {
|
||||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||||
|
|
||||||
@ -8818,6 +9112,10 @@ static struct ggml_cgraph * llama_build_graph(
|
|||||||
{
|
{
|
||||||
result = llm.build_falcon();
|
result = llm.build_falcon();
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_GROK:
|
||||||
|
{
|
||||||
|
result = llm.build_grok();
|
||||||
|
} break;
|
||||||
case LLM_ARCH_STARCODER:
|
case LLM_ARCH_STARCODER:
|
||||||
{
|
{
|
||||||
result = llm.build_starcoder();
|
result = llm.build_starcoder();
|
||||||
@ -13561,6 +13859,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
|||||||
|
|
||||||
// the pairs of head values are offset by n_rot/2
|
// the pairs of head values are offset by n_rot/2
|
||||||
case LLM_ARCH_FALCON:
|
case LLM_ARCH_FALCON:
|
||||||
|
case LLM_ARCH_GROK:
|
||||||
case LLM_ARCH_PERSIMMON:
|
case LLM_ARCH_PERSIMMON:
|
||||||
case LLM_ARCH_BERT:
|
case LLM_ARCH_BERT:
|
||||||
case LLM_ARCH_NOMIC_BERT:
|
case LLM_ARCH_NOMIC_BERT:
|
||||||
|
Loading…
Reference in New Issue
Block a user