mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 02:14:35 +00:00
llama : add IBM Granite MoE architecture (#9438)
* feat(gguf-py): Add granitemoe architecture This includes the addition of new tensor names for the new moe layers. These may not be correct at this point due to the need for the hack in gguf_writer.py to double-check the length of the shape for these layers. Branch: GraniteMoE Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat(convert_hf_to_gguf): Add GraniteMoeModel GraniteMoe has the same configuration deltas as Granite Branch: GraniteMoE Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix(granitemoe convert): Split the double-sized input layer into gate and up After a lot of staring and squinting, it's clear that the standard mixtral expert implementation is equivalent to the vectorized parallel experts in granite. The difference is that in granite, the w1 and w3 are concatenated into a single tensor "input_linear." Rather than reimplementing all of the math on the llama.cpp side, the much simpler route is to just split this tensor during conversion and follow the standard mixtral route. Branch: GraniteMoE Co-Authored-By: alex.brooks@ibm.com Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat(granitemoe): Implement granitemoe GraniteMoE follows the mixtral architecture (once the input_linear layers are split into gate_exps/up_exps). The main delta is the addition of the same four multipliers used in Granite. Branch: GraniteMoE Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * Typo fix in docstring Co-Authored-By: ggerganov@gmail.com Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix(conversion): Simplify tensor name mapping in conversion Branch: GraniteMoE Co-Authored-By: git@compilade.net Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix(convert): Remove unused tensor name mappings Branch: GraniteMoE Co-Authored-By: git@compilade.net Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix(convert): Sanity check on merged FFN tensor sizes Branch: GraniteMoE Co-Authored-By: git@compilade.net Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Allow "output" layer in granite moe architecture (convert and cpp) Branch: GraniteMoE Co-Authored-By: git@compilade.net Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix(granite): Add missing 'output' tensor for Granite This is a fix for the previous `granite` architecture PR. Recent snapshots have included this (`lm_head.weights`) as part of the architecture Branch: GraniteMoE Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> --------- Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
904837e0cb
commit
3d6bf6919f
@ -4102,16 +4102,45 @@ class GraniteModel(LlamaModel):
|
||||
# consistency
|
||||
if attention_scale := self.hparams.get("attention_multiplier"):
|
||||
self.gguf_writer.add_attention_scale(attention_scale)
|
||||
logger.info("gguf: (granite) attention_scale = %s", attention_scale)
|
||||
if embedding_scale := self.hparams.get("embedding_multiplier"):
|
||||
self.gguf_writer.add_embedding_scale(embedding_scale)
|
||||
logger.info("gguf: (granite) embedding_scale = %s", embedding_scale)
|
||||
if residual_scale := self.hparams.get("residual_multiplier"):
|
||||
self.gguf_writer.add_residual_scale(residual_scale)
|
||||
if logits_scaling := self.hparams.get("logits_scaling"):
|
||||
self.gguf_writer.add_logit_scale(logits_scaling)
|
||||
logger.info("gguf: (granite) residual_scale = %s", residual_scale)
|
||||
if logits_scale := self.hparams.get("logits_scaling"):
|
||||
self.gguf_writer.add_logit_scale(logits_scale)
|
||||
logger.info("gguf: (granite) logits_scale = %s", logits_scale)
|
||||
|
||||
|
||||
@Model.register("GraniteMoeForCausalLM")
|
||||
class GraniteMoeModel(GraniteModel):
|
||||
"""Conversion for IBM's GraniteMoeForCausalLM"""
|
||||
model_arch = gguf.MODEL_ARCH.GRANITE_MOE
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
"""In modeling_granitemoe, the JetMoe implementation of parallel experts
|
||||
is used. This essentially merges w1 and w3 into a single tensor with 2x
|
||||
the hidden size that is then split during forward. To keep compatibility
|
||||
with existing mixtral support, we pull them apart here.
|
||||
"""
|
||||
|
||||
if name.endswith("block_sparse_moe.input_linear.weight"):
|
||||
ffn_dim = self.hparams["intermediate_size"]
|
||||
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size"
|
||||
gate, up = data_torch[..., :ffn_dim, :], data_torch[..., ffn_dim:, :]
|
||||
return [
|
||||
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate),
|
||||
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up),
|
||||
]
|
||||
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
# tree of lazy tensors
|
||||
class LazyTorchTensor(gguf.LazyBase):
|
||||
_tensor_type = torch.Tensor
|
||||
|
@ -235,6 +235,7 @@ class MODEL_ARCH(IntEnum):
|
||||
NEMOTRON = auto()
|
||||
EXAONE = auto()
|
||||
GRANITE = auto()
|
||||
GRANITE_MOE = auto()
|
||||
|
||||
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
@ -392,6 +393,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.NEMOTRON: "nemotron",
|
||||
MODEL_ARCH.EXAONE: "exaone",
|
||||
MODEL_ARCH.GRANITE: "granite",
|
||||
MODEL_ARCH.GRANITE_MOE: "granitemoe",
|
||||
}
|
||||
|
||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
@ -1232,6 +1234,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_ARCH.GRANITE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
@ -1242,6 +1245,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.GRANITE_MOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
@ -251,11 +251,12 @@ class TensorNameMap:
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP: (
|
||||
"layers.{bid}.feed_forward.gate", # mixtral
|
||||
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
|
||||
"model.layers.{bid}.mlp.gate", # qwen2moe olmoe
|
||||
"transformer.decoder_layer.{bid}.router", # Grok
|
||||
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
|
||||
"layers.{bid}.feed_forward.gate", # mixtral
|
||||
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
|
||||
"model.layers.{bid}.mlp.gate", # qwen2moe olmoe
|
||||
"transformer.decoder_layer.{bid}.router", # Grok
|
||||
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
|
||||
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
||||
@ -364,10 +365,11 @@ class TensorNameMap:
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
|
||||
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
|
||||
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
|
||||
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP: (
|
||||
|
@ -215,6 +215,7 @@ enum llm_arch {
|
||||
LLM_ARCH_EXAONE,
|
||||
LLM_ARCH_RWKV6,
|
||||
LLM_ARCH_GRANITE,
|
||||
LLM_ARCH_GRANITE_MOE,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
@ -266,6 +267,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_EXAONE, "exaone" },
|
||||
{ LLM_ARCH_RWKV6, "rwkv6" },
|
||||
{ LLM_ARCH_GRANITE, "granite" },
|
||||
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
@ -1467,6 +1469,7 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
@ -1478,6 +1481,24 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_GRANITE_MOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
@ -2396,7 +2417,7 @@ struct llama_hparams {
|
||||
float f_max_alibi_bias = 0.0f;
|
||||
float f_logit_scale = 0.0f;
|
||||
|
||||
// Additional scale factors (Granite)
|
||||
// Additional scale factors (Granite/Granite MoE)
|
||||
float f_residual_scale = 0.0f;
|
||||
float f_embedding_scale = 0.0f;
|
||||
float f_attention_scale = 0.0f;
|
||||
@ -6048,6 +6069,7 @@ static void llm_load_hparams(
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GRANITE:
|
||||
case LLM_ARCH_GRANITE_MOE:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
|
||||
@ -6056,6 +6078,7 @@ static void llm_load_hparams(
|
||||
ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 32: model.type = e_model::MODEL_3B; break;
|
||||
case 40: model.type = e_model::MODEL_3B; break;
|
||||
// Add additional layer/vocab/etc checks here for other model sizes
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
@ -6810,7 +6833,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
||||
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
|
||||
}
|
||||
|
||||
if (model.arch == LLM_ARCH_GRANITE) {
|
||||
if (model.arch == LLM_ARCH_GRANITE || model.arch == LLM_ARCH_GRANITE_MOE) {
|
||||
LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
|
||||
LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale);
|
||||
LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
|
||||
@ -6984,6 +7007,7 @@ static bool llm_load_tensors(
|
||||
case LLM_ARCH_REFACT:
|
||||
case LLM_ARCH_MINICPM:
|
||||
case LLM_ARCH_GRANITE:
|
||||
case LLM_ARCH_GRANITE_MOE:
|
||||
{
|
||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
|
||||
@ -15930,6 +15954,7 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
switch (model.arch) {
|
||||
case LLM_ARCH_LLAMA:
|
||||
case LLM_ARCH_GRANITE:
|
||||
case LLM_ARCH_GRANITE_MOE:
|
||||
{
|
||||
result = llm.build_llama();
|
||||
} break;
|
||||
@ -19231,6 +19256,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_DEEPSEEK2:
|
||||
case LLM_ARCH_CHATGLM:
|
||||
case LLM_ARCH_GRANITE:
|
||||
case LLM_ARCH_GRANITE_MOE:
|
||||
return LLAMA_ROPE_TYPE_NORM;
|
||||
|
||||
// the pairs of head values are offset by n_rot/2
|
||||
|
Loading…
Reference in New Issue
Block a user