mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 03:14:35 +00:00
llama : support for Llama-3_1-Nemotron-51B (#10669)
* conflict resolution * move comments after bracket to its own line
This commit is contained in:
parent
dab76c92cc
commit
6f0c9e034b
@ -1692,6 +1692,184 @@ class LlamaModel(Model):
|
|||||||
raise ValueError(f"Unprocessed experts: {experts}")
|
raise ValueError(f"Unprocessed experts: {experts}")
|
||||||
|
|
||||||
|
|
||||||
|
@Model.register("DeciLMForCausalLM")
|
||||||
|
class DeciModel(Model):
|
||||||
|
model_arch = gguf.MODEL_ARCH.DECI
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int:
|
||||||
|
# DeciLM-specific code
|
||||||
|
intermediate_size = int(2 * ffn_mult * n_embd / 3)
|
||||||
|
return DeciModel._find_multiple(intermediate_size, 256)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _find_multiple(n: int, k: int) -> int:
|
||||||
|
# DeciLM-specific code
|
||||||
|
if n % k == 0:
|
||||||
|
return n
|
||||||
|
return n + k - (n % k)
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
if "block_configs" in self.hparams: # Llama-3_1-Nemotron-51B
|
||||||
|
_block_configs: list[dict[str,Any]] = self.hparams["block_configs"]
|
||||||
|
assert self.block_count == len(_block_configs)
|
||||||
|
self._num_kv_heads = list()
|
||||||
|
self._num_heads = list()
|
||||||
|
_ffn_multipliers = list()
|
||||||
|
# ***linear attention layer***
|
||||||
|
# if n_heads_in_group is None and replace_with_linear is True
|
||||||
|
# then _num_kv_heads[il] is 0 and _num_heads[il] is num_attention_heads
|
||||||
|
# ***attention-free layer***
|
||||||
|
# if n_heads_in_group is None and replace_with_linear is False
|
||||||
|
# then _num_kv_heads[il] is 0 and _num_heads[il] is 0
|
||||||
|
# ***normal attention-layer***
|
||||||
|
# if n_heads_in_group is not None, then
|
||||||
|
# _num_kv_heads[il] is num_attention_head // n_heads_in_group and
|
||||||
|
# _num_heads[il] is num_attention_head
|
||||||
|
for il in range(len(_block_configs)):
|
||||||
|
if _block_configs[il]["attention"]["n_heads_in_group"] is None:
|
||||||
|
if _block_configs[il]["attention"]["replace_with_linear"] is True:
|
||||||
|
self._num_kv_heads.append(0)
|
||||||
|
self._num_heads.append(self.hparams["num_attention_heads"])
|
||||||
|
else:
|
||||||
|
self._num_kv_heads.append(0)
|
||||||
|
self._num_heads.append(0)
|
||||||
|
else:
|
||||||
|
self._num_kv_heads.append(self.hparams["num_attention_heads"] // _block_configs[il]["attention"]["n_heads_in_group"])
|
||||||
|
self._num_heads.append(self.hparams["num_attention_heads"])
|
||||||
|
_ffn_multipliers.append(_block_configs[il]["ffn"]["ffn_mult"])
|
||||||
|
assert self.block_count == len(self._num_kv_heads)
|
||||||
|
assert self.block_count == len(self._num_heads)
|
||||||
|
assert self.block_count == len(_ffn_multipliers)
|
||||||
|
assert isinstance(self._num_kv_heads, list) and isinstance(self._num_kv_heads[0], int)
|
||||||
|
assert isinstance(self._num_heads, list) and isinstance(self._num_heads[0], int)
|
||||||
|
assert isinstance(_ffn_multipliers, list) and isinstance(_ffn_multipliers[0], float)
|
||||||
|
self._ffn_dims: list[int] = [
|
||||||
|
DeciModel._ffn_mult_to_intermediate_size(multiplier, self.hparams["hidden_size"])
|
||||||
|
for multiplier in _ffn_multipliers
|
||||||
|
]
|
||||||
|
|
||||||
|
def set_vocab(self):
|
||||||
|
# Please change tokenizer_config.json of Llama-3_1-Nemotron-51B's
|
||||||
|
# eos_token from '|eot_id|' to '|end_of_text|'
|
||||||
|
if self.hparams.get("vocab_size", 128256) == 128256:
|
||||||
|
tokens, toktypes, tokpre = self.get_vocab_base()
|
||||||
|
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||||
|
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||||
|
self.gguf_writer.add_token_list(tokens)
|
||||||
|
self.gguf_writer.add_token_types(toktypes)
|
||||||
|
|
||||||
|
special_vocab = gguf.SpecialVocab(
|
||||||
|
self.dir_model, load_merges=True,
|
||||||
|
special_token_types = ['bos', 'eos', 'eom', 'eot']
|
||||||
|
)
|
||||||
|
special_vocab._set_special_token("bos", 128000)
|
||||||
|
special_vocab._set_special_token("eos", 128001)
|
||||||
|
special_vocab._set_special_token("eom", 128008)
|
||||||
|
special_vocab._set_special_token("eot", 128009)
|
||||||
|
special_vocab.add_to_gguf(self.gguf_writer)
|
||||||
|
else:
|
||||||
|
# DeciLM-7B
|
||||||
|
self._set_vocab_llama_hf()
|
||||||
|
# self._set_vocab_gpt2()
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
if "block_configs" in self.hparams: # Llama-3_1-Nemotron-51B
|
||||||
|
assert self.block_count == len(self._num_kv_heads)
|
||||||
|
assert self.block_count == len(self._num_heads)
|
||||||
|
assert self.block_count == len(self._ffn_dims)
|
||||||
|
self.gguf_writer.add_head_count_kv(self._num_kv_heads)
|
||||||
|
self.gguf_writer.add_head_count(self._num_heads)
|
||||||
|
self.gguf_writer.add_feed_forward_length(self._ffn_dims)
|
||||||
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
|
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
||||||
|
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||||
|
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
|
||||||
|
self.gguf_writer.add_key_length(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||||
|
self.gguf_writer.add_value_length(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||||
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
|
else: # DeciLM-7B
|
||||||
|
super().set_gguf_parameters()
|
||||||
|
if "num_key_value_heads_per_layer" in self.hparams: # DeciLM-7B
|
||||||
|
self._num_kv_heads: list[int] = self.hparams["num_key_value_heads_per_layer"]
|
||||||
|
assert self.block_count == len(self._num_kv_heads)
|
||||||
|
self.gguf_writer.add_head_count_kv(self._num_kv_heads)
|
||||||
|
hparams = self.hparams
|
||||||
|
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||||
|
|
||||||
|
if "head_dim" in hparams:
|
||||||
|
rope_dim = hparams["head_dim"]
|
||||||
|
else:
|
||||||
|
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
|
||||||
|
self.gguf_writer.add_rope_dimension_count(rope_dim)
|
||||||
|
|
||||||
|
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
|
||||||
|
if self.hparams["rope_scaling"].get("type") == "linear":
|
||||||
|
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||||
|
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
|
||||||
|
if n_head_kv is not None and n_head != n_head_kv:
|
||||||
|
n_head = n_head_kv
|
||||||
|
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
||||||
|
.swapaxes(1, 2)
|
||||||
|
.reshape(weights.shape))
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
n_head = self.hparams["num_attention_heads"]
|
||||||
|
if bid is not None:
|
||||||
|
if "num_key_value_heads_per_layer" in self.hparams:
|
||||||
|
n_kv_head = self.hparams["num_key_value_heads_per_layer"][bid]
|
||||||
|
elif "block_configs" in self.hparams:
|
||||||
|
n_kv_head = self._num_kv_heads[bid]
|
||||||
|
n_head = self._num_heads[bid]
|
||||||
|
else:
|
||||||
|
n_kv_head = self.hparams.get("num_key_value_heads")
|
||||||
|
else:
|
||||||
|
n_kv_head = self.hparams.get("num_key_value_heads")
|
||||||
|
|
||||||
|
if name.endswith(("q_proj.weight", "q_proj.bias")):
|
||||||
|
data_torch = DeciModel.permute(data_torch, n_head, n_head)
|
||||||
|
if name.endswith(("k_proj.weight", "k_proj.bias")):
|
||||||
|
data_torch = DeciModel.permute(data_torch, n_head, n_kv_head)
|
||||||
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
|
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
|
||||||
|
if rope_scaling.get("rope_type", '').lower() == "llama3":
|
||||||
|
base = self.hparams.get("rope_theta", 10000.0)
|
||||||
|
dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||||
|
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
||||||
|
|
||||||
|
factor = rope_scaling.get("factor", 8.0)
|
||||||
|
low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
|
||||||
|
high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
|
||||||
|
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
|
||||||
|
|
||||||
|
low_freq_wavelen = old_context_len / low_freq_factor
|
||||||
|
high_freq_wavelen = old_context_len / high_freq_factor
|
||||||
|
assert low_freq_wavelen != high_freq_wavelen
|
||||||
|
|
||||||
|
rope_factors = []
|
||||||
|
for freq in freqs:
|
||||||
|
wavelen = 2 * math.pi / freq
|
||||||
|
if wavelen < high_freq_wavelen:
|
||||||
|
rope_factors.append(1)
|
||||||
|
elif wavelen > low_freq_wavelen:
|
||||||
|
rope_factors.append(factor)
|
||||||
|
else:
|
||||||
|
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||||
|
rope_factors.append(1 / ((1 - smooth) / factor + smooth))
|
||||||
|
|
||||||
|
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32))
|
||||||
|
|
||||||
|
def prepare_tensors(self):
|
||||||
|
super().prepare_tensors()
|
||||||
|
|
||||||
|
|
||||||
@Model.register("BitnetForCausalLM")
|
@Model.register("BitnetForCausalLM")
|
||||||
class BitnetModel(Model):
|
class BitnetModel(Model):
|
||||||
model_arch = gguf.MODEL_ARCH.BITNET
|
model_arch = gguf.MODEL_ARCH.BITNET
|
||||||
|
@ -221,6 +221,7 @@ class GGUFType:
|
|||||||
|
|
||||||
class MODEL_ARCH(IntEnum):
|
class MODEL_ARCH(IntEnum):
|
||||||
LLAMA = auto()
|
LLAMA = auto()
|
||||||
|
DECI = auto()
|
||||||
FALCON = auto()
|
FALCON = auto()
|
||||||
BAICHUAN = auto()
|
BAICHUAN = auto()
|
||||||
GROK = auto()
|
GROK = auto()
|
||||||
@ -402,6 +403,7 @@ class MODEL_TENSOR(IntEnum):
|
|||||||
|
|
||||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
MODEL_ARCH.LLAMA: "llama",
|
MODEL_ARCH.LLAMA: "llama",
|
||||||
|
MODEL_ARCH.DECI: "deci",
|
||||||
MODEL_ARCH.FALCON: "falcon",
|
MODEL_ARCH.FALCON: "falcon",
|
||||||
MODEL_ARCH.BAICHUAN: "baichuan",
|
MODEL_ARCH.BAICHUAN: "baichuan",
|
||||||
MODEL_ARCH.GROK: "grok",
|
MODEL_ARCH.GROK: "grok",
|
||||||
@ -602,6 +604,26 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||||
MODEL_TENSOR.FFN_UP_EXP,
|
MODEL_TENSOR.FFN_UP_EXP,
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.DECI: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ROPE_FREQS,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_Q,
|
||||||
|
MODEL_TENSOR.ATTN_K,
|
||||||
|
MODEL_TENSOR.ATTN_V,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||||
|
MODEL_TENSOR.FFN_GATE_INP,
|
||||||
|
MODEL_TENSOR.FFN_NORM,
|
||||||
|
MODEL_TENSOR.FFN_GATE,
|
||||||
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
|
MODEL_TENSOR.FFN_UP,
|
||||||
|
MODEL_TENSOR.FFN_GATE_EXP,
|
||||||
|
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||||
|
MODEL_TENSOR.FFN_UP_EXP,
|
||||||
|
],
|
||||||
MODEL_ARCH.GROK: [
|
MODEL_ARCH.GROK: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
MODEL_TENSOR.OUTPUT_NORM,
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
@ -1448,6 +1470,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||||||
MODEL_TENSOR.ROPE_FREQS,
|
MODEL_TENSOR.ROPE_FREQS,
|
||||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.DECI: [
|
||||||
|
MODEL_TENSOR.ROPE_FREQS,
|
||||||
|
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||||
|
],
|
||||||
MODEL_ARCH.BAICHUAN: [
|
MODEL_ARCH.BAICHUAN: [
|
||||||
MODEL_TENSOR.ROPE_FREQS,
|
MODEL_TENSOR.ROPE_FREQS,
|
||||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||||
|
@ -198,6 +198,7 @@ class TensorNameMap:
|
|||||||
"transformer.h.{bid}.self_attention.dense", # falcon
|
"transformer.h.{bid}.self_attention.dense", # falcon
|
||||||
"h.{bid}.self_attention.dense", # bloom
|
"h.{bid}.self_attention.dense", # bloom
|
||||||
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2
|
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2
|
||||||
|
"model.layers.{bid}.self_attn.linear_attn", # deci
|
||||||
"layers.{bid}.attention.wo", # llama-pth
|
"layers.{bid}.attention.wo", # llama-pth
|
||||||
"encoder.layer.{bid}.attention.output.dense", # bert
|
"encoder.layer.{bid}.attention.output.dense", # bert
|
||||||
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
||||||
|
267
src/llama.cpp
267
src/llama.cpp
@ -146,6 +146,7 @@ static std::string format(const char * fmt, ...) {
|
|||||||
|
|
||||||
enum llm_arch {
|
enum llm_arch {
|
||||||
LLM_ARCH_LLAMA,
|
LLM_ARCH_LLAMA,
|
||||||
|
LLM_ARCH_DECI,
|
||||||
LLM_ARCH_FALCON,
|
LLM_ARCH_FALCON,
|
||||||
LLM_ARCH_BAICHUAN,
|
LLM_ARCH_BAICHUAN,
|
||||||
LLM_ARCH_GROK,
|
LLM_ARCH_GROK,
|
||||||
@ -203,6 +204,7 @@ enum llm_arch {
|
|||||||
|
|
||||||
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||||
{ LLM_ARCH_LLAMA, "llama" },
|
{ LLM_ARCH_LLAMA, "llama" },
|
||||||
|
{ LLM_ARCH_DECI, "deci" },
|
||||||
{ LLM_ARCH_FALCON, "falcon" },
|
{ LLM_ARCH_FALCON, "falcon" },
|
||||||
{ LLM_ARCH_GROK, "grok" },
|
{ LLM_ARCH_GROK, "grok" },
|
||||||
{ LLM_ARCH_GPT2, "gpt2" },
|
{ LLM_ARCH_GPT2, "gpt2" },
|
||||||
@ -674,6 +676,32 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_DECI,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||||
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
|
||||||
|
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_BAICHUAN,
|
LLM_ARCH_BAICHUAN,
|
||||||
{
|
{
|
||||||
@ -5694,7 +5722,7 @@ static void llm_load_hparams(
|
|||||||
|
|
||||||
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
|
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
|
||||||
|
|
||||||
if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
|
if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_DECI || model.arch == LLM_ARCH_FALCON) {
|
||||||
if (hparams.n_rot != hparams.n_embd_head_k) {
|
if (hparams.n_rot != hparams.n_embd_head_k) {
|
||||||
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
|
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
|
||||||
}
|
}
|
||||||
@ -5734,6 +5762,15 @@ static void llm_load_hparams(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_DECI:
|
||||||
|
{
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
switch (hparams.n_layer) {
|
||||||
|
case 32: model.type = e_model::MODEL_7B; break;
|
||||||
|
case 80: model.type = e_model::MODEL_70B; break;
|
||||||
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case LLM_ARCH_MINICPM:
|
case LLM_ARCH_MINICPM:
|
||||||
{
|
{
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
@ -7939,6 +7976,68 @@ static bool llm_load_tensors(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_DECI:
|
||||||
|
{
|
||||||
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||||
|
|
||||||
|
// output
|
||||||
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||||
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
|
||||||
|
// if output is NULL, init from the input tok embed
|
||||||
|
if (model.output == NULL) {
|
||||||
|
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < n_layer; ++i) {
|
||||||
|
auto & layer = model.layers[i];
|
||||||
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(i);
|
||||||
|
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(i);
|
||||||
|
const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i);
|
||||||
|
const int64_t n_ff = hparams.n_ff(i);
|
||||||
|
const int64_t n_head = hparams.n_head(i);
|
||||||
|
const int64_t n_head_kv = hparams.n_head_kv(i);
|
||||||
|
|
||||||
|
if (n_head_kv == 0 && n_head > 0) {
|
||||||
|
// linear attention for DeciLMCausalModel
|
||||||
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||||
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
||||||
|
}
|
||||||
|
else if (n_head_kv > 0) {
|
||||||
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||||
|
|
||||||
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||||
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
||||||
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
||||||
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// optional bias tensors
|
||||||
|
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
|
||||||
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||||
|
|
||||||
|
if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
|
||||||
|
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||||
|
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||||
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||||
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||||
|
|
||||||
|
// optional MLP bias
|
||||||
|
layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case LLM_ARCH_MINICPM3:
|
case LLM_ARCH_MINICPM3:
|
||||||
{
|
{
|
||||||
const int64_t n_embd_head_qk_rope = hparams.n_rot;
|
const int64_t n_embd_head_qk_rope = hparams.n_rot;
|
||||||
@ -11308,6 +11407,167 @@ struct llm_build_context {
|
|||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_cgraph * build_deci() {
|
||||||
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||||
|
|
||||||
|
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
||||||
|
int32_t n_tokens = this->n_tokens;
|
||||||
|
|
||||||
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||||
|
|
||||||
|
struct ggml_tensor * cur;
|
||||||
|
struct ggml_tensor * inpL;
|
||||||
|
|
||||||
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
||||||
|
|
||||||
|
// inp_pos - contains the positions
|
||||||
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||||
|
|
||||||
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||||
|
|
||||||
|
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||||
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
struct ggml_tensor * inpSA = inpL;
|
||||||
|
const int64_t n_head_kv = hparams.n_head_kv(il);
|
||||||
|
const int64_t n_head = hparams.n_head(il);
|
||||||
|
|
||||||
|
if (n_head == 0) {
|
||||||
|
// attention-free layer of Llama-3_1-Nemotron-51B
|
||||||
|
cur = inpL;
|
||||||
|
} else {
|
||||||
|
// norm
|
||||||
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||||
|
model.layers[il].attn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, il);
|
||||||
|
cb(cur, "attn_norm", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n_head > 0 && n_head_kv == 0) {
|
||||||
|
// "linear attention" of Llama-3_1-Nemotron-51B
|
||||||
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
|
||||||
|
cb(cur, "wo", il);
|
||||||
|
} else if (n_head > 0) {
|
||||||
|
// self-attention
|
||||||
|
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
||||||
|
struct ggml_tensor * rope_factors = build_rope_factors(il);
|
||||||
|
|
||||||
|
// compute Q and K and RoPE them
|
||||||
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
if (model.layers[il].bq) {
|
||||||
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
if (model.layers[il].bk) {
|
||||||
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
if (model.layers[il].bv) {
|
||||||
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
Qcur = ggml_rope_ext(
|
||||||
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
|
||||||
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
|
Kcur = ggml_rope_ext(
|
||||||
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
|
||||||
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||||
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (il == n_layer - 1) {
|
||||||
|
// skip computing output for unused tokens
|
||||||
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||||
|
n_tokens = n_outputs;
|
||||||
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||||
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
// For Granite architecture
|
||||||
|
if (hparams.f_residual_scale) {
|
||||||
|
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
// modified to support attention-free layer of Llama-3_1-Nemotron-51B
|
||||||
|
struct ggml_tensor * ffn_inp = cur;
|
||||||
|
if (n_head > 0) {
|
||||||
|
ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||||
|
cb(ffn_inp, "ffn_inp", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
// feed-forward network
|
||||||
|
if (model.layers[il].ffn_gate_inp == nullptr) {
|
||||||
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||||
|
model.layers[il].ffn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, il);
|
||||||
|
cb(cur, "ffn_norm", il);
|
||||||
|
|
||||||
|
cur = llm_build_ffn(ctx0, lctx, cur,
|
||||||
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
||||||
|
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
|
||||||
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
||||||
|
NULL,
|
||||||
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
// For Granite architecture
|
||||||
|
if (hparams.f_residual_scale) {
|
||||||
|
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
|
||||||
|
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||||
|
cb(cur, "l_out", il);
|
||||||
|
|
||||||
|
// input for next layer
|
||||||
|
inpL = cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = inpL;
|
||||||
|
|
||||||
|
cur = llm_build_norm(ctx0, cur, hparams,
|
||||||
|
model.output_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, -1);
|
||||||
|
cb(cur, "result_norm", -1);
|
||||||
|
|
||||||
|
// lm_head
|
||||||
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
||||||
|
|
||||||
|
// For Granite architecture
|
||||||
|
if (hparams.f_logit_scale) {
|
||||||
|
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
cb(cur, "result_output", -1);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_cgraph * build_baichuan() {
|
struct ggml_cgraph * build_baichuan() {
|
||||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||||
|
|
||||||
@ -17422,6 +17682,10 @@ static struct ggml_cgraph * llama_build_graph(
|
|||||||
{
|
{
|
||||||
result = llm.build_llama();
|
result = llm.build_llama();
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_DECI:
|
||||||
|
{
|
||||||
|
result = llm.build_deci();
|
||||||
|
} break;
|
||||||
case LLM_ARCH_BAICHUAN:
|
case LLM_ARCH_BAICHUAN:
|
||||||
{
|
{
|
||||||
result = llm.build_baichuan();
|
result = llm.build_baichuan();
|
||||||
@ -20797,6 +21061,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
|||||||
|
|
||||||
// use what we call a normal RoPE, operating on pairs of consecutive head values
|
// use what we call a normal RoPE, operating on pairs of consecutive head values
|
||||||
case LLM_ARCH_LLAMA:
|
case LLM_ARCH_LLAMA:
|
||||||
|
case LLM_ARCH_DECI:
|
||||||
case LLM_ARCH_BAICHUAN:
|
case LLM_ARCH_BAICHUAN:
|
||||||
case LLM_ARCH_STARCODER:
|
case LLM_ARCH_STARCODER:
|
||||||
case LLM_ARCH_PLAMO:
|
case LLM_ARCH_PLAMO:
|
||||||
|
Loading…
Reference in New Issue
Block a user