llama : support IBM Granite architecture (#9412)

* feat(gguf-py): Add Granite model and params to gguf-py

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat(convert_hf_to_gguf): Add registration and param setup for Granite

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat(llama.cpp): Add config parsing for Granite multiplier params

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat(llama.cpp): First pass at full port of granite deviations from llama

Something is still not working right since the results are mostly terrible,
but on occasion it's producing relevant results at this point, so
_something_ is working.

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix(llama.cpp): Determine granite language 3b instruct by vocab size

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix(convert_hf_to_gguf): Use LlamaModel as base for GraniteModel

The defaults in LlamaModel are needed for Granite as well

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix(llama.cpp): Switch Granite param names to use _scale for consistency

Other scalar multipliers are called *_scale, so this provides a more
consistent naming convention.

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix(convert_hf_to_gguf/gguf-py): _multiplier -> _scale

The transformers names with _multiplier will now be converted to the _scale
equivalent during conversion.

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix(llama.cpp): Use separate switch clause for granite in llm_load_hparams

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

---------

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart 2024-09-17 00:44:58 -06:00 committed by GitHub
parent 37f3a3810e
commit 0d2ec43833
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 135 additions and 1 deletions

View File

@ -4080,6 +4080,36 @@ class ExaoneModel(Model):
super().prepare_tensors() super().prepare_tensors()
@Model.register("GraniteForCausalLM")
class GraniteModel(LlamaModel):
"""Conversion for IBM's GraniteForCausalLM"""
model_arch = gguf.MODEL_ARCH.GRANITE
def set_gguf_parameters(self):
"""Granite uses standard llama parameters with the following differences:
- No head_dim support
- New multiplier params:
- attention_scale
- embedding_scale
- residual_scale
- logits_scaling
"""
if head_dim := self.hparams.pop("head_dim", None):
logger.warning("Ignoring head_dim (%s) from config for Granite", head_dim)
super().set_gguf_parameters()
# NOTE: Convert _multiplier params to _scale params for naming
# consistency
if attention_scale := self.hparams.get("attention_multiplier"):
self.gguf_writer.add_attention_scale(attention_scale)
if embedding_scale := self.hparams.get("embedding_multiplier"):
self.gguf_writer.add_embedding_scale(embedding_scale)
if residual_scale := self.hparams.get("residual_multiplier"):
self.gguf_writer.add_residual_scale(residual_scale)
if logits_scaling := self.hparams.get("logits_scaling"):
self.gguf_writer.add_logit_scale(logits_scaling)
###### CONVERSION LOGIC ###### ###### CONVERSION LOGIC ######
# tree of lazy tensors # tree of lazy tensors

View File

@ -97,6 +97,8 @@ class Keys:
RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers" RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers"
TIME_MIX_EXTRA_DIM = "{arch}.time_mix_extra_dim" TIME_MIX_EXTRA_DIM = "{arch}.time_mix_extra_dim"
TIME_DECAY_EXTRA_DIM = "{arch}.time_decay_extra_dim" TIME_DECAY_EXTRA_DIM = "{arch}.time_decay_extra_dim"
RESIDUAL_SCALE = "{arch}.residual_scale"
EMBEDDING_SCALE = "{arch}.embedding_scale"
class Attention: class Attention:
HEAD_COUNT = "{arch}.attention.head_count" HEAD_COUNT = "{arch}.attention.head_count"
@ -112,6 +114,7 @@ class Keys:
KV_LORA_RANK = "{arch}.attention.kv_lora_rank" KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count" REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
SLIDING_WINDOW = "{arch}.attention.sliding_window" SLIDING_WINDOW = "{arch}.attention.sliding_window"
SCALE = "{arch}.attention.scale"
class Rope: class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count" DIMENSION_COUNT = "{arch}.rope.dimension_count"
@ -231,6 +234,7 @@ class MODEL_ARCH(IntEnum):
JAIS = auto() JAIS = auto()
NEMOTRON = auto() NEMOTRON = auto()
EXAONE = auto() EXAONE = auto()
GRANITE = auto()
class MODEL_TENSOR(IntEnum): class MODEL_TENSOR(IntEnum):
@ -387,6 +391,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.JAIS: "jais", MODEL_ARCH.JAIS: "jais",
MODEL_ARCH.NEMOTRON: "nemotron", MODEL_ARCH.NEMOTRON: "nemotron",
MODEL_ARCH.EXAONE: "exaone", MODEL_ARCH.EXAONE: "exaone",
MODEL_ARCH.GRANITE: "granite",
} }
TENSOR_NAMES: dict[MODEL_TENSOR, str] = { TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@ -1224,6 +1229,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP, MODEL_TENSOR.FFN_UP,
], ],
MODEL_ARCH.GRANITE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
# TODO # TODO
} }

