convert : refactor vocab selection logic (#6355)

This commit is contained in:
Jared Van Bortel 2024-03-28 11:44:36 -04:00 committed by GitHub
parent 66ba560256
commit be55134a53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 204 additions and 176 deletions

View File

@ -23,7 +23,7 @@ if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
import gguf
from convert import HfVocab
from convert import LlamaHfVocab
###### MODEL DEFINITIONS ######
@ -230,7 +230,7 @@ class Model(ABC):
def _set_vocab_gpt2(self):
dir_model = self.dir_model
hparams = self.hparams
tokens: list[bytearray] = []
tokens: list[str] = []
toktypes: list[int] = []
from transformers import AutoTokenizer
@ -243,8 +243,7 @@ class Model(ABC):
for i in range(vocab_size):
if i not in reverse_vocab:
pad_token = f"[PAD{i}]".encode('utf-8')
tokens.append(bytearray(pad_token))
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.USER_DEFINED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
@ -266,7 +265,7 @@ class Model(ABC):
def _set_vocab_qwen(self):
dir_model = self.dir_model
hparams = self.hparams
tokens: list[bytearray] = []
tokens: list[str] = []
toktypes: list[int] = []
from transformers import AutoTokenizer
@ -291,8 +290,7 @@ class Model(ABC):
for i in range(vocab_size):
if i not in reverse_vocab:
pad_token = f"[PAD{i}]".encode("utf-8")
tokens.append(bytearray(pad_token))
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.USER_DEFINED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
@ -372,12 +370,8 @@ class Model(ABC):
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)
def _set_vocab_hf(self):
path = self.dir_model
added_tokens_path = self.dir_model
vocab = HfVocab(
path, added_tokens_path if added_tokens_path.exists() else None
)
def _set_vocab_llama_hf(self):
vocab = LlamaHfVocab(self.dir_model)
tokens = []
scores = []
toktypes = []
@ -1099,7 +1093,7 @@ class MiniCPMModel(Model):
self.gguf_writer.add_file_type(self.ftype)
def set_vocab(self):
self._set_vocab_hf()
self._set_vocab_llama_hf()
def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
if n_kv_head is not None and n_head != n_kv_head:
@ -1700,11 +1694,8 @@ class BertModel(Model):
self.gguf_writer.add_pooling_type(pooling_type)
def set_vocab(self):
path = self.dir_model
added_tokens_path = self.dir_model if self.dir_model.exists() else None
# use huggingface vocab to get all tokens
vocab = HfVocab(path, added_tokens_path)
vocab = LlamaHfVocab(self.dir_model, ignore_nonllama=True)
tokens, scores, toktypes = zip(*vocab.all_tokens())
assert len(tokens) == vocab.vocab_size
self.vocab_size = vocab.vocab_size

View File

@ -106,12 +106,12 @@ def main():
tensor_map = gguf.get_tensor_name_map(arch, block_count)
print(tensor_map)
for name in tensors.keys():
data = tensors[name]
data_torch = tensors[name]
if name.endswith(".self_attention.rotary_emb.inv_freq"):
continue
old_dtype = data.dtype
old_dtype = data_torch.dtype
# TODO: FP16 conversion produces garbage outputs. (Q8_0 does not, so..?)
data = data.to(torch.float32).squeeze().numpy()
data = data_torch.to(torch.float32).squeeze().numpy()
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
if new_name is None:
print("Can not map tensor '" + name + "'")

View File

