mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 10:24:35 +00:00
llama : add OpenELM support (#7359)
* Initial OpenELM support (270M only so far) * Fill out missing entries in llama_model_type_name * fixup! Initial OpenELM support (270M only so far) Fix formatting * llama : support all OpenELM models * llama : add variable GQA and variable FFN sizes Some metadata keys can now also be arrays to support setting their value per-layer for models like OpenELM. * llama : minor spacing changes Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * llama : use std::array for per-layer hparams * llama : fix save/load state * llama : do not print hparams for vocab-only models * llama : handle n_head == 0 * llama : use const ref for print_f and fix division by zero * llama : fix t5 uses of n_head and n_ff * llama : minor comment --------- Co-authored-by: Francis Couture-Harpin <git@compilade.net> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
6f63d646c1
commit
d7fd29fff1
@ -13,7 +13,7 @@ import sys
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from hashlib import sha256
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
@ -677,6 +677,51 @@ class Model:
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab_size: int):
|
||||
tokenizer_path = Path(sys.path[0]) / "models" / f"ggml-vocab-{model_name}.gguf"
|
||||
logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
|
||||
vocab_reader = gguf.GGUFReader(tokenizer_path, "r")
|
||||
|
||||
default_pre = "mpt" if model_name == "gpt-neox" else "default"
|
||||
|
||||
field = vocab_reader.get_field(gguf.Keys.Tokenizer.MODEL)
|
||||
assert field # tokenizer model
|
||||
self.gguf_writer.add_tokenizer_model(bytes(field.parts[-1]).decode("utf-8"))
|
||||
|
||||
field = vocab_reader.get_field(gguf.Keys.Tokenizer.PRE)
|
||||
self.gguf_writer.add_tokenizer_pre(bytes(field.parts[-1]).decode("utf-8") if field else default_pre)
|
||||
|
||||
field = vocab_reader.get_field(gguf.Keys.Tokenizer.LIST)
|
||||
assert field # token list
|
||||
self.gguf_writer.add_token_list([bytes(field.parts[i]) for i in field.data][:vocab_size])
|
||||
|
||||
if model_name == "llama-spm":
|
||||
field = vocab_reader.get_field(gguf.Keys.Tokenizer.SCORES)
|
||||
assert field # token scores
|
||||
self.gguf_writer.add_token_scores([field.parts[i].tolist()[0] for i in field.data][:vocab_size])
|
||||
|
||||
field = vocab_reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE)
|
||||
assert field # token types
|
||||
self.gguf_writer.add_token_types([field.parts[i].tolist()[0] for i in field.data][:vocab_size])
|
||||
|
||||
if model_name != "llama-spm":
|
||||
field = vocab_reader.get_field(gguf.Keys.Tokenizer.MERGES)
|
||||
assert field # token merges
|
||||
self.gguf_writer.add_token_merges([bytes(field.parts[i]) for i in field.data])
|
||||
|
||||
if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.BOS_ID)) is not None:
|
||||
self.gguf_writer.add_bos_token_id(field.parts[-1].tolist()[0])
|
||||
if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.EOS_ID)) is not None:
|
||||
self.gguf_writer.add_eos_token_id(field.parts[-1].tolist()[0])
|
||||
if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.UNK_ID)) is not None:
|
||||
self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0])
|
||||
if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.PAD_ID)) is not None:
|
||||
self.gguf_writer.add_pad_token_id(field.parts[-1].tolist()[0])
|
||||
if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.ADD_BOS)) is not None:
|
||||
self.gguf_writer.add_add_bos_token(field.parts[-1].tolist()[0])
|
||||
if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.ADD_EOS)) is not None:
|
||||
self.gguf_writer.add_add_eos_token(field.parts[-1].tolist()[0])
|
||||
|
||||
|
||||
@Model.register("GPTNeoXForCausalLM")
|
||||
class GPTNeoXModel(Model):
|
||||
@ -2439,39 +2484,7 @@ class MambaModel(Model):
|
||||
self._set_vocab_sentencepiece()
|
||||
else:
|
||||
# Use the GPT-NeoX tokenizer when no tokenizer files are present
|
||||
tokenizer_path = Path(sys.path[0]) / "models" / "ggml-vocab-gpt-neox.gguf"
|
||||
logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
|
||||
neox_reader = gguf.GGUFReader(tokenizer_path, "r")
|
||||
|
||||
field = neox_reader.get_field(gguf.Keys.Tokenizer.MODEL)
|
||||
self.gguf_writer.add_tokenizer_model(bytes(field.parts[-1]).decode("utf-8") if field else "gpt2")
|
||||
|
||||
field = neox_reader.get_field(gguf.Keys.Tokenizer.PRE)
|
||||
self.gguf_writer.add_tokenizer_pre(bytes(field.parts[-1]).decode("utf-8") if field else "mpt")
|
||||
|
||||
field = neox_reader.get_field(gguf.Keys.Tokenizer.LIST)
|
||||
assert field
|
||||
self.gguf_writer.add_token_list([bytes(field.parts[i]) for i in field.data][:vocab_size])
|
||||
|
||||
field = neox_reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE)
|
||||
assert field
|
||||
self.gguf_writer.add_token_types([field.parts[i].tolist()[0] for i in field.data][:vocab_size])
|
||||
|
||||
field = neox_reader.get_field(gguf.Keys.Tokenizer.MERGES)
|
||||
assert field
|
||||
self.gguf_writer.add_token_merges([bytes(field.parts[i]) for i in field.data])
|
||||
|
||||
field = neox_reader.get_field(gguf.Keys.Tokenizer.BOS_ID)
|
||||
self.gguf_writer.add_bos_token_id(field.parts[-1].tolist()[0] if field else 1)
|
||||
|
||||
field = neox_reader.get_field(gguf.Keys.Tokenizer.EOS_ID)
|
||||
self.gguf_writer.add_eos_token_id(field.parts[-1].tolist()[0] if field else 0)
|
||||
|
||||
field = neox_reader.get_field(gguf.Keys.Tokenizer.UNK_ID)
|
||||
self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0] if field else 0)
|
||||
|
||||
field = neox_reader.get_field(gguf.Keys.Tokenizer.PAD_ID)
|
||||
self.gguf_writer.add_pad_token_id(field.parts[-1].tolist()[0] if field else 0)
|
||||
self._set_vocab_builtin("gpt-neox", vocab_size)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
d_model = self.find_hparam(["hidden_size", "d_model"])
|
||||
@ -2623,6 +2636,82 @@ class JinaBertV2Model(BertModel):
|
||||
self.gguf_writer.add_add_eos_token(True)
|
||||
|
||||
|
||||
@Model.register("OpenELMForCausalLM")
|
||||
class OpenELMModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.OPENELM
|
||||
|
||||
@staticmethod
|
||||
def _make_divisible(v: float | int, divisor: int) -> int:
|
||||
# ref: https://huggingface.co/apple/OpenELM-270M-Instruct/blob/eb111ff2e6724348e5b905984063d4064d4bc579/configuration_openelm.py#L34-L38
|
||||
new_v = max(divisor, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
ffn_multipliers: list[float] = self.hparams["ffn_multipliers"]
|
||||
ffn_dim_divisor: int = self.hparams["ffn_dim_divisor"]
|
||||
self._n_embd: int = self.hparams["model_dim"]
|
||||
self._num_kv_heads: list[int] = self.hparams["num_kv_heads"]
|
||||
self._num_query_heads: list[int] = self.hparams["num_query_heads"]
|
||||
self._ffn_dims: list[int] = [
|
||||
OpenELMModel._make_divisible(multiplier * self._n_embd, ffn_dim_divisor)
|
||||
for multiplier in ffn_multipliers
|
||||
]
|
||||
assert isinstance(self._num_kv_heads, list) and isinstance(self._num_kv_heads[0], int)
|
||||
assert isinstance(self._num_query_heads, list) and isinstance(self._num_query_heads[0], int)
|
||||
|
||||
# Uses the tokenizer from meta-llama/Llama-2-7b-hf
|
||||
def set_vocab(self):
|
||||
try:
|
||||
self._set_vocab_sentencepiece()
|
||||
except FileNotFoundError:
|
||||
self._set_vocab_builtin("llama-spm", self.hparams["vocab_size"])
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
n_embd = self._n_embd
|
||||
head_dim = self.hparams["head_dim"]
|
||||
rot_pct = 1.0
|
||||
assert self.block_count == len(self._num_kv_heads)
|
||||
assert self.block_count == len(self._num_query_heads)
|
||||
assert self.block_count == len(self._ffn_dims)
|
||||
|
||||
self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_context_length(self.hparams["max_context_length"])
|
||||
self.gguf_writer.add_embedding_length(n_embd)
|
||||
self.gguf_writer.add_feed_forward_length(self._ffn_dims)
|
||||
self.gguf_writer.add_head_count(self._num_query_heads)
|
||||
self.gguf_writer.add_head_count_kv(self._num_kv_heads)
|
||||
self.gguf_writer.add_rope_freq_base(self.hparams["rope_freq_constant"])
|
||||
# https://huggingface.co/apple/OpenELM-270M-Instruct/blob/c401df2/modeling_openelm.py#L30
|
||||
self.gguf_writer.add_layer_norm_rms_eps(1e-6)
|
||||
self.gguf_writer.add_rope_dimension_count(int(rot_pct * head_dim))
|
||||
self.gguf_writer.add_key_length(head_dim)
|
||||
self.gguf_writer.add_value_length(head_dim)
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any:
|
||||
if "n_layers" in keys:
|
||||
return self.hparams["num_transformer_layers"]
|
||||
|
||||
return super().find_hparam(keys, optional)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
|
||||
# split ff
|
||||
if bid is not None and name == f"transformer.layers.{bid}.ffn.proj_1.weight":
|
||||
ff_dim = self._ffn_dims[bid]
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), data_torch[:ff_dim])
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), data_torch[ff_dim:])
|
||||
return
|
||||
|
||||
yield (self.map_tensor_name(name), data_torch)
|
||||
|
||||
|
||||
@Model.register("ArcticForCausalLM")
|
||||
class ArcticModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.ARCTIC
|
||||
|
@ -160,6 +160,7 @@ class MODEL_ARCH(IntEnum):
|
||||
COMMAND_R = auto()
|
||||
DBRX = auto()
|
||||
OLMO = auto()
|
||||
OPENELM = auto()
|
||||
ARCTIC = auto()
|
||||
DEEPSEEK2 = auto()
|
||||
BITNET = auto()
|
||||
@ -285,6 +286,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.COMMAND_R: "command-r",
|
||||
MODEL_ARCH.DBRX: "dbrx",
|
||||
MODEL_ARCH.OLMO: "olmo",
|
||||
MODEL_ARCH.OPENELM: "openelm",
|
||||
MODEL_ARCH.ARCTIC: "arctic",
|
||||
MODEL_ARCH.DEEPSEEK2: "deepseek2",
|
||||
MODEL_ARCH.BITNET: "bitnet",
|
||||
@ -861,6 +863,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.OPENELM: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.ARCTIC: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
@ -480,8 +480,11 @@ class GGUFWriter:
|
||||
def add_leading_dense_block_count(self, length: int) -> None:
|
||||
self.add_uint32(Keys.LLM.LEADING_DENSE_BLOCK_COUNT.format(arch=self.arch), length)
|
||||
|
||||
def add_feed_forward_length(self, length: int) -> None:
|
||||
def add_feed_forward_length(self, length: int | Sequence[int]) -> None:
|
||||
if isinstance(length, int):
|
||||
self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
|
||||
else:
|
||||
self.add_array(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
|
||||
|
||||
def add_expert_feed_forward_length(self, length: int) -> None:
|
||||
self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
|
||||
@ -495,11 +498,17 @@ class GGUFWriter:
|
||||
def add_decoder_start_token_id(self, id: int) -> None:
|
||||
self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id)
|
||||
|
||||
def add_head_count(self, count: int) -> None:
|
||||
def add_head_count(self, count: int | Sequence[int]) -> None:
|
||||
if isinstance(count, int):
|
||||
self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
|
||||
else:
|
||||
self.add_array(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
|
||||
|
||||
def add_head_count_kv(self, count: int) -> None:
|
||||
def add_head_count_kv(self, count: int | Sequence[int]) -> None:
|
||||
if isinstance(count, int):
|
||||
self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
|
||||
else:
|
||||
self.add_array(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
|
||||
|
||||
def add_key_length(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.KEY_LENGTH.format(arch=self.arch), length)
|
||||
|
@ -24,6 +24,7 @@ class TensorNameMap:
|
||||
"backbone.embedding", # mamba
|
||||
"backbone.embeddings", # mamba-hf
|
||||
"transformer.in_out_embed", # Grok
|
||||
"transformer.token_embeddings", # openelm
|
||||
"shared", # t5
|
||||
),
|
||||
|
||||
@ -37,6 +38,7 @@ class TensorNameMap:
|
||||
"word_embeddings_layernorm", # bloom
|
||||
"embeddings.LayerNorm", # bert
|
||||
"emb_ln", # nomic-bert
|
||||
"transformer.norm", # openelm
|
||||
),
|
||||
|
||||
# Position embeddings
|
||||
@ -69,6 +71,7 @@ class TensorNameMap:
|
||||
"model.norm_f", # mamba-qbert
|
||||
"backbone.norm_f", # mamba
|
||||
"transformer.rms_norm", # Grok
|
||||
"transformer.norm", # openelm
|
||||
),
|
||||
|
||||
# Rope frequencies
|
||||
@ -98,6 +101,7 @@ class TensorNameMap:
|
||||
"backbone.layers.{bid}.norm", # mamba
|
||||
"transformer.decoder_layer.{bid}.rms_norm", # Grok
|
||||
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
|
||||
"transformer.layers.{bid}.attn_norm", # openelm
|
||||
),
|
||||
|
||||
# Attention norm 2
|
||||
@ -119,7 +123,8 @@ class TensorNameMap:
|
||||
"h.{bid}.attn.c_attn", # gpt2
|
||||
"transformer.h.{bid}.mixer.Wqkv", # phi2
|
||||
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
|
||||
"model.layers.{bid}.self_attn.qkv_proj" # phi3
|
||||
"model.layers.{bid}.self_attn.qkv_proj", # phi3
|
||||
"transformer.layers.{bid}.attn.qkv_proj", # openelm
|
||||
),
|
||||
|
||||
# Attention query
|
||||
@ -177,6 +182,7 @@ class TensorNameMap:
|
||||
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
|
||||
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
|
||||
"transformer.layers.{bid}.attn.out_proj", # openelm
|
||||
),
|
||||
|
||||
# Attention output norm
|
||||
@ -212,6 +218,7 @@ class TensorNameMap:
|
||||
"h.{bid}.ln_2", # gpt2
|
||||
"model.layers.{bid}.ffn_norm", # internlm2
|
||||
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
||||
"transformer.layers.{bid}.ffn_norm", # openelm
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
@ -327,6 +334,7 @@ class TensorNameMap:
|
||||
"encoder.layers.{bid}.mlp.fc2", # nomic-bert
|
||||
"model.layers.{bid}.mlp.c_proj", # starcoder2
|
||||
"encoder.layer.{bid}.mlp.wo", # jina-bert-v2
|
||||
"transformer.layers.{bid}.ffn.proj_2", # openelm
|
||||
"model.layers.{bid}.residual_mlp.w2", # arctic
|
||||
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
|
||||
),
|
||||
@ -348,7 +356,8 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
|
||||
"model.layers.{bid}.self_attn.q_norm", # cohere
|
||||
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
|
||||
"encoder.layer.{bid}.attention.self.layer_norm_q" # jina-bert-v2
|
||||
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
|
||||
"transformer.layers.{bid}.attn.q_norm", # openelm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_K_NORM: (
|
||||
@ -356,7 +365,8 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
|
||||
"model.layers.{bid}.self_attn.k_norm", # cohere
|
||||
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
|
||||
"encoder.layer.{bid}.attention.self.layer_norm_k" # jina-bert-v2
|
||||
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
|
||||
"transformer.layers.{bid}.attn.k_norm", # openelm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ROPE_FREQS: (
|
||||
|
569
src/llama.cpp
569
src/llama.cpp
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user