convert : support Mixtral as LLAMA arch

This commit is contained in:
Georgi Gerganov 2023-12-09 10:51:58 +02:00
parent fe680e3d10
commit dff8cbeb39
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 52 additions and 10 deletions

View File

@ -266,12 +266,23 @@ class Params:
# LLaMA v1 # LLaMA v1
n_ctx = 2048 n_ctx = 2048
# print model keys
for k in model.keys():
print(k)
# check if MoE
if "layers.0.feed_forward.experts.0.w1.weight" in model:
n_ff = model["layers.0.feed_forward.experts.0.w1.weight"].shape[0]
n_ctx = 32768
else:
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0],
return Params( return Params(
n_vocab = model["tok_embeddings.weight"].shape[0], n_vocab = model["tok_embeddings.weight"].shape[0],
n_embd = config["dim"], n_embd = config["dim"],
n_layer = config["n_layers"], n_layer = config["n_layers"],
n_ctx = n_ctx, n_ctx = n_ctx,
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0], n_ff = n_ff,
n_head = (n_head := config["n_heads"]), n_head = (n_head := config["n_heads"]),
n_head_kv = config.get("n_kv_heads", n_head), n_head_kv = config.get("n_kv_heads", n_head),
f_norm_eps = config["norm_eps"], f_norm_eps = config["norm_eps"],

View File

@ -111,10 +111,14 @@ class MODEL_TENSOR(IntEnum):
ATTN_NORM = auto() ATTN_NORM = auto()
ATTN_NORM_2 = auto() ATTN_NORM_2 = auto()
ATTN_ROT_EMBD = auto() ATTN_ROT_EMBD = auto()
FFN_GATE_INP = auto()
FFN_NORM = auto()
FFN_GATE = auto() FFN_GATE = auto()
FFN_DOWN = auto() FFN_DOWN = auto()
FFN_UP = auto() FFN_UP = auto()
FFN_NORM = auto() FFN_GATE_EXP = auto()
FFN_DOWN_EXP = auto()
FFN_UP_EXP = auto()
ATTN_Q_NORM = auto() ATTN_Q_NORM = auto()
ATTN_K_NORM = auto() ATTN_K_NORM = auto()
@ -154,10 +158,14 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp",
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate.{xid}",
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}",
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}",
} }
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@ -172,10 +180,14 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ATTN_V, MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT, MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD, MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM, MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP, MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
], ],
MODEL_ARCH.GPTNEOX: [ MODEL_ARCH.GPTNEOX: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,

View File

@ -149,6 +149,10 @@ class TensorNameMap:
"model.layers.{bid}.ln2", # yi "model.layers.{bid}.ln2", # yi
), ),
MODEL_TENSOR.FFN_GATE_INP: (
"layers.{bid}.feed_forward.gate", # mixtral
),
# Feed-forward up # Feed-forward up
MODEL_TENSOR.FFN_UP: ( MODEL_TENSOR.FFN_UP: (
"gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
@ -164,11 +168,19 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.w1", # qwen "transformer.h.{bid}.mlp.w1", # qwen
), ),
MODEL_TENSOR.FFN_UP_EXP: (
"layers.{bid}.feed_forward.experts.{xid}.w3", # mixtral
),
# Feed-forward gate # Feed-forward gate
MODEL_TENSOR.FFN_GATE: ( MODEL_TENSOR.FFN_GATE: (
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact "model.layers.{bid}.mlp.gate_proj", # llama-hf refact
"layers.{bid}.feed_forward.w1", # llama-pth "layers.{bid}.feed_forward.w1", # llama-pth
"transformer.h.{bid}.mlp.w2", # qwen "transformer.h.{bid}.mlp.w2", # qwen
),
MODEL_TENSOR.FFN_GATE_EXP: (
"layers.{bid}.feed_forward.experts.{xid}.w1", # mixtral
), ),
# Feed-forward down # Feed-forward down
@ -185,6 +197,10 @@ class TensorNameMap:
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
), ),
MODEL_TENSOR.FFN_DOWN_EXP: (
"layers.{bid}.feed_forward.experts.{xid}.w2", # mixtral
),
MODEL_TENSOR.ATTN_Q_NORM: ( MODEL_TENSOR.ATTN_Q_NORM: (
"language_model.encoder.layers.{bid}.self_attention.q_layernorm", "language_model.encoder.layers.{bid}.self_attention.q_layernorm",
), ),
@ -213,11 +229,14 @@ class TensorNameMap:
for tensor, keys in self.block_mappings_cfg.items(): for tensor, keys in self.block_mappings_cfg.items():
if tensor not in MODEL_TENSORS[arch]: if tensor not in MODEL_TENSORS[arch]:
continue continue
tensor_name = TENSOR_NAMES[tensor].format(bid = bid) # TODO: make this configurable
self.mapping[tensor_name] = (tensor, tensor_name) n_experts = 8
for key in keys: for xid in range(n_experts):
key = key.format(bid = bid) tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid)
self.mapping[key] = (tensor, tensor_name) self.mapping[tensor_name] = (tensor, tensor_name)
for key in keys:
key = key.format(bid = bid, xid = xid)
self.mapping[key] = (tensor, tensor_name)
def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None: def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
result = self.mapping.get(key) result = self.mapping.get(key)