mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 13:30:35 +00:00
convert-hf : make model class definitions self-contained (#5825)
This commit is contained in:
parent
bbde6eb256
commit
c7a0ad8ec9
@ -8,9 +8,10 @@ import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, ContextManager, Iterator, Sequence, cast
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterator, Sequence, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -35,8 +36,11 @@ class SentencePieceTokenTypes(IntEnum):
|
||||
UNUSED = 5
|
||||
BYTE = 6
|
||||
|
||||
AnyModel = TypeVar("AnyModel", bound="type[Model]")
|
||||
|
||||
class Model(ABC):
|
||||
_model_classes: dict[str, type[Model]] = {}
|
||||
|
||||
class Model:
|
||||
def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool):
|
||||
self.dir_model = dir_model
|
||||
self.ftype = ftype
|
||||
@ -47,10 +51,14 @@ class Model:
|
||||
self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin")
|
||||
self.part_names = self._get_part_names()
|
||||
self.hparams = Model.load_hparams(self.dir_model)
|
||||
self.model_arch = self._get_model_architecture()
|
||||
self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=False)
|
||||
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model_arch(self) -> gguf.MODEL_ARCH:
|
||||
pass
|
||||
|
||||
def find_hparam(self, keys: Sequence[str], optional: bool = False) -> Any:
|
||||
key = next((k for k in keys if k in self.hparams), None)
|
||||
if key is not None:
|
||||
@ -176,55 +184,21 @@ class Model:
|
||||
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
@staticmethod
|
||||
def from_model_architecture(model_architecture):
|
||||
if model_architecture == "GPTNeoXForCausalLM":
|
||||
return GPTNeoXModel
|
||||
if model_architecture == "BloomForCausalLM":
|
||||
return BloomModel
|
||||
if model_architecture == "MPTForCausalLM":
|
||||
return MPTModel
|
||||
if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
|
||||
return BaichuanModel
|
||||
if model_architecture in ("FalconForCausalLM", "RWForCausalLM"):
|
||||
return FalconModel
|
||||
if model_architecture == "GPTBigCodeForCausalLM":
|
||||
return StarCoderModel
|
||||
if model_architecture == "GPTRefactForCausalLM":
|
||||
return RefactModel
|
||||
if model_architecture == "PersimmonForCausalLM":
|
||||
return PersimmonModel
|
||||
if model_architecture in ("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
|
||||
return StableLMModel
|
||||
if model_architecture == "QWenLMHeadModel":
|
||||
return QwenModel
|
||||
if model_architecture == "Qwen2ForCausalLM":
|
||||
return Model
|
||||
if model_architecture == "MixtralForCausalLM":
|
||||
return MixtralModel
|
||||
if model_architecture == "GPT2LMHeadModel":
|
||||
return GPT2Model
|
||||
if model_architecture == "PhiForCausalLM":
|
||||
return Phi2Model
|
||||
if model_architecture == "PlamoForCausalLM":
|
||||
return PlamoModel
|
||||
if model_architecture == "CodeShellForCausalLM":
|
||||
return CodeShellModel
|
||||
if model_architecture == "OrionForCausalLM":
|
||||
return OrionModel
|
||||
if model_architecture == "InternLM2ForCausalLM":
|
||||
return InternLM2Model
|
||||
if model_architecture == "MiniCPMForCausalLM":
|
||||
return MiniCPMModel
|
||||
if model_architecture == "BertModel":
|
||||
return BertModel
|
||||
if model_architecture == "NomicBertModel":
|
||||
return NomicBertModel
|
||||
if model_architecture == "GemmaForCausalLM":
|
||||
return GemmaModel
|
||||
if model_architecture == "Starcoder2ForCausalLM":
|
||||
return Model
|
||||
return Model
|
||||
@classmethod
|
||||
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
|
||||
assert names
|
||||
def func(modelcls: type[Model]):
|
||||
for name in names:
|
||||
cls._model_classes[name] = modelcls
|
||||
return modelcls
|
||||
return func
|
||||
|
||||
@classmethod
|
||||
def from_model_architecture(cls, arch):
|
||||
try:
|
||||
return cls._model_classes[arch]
|
||||
except KeyError:
|
||||
raise NotImplementedError(f'Architecture {arch!r} not supported!') from None
|
||||
|
||||
def _is_model_safetensors(self) -> bool:
|
||||
return Model.count_model_parts(self.dir_model, ".safetensors") > 0
|
||||
@ -239,57 +213,6 @@ class Model:
|
||||
return ("pytorch_model.bin",)
|
||||
return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
|
||||
|
||||
def _get_model_architecture(self) -> gguf.MODEL_ARCH:
|
||||
arch = self.hparams["architectures"][0]
|
||||
if arch == "GPTNeoXForCausalLM":
|
||||
return gguf.MODEL_ARCH.GPTNEOX
|
||||
if arch == "BloomForCausalLM":
|
||||
return gguf.MODEL_ARCH.BLOOM
|
||||
if arch == "MPTForCausalLM":
|
||||
return gguf.MODEL_ARCH.MPT
|
||||
if arch in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
|
||||
return gguf.MODEL_ARCH.BAICHUAN
|
||||
if arch in ("FalconForCausalLM", "RWForCausalLM"):
|
||||
return gguf.MODEL_ARCH.FALCON
|
||||
if arch == "GPTBigCodeForCausalLM":
|
||||
return gguf.MODEL_ARCH.STARCODER
|
||||
if arch == "GPTRefactForCausalLM":
|
||||
return gguf.MODEL_ARCH.REFACT
|
||||
if arch == "PersimmonForCausalLM":
|
||||
return gguf.MODEL_ARCH.PERSIMMON
|
||||
if arch in ("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
|
||||
return gguf.MODEL_ARCH.STABLELM
|
||||
if arch == "QWenLMHeadModel":
|
||||
return gguf.MODEL_ARCH.QWEN
|
||||
if arch == "Qwen2ForCausalLM":
|
||||
return gguf.MODEL_ARCH.QWEN2
|
||||
if arch == "MixtralForCausalLM":
|
||||
return gguf.MODEL_ARCH.LLAMA
|
||||
if arch == "GPT2LMHeadModel":
|
||||
return gguf.MODEL_ARCH.GPT2
|
||||
if arch == "PhiForCausalLM":
|
||||
return gguf.MODEL_ARCH.PHI2
|
||||
if arch == "PlamoForCausalLM":
|
||||
return gguf.MODEL_ARCH.PLAMO
|
||||
if arch == "CodeShellForCausalLM":
|
||||
return gguf.MODEL_ARCH.CODESHELL
|
||||
if arch == "OrionForCausalLM":
|
||||
return gguf.MODEL_ARCH.ORION
|
||||
if arch == "InternLM2ForCausalLM":
|
||||
return gguf.MODEL_ARCH.INTERNLM2
|
||||
if arch == "MiniCPMForCausalLM":
|
||||
return gguf.MODEL_ARCH.MINICPM
|
||||
if arch == "BertModel":
|
||||
return gguf.MODEL_ARCH.BERT
|
||||
if arch == "NomicBertModel":
|
||||
return gguf.MODEL_ARCH.NOMIC_BERT
|
||||
if arch == "GemmaForCausalLM":
|
||||
return gguf.MODEL_ARCH.GEMMA
|
||||
if arch == "Starcoder2ForCausalLM":
|
||||
return gguf.MODEL_ARCH.STARCODER2
|
||||
|
||||
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
||||
|
||||
def _set_vocab_gpt2(self):
|
||||
dir_model = self.dir_model
|
||||
hparams = self.hparams
|
||||
@ -457,7 +380,10 @@ class Model:
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
|
||||
@Model.register("GPTNeoXForCausalLM")
|
||||
class GPTNeoXModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.GPTNEOX
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
|
||||
@ -474,7 +400,10 @@ class GPTNeoXModel(Model):
|
||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
|
||||
|
||||
|
||||
@Model.register("BloomForCausalLM")
|
||||
class BloomModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.BLOOM
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_name("Bloom")
|
||||
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
|
||||
@ -566,7 +495,10 @@ class BloomModel(Model):
|
||||
print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
|
||||
|
||||
|
||||
@Model.register("MPTForCausalLM")
|
||||
class MPTModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.MPT
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["n_layers"]
|
||||
self.gguf_writer.add_name(self.dir_model.name)
|
||||
@ -629,7 +561,10 @@ class MPTModel(Model):
|
||||
self.gguf_writer.add_tensor(new_name, data)
|
||||
|
||||
|
||||
@Model.register("OrionForCausalLM")
|
||||
class OrionModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.ORION
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_sentencepiece()
|
||||
|
||||
@ -708,7 +643,10 @@ class OrionModel(Model):
|
||||
self.gguf_writer.add_tensor(new_name, data)
|
||||
|
||||
|
||||
@Model.register("BaichuanForCausalLM", "BaiChuanForCausalLM")
|
||||
class BaichuanModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.BAICHUAN
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_sentencepiece()
|
||||
|
||||
@ -823,7 +761,10 @@ class BaichuanModel(Model):
|
||||
return weights[r * n_part:r * n_part + r, ...]
|
||||
|
||||
|
||||
@Model.register("FalconForCausalLM", "RWForCausalLM")
|
||||
class FalconModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.FALCON
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams.get("num_hidden_layers")
|
||||
if block_count is None:
|
||||
@ -916,7 +857,10 @@ class FalconModel(Model):
|
||||
self.gguf_writer.add_tensor(new_name, data)
|
||||
|
||||
|
||||
@Model.register("GPTBigCodeForCausalLM")
|
||||
class StarCoderModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.STARCODER
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["n_layer"]
|
||||
|
||||
@ -931,7 +875,10 @@ class StarCoderModel(Model):
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
|
||||
@Model.register("GPTRefactForCausalLM")
|
||||
class RefactModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.REFACT
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
hidden_dim = self.hparams["n_embd"]
|
||||
inner_dim = 4 * hidden_dim
|
||||
@ -1015,7 +962,10 @@ class RefactModel(Model):
|
||||
self.gguf_writer.add_tensor(new_name, data)
|
||||
|
||||
|
||||
@Model.register("PersimmonForCausalLM")
|
||||
class PersimmonModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.PERSIMMON
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
|
||||
head_count = self.hparams["num_attention_heads"]
|
||||
@ -1063,7 +1013,10 @@ class PersimmonModel(Model):
|
||||
self.gguf_writer.add_tensor(new_name, data)
|
||||
|
||||
|
||||
@Model.register("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM")
|
||||
class StableLMModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.STABLELM
|
||||
|
||||
def set_vocab(self):
|
||||
if (self.dir_model / "tokenizer.json").is_file():
|
||||
self._set_vocab_gpt2()
|
||||
@ -1087,12 +1040,18 @@ class StableLMModel(Model):
|
||||
self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_eps", "norm_eps"]))
|
||||
|
||||
|
||||
@Model.register("MixtralForCausalLM")
|
||||
class MixtralModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_sentencepiece()
|
||||
|
||||
|
||||
@Model.register("MiniCPMForCausalLM")
|
||||
class MiniCPMModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.MINICPM
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
self.gguf_writer.add_name("MiniCPM")
|
||||
@ -1169,7 +1128,10 @@ class MiniCPMModel(Model):
|
||||
self.gguf_writer.add_tensor(new_name, data)
|
||||
|
||||
|
||||
@Model.register("QWenLMHeadModel")
|
||||
class QwenModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN
|
||||
|
||||
@staticmethod
|
||||
def token_bytes_to_string(b):
|
||||
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
|
||||
@ -1249,7 +1211,15 @@ class QwenModel(Model):
|
||||
self.gguf_writer.add_tensor(new_name, data)
|
||||
|
||||
|
||||
@Model.register("Qwen2ForCausalLM")
|
||||
class Qwen2Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN2
|
||||
|
||||
|
||||
@Model.register("GPT2LMHeadModel")
|
||||
class GPT2Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.GPT2
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_name(self.dir_model.name)
|
||||
self.gguf_writer.add_block_count(self.hparams["n_layer"])
|
||||
@ -1311,7 +1281,10 @@ class GPT2Model(Model):
|
||||
self.gguf_writer.add_tensor("output.weight", data)
|
||||
|
||||
|
||||
@Model.register("PhiForCausalLM")
|
||||
class Phi2Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.PHI2
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.find_hparam(["num_hidden_layers", "n_layer"])
|
||||
|
||||
@ -1333,7 +1306,10 @@ class Phi2Model(Model):
|
||||
self.gguf_writer.add_add_bos_token(False)
|
||||
|
||||
|
||||
@Model.register("PlamoForCausalLM")
|
||||
class PlamoModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.PLAMO
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_sentencepiece()
|
||||
|
||||
@ -1412,7 +1388,10 @@ class PlamoModel(Model):
|
||||
self.gguf_writer.add_tensor(new_name, data)
|
||||
|
||||
|
||||
@Model.register("CodeShellForCausalLM")
|
||||
class CodeShellModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.CODESHELL
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["n_layer"]
|
||||
|
||||
@ -1477,7 +1456,10 @@ class CodeShellModel(Model):
|
||||
print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
|
||||
|
||||
|
||||
@Model.register("InternLM2ForCausalLM")
|
||||
class InternLM2Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.INTERNLM2
|
||||
|
||||
def set_vocab(self):
|
||||
# (TODO): Is there a better way?
|
||||
# Copy from _set_vocab_sentencepiece, The only difference is that we will treat the character
|
||||
@ -1649,7 +1631,10 @@ in chat mode so that the conversation can end normally.")
|
||||
self.post_write_tensors(tensor_map, name, data_torch)
|
||||
|
||||
|
||||
@Model.register("BertModel")
|
||||
class BertModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.BERT
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.vocab_size = None
|
||||
@ -1679,7 +1664,7 @@ class BertModel(Model):
|
||||
else:
|
||||
raise NotImplementedError("Only MEAN and CLS pooling types supported")
|
||||
|
||||
self.gguf_writer.add_pooling_type(pooling_type.value)
|
||||
self.gguf_writer.add_pooling_type(pooling_type)
|
||||
|
||||
def set_vocab(self):
|
||||
path = self.dir_model
|
||||
@ -1755,7 +1740,10 @@ class BertModel(Model):
|
||||
self.gguf_writer.add_tensor(new_name, data)
|
||||
|
||||
|
||||
@Model.register("NomicBertModel")
|
||||
class NomicBertModel(BertModel):
|
||||
model_arch = gguf.MODEL_ARCH.NOMIC_BERT
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@ -1792,7 +1780,10 @@ class NomicBertModel(BertModel):
|
||||
yield name, data
|
||||
|
||||
|
||||
@Model.register("GemmaForCausalLM")
|
||||
class GemmaModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.GEMMA
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_sentencepiece()
|
||||
|
||||
@ -1848,6 +1839,11 @@ class GemmaModel(Model):
|
||||
self.gguf_writer.add_tensor(new_name, data)
|
||||
|
||||
|
||||
@Model.register("Starcoder2ForCausalLM")
|
||||
class StarCoder2Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.STARCODER2
|
||||
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
|
@ -362,7 +362,7 @@ class GGUFWriter:
|
||||
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
|
||||
|
||||
def add_pooling_type(self, value: PoolingType) -> None:
|
||||
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value)
|
||||
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
|
||||
|
||||
def add_rope_dimension_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
|
||||
|
Loading…
Reference in New Issue
Block a user