View File

@ -679,6 +679,12 @@ class GGUFWriter:
def add_time_decay_extra_dim(self, dim: int) -> None: def add_time_decay_extra_dim(self, dim: int) -> None:
self.add_uint32(Keys.LLM.TIME_DECAY_EXTRA_DIM.format(arch=self.arch), dim) self.add_uint32(Keys.LLM.TIME_DECAY_EXTRA_DIM.format(arch=self.arch), dim)
def add_residual_scale(self, value: float) -> None:
self.add_float32(Keys.LLM.RESIDUAL_SCALE.format(arch=self.arch), value)
def add_embedding_scale(self, value: float) -> None:
self.add_float32(Keys.LLM.EMBEDDING_SCALE.format(arch=self.arch), value)
def add_wkv_head_size(self, size: int) -> None: def add_wkv_head_size(self, size: int) -> None:
self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size) self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size)
@ -703,6 +709,9 @@ class GGUFWriter:
def add_sliding_window(self, value: int) -> None: def add_sliding_window(self, value: int) -> None:
self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value) self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value)
def add_attention_scale(self, value: float) -> None:
self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)
def add_pooling_type(self, value: PoolingType) -> None: def add_pooling_type(self, value: PoolingType) -> None:
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value) self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)

View File

@ -214,6 +214,7 @@ enum llm_arch {
LLM_ARCH_NEMOTRON, LLM_ARCH_NEMOTRON,
LLM_ARCH_EXAONE, LLM_ARCH_EXAONE,
LLM_ARCH_RWKV6, LLM_ARCH_RWKV6,
LLM_ARCH_GRANITE,
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
}; };
@ -264,6 +265,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_NEMOTRON, "nemotron" }, { LLM_ARCH_NEMOTRON, "nemotron" },
{ LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_EXAONE, "exaone" },
{ LLM_ARCH_RWKV6, "rwkv6" }, { LLM_ARCH_RWKV6, "rwkv6" },
{ LLM_ARCH_GRANITE, "granite" },
{ LLM_ARCH_UNKNOWN, "(unknown)" }, { LLM_ARCH_UNKNOWN, "(unknown)" },
}; };
@ -303,6 +305,8 @@ enum llm_kv {
LLM_KV_RESCALE_EVERY_N_LAYERS, LLM_KV_RESCALE_EVERY_N_LAYERS,
LLM_KV_TIME_MIX_EXTRA_DIM, LLM_KV_TIME_MIX_EXTRA_DIM,
LLM_KV_TIME_DECAY_EXTRA_DIM, LLM_KV_TIME_DECAY_EXTRA_DIM,
LLM_KV_RESIDUAL_SCALE,
LLM_KV_EMBEDDING_SCALE,
LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV, LLM_KV_ATTENTION_HEAD_COUNT_KV,
@ -317,6 +321,7 @@ enum llm_kv {
LLM_KV_ATTENTION_KV_LORA_RANK, LLM_KV_ATTENTION_KV_LORA_RANK,
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
LLM_KV_ATTENTION_SLIDING_WINDOW, LLM_KV_ATTENTION_SLIDING_WINDOW,
LLM_KV_ATTENTION_SCALE,
LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_FREQ_BASE, LLM_KV_ROPE_FREQ_BASE,
@ -407,6 +412,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" }, { LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" },
{ LLM_KV_TIME_MIX_EXTRA_DIM, "%s.time_mix_extra_dim" }, { LLM_KV_TIME_MIX_EXTRA_DIM, "%s.time_mix_extra_dim" },
{ LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" }, { LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" },
{ LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@ -421,6 +428,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
@ -1454,6 +1462,22 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "blk.%d.channel_mix_receptance" }, { LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "blk.%d.channel_mix_receptance" },
}, },
}, },
{
LLM_ARCH_GRANITE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{ {
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
{ {
@ -2372,6 +2396,11 @@ struct llama_hparams {
float f_max_alibi_bias = 0.0f; float f_max_alibi_bias = 0.0f;
float f_logit_scale = 0.0f; float f_logit_scale = 0.0f;
// Additional scale factors (Granite)
float f_residual_scale = 0.0f;
float f_embedding_scale = 0.0f;
float f_attention_scale = 0.0f;
bool causal_attn = true; bool causal_attn = true;
bool use_alibi = false; bool use_alibi = false;
bool attn_soft_cap = false; bool attn_soft_cap = false;
@ -2434,6 +2463,9 @@ struct llama_hparams {
if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true; if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true;
if (!is_float_close(this->expert_weights_scale, other.expert_weights_scale, EPSILON)) return true; if (!is_float_close(this->expert_weights_scale, other.expert_weights_scale, EPSILON)) return true;
if (!is_float_close(this->rope_yarn_log_mul, other.rope_yarn_log_mul, EPSILON)) return true; if (!is_float_close(this->rope_yarn_log_mul, other.rope_yarn_log_mul, EPSILON)) return true;
if (!is_float_close(this->f_residual_scale, other.f_residual_scale, EPSILON)) return true;
if (!is_float_close(this->f_embedding_scale, other.f_embedding_scale, EPSILON)) return true;
if (!is_float_close(this->f_attention_scale, other.f_attention_scale, EPSILON)) return true;
return false; return false;
} }
@ -6019,6 +6051,20 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN; default: model.type = e_model::MODEL_UNKNOWN;
} }
} break; } break;
case LLM_ARCH_GRANITE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale);
ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale);
ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale);
switch (hparams.n_layer) {
case 40: model.type = e_model::MODEL_3B; break;
// Add additional layer/vocab/etc checks here for other model sizes
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
default: (void)0; default: (void)0;
} }
@ -6717,6 +6763,12 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
} }
if (model.arch == LLM_ARCH_GRANITE) {
LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale);
LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
}
} }
// Returns false if cancelled by progress_callback // Returns false if cancelled by progress_callback
@ -6885,6 +6937,7 @@ static bool llm_load_tensors(
case LLM_ARCH_LLAMA: case LLM_ARCH_LLAMA:
case LLM_ARCH_REFACT: case LLM_ARCH_REFACT:
case LLM_ARCH_MINICPM: case LLM_ARCH_MINICPM:
case LLM_ARCH_GRANITE:
{ {
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@ -8868,6 +8921,11 @@ static struct ggml_tensor * llm_build_inp_embd(
ggml_set_input(lctx.inp_embd); ggml_set_input(lctx.inp_embd);
} }
// For Granite architecture
if (hparams.f_embedding_scale != 0.0f) {
inpL = ggml_scale(ctx, inpL, hparams.f_embedding_scale);
}
cb(inpL, "inp_embd", -1); cb(inpL, "inp_embd", -1);
return inpL; return inpL;
@ -10146,6 +10204,7 @@ struct llm_build_context {
// KQ_mask (mask for 1 head, it will be broadcasted to all heads) // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL; struct ggml_tensor * inpSA = inpL;
@ -10198,7 +10257,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, lctx, kv_self, gf, cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
} }
if (il == n_layer - 1) { if (il == n_layer - 1) {
@ -10209,6 +10268,11 @@ struct llm_build_context {
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
} }
// For Granite architecture
if (hparams.f_residual_scale) {
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
}
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il); cb(ffn_inp, "ffn_inp", il);
@ -10245,6 +10309,11 @@ struct llm_build_context {
cb(cur, "ffn_moe_out", il); cb(cur, "ffn_moe_out", il);
} }
// For Granite architecture
if (hparams.f_residual_scale) {
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
}
cur = ggml_add(ctx0, cur, ffn_inp); cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
@ -10264,6 +10333,12 @@ struct llm_build_context {
// lm_head // lm_head
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
// For Granite architecture
if (hparams.f_logit_scale) {
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
}
cb(cur, "result_output", -1); cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
@ -15789,6 +15864,7 @@ static struct ggml_cgraph * llama_build_graph(
switch (model.arch) { switch (model.arch) {
case LLM_ARCH_LLAMA: case LLM_ARCH_LLAMA:
case LLM_ARCH_GRANITE:
{ {
result = llm.build_llama(); result = llm.build_llama();
} break; } break;
@ -19089,6 +19165,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_ARCTIC: case LLM_ARCH_ARCTIC:
case LLM_ARCH_DEEPSEEK2: case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_CHATGLM: case LLM_ARCH_CHATGLM:
case LLM_ARCH_GRANITE:
return LLAMA_ROPE_TYPE_NORM; return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2 // the pairs of head values are offset by n_rot/2