@ -16,13 +16,14 @@ import re
import signal
import struct
import sys
import textwrap
import time
import zipfile
from abc import ABCMeta, abstractmethod
from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from dataclasses import dataclass
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, TypeVar
from typing import TYPE_CHECKING, Any, Callable, ClassVar, IO, Iterable, Literal, Protocol, TypeVar, runtime_checkable
import numpy as np
from sentencepiece import SentencePieceProcessor
@ -43,6 +44,9 @@ ARCH = gguf.MODEL_ARCH.LLAMA
DEFAULT_CONCURRENCY = 8
ADDED_TOKENS_FILE = 'added_tokens.json'
FAST_TOKENIZER_FILE = 'tokenizer.json'
#
# data types
#
@ -188,8 +192,10 @@ class Params:
n_layer = next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model)
if n_layer < 1:
raise Exception("failed to guess 'n_layer'. This model is unknown or unsupported.\n"
"Suggestion: provide 'config.json' of the model in the same directory containing model files.")
msg = """\
failed to guess 'n_layer'. This model is unknown or unsupported.
Suggestion: provide 'config.json' of the model in the same directory containing model files."""
raise KeyError(textwrap.dedent(msg))
n_head = n_embd // 128 # guessed
n_mult = 256 # guessed
@ -211,7 +217,8 @@ class Params:
@staticmethod
def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
config = json.load(open(config_path))
with open(config_path) as f:
config = json.load(f)
rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None
rope_scaling = config.get("rope_scaling")
@ -233,8 +240,10 @@ class Params:
elif "max_position_embeddings" in config:
n_ctx = config["max_position_embeddings"]
else:
raise Exception("failed to guess 'n_ctx'. This model is unknown or unsupported.\n"
"Suggestion: provide 'config.json' of the model in the same directory containing model files.")
msg = """\
failed to guess 'n_ctx'. This model is unknown or unsupported.
Suggestion: provide 'config.json' of the model in the same directory containing model files."""
raise KeyError(textwrap.dedent(msg))
n_experts = None
n_experts_used = None
@ -265,7 +274,8 @@ class Params:
# {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1}
@staticmethod
def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
config = json.load(open(config_path))
with open(config_path) as f:
config = json.load(f)
n_experts = None
n_experts_used = None
@ -331,47 +341,86 @@ class Params:
# vocab
#
class BpeVocab:
@runtime_checkable
class BaseVocab(Protocol):
tokenizer_model: ClassVar[str]
name: ClassVar[str]
class NoVocab(BaseVocab):
tokenizer_model = "no_vocab"
name = "no_vocab"
def __repr__(self) -> str:
return "<NoVocab for a model without integrated vocabulary>"
@runtime_checkable
class Vocab(BaseVocab, Protocol):
vocab_size: int
added_tokens_dict: dict[str, int]
added_tokens_list: list[str]
fname_tokenizer: Path
def __init__(self, base_path: Path): ...
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: ...
class BpeVocab(Vocab):
tokenizer_model = "gpt2"
name = "bpe"
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None:
self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read())
if isinstance(self.bpe_tokenizer.get('model'), dict):
self.vocab = self.bpe_tokenizer["model"]["vocab"]
else:
self.vocab = self.bpe_tokenizer
added_tokens: dict[str, int]
if fname_added_tokens is not None:
# FIXME: Verify that added tokens here _cannot_ overlap with the main vocab.
added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
else:
# Fall back to trying to find the added tokens in tokenizer.json
tokenizer_json_file = fname_tokenizer.parent / 'tokenizer.json'
if not tokenizer_json_file.is_file():
added_tokens = {}
else:
tokenizer_json = json.load(open(tokenizer_json_file, encoding="utf-8"))
added_tokens = dict(
(item['content'], item['id'])
for item in tokenizer_json.get('added_tokens', [])
# Added tokens here can be duplicates of the main vocabulary.
if item['content'] not in self.bpe_tokenizer)
def __init__(self, base_path: Path):
added_tokens: dict[str, int] = {}
vocab_size: int = len(self.vocab)
expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
actual_ids = sorted(added_tokens.values())
if (fname_tokenizer := base_path / 'vocab.json').exists():
# "slow" tokenizer
with open(fname_tokenizer, encoding="utf-8") as f:
self.vocab = json.load(f)
try:
# FIXME: Verify that added tokens here _cannot_ overlap with the main vocab.
with open(base_path / ADDED_TOKENS_FILE, encoding="utf-8") as f:
added_tokens = json.load(f)
except FileNotFoundError:
pass
else:
# "fast" tokenizer
fname_tokenizer = base_path / FAST_TOKENIZER_FILE
# if this fails, FileNotFoundError propagates to caller
with open(fname_tokenizer, encoding="utf-8") as f:
tokenizer_json = json.load(f)
tokenizer_model: dict[str, Any] = tokenizer_json['model']
if (
tokenizer_model['type'] != 'BPE' or tokenizer_model.get('byte_fallback', False)
or tokenizer_json['decoder']['type'] != 'ByteLevel'
):
raise FileNotFoundError('Cannot find GPT-2 BPE tokenizer')
self.vocab = tokenizer_model["vocab"]
if (added := tokenizer_json.get('added_tokens')) is not None:
# Added tokens here can be duplicates of the main vocabulary.
added_tokens = {item['content']: item['id']
for item in added
if item['content'] not in self.vocab}
vocab_size = len(self.vocab)
expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
actual_ids = sorted(added_tokens.values())
if expected_ids != actual_ids:
expected_end_id = vocab_size + len(actual_ids) - 1
raise Exception(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range {vocab_size} - {expected_end_id}; got {actual_ids}")
raise ValueError(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range "
f"{vocab_size} - {expected_end_id}; got {actual_ids}")
items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
self.added_tokens_dict = added_tokens
self.added_tokens_list = [text for (text, idx) in items]
self.vocab_size_base: int = vocab_size
self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list)
self.vocab_size_base = vocab_size
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
self.fname_tokenizer = fname_tokenizer
self.fname_added_tokens = fname_added_tokens
def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
reverse_vocab = {id: encoded_tok for encoded_tok, id in self.vocab.items()}
@ -392,19 +441,25 @@ class BpeVocab:
return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
class SentencePieceVocab:
class SentencePieceVocab(Vocab):
tokenizer_model = "llama"
name = "spm"
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None:
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
added_tokens: dict[str, int]
if fname_added_tokens is not None:
added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
else:
added_tokens = {}
def __init__(self, base_path: Path):
added_tokens: dict[str, int] = {}
if (fname_tokenizer := base_path / 'tokenizer.model').exists():
# normal location
try:
with open(base_path / ADDED_TOKENS_FILE, encoding="utf-8") as f:
added_tokens = json.load(f)
except FileNotFoundError:
pass
elif not (fname_tokenizer := base_path.parent / 'tokenizer.model').exists():
# not found in alternate location either
raise FileNotFoundError('Cannot find tokenizer.model')
vocab_size: int = self.sentencepiece_tokenizer.vocab_size()
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
vocab_size = self.sentencepiece_tokenizer.vocab_size()
new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens)))
@ -414,18 +469,17 @@ class SentencePieceVocab:
raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
# Token pieces that were added to the base vocabulary.
self.added_tokens_dict = added_tokens
self.added_tokens_dict = added_tokens
self.added_tokens_list = [new_tokens[id] for id in actual_new_ids]
self.vocab_size_base = vocab_size
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
self.fname_tokenizer = fname_tokenizer
self.fname_added_tokens = fname_added_tokens
def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
tokenizer = self.sentencepiece_tokenizer
for i in range(tokenizer.vocab_size()):
piece = tokenizer.id_to_piece(i)
text: bytes = piece.encode("utf-8")
text = piece.encode("utf-8")
score: float = tokenizer.get_score(i)
toktype = gguf.TokenType.NORMAL
@ -458,27 +512,42 @@ class SentencePieceVocab:
return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
class HfVocab:
class LlamaHfVocab(Vocab):
tokenizer_model = "llama"
name = "hfft"
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None = None) -> None:
def __init__(self, base_path: Path, ignore_nonllama: bool = False):
fname_tokenizer = base_path / FAST_TOKENIZER_FILE
# if this fails, FileNotFoundError propagates to caller
with open(fname_tokenizer, encoding='utf-8') as f:
tokenizer_json = json.load(f)
# pre-check so we know if we need transformers
tokenizer_model: dict[str, Any] = tokenizer_json['model']
if ignore_nonllama:
pass # workaround incorrect use of this class for WordPiece
elif (
tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False)
or tokenizer_json['decoder']['type'] != 'Sequence'
):
raise FileNotFoundError('Cannot find Llama BPE tokenizer')
try:
from transformers import AutoTokenizer
except ImportError as e:
raise ImportError(
"To use HfVocab, please install the `transformers` package. "
"To use LlamaHfVocab, please install the `transformers` package. "
"You can install it with `pip install transformers`."
) from e
print("fname_tokenizer:", fname_tokenizer)
# Allow the tokenizer to default to slow or fast versions.
# Explicitly set tokenizer to use local paths.
self.tokenizer = AutoTokenizer.from_pretrained(
fname_tokenizer,
cache_dir=fname_tokenizer,
base_path,
cache_dir=base_path,
local_files_only=True,
)
assert self.tokenizer.is_fast # assume tokenizer.json is used
# Initialize lists and dictionaries for added tokens
self.added_tokens_list = []
@ -506,8 +575,7 @@ class HfVocab:
self.vocab_size_base = self.tokenizer.vocab_size
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
self.fname_tokenizer = fname_tokenizer
self.fname_added_tokens = fname_added_tokens
self.fname_tokenizer = fname_tokenizer
def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
reverse_vocab = {
@ -559,18 +627,7 @@ class HfVocab:
yield from self.added_tokens()
def __repr__(self) -> str:
return f"<HfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
class NoVocab:
tokenizer_model = "no_vocab"
name = "no_vocab"
def __repr__(self) -> str:
return "<NoVocab for a model without integrated vocabulary>"
Vocab: TypeAlias = "BpeVocab | SentencePieceVocab | HfVocab | NoVocab"
return f"<LlamaHfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
#
@ -588,7 +645,7 @@ def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:
.reshape(weights.shape))
class Tensor(metaclass=ABCMeta):
class Tensor(ABC):
data_type: DataType
@abstractmethod
@ -610,7 +667,7 @@ def bf16_to_fp32(bf16_arr: np.ndarray[Any, np.dtype[np.uint16]]) -> NDArray:
class UnquantizedTensor(Tensor):
def __init__(self, ndarray: NDArray) -> None:
def __init__(self, ndarray: NDArray):
assert isinstance(ndarray, np.ndarray)
self.ndarray = ndarray
self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype]
@ -689,7 +746,7 @@ class ModelPlus:
model: LazyModel
paths: list[Path] # Where this was read from.
format: Literal['ggml', 'torch', 'safetensors', 'none']
vocab: Vocab | None # For GGML models (which have vocab built in), the vocab.
vocab: BaseVocab | None # For GGML models (which have vocab built in), the vocab.
def merge_sharded(models: list[LazyModel]) -> LazyModel:
@ -698,7 +755,7 @@ def merge_sharded(models: list[LazyModel]) -> LazyModel:
names = {name: None for model in models for name in model}
def convert(name: str) -> LazyTensor:
lazy_tensors: list[LazyTensor] = [model[name] for model in models]
lazy_tensors = [model[name] for model in models]
if len(lazy_tensors) == 1:
# only one file; don't go through this procedure since there might
# be quantized tensors
@ -719,7 +776,7 @@ def merge_sharded(models: list[LazyModel]) -> LazyModel:
def load() -> UnquantizedTensor:
ndarrays = [load_unquantized(tensor) for tensor in lazy_tensors]
concatenated: NDArray = np.concatenate(ndarrays, axis=axis)
concatenated = np.concatenate(ndarrays, axis=axis)
return UnquantizedTensor(concatenated)
description = 'concatenated[[' + '] | ['.join(lt.description for lt in lazy_tensors) + ']]'
return LazyTensor(load, concatenated_shape, lazy_tensors[0].data_type, description)
@ -807,10 +864,10 @@ class LazyUnpickler(pickle.Unpickler):
def load(offset: int, elm_count: int) -> NDArray:
dtype = data_type.dtype
fp = self.zip_file.open(info)
fp.seek(offset * dtype.itemsize)
size = elm_count * dtype.itemsize
data = fp.read(size)
with self.zip_file.open(info) as fp:
fp.seek(offset * dtype.itemsize)
size = elm_count * dtype.itemsize
data = fp.read(size)
assert len(data) == size
return np.frombuffer(data, dtype)
description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}'
@ -831,7 +888,7 @@ class LazyUnpickler(pickle.Unpickler):
def rebuild_from_type_v2(func, new_type, args, state):
return func(*args)
CLASSES: dict[tuple[str, str], Any] = {
CLASSES = {
# getattr used here as a workaround for mypy not being smart enough to determine
# the staticmethods have a __func__ attribute.
('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),
@ -890,7 +947,7 @@ def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
def must_read(fp: IO[bytes], length: int) -> bytes:
ret = fp.read(length)
if len(ret) < length:
raise Exception("unexpectedly reached end of file")
raise EOFError("unexpectedly reached end of file")
return ret
@ -948,13 +1005,14 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
yield result
def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None:
def check_vocab_size(params: Params, vocab: BaseVocab, pad_vocab: bool = False) -> None:
# Handle special case where the model's vocab size is not set
if params.n_vocab == -1:
raise ValueError(
f"The model's vocab size is set to -1 in params.json. Please update it manually.{f' Maybe {vocab.vocab_size}?' if hasattr(vocab, 'vocab_size') else ''}"
"The model's vocab size is set to -1 in params.json. Please update it manually."
+ (f" Maybe {vocab.vocab_size}?" if isinstance(vocab, Vocab) else ""),
)
if isinstance(vocab, NoVocab):
if not isinstance(vocab, Vocab):
return # model has no vocab
# Check for a vocab size mismatch
@ -979,11 +1037,11 @@ def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> N
if vocab.vocab_size < params.n_vocab:
msg += " Add the --pad-vocab option and try again."
raise Exception(msg)
raise ValueError(msg)
class OutputFile:
def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE) -> None:
def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
def add_meta_arch(self, params: Params) -> None:
@ -1034,8 +1092,6 @@ class OutputFile:
self.gguf.add_file_type(params.ftype)
def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list[float], list[gguf.TokenType]]:
assert not isinstance(vocab, NoVocab)
tokens = []
scores = []
toktypes = []
@ -1135,7 +1191,7 @@ class OutputFile:
@staticmethod
def write_all(
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab,
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
pad_vocab: bool = False,
) -> None:
@ -1145,11 +1201,11 @@ class OutputFile:
# meta data
of.add_meta_arch(params)
if isinstance(vocab, NoVocab):
of.gguf.add_tokenizer_model(vocab.tokenizer_model)
else:
if isinstance(vocab, Vocab):
of.add_meta_vocab(vocab)
of.add_meta_special_vocab(svocab)
else: # NoVocab
of.gguf.add_tokenizer_model(vocab.tokenizer_model)
# tensor info
for name, lazy_tensor in model.items():
@ -1176,7 +1232,7 @@ def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileT
name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()}
raise Exception(f"Unexpected combination of types: {name_to_type}")
raise ValueError(f"Unexpected combination of types: {name_to_type}")
def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel:
@ -1186,7 +1242,7 @@ def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyM
def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) -> LazyModel:
tmap = gguf.TensorNameMap(ARCH, params.n_layer)
should_skip: set[gguf.MODEL_TENSOR] = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, []))
should_skip = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, []))
tmp = model
@ -1213,8 +1269,7 @@ def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) ->
if skip_unknown:
print(f"Unexpected tensor name: {name} - skipping")
continue
else:
raise Exception(f"Unexpected tensor name: {name}. Use --skip-unknown to ignore it (e.g. LLaVA)")
raise ValueError(f"Unexpected tensor name: {name}. Use --skip-unknown to ignore it (e.g. LLaVA)")
if tensor_type in should_skip:
print(f"skipping tensor {name_new}")
@ -1231,7 +1286,7 @@ def nth_multifile_path(path: Path, n: int) -> Path | None:
the nth path in the model.
'''
# Support the following patterns:
patterns: list[tuple[str, str]] = [
patterns = [
# - x.00.pth, x.01.pth, etc.
(r'\.[0-9]{2}\.pth$', f'.{n:02}.pth'),
# - x-00001-of-00002.bin, x-00002-of-00002.bin, etc.
@ -1277,9 +1332,9 @@ def load_some_model(path: Path) -> ModelPlus:
globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"]
files = [file for glob in globs for file in path.glob(glob)]
if not files:
raise Exception(f"Can't find model in directory {path}")
raise FileNotFoundError(f"Can't find model in directory {path}")
if len(files) > 1:
raise Exception(f"Found multiple models in {path}, not sure which to pick: {files}")
raise ValueError(f"Found multiple models in {path}, not sure which to pick: {files}")
path = files[0]
paths = find_multifile_paths(path)
@ -1293,36 +1348,14 @@ def load_some_model(path: Path) -> ModelPlus:
class VocabFactory:
_FILES = {"spm": "tokenizer.model", "bpe": "vocab.json", "hfft": "tokenizer.json"}
_VOCAB_CLASSES: list[type[Vocab]] = [SentencePieceVocab, BpeVocab, LlamaHfVocab]
def __init__(self, path: Path):
self.path = path
self.file_paths = self._detect_files()
print(f"Found vocab files: {self.file_paths}")
def _detect_files(self) -> dict[str, Path | None]:
def locate(file: str) -> Path | None:
if (path := self.path / file).exists():
return path
if (path := self.path.parent / file).exists():
return path
return None
return {vt: locate(f) for vt, f in self._FILES.items()}
def _select_file(self, vocab_types: list[str]) -> tuple[str, Path]:
for vtype in vocab_types:
try:
path = self.file_paths[vtype]
except KeyError:
raise ValueError(f"Unsupported vocabulary type {vtype}") from None
if path is not None:
return vtype, path
raise FileNotFoundError(f"Could not find any of {[self._FILES[vt] for vt in vocab_types]}")
def _create_special_vocab(self, vocab: Vocab, model_parent_path: Path) -> gguf.SpecialVocab:
def _create_special_vocab(self, vocab: BaseVocab, model_parent_path: Path) -> gguf.SpecialVocab:
load_merges = vocab.name == "bpe"
n_vocab = vocab.vocab_size if hasattr(vocab, "vocab_size") else None
n_vocab = vocab.vocab_size if isinstance(vocab, Vocab) else None
return gguf.SpecialVocab(
model_parent_path,
load_merges=load_merges,
@ -1331,27 +1364,29 @@ class VocabFactory:
)
def _create_vocab_by_path(self, vocab_types: list[str]) -> Vocab:
vocab_type, path = self._select_file(vocab_types)
print(f"Loading vocab file {path!r}, type {vocab_type!r}")
vocab_classes: dict[str, type[Vocab]] = {cls.name: cls for cls in self._VOCAB_CLASSES}
selected_vocabs: dict[str, type[Vocab]] = {}
for vtype in vocab_types:
try:
selected_vocabs[vtype] = vocab_classes[vtype]
except KeyError:
raise ValueError(f"Unsupported vocabulary type {vtype}") from None
added_tokens_path = path.parent / "added_tokens.json"
if vocab_type == "bpe":
return BpeVocab(
path, added_tokens_path if added_tokens_path.exists() else None
)
if vocab_type == "spm":
return SentencePieceVocab(
path, added_tokens_path if added_tokens_path.exists() else None
)
if vocab_type == "hfft":
return HfVocab(
path.parent, added_tokens_path if added_tokens_path.exists() else None
)
raise ValueError(vocab_type)
for vtype, cls in selected_vocabs.items():
try:
vocab = cls(self.path)
break
except FileNotFoundError:
pass # ignore unavailable tokenizers
else:
raise FileNotFoundError(f"Could not find a tokenizer matching any of {vocab_types}")
def load_vocab(self, vocab_types: list[str], model_parent_path: Path) -> tuple[Vocab, gguf.SpecialVocab]:
vocab: Vocab
if len(vocab_types) == 1 and "no_vocab" in vocab_types:
print(f"Loaded vocab file {vocab.fname_tokenizer!r}, type {vocab.name!r}")
return vocab
def load_vocab(self, vocab_types: list[str] | None, model_parent_path: Path) -> tuple[BaseVocab, gguf.SpecialVocab]:
vocab: BaseVocab
if vocab_types is None:
vocab = NoVocab()
else:
vocab = self._create_vocab_by_path(vocab_types)
@ -1408,10 +1443,8 @@ def main(args_in: list[str] | None = None) -> None:
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
args = parser.parse_args(args_in)
if args.no_vocab:
if args.vocab_only:
raise ValueError("no need to specify --vocab-only if using --no-vocab")
args.vocab_type = "no_vocab"
if args.no_vocab and args.vocab_only:
raise ValueError("--vocab-only does not make sense with --no-vocab")
if args.dump_single:
model_plus = lazy_load_file(args.model)
@ -1433,10 +1466,12 @@ def main(args_in: list[str] | None = None) -> None:
params = Params.load(model_plus)
if params.n_ctx == -1:
if args.ctx is None:
raise Exception("The model doesn't have a context size, and you didn't specify one with --ctx\n"
"Please specify one with --ctx:\n"
" - LLaMA v1: --ctx 2048\n"
" - LLaMA v2: --ctx 4096\n")
msg = """\
The model doesn't have a context size, and you didn't specify one with --ctx
Please specify one with --ctx:
- LLaMA v1: --ctx 2048
- LLaMA v2: --ctx 4096"""
parser.error(textwrap.dedent(msg))
params.n_ctx = args.ctx
if args.outtype:
@ -1451,9 +1486,11 @@ def main(args_in: list[str] | None = None) -> None:
model_parent_path = model_plus.paths[0].parent
vocab_path = Path(args.vocab_dir or args.model or model_parent_path)
vocab_factory = VocabFactory(vocab_path)
vocab, special_vocab = vocab_factory.load_vocab(args.vocab_type.split(","), model_parent_path)
vocab_types = None if args.no_vocab else args.vocab_type.split(",")
vocab, special_vocab = vocab_factory.load_vocab(vocab_types, model_parent_path)
if args.vocab_only:
assert isinstance(vocab, Vocab)
if not args.outfile:
raise ValueError("need --outfile if using --vocab-only")
outfile = args.outfile

View File

@ -60,9 +60,9 @@ extern "C" {
enum llama_vocab_type {
LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab
LLAMA_VOCAB_TYPE_SPM = 1, // SentencePiece
LLAMA_VOCAB_TYPE_BPE = 2, // Byte Pair Encoding
LLAMA_VOCAB_TYPE_WPM = 3, // WordPiece
LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
};
// note: these values should be synchronized with ggml_rope