convert-new.py : output gguf (#2635)

* convert-new.py : output gguf (WIP)

* convert-new.py : add gguf key-value pairs

* llama : add hparams.ctx_train + no longer print ftype

* convert-new.py : minor fixes

* convert-new.py : vocab-only option should work now

* llama : fix tokenizer to use llama_char_to_byte

* tests : add new ggml-vocab-llama.gguf

* convert-new.py : tensor name mapping

* convert-new.py : add map for skipping tensor serialization

* convert-new.py : convert script now works

* gguf.py : pick some of the refactoring from #2644

* convert-new.py : minor fixes
This commit is contained in:
Georgi Gerganov 2023-08-17 17:19:52 +03:00 committed by GitHub
parent d6fd53afd6
commit e0429d38e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 526 additions and 327 deletions

View File

@ -298,7 +298,7 @@ for part_name in part_names:
print( name + ", shape " + str(len(data.shape)) + ", " + str(old_dtype) + " --> " + str(data.dtype)) print( name + ", shape " + str(len(data.shape)) + ", " + str(old_dtype) + " --> " + str(data.dtype))
gguf_writer.write_tensor_to_file(data) gguf_writer.write_tensor_data(data)
gguf_writer.close() gguf_writer.close()

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
import gguf
import argparse import argparse
import concurrent.futures import concurrent.futures
import copy import copy
@ -33,6 +34,13 @@ if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'):
NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]' NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
ARCH=gguf.MODEL_ARCH.LLAMA
NAMES=gguf.MODEL_TENSOR_NAMES[ARCH]
#
# data types
#
@dataclass(frozen=True) @dataclass(frozen=True)
class UnquantizedDataType: class UnquantizedDataType:
name: str name: str
@ -44,14 +52,6 @@ DT_BF16 = UnquantizedDataType('BF16')
DataType = Union[UnquantizedDataType] DataType = Union[UnquantizedDataType]
DATA_TYPE_TO_FTYPE: Dict[DataType, int] = {
DT_F32: 0,
DT_F16: 1,
}
FTYPE_TO_DATA_TYPE: Dict[int, DataType] = \
{ftype: dtype for (dtype, ftype) in DATA_TYPE_TO_FTYPE.items()}
DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = { DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = {
DT_BF16: np.dtype(np.uint16), DT_BF16: np.dtype(np.uint16),
DT_F16: np.dtype(np.float16), DT_F16: np.dtype(np.float16),
@ -62,6 +62,13 @@ DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = {
NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = \ NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = \
{dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()} {dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()}
SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
'BF16': DT_BF16,
'F16': DT_F16,
'F32': DT_F32,
'I32': DT_I32,
}
class GGMLFileType(enum.Enum): class GGMLFileType(enum.Enum):
AllF32 = 0 AllF32 = 0
MostlyF16 = 1 # except 1d tensors MostlyF16 = 1 # except 1d tensors
@ -77,48 +84,31 @@ class GGMLFileType(enum.Enum):
else: else:
raise ValueError(self) raise ValueError(self)
# TODO: this is LLaMA specific
def make_tensors_list() -> List[str]:
ret = [
'tok_embeddings.weight',
'norm.weight',
'output.weight',
]
for i in range(80): # maximum number of layer
ret += [
f'layers.{i}.attention.wq.weight',
f'layers.{i}.attention.wk.weight',
f'layers.{i}.attention.wv.weight',
f'layers.{i}.attention.wo.weight',
f'layers.{i}.attention_norm.weight',
f'layers.{i}.feed_forward.w1.weight',
f'layers.{i}.feed_forward.w2.weight',
f'layers.{i}.feed_forward.w3.weight',
f'layers.{i}.ffn_norm.weight',
]
return ret
# TODO: this should be generalized for non-LLaMA models
TENSORS_LIST = make_tensors_list()
TENSORS_SET = set(TENSORS_LIST)
def find_n_mult(n_ff: int, n_embd: int) -> int:
# hardcoded magic range
for n_mult in range(8192, 1, -1):
calc_ff = (((8*n_embd) // 3 + n_mult - 1) // n_mult)*n_mult
if calc_ff == n_ff:
return n_mult
raise Exception(f"failed to find n_mult for (n_ff={n_ff}, n_embd={n_embd}).")
#
# hparams loading
#
@dataclass @dataclass
class Params: class Params:
n_vocab: int n_vocab: int
n_embd: int n_embd: int
n_mult: int n_mult: int
n_head: int n_layer: int
n_layer: int n_ctx: int
n_kv_head: Optional[int] # This parameter is only used for Llama 2 n_ff: int
n_head: int
n_head_kv: int
f_norm_eps: float
@staticmethod
def find_n_mult(n_ff: int, n_embd: int) -> int:
# hardcoded magic range
for n_mult in range(8192, 1, -1):
calc_ff = (((8*n_embd) // 3 + n_mult - 1) // n_mult)*n_mult
if calc_ff == n_ff:
return n_mult
raise Exception(f"failed to find n_mult for (n_ff={n_ff}, n_embd={n_embd}).")
@staticmethod @staticmethod
def guessed(model: 'LazyModel') -> 'Params': def guessed(model: 'LazyModel') -> 'Params':
@ -137,37 +127,57 @@ class Params:
raise Exception("failed to guess 'n_layer'. This model is unknown or unsupported.\n" raise Exception("failed to guess 'n_layer'. This model is unknown or unsupported.\n"
"Suggestion: provide 'config.json' of the model in the same directory containing model files.") "Suggestion: provide 'config.json' of the model in the same directory containing model files.")
n_head=n_embd // 128 # guessed n_head = n_embd // 128 # guessed
n_mult = 256 # guessed
# TODO: verify this
n_ff = int(2 * (4 * n_embd) / 3)
n_ff = n_mult * ((n_ff + n_mult - 1) // n_mult)
return Params( return Params(
n_vocab = n_vocab, n_vocab = n_vocab,
n_embd = n_embd, n_embd = n_embd,
n_mult = 256, n_mult = n_mult,
n_head = n_head, n_layer = n_layer,
n_layer = n_layer, n_ctx = -1,
n_kv_head = None, n_ff = n_ff,
n_head = n_head,
n_head_kv = n_head,
f_norm_eps = 1e-5,
) )
@staticmethod @staticmethod
def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params': def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
config = json.load(open(config_path)) config = json.load(open(config_path))
n_vocab = config["vocab_size"]; n_vocab = config["vocab_size"];
n_embd = config["hidden_size"]; n_embd = config["hidden_size"];
n_head = config["num_attention_heads"]; n_layer = config["num_hidden_layers"];
n_layer = config["num_hidden_layers"]; n_ff = config["intermediate_size"];
n_ff = config["intermediate_size"]; n_head = config["num_attention_heads"];
n_kv_head = config.get("num_key_value_heads") n_head_kv = config["num_key_value_heads"];
f_norm_eps = config["rms_norm_eps"];
n_mult = find_n_mult(n_ff, n_embd); n_mult = Params.find_n_mult(n_ff, n_embd);
if "max_sequence_length" in config:
n_ctx = config["max_sequence_length"]
elif "max_position_embeddings" in config:
n_ctx = config["max_position_embeddings"]
else:
raise Exception("failed to guess 'n_ctx'. This model is unknown or unsupported.\n"
"Suggestion: provide 'config.json' of the model in the same directory containing model files.")
return Params( return Params(
n_vocab = n_vocab, n_vocab = n_vocab,
n_embd = n_embd, n_embd = n_embd,
n_mult = n_mult, n_mult = n_mult,
n_head = n_head, n_layer = n_layer,
n_layer = n_layer, n_ctx = n_ctx,
n_kv_head = n_kv_head, n_ff = n_ff,
n_head = n_head,
n_head_kv = n_head_kv,
f_norm_eps = f_norm_eps,
) )
# LLaMA v2 70B params.json # LLaMA v2 70B params.json
@ -176,22 +186,32 @@ class Params:
def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params': def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
config = json.load(open(config_path)) config = json.load(open(config_path))
n_vocab = config["vocab_size"]; n_vocab = config["vocab_size"];
n_embd = config["dim"]; n_embd = config["dim"];
n_head = config["n_heads"]; n_layer = config["n_layers"];
n_layer = config["n_layers"]; n_mult = config["multiple_of"];
n_mult = config["multiple_of"]; n_ctx = 2048 if config["norm_eps"] == 1e-06 else 4096 # hack to determine LLaMA v1 vs v2
n_ff = -1;
n_head = config["n_heads"];
n_head_kv = config["n_kv_heads"] if "n_kv_heads" in config else n_head;
f_norm_eps = config["norm_eps"];
if n_vocab == -1: if n_vocab == -1:
n_vocab = model["tok_embeddings.weight"].shape[0] n_vocab = model["tok_embeddings.weight"].shape[0]
if n_ff == -1:
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]
return Params( return Params(
n_vocab = n_vocab, n_vocab = n_vocab,
n_embd = n_embd, n_embd = n_embd,
n_mult = n_mult, n_mult = n_mult,
n_head = n_head, n_layer = n_layer,
n_layer = n_layer, n_ctx = n_ctx,
n_kv_head = None, n_ff = n_ff,
n_head = n_head,
n_head_kv = n_head_kv,
f_norm_eps = f_norm_eps,
) )
@staticmethod @staticmethod
@ -206,10 +226,13 @@ class Params:
else: else:
params = Params.guessed(model_plus.model) params = Params.guessed(model_plus.model)
print(f'params: n_vocab:{params.n_vocab} n_embd:{params.n_embd} n_mult:{params.n_mult} n_head:{params.n_head} n_layer:{params.n_layer}')
return params return params
#
# vocab
#
class BpeVocab: class BpeVocab:
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None: def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read()) self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read())
@ -294,13 +317,17 @@ class SentencePieceVocab:
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>" return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
Vocab = Union[BpeVocab, SentencePieceVocab] Vocab = Union[BpeVocab, SentencePieceVocab]
def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray: #
if n_kv_head is not None and n_head != n_kv_head: # data loading
n_head //= n_kv_head # TODO: reuse (probably move to gguf.py?)
#
def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:
if n_head_kv is not None and n_head != n_head_kv:
n_head //= n_head_kv
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2) .swapaxes(1, 2)
.reshape(weights.shape)) .reshape(weights.shape))
@ -312,7 +339,7 @@ class Tensor(metaclass=ABCMeta):
@abstractmethod @abstractmethod
def astype(self, data_type: DataType) -> 'Tensor': ... def astype(self, data_type: DataType) -> 'Tensor': ...
@abstractmethod @abstractmethod
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'Tensor': ... def permute(self, n_head: int, n_head_kv: int) -> 'Tensor': ...
@abstractmethod @abstractmethod
def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': ... def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': ...
@abstractmethod @abstractmethod
@ -350,8 +377,8 @@ class UnquantizedTensor(Tensor):
r = self.ndarray.shape[0] // 3 r = self.ndarray.shape[0] // 3
return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...]) return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...])
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'UnquantizedTensor': def permute(self, n_head: int, n_head_kv: int) -> 'UnquantizedTensor':
return UnquantizedTensor(permute(self.ndarray, n_head, n_kv_head)) return UnquantizedTensor(permute(self.ndarray, n_head, n_head_kv))
def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, convert: bool = False) -> NDArray: def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, convert: bool = False) -> NDArray:
@ -374,18 +401,18 @@ GGMLCompatibleTensor = Union[UnquantizedTensor]
class DeferredPermutedTensor(Tensor): class DeferredPermutedTensor(Tensor):
def __init__(self, base: Tensor, n_head: int, n_kv_head: Optional[int] = None) -> None: def __init__(self, base: Tensor, n_head: int, n_head_kv: int) -> None:
self.base = base self.base = base
self.n_head = n_head self.n_head = n_head
self.data_type = self.base.data_type self.data_type = self.base.data_type
def astype(self, data_type: DataType) -> Tensor: def astype(self, data_type: DataType) -> Tensor:
return self.base.astype(data_type).permute(self.n_head, self.n_kv_head) return self.base.astype(data_type).permute(self.n_head, self.n_head_kv)
def to_ggml(self) -> GGMLCompatibleTensor: def to_ggml(self) -> GGMLCompatibleTensor:
return self.base.to_ggml().permute(self.n_head, self.n_kv_head) return self.base.to_ggml().permute(self.n_head, self.n_head_kv)
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> Tensor: def permute(self, n_head: int, n_head_kv: int) -> Tensor:
raise Exception("shouldn't permute twice") raise Exception("shouldn't permute twice")
@ -481,10 +508,10 @@ def merge_multifile_models(models_plus: List[ModelPlus]) -> ModelPlus:
return ModelPlus(model, paths, format, vocab) return ModelPlus(model, paths, format, vocab)
def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_kv_head: Optional[int] = None) -> LazyTensor: def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTensor:
def load() -> Tensor: def load() -> Tensor:
return lazy_tensor.load().permute(n_head, n_kv_head) return lazy_tensor.load().permute(n_head, n_head_kv)
return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_kv_head}) ' + lazy_tensor.description) return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description)
def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int) -> LazyTensor: def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int) -> LazyTensor:
def load() -> Tensor: def load() -> Tensor:
@ -500,34 +527,6 @@ def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor:
s[0] = s[0] // 3 s[0] = s[0] // 3
return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description) return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description)
def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel:
out: LazyModel = {}
out["tok_embeddings.weight"] = model["model.embed_tokens.weight"]
out["norm.weight"] = model["model.norm.weight"]
out["output.weight"] = model["lm_head.weight"]
for i in itertools.count():
if f"model.layers.{i}.self_attn.q_proj.weight" in model:
out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head)
out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_kv_head)
out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
elif f"model.layers.{i}.self_attn.W_pack.weight" in model:
out[f"layers.{i}.attention.wq.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head)
out[f"layers.{i}.attention.wk.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head)
out[f"layers.{i}.attention.wv.weight"] = part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 2)
else:
break
out[f"layers.{i}.attention.wo.weight"] = model[f"model.layers.{i}.self_attn.o_proj.weight"]
out[f"layers.{i}.feed_forward.w1.weight"] = model[f"model.layers.{i}.mlp.gate_proj.weight"]
out[f"layers.{i}.feed_forward.w2.weight"] = model[f"model.layers.{i}.mlp.down_proj.weight"]
out[f"layers.{i}.feed_forward.w3.weight"] = model[f"model.layers.{i}.mlp.up_proj.weight"]
out[f"layers.{i}.attention_norm.weight"] = model[f"model.layers.{i}.input_layernorm.weight"]
out[f"layers.{i}.ffn_norm.weight"] = model[f"model.layers.{i}.post_attention_layernorm.weight"]
return out
# Functionality that simulates `torch.load` but where individual tensors are # Functionality that simulates `torch.load` but where individual tensors are
# only loaded into memory on demand, not all at once. # only loaded into memory on demand, not all at once.
@ -621,14 +620,6 @@ def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus:
return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None) return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None)
SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
'BF16': DT_BF16,
'F16': DT_F16,
'F32': DT_F32,
'I32': DT_I32,
}
def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus: def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
header_size, = struct.unpack('<Q', fp.read(8)) header_size, = struct.unpack('<Q', fp.read(8))
header: Dict[str, Dict[str, Any]] = json.loads(fp.read(header_size)) header: Dict[str, Dict[str, Any]] = json.loads(fp.read(header_size))
@ -678,7 +669,6 @@ def lazy_load_file(path: Path) -> ModelPlus:
In = TypeVar('In') In = TypeVar('In')
Out = TypeVar('Out') Out = TypeVar('Out')
def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int) -> Iterable[Out]: def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int) -> Iterable[Out]:
'''Parallel map, but with backpressure. If the caller doesn't call `next` '''Parallel map, but with backpressure. If the caller doesn't call `next`
fast enough, this will stop calling `func` at some point rather than fast enough, this will stop calling `func` at some point rather than
@ -715,88 +705,133 @@ def check_vocab_size(params: Params, vocab: Vocab) -> None:
class OutputFile: class OutputFile:
def __init__(self, fname_out: Path) -> None: def __init__(self, fname_out: Path) -> None:
self.fout = open(fname_out, "wb") self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
def write_file_header(self, params: Params, file_type: GGMLFileType) -> None: def add_meta_arch(self, params: Params) -> None:
self.fout.write(b"ggjt"[::-1]) # magic self.gguf.add_context_length (params.n_ctx)
values = [ self.gguf.add_embedding_length (params.n_embd)
1, # file version self.gguf.add_block_count (params.n_layer)
params.n_vocab, self.gguf.add_feed_forward_length (params.n_ff)
params.n_embd, self.gguf.add_rope_dimension_count(params.n_embd // params.n_head)
params.n_mult, self.gguf.add_head_count (params.n_head)
params.n_head, self.gguf.add_head_count_kv (params.n_head_kv)
params.n_layer, self.gguf.add_layer_norm_rms_eps (params.f_norm_eps)
params.n_embd // params.n_head, # rot (obsolete)
file_type.value,
]
self.fout.write(struct.pack("i" * len(values), *values))
def write_tensor_header(self, name: str, shape: Sequence[int], data_type: DataType) -> None: def add_meta_vocab(self, vocab: Vocab) -> None:
sname = name.encode('utf-8') tokens = []
self.fout.write(struct.pack("iii", len(shape), len(sname), DATA_TYPE_TO_FTYPE[data_type])) scores = []
self.fout.write(struct.pack("i" * len(shape), *shape[::-1]))
self.fout.write(sname)
self.fout.seek((self.fout.tell() + 31) & -32)
def write_vocab(self, vocab: Vocab) -> None:
for text, score in vocab.all_tokens(): for text, score in vocab.all_tokens():
self.fout.write(struct.pack("i", len(text))) tokens.append(text)
self.fout.write(text) scores.append(score)
self.fout.write(struct.pack("f", score))
self.gguf.add_tokenizer_model("llama")
self.gguf.add_token_list(tokens)
self.gguf.add_token_scores(scores)
#self.gguf.add_token_types(toktypes) # TODO: add this
# TODO: added / special tokens
def add_tensor_info(self, name: str, tensor: LazyTensor) -> None:
n_elements = 1
for dim in tensor.shape:
n_elements *= dim
data_type = DATA_TYPE_TO_NUMPY[tensor.data_type]
data_nbytes = n_elements * data_type.itemsize
self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes)
def write_meta(self) -> None:
self.gguf.write_header_to_file()
self.gguf.write_kv_data_to_file()
def write_tensor_info(self) -> None:
self.gguf.write_ti_data_to_file()
def close(self) -> None:
self.gguf.close()
@staticmethod @staticmethod
def write_vocab_only(fname_out: Path, vocab: Vocab) -> None: def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab) -> None:
of = OutputFile(fname_out)
params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0)
of = OutputFile(fname_out)
of.write_file_header(params, file_type=GGMLFileType.AllF32)
of.write_vocab(vocab)
of.fout.close()
@staticmethod
def write_all(fname_out: Path, params: Params, file_type: GGMLFileType, model: LazyModel, vocab: Vocab) -> None:
check_vocab_size(params, vocab) check_vocab_size(params, vocab)
of = OutputFile(fname_out) of = OutputFile(fname_out)
of.write_file_header(params, file_type)
print("Writing vocab...") # meta data
of.write_vocab(vocab) of.add_meta_arch(params)
of.add_meta_vocab(vocab)
of.write_meta()
of.close()
@staticmethod
def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -> None:
check_vocab_size(params, vocab)
of = OutputFile(fname_out)
# meta data
of.add_meta_arch(params)
of.add_meta_vocab(vocab)
# tensor info
for name, lazy_tensor in model.items():
of.add_tensor_info(name, lazy_tensor)
of.write_meta()
of.write_tensor_info()
def do_item(item: Tuple[str, LazyTensor]) -> NDArray: def do_item(item: Tuple[str, LazyTensor]) -> NDArray:
name, lazy_tensor = item name, lazy_tensor = item
return lazy_tensor.load().to_ggml().ndarray return lazy_tensor.load().to_ggml().ndarray
# tensor data
ndarrays = bounded_parallel_map(do_item, model.items(), concurrency=8) ndarrays = bounded_parallel_map(do_item, model.items(), concurrency=8)
for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)):
size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape) size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape)
padi = len(str(len(model))) padi = len(str(len(model)))
print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type}") print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type}")
of.write_tensor_header(name, lazy_tensor.shape, lazy_tensor.data_type) of.gguf.write_tensor_data(ndarray)
ndarray.tofile(of.fout)
of.fout.close()
of.close()
def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFileType: def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFileType:
wq_type = model["layers.0.attention.wq.weight"].data_type wq_type = model[NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0)+".weight"].data_type
if output_type_str == "f32" or (output_type_str is None and wq_type in (DT_F32, DT_BF16)):
if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32):
return GGMLFileType.AllF32 return GGMLFileType.AllF32
if output_type_str == "f16" or (output_type_str is None and wq_type == DT_F16): if output_type_str == "f16" or (output_type_str is None and wq_type in (DT_F16, DT_BF16)):
return GGMLFileType.MostlyF16 return GGMLFileType.MostlyF16
name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()} name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()}
raise Exception(f"Unexpected combination of types: {name_to_type}") raise Exception(f"Unexpected combination of types: {name_to_type}")
def do_necessary_conversions(model: LazyModel, params: Params) -> LazyModel:
if "lm_head.weight" in model:
model = convert_transformers_to_orig(model, params)
model = filter_and_sort_tensors(model)
return model
def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel: def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel:
return {name: tensor.astype(output_type.type_for_tensor(name, tensor)) return {name: tensor.astype(output_type.type_for_tensor(name, tensor))
for (name, tensor) in model.items()} for (name, tensor) in model.items()}
def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
tmap = gguf.get_tensor_name_map(ARCH, params.n_layer)
out: LazyModel = {}
for name, lazy_tensor in model.items():
name_new = name
if name in tmap:
name_new = tmap[name]
elif name.endswith(".weight") and name[:-7] in tmap:
name_new = tmap[name[:-7]] + ".weight"
elif name.endswith(".bias") and name[:-5] in tmap:
name_new = tmap[name[:-5]] + ".bias"
else:
raise Exception(f"Unexpected tensor name: {name}")
if gguf.should_skip_tensor(ARCH, params.n_layer, name_new):
print(f"skipping tensor {name_new}")
else:
print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type} | {lazy_tensor.shape}")
out[name_new] = lazy_tensor
return out
def nth_multifile_path(path: Path, n: int) -> Optional[Path]: def nth_multifile_path(path: Path, n: int) -> Optional[Path]:
'''Given any path belonging to a multi-file model (e.g. foo.bin.1), return '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return
@ -847,11 +882,6 @@ def load_some_model(path: Path) -> ModelPlus:
# Try the PyTorch patterns too, with lower priority # Try the PyTorch patterns too, with lower priority
globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"] globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"]
files = [file for glob in globs for file in path.glob(glob)] files = [file for glob in globs for file in path.glob(glob)]
if not files:
# Try GGML too, but with lower priority, since if both a non-GGML
# model and a GGML model exist in the same directory, we assume the
# latter was converted from the former.
files = list(path.glob("ggml-model*.bin*"))
if not files: if not files:
raise Exception(f"Can't find model in directory {path}") raise Exception(f"Can't find model in directory {path}")
if len(files) > 1: if len(files) > 1:
@ -868,12 +898,7 @@ def load_some_model(path: Path) -> ModelPlus:
return model_plus return model_plus
def filter_and_sort_tensors(model: LazyModel) -> LazyModel:
return {name: model[name] for name in TENSORS_LIST if name in model}
def load_vocab(path: Path, vocabtype: Optional[str]) -> Union[BpeVocab, SentencePieceVocab]: def load_vocab(path: Path, vocabtype: Optional[str]) -> Union[BpeVocab, SentencePieceVocab]:
print(f"vocabtype: {vocabtype}")
# Be extra-friendly and accept either a file or a directory. Also, if it's # Be extra-friendly and accept either a file or a directory. Also, if it's
# a directory, it might be the model directory, and tokenizer.model might # a directory, it might be the model directory, and tokenizer.model might
# be in the parent of that. # be in the parent of that.
@ -892,8 +917,10 @@ def load_vocab(path: Path, vocabtype: Optional[str]) -> Union[BpeVocab, Sentence
raise FileNotFoundError( raise FileNotFoundError(
f"Could not find tokenizer.model in {path} or its parent; " f"Could not find tokenizer.model in {path} or its parent; "
"if it's in another directory, pass the directory as --vocab-dir") "if it's in another directory, pass the directory as --vocab-dir")
print(f"Loading vocab file '{path}', type '{vocabtype}'")
added_tokens_path = path.parent / "added_tokens.json" added_tokens_path = path.parent / "added_tokens.json"
print(f"Loading vocab file {path}")
if vocabtype == "bpe": if vocabtype == "bpe":
return BpeVocab(path, added_tokens_path if added_tokens_path.exists() else None) return BpeVocab(path, added_tokens_path if added_tokens_path.exists() else None)
elif vocabtype == "spm": elif vocabtype == "spm":
@ -933,38 +960,52 @@ def main(args_in: Optional[List[str]] = None) -> None:
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)") parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)") parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm")
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
args = parser.parse_args(args_in) args = parser.parse_args(args_in)
vocab: Vocab
if args.dump_single: if args.dump_single:
model_plus = lazy_load_file(args.model) model_plus = lazy_load_file(args.model)
do_dump_model(model_plus) do_dump_model(model_plus)
elif args.vocab_only:
model_plus = load_some_model(args.model)
params = Params.load(model_plus)
if params.n_ctx == -1:
if args.ctx is None:
raise Exception("The model doesn't have a context size, and you didn't specify one with --ctx\n"
"Please specify one with --ctx:\n"
" - LLaMA v1: --ctx 2048\n"
" - LLaMA v2: --ctx 4096\n")
params.n_ctx = args.ctx
print(f"params = {params}")
vocab: Vocab
if args.vocab_only:
vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype) vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype)
assert args.outfile, "need --outfile if using --vocab-only" assert args.outfile, "need --outfile if using --vocab-only"
outfile = args.outfile outfile = args.outfile
OutputFile.write_vocab_only(outfile, vocab) OutputFile.write_vocab_only(outfile, params, vocab)
print(f"Wrote {outfile}") print(f"Wrote {outfile}")
else: else:
model_plus = load_some_model(args.model)
if args.dump: if args.dump:
do_dump_model(model_plus) do_dump_model(model_plus)
return return
if model_plus.vocab is not None and args.vocab_dir is None: if model_plus.vocab is not None and args.vocab_dir is None:
vocab = model_plus.vocab vocab = model_plus.vocab
else: else:
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
vocab = load_vocab(vocab_dir, args.vocabtype) vocab = load_vocab(vocab_dir, args.vocabtype)
params = Params.load(model_plus)
model = model_plus.model model = model_plus.model
model = do_necessary_conversions(model, params) model = convert_model_names(model, params)
output_type = pick_output_type(model, args.outtype) output_type = pick_output_type(model, args.outtype)
model = convert_to_output_type(model, output_type) model = convert_to_output_type(model, output_type)
outfile = args.outfile or default_outfile(model_plus.paths, output_type) outfile = args.outfile or default_outfile(model_plus.paths, output_type)
OutputFile.write_all(outfile, params, output_type, model, vocab) OutputFile.write_all(outfile, params, model, vocab)
print(f"Wrote {outfile}") print(f"Wrote {outfile}")

301
gguf.py
View File

@ -8,7 +8,7 @@ import sys
import struct import struct
import numpy as np import numpy as np
from enum import IntEnum from enum import IntEnum, auto
from typing import Any, IO, List from typing import Any, IO, List
# #
@ -33,24 +33,24 @@ KEY_GENERAL_SOURCE_URL = "general.source.url"
KEY_GENERAL_SOURCE_HF_REPO = "general.source.hugginface.repository" KEY_GENERAL_SOURCE_HF_REPO = "general.source.hugginface.repository"
# LLM # LLM
KEY_LLM_CONTEXT_LENGTH = "{llm}.context_length" KEY_LLM_CONTEXT_LENGTH = "{arch}.context_length"
KEY_LLM_EMBEDDING_LENGTH = "{llm}.embedding_length" KEY_LLM_EMBEDDING_LENGTH = "{arch}.embedding_length"
KEY_LLM_BLOCK_COUNT = "{llm}.block_count" KEY_LLM_BLOCK_COUNT = "{arch}.block_count"
KEY_LLM_FEED_FORWARD_LENGTH = "{llm}.feed_forward_length" KEY_LLM_FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
KEY_LLM_USE_PARALLEL_RESIDUAL = "{llm}.use_parallel_residual" KEY_LLM_USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
KEY_LLM_TENSOR_DATA_LAYOUT = "{llm}.tensor_data_layout" KEY_LLM_TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
# attention # attention
KEY_ATTENTION_HEAD_COUNT = "{llm}.attention.head_count" KEY_ATTENTION_HEAD_COUNT = "{arch}.attention.head_count"
KEY_ATTENTION_HEAD_COUNT_KV = "{llm}.attention.head_count_kv" KEY_ATTENTION_HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
KEY_ATTENTION_MAX_ALIBI_BIAS = "{llm}.attention.max_alibi_bias" KEY_ATTENTION_MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
KEY_ATTENTION_CLAMP_KQV = "{llm}.attention.clamp_kqv" KEY_ATTENTION_CLAMP_KQV = "{arch}.attention.clamp_kqv"
KEY_ATTENTION_LAYERNORM_EPS = "{llm}.attention.layer_norm_epsilon" KEY_ATTENTION_LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
KEY_ATTENTION_LAYERNORM_RMS_EPS = "{llm}.attention.layer_norm_rms_epsilon" KEY_ATTENTION_LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
# RoPE # RoPE
KEY_ROPE_DIMENSION_COUNT = "{llm}.rope.dimension_count" KEY_ROPE_DIMENSION_COUNT = "{arch}.rope.dimension_count"
KEY_ROPE_SCALE = "{llm}.rope.scale" KEY_ROPE_SCALE = "{arch}.rope.scale"
# tokenization # tokenization
KEY_TOKENIZER_MODEL = "tokenizer.ggml.model" KEY_TOKENIZER_MODEL = "tokenizer.ggml.model"
@ -70,34 +70,137 @@ KEY_TOKENIZER_RWKV = "tokenizer.rwkv.world"
# recommended mapping of model tensor names for storage in gguf # recommended mapping of model tensor names for storage in gguf
# #
def get_tensor_name_map(n_blocks : int): class MODEL_ARCH(IntEnum):
LLAMA = auto()
FALCON = auto()
GPT2 = auto()
GPTJ = auto()
GPTNEOX = auto()
MPT = auto()
class MODEL_TENSOR(IntEnum):
TOKEN_EMBD = auto()
POS_EMBD = auto()
OUTPUT = auto()
OUTPUT_NORM = auto()
ROPE_FREQS = auto()
ATTN_Q = auto()
ATTN_K = auto()
ATTN_V = auto()
ATTN_QKV = auto()
ATTN_OUT = auto()
ATTN_NORM = auto()
ATTN_NORM_2 = auto()
ATTN_ROT_EMBD = auto()
FFN_GATE = auto()
FFN_DOWN = auto()
FFN_UP = auto()
FFN_NORM = auto()
MODEL_ARCH_NAMES = {
MODEL_ARCH.LLAMA : "llama",
MODEL_ARCH.FALCON : "falcon",
MODEL_ARCH.GPT2 : "gpt2",
MODEL_ARCH.GPTJ : "gptj",
MODEL_ARCH.GPTNEOX : "gptneox",
MODEL_ARCH.MPT : "mpt",
}
MODEL_TENSOR_NAMES = {
MODEL_ARCH.LLAMA : {
MODEL_TENSOR.TOKEN_EMBD : "token_embd",
MODEL_TENSOR.OUTPUT_NORM : "output_norm",
MODEL_TENSOR.OUTPUT : "output",
MODEL_TENSOR.ROPE_FREQS : "rope_freqs",
MODEL_TENSOR.ATTN_NORM : "blk.{bid}.attn_norm",
MODEL_TENSOR.ATTN_Q : "blk.{bid}.attn_q",
MODEL_TENSOR.ATTN_K : "blk.{bid}.attn_k",
MODEL_TENSOR.ATTN_V : "blk.{bid}.attn_v",
MODEL_TENSOR.ATTN_OUT : "blk.{bid}.attn_output",
MODEL_TENSOR.ATTN_ROT_EMBD : "blk.{bid}.attn_rot_embd",
MODEL_TENSOR.FFN_NORM : "blk.{bid}.ffn_norm",
MODEL_TENSOR.FFN_GATE : "blk.{bid}.ffn_gate",
MODEL_TENSOR.FFN_DOWN : "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP : "blk.{bid}.ffn_up",
},
MODEL_ARCH.FALCON : {
MODEL_TENSOR.TOKEN_EMBD : "token_embd",
MODEL_TENSOR.OUTPUT_NORM : "output_norm",
MODEL_TENSOR.OUTPUT : "output",
MODEL_TENSOR.ATTN_NORM : "blk.{bid}.attn_norm",
MODEL_TENSOR.ATTN_NORM_2 : "blk.{bid}.attn_norm_2",
MODEL_TENSOR.ATTN_QKV : "blk.{bid}.attn_qkv",
MODEL_TENSOR.ATTN_OUT : "blk.{bid}.attn_output",
MODEL_TENSOR.FFN_DOWN : "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP : "blk.{bid}.ffn_up",
},
MODEL_ARCH.GPT2 : {
# TODO
},
# TODO
}
# tensors that will not be serialized
MODEL_TENSOR_SKIP = {
MODEL_ARCH.LLAMA : [
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
}
def should_skip_tensor(arch : MODEL_ARCH, n_blocks : int, name : str) -> bool:
for skip in MODEL_TENSOR_SKIP.get(arch, []):
for i in range(n_blocks):
if name == MODEL_TENSOR_NAMES[arch][skip].format(bid=i):
return True
return False
def get_tensor_name_map(arch : MODEL_ARCH, n_blocks : int) -> dict:
tensor_map = {} tensor_map = {}
# Token embeddings # Token embeddings
mapped_to = "token_embd" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.TOKEN_EMBD, None)
tensor_map["gpt_neox.embed_in"] = mapped_to # gptneox tensor_map["gpt_neox.embed_in"] = mapped_to # gptneox
tensor_map["transformer.wte"] = mapped_to # gpt2 mpt tensor_map["transformer.wte"] = mapped_to # gpt2 mpt
tensor_map["transformer.word_embeddings"] = mapped_to # falcon tensor_map["transformer.word_embeddings"] = mapped_to # falcon
tensor_map["model.embed_tokens"] = mapped_to # llama-hf tensor_map["model.embed_tokens"] = mapped_to # llama-hf
tensor_map["tok_embeddings"] = mapped_to # llama-pth tensor_map["tok_embeddings"] = mapped_to # llama-pth
# Position embeddings # Position embeddings
mapped_to = "pos_embd" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.POS_EMBD, None)
tensor_map["transformer.wpe"] = mapped_to # gpt2 tensor_map["transformer.wpe"] = mapped_to # gpt2
# Output
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.OUTPUT, None)
tensor_map["embed_out"] = mapped_to # gptneox
tensor_map["lm_head"] = mapped_to # gpt2 mpt falcon llama-hf
tensor_map["output"] = mapped_to # llama-pth
# Output norm # Output norm
mapped_to = "output_norm" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.OUTPUT_NORM, None)
tensor_map["gpt_neox.final_layer_norm"] = mapped_to # gptneox tensor_map["gpt_neox.final_layer_norm"] = mapped_to # gptneox
tensor_map["transformer.ln_f"] = mapped_to # gpt2 falcon tensor_map["transformer.ln_f"] = mapped_to # gpt2 falcon
tensor_map["transformer.norm_f"] = mapped_to # mpt tensor_map["transformer.norm_f"] = mapped_to # mpt
tensor_map["model.norm"] = mapped_to # llama-hf tensor_map["model.norm"] = mapped_to # llama-hf
tensor_map["norm"] = mapped_to # llama-pth tensor_map["norm"] = mapped_to # llama-pth
# Output
mapped_to = "output" # Rope frequencies
tensor_map["embed_out"] = mapped_to # gptneox mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ROPE_FREQS, None)
tensor_map["lm_head"] = mapped_to # gpt2 mpt falcon llama-hf
tensor_map["output"] = mapped_to # llama-pth tensor_map["rope.freqs"] = mapped_to # llama-pth
# Attention and fee-forward layer blocks
# Attention and feed-forward blocks
for i in range(0,n_blocks): for i in range(0,n_blocks):
# Attention norm # Attention norm
mapped_to = "blk."+str(i)+".attn_norm" # TODO: is there are simpler way to write these 2 lines in Python?
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_NORM, None)
mapped_to = mapped_to.format(bid=i) if mapped_to else None
tensor_map["gpt_neox.layers."+str(i)+".input_layernorm"] = mapped_to # gptneox tensor_map["gpt_neox.layers."+str(i)+".input_layernorm"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".ln_1"] = mapped_to # gpt2 tensor_map["transformer.h."+str(i)+".ln_1"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".norm_1"] = mapped_to # mpt tensor_map["transformer.blocks."+str(i)+".norm_1"] = mapped_to # mpt
@ -105,56 +208,93 @@ def get_tensor_name_map(n_blocks : int):
tensor_map["transformer.h."+str(i)+".ln_attn"] = mapped_to # falcon40b tensor_map["transformer.h."+str(i)+".ln_attn"] = mapped_to # falcon40b
tensor_map["model.layers."+str(i)+".input_layernorm"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".input_layernorm"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention_norm"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".attention_norm"] = mapped_to # llama-pth
# Attention norm 2 # Attention norm 2
mapped_to = "blk."+str(i)+".attn_norm_2" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_NORM_2, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["transformer.h."+str(i)+".ln_mlp"] = mapped_to # falcon40b tensor_map["transformer.h."+str(i)+".ln_mlp"] = mapped_to # falcon40b
# Attention query-key-value # Attention query-key-value
mapped_to = "blk."+str(i)+".attn_qkv" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_QKV, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["gpt_neox.layers."+str(i)+".attention.query_key_value"] = mapped_to # gptneox tensor_map["gpt_neox.layers."+str(i)+".attention.query_key_value"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".attn.c_attn"] = mapped_to # gpt2 tensor_map["transformer.h."+str(i)+".attn.c_attn"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".attn.Wqkv"] = mapped_to # mpt tensor_map["transformer.blocks."+str(i)+".attn.Wqkv"] = mapped_to # mpt
tensor_map["transformer.h."+str(i)+".self_attention.query_key_value"] = mapped_to # falcon tensor_map["transformer.h."+str(i)+".self_attention.query_key_value"] = mapped_to # falcon
# Attention query # Attention query
mapped_to = "blk."+str(i)+".attn_q" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_Q, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["model.layers."+str(i)+".self_attn.q_proj"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".self_attn.q_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention.wq"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".attention.wq"] = mapped_to # llama-pth
# Attention key # Attention key
mapped_to = "blk."+str(i)+".attn_k" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_K, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["model.layers."+str(i)+".self_attn.k_proj"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".self_attn.k_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention.wk"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".attention.wk"] = mapped_to # llama-pth
# Attention value # Attention value
mapped_to = "blk."+str(i)+".attn_v" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_V, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["model.layers."+str(i)+".self_attn.v_proj"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".self_attn.v_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention.wv"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".attention.wv"] = mapped_to # llama-pth
# Attention output # Attention output
mapped_to = "blk."+str(i)+".attn_output" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_OUT, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["gpt_neox.layers."+str(i)+".attention.dense"] = mapped_to # gptneox tensor_map["gpt_neox.layers."+str(i)+".attention.dense"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".attn.c_proj"] = mapped_to # gpt2 tensor_map["transformer.h."+str(i)+".attn.c_proj"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".attn.out_proj"] = mapped_to # mpt tensor_map["transformer.blocks."+str(i)+".attn.out_proj"] = mapped_to # mpt
tensor_map["transformer.h."+str(i)+".self_attention.dense"] = mapped_to # falcon tensor_map["transformer.h."+str(i)+".self_attention.dense"] = mapped_to # falcon
tensor_map["model.layers."+str(i)+".self_attn.o_proj"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".self_attn.o_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention.wo"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".attention.wo"] = mapped_to # llama-pth
# Rotary embeddings
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_ROT_EMBD, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["model.layers."+str(i)+".self_attn.rotary_emb.inv_freq"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention.inner_attention.rope.freqs"] = mapped_to # llama-pth
# Feed-forward norm # Feed-forward norm
mapped_to = "blk."+str(i)+".ffn_norm" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_NORM, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["gpt_neox.layers."+str(i)+".post_attention_layernorm"] = mapped_to # gptneox tensor_map["gpt_neox.layers."+str(i)+".post_attention_layernorm"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".ln_2"] = mapped_to # gpt2 tensor_map["transformer.h."+str(i)+".ln_2"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".norm_2"] = mapped_to # mpt tensor_map["transformer.blocks."+str(i)+".norm_2"] = mapped_to # mpt
tensor_map["model.layers."+str(i)+".post_attention_layernorm"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".post_attention_layernorm"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".ffn_norm"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".ffn_norm"] = mapped_to # llama-pth
# Feed-forward up # Feed-forward up
mapped_to = "blk."+str(i)+".ffn_up" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_UP, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # gptneox tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".mlp.c_fc"] = mapped_to # gpt2 tensor_map["transformer.h."+str(i)+".mlp.c_fc"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".ffn.up_proj"] = mapped_to # mpt tensor_map["transformer.blocks."+str(i)+".ffn.up_proj"] = mapped_to # mpt
tensor_map["transformer.h."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # falcon tensor_map["transformer.h."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # falcon
tensor_map["model.layers."+str(i)+".mlp.up_proj"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".mlp.up_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".feed_forward.w3"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".feed_forward.w3"] = mapped_to # llama-pth
# Feed-forward gate # Feed-forward gate
mapped_to = "blk."+str(i)+".ffn_gate" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_GATE, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["model.layers."+str(i)+".mlp.gate_proj"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".mlp.gate_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".feed_forward.w1"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".feed_forward.w1"] = mapped_to # llama-pth
# Feed-forward down # Feed-forward down
mapped_to = "blk."+str(i)+".ffn_down" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_DOWN, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # gptneox tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".mlp.c_proj"] = mapped_to # gpt2 tensor_map["transformer.h."+str(i)+".mlp.c_proj"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".ffn.down_proj"] = mapped_to # mpt tensor_map["transformer.blocks."+str(i)+".ffn.down_proj"] = mapped_to # mpt
@ -203,14 +343,16 @@ class GGUFValueType(IntEnum):
class GGUFWriter: class GGUFWriter:
def __init__(self, fout: IO): def __init__(self, path: str, arch: str):
self.fout = fout self.fout = open(path, "wb")
self.arch = arch
self.offset_tensor = 0 self.offset_tensor = 0
self.data_alignment = GGUF_DEFAULT_ALIGNMENT self.data_alignment = GGUF_DEFAULT_ALIGNMENT
self.kv_data = b"" self.kv_data = b""
self.kv_data_count = 0 self.kv_data_count = 0
self.ti_data = b"" self.ti_data = b""
self.ti_data_count = 0 self.ti_data_count = 0
self.add_architecture()
def write_header_to_file(self): def write_header_to_file(self):
self.fout.write(struct.pack("<I", GGUF_MAGIC)) self.fout.write(struct.pack("<I", GGUF_MAGIC))
@ -228,11 +370,6 @@ class GGUFWriter:
self.fout.write(self.ti_data) self.fout.write(self.ti_data)
self.flush() self.flush()
@classmethod
def open(cls, path: str) -> "GGUFWriter":
f = open(path, "wb")
return cls(f)
def add_key(self, key: str): def add_key(self, key: str):
self.add_val(key, GGUFValueType.STRING, add_vtype=False) self.add_val(key, GGUFValueType.STRING, add_vtype=False)
@ -269,7 +406,8 @@ class GGUFWriter:
self.add_val(val, GGUFValueType.BOOL) self.add_val(val, GGUFValueType.BOOL)
def add_string(self, key: str, val: str): def add_string(self, key: str, val: str):
if len(val) == 0: return if len(val) == 0:
return
self.add_key(key) self.add_key(key)
self.add_val(val, GGUFValueType.STRING) self.add_val(val, GGUFValueType.STRING)
@ -323,6 +461,8 @@ class GGUFWriter:
return ((x + n - 1) // n) * n return ((x + n - 1) // n) * n
def add_tensor_info(self, name: str, tensor_shape: np.ndarray, tensor_dtype: np.dtype, tensor_nbytes: int): def add_tensor_info(self, name: str, tensor_shape: np.ndarray, tensor_dtype: np.dtype, tensor_nbytes: int):
assert tensor_dtype in (np.float32, np.float16), "Only F32 and F16 tensors are supported for now"
encoded_name = name.encode("utf8") encoded_name = name.encode("utf8")
self.ti_data += struct.pack("<I", len(encoded_name)) self.ti_data += struct.pack("<I", len(encoded_name))
self.ti_data += encoded_name self.ti_data += encoded_name
@ -331,14 +471,13 @@ class GGUFWriter:
for i in range(n_dims): for i in range(n_dims):
self.ti_data += struct.pack("<I", tensor_shape[n_dims - 1 - i]) self.ti_data += struct.pack("<I", tensor_shape[n_dims - 1 - i])
assert tensor_dtype in (np.float32, np.float16), "Only F32 and F16 tensors are supported for now"
dtype = GGMLQuantizationType.F32 if tensor_dtype == np.float32 else GGMLQuantizationType.F16 dtype = GGMLQuantizationType.F32 if tensor_dtype == np.float32 else GGMLQuantizationType.F16
self.ti_data += struct.pack("<I", dtype) self.ti_data += struct.pack("<I", dtype)
self.ti_data += struct.pack("<Q", self.offset_tensor) self.ti_data += struct.pack("<Q", self.offset_tensor)
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment) self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
self.ti_data_count += 1 self.ti_data_count += 1
def write_tensor_to_file(self, tensor: np.ndarray): def write_tensor_data(self, tensor: np.ndarray):
pad = GGUFWriter.ggml_pad(self.fout.tell(), self.data_alignment) - self.fout.tell() pad = GGUFWriter.ggml_pad(self.fout.tell(), self.data_alignment) - self.fout.tell()
if pad != 0: if pad != 0:
self.fout.write(bytes([0] * pad)) self.fout.write(bytes([0] * pad))
@ -355,15 +494,14 @@ class GGUFWriter:
def close(self): def close(self):
self.fout.close() self.fout.close()
def add_architecture(self, architecture: str): def add_architecture(self):
self.add_string(KEY_GENERAL_ARCHITECTURE, self.add_string(KEY_GENERAL_ARCHITECTURE, self.arch)
architecture)
def add_author(self, author: str): def add_author(self, author: str):
self.add_string(KEY_GENERAL_AUTHOR, author) self.add_string(KEY_GENERAL_AUTHOR, author)
def add_tensor_data_layout(self, layout: str): def add_tensor_data_layout(self, layout: str):
self.add_string(KEY_LLM_TENSOR_DATA_LAYOUT , layout) self.add_string(KEY_LLM_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
def add_url(self, url: str): def add_url(self, url: str):
self.add_string(KEY_GENERAL_URL, url) self.add_string(KEY_GENERAL_URL, url)
@ -391,60 +529,60 @@ class GGUFWriter:
self.data_alignment = alignment self.data_alignment = alignment
self.add_uint32(KEY_GENERAL_ALIGNMENT, alignment) self.add_uint32(KEY_GENERAL_ALIGNMENT, alignment)
def add_context_length(self, llm: str, length: int): def add_context_length(self, length: int):
self.add_uint32( self.add_uint32(
KEY_LLM_CONTEXT_LENGTH.format(llm=llm), length) KEY_LLM_CONTEXT_LENGTH.format(arch=self.arch), length)
def add_embedding_length(self, llm: str, length: int): def add_embedding_length(self, length: int):
self.add_uint32( self.add_uint32(
KEY_LLM_EMBEDDING_LENGTH.format(llm=llm), length) KEY_LLM_EMBEDDING_LENGTH.format(arch=self.arch), length)
def add_block_count(self, llm: str, length: int): def add_block_count(self, length: int):
self.add_uint32( self.add_uint32(
KEY_LLM_BLOCK_COUNT.format(llm=llm), length) KEY_LLM_BLOCK_COUNT.format(arch=self.arch), length)
def add_feed_forward_length(self, llm: str, length: int): def add_feed_forward_length(self, length: int):
self.add_uint32( self.add_uint32(
KEY_LLM_FEED_FORWARD_LENGTH.format(llm=llm), length) KEY_LLM_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
def add_parallel_residual(self, llm: str, use: bool): def add_parallel_residual(self, use: bool):
self.add_bool( self.add_bool(
KEY_LLM_USE_PARALLEL_RESIDUAL.format(llm=llm), use) KEY_LLM_USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
def add_tensor_data_layout(self, llm: str, layout: str): def add_tensor_data_layout(self, layout: str):
self.add_string( self.add_string(
KEY_LLM_TENSOR_DATA_LAYOUT.format(llm=llm), layout) KEY_LLM_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
def add_head_count(self, llm: str, count: int): def add_head_count(self, count: int):
self.add_uint32( self.add_uint32(
KEY_ATTENTION_HEAD_COUNT.format(llm=llm), count) KEY_ATTENTION_HEAD_COUNT.format(arch=self.arch), count)
def add_head_count_kv(self, llm: str, count: int): def add_head_count_kv(self, count: int):
self.add_uint32( self.add_uint32(
KEY_ATTENTION_HEAD_COUNT_KV.format(llm=llm), count) KEY_ATTENTION_HEAD_COUNT_KV.format(arch=self.arch), count)
def add_max_alibi_bias(self, llm: str, bias: float): def add_max_alibi_bias(self, bias: float):
self.add_float32( self.add_float32(
KEY_ATTENTION_MAX_ALIBI_BIAS.format(llm=llm), bias) KEY_ATTENTION_MAX_ALIBI_BIAS.format(arch=self.arch), bias)
def add_clamp_kqv(self, llm: str, value: float): def add_clamp_kqv(self, value: float):
self.add_float32( self.add_float32(
KEY_ATTENTION_CLAMP_KQV.format(llm=llm), value) KEY_ATTENTION_CLAMP_KQV.format(arch=self.arch), value)
def add_layer_norm_eps(self, llm: str, value: float): def add_layer_norm_eps(self, value: float):
self.add_float32( self.add_float32(
KEY_ATTENTION_LAYERNORM_EPS.format(llm=llm), value) KEY_ATTENTION_LAYERNORM_EPS.format(arch=self.arch), value)
def add_layer_norm_rms_eps(self, llm: str, value: float): def add_layer_norm_rms_eps(self, value: float):
self.add_float32( self.add_float32(
KEY_ATTENTION_LAYERNORM_RMS_EPS.format(llm=llm), value) KEY_ATTENTION_LAYERNORM_RMS_EPS.format(arch=self.arch), value)
def add_rope_dimension_count(self, llm: str, count: int): def add_rope_dimension_count(self, count: int):
self.add_uint32( self.add_uint32(
KEY_ROPE_DIMENSION_COUNT.format(llm=llm), count) KEY_ROPE_DIMENSION_COUNT.format(arch=self.arch), count)
def add_rope_scale(self, llm: str, value: float): def add_rope_scale(self, value: float):
self.add_float32(KEY_ROPE_SCALE.format(llm=llm), value) self.add_float32(KEY_ROPE_SCALE.format(arch=self.arch), value)
def add_tokenizer_model(self, model: str): def add_tokenizer_model(self, model: str):
self.add_string(KEY_TOKENIZER_MODEL, model) self.add_string(KEY_TOKENIZER_MODEL, model)
@ -479,9 +617,8 @@ class GGUFWriter:
# Example usage: # Example usage:
if __name__ == "__main__": if __name__ == "__main__":
# Example usage with a file # Example usage with a file
gguf_writer = GGUFWriter.open("example.gguf") gguf_writer = GGUFWriter("example.gguf", "llama")
gguf_writer.add_architecture("llama")
gguf_writer.add_uint32("answer", 42) # Write a 32-bit integer gguf_writer.add_uint32("answer", 42) # Write a 32-bit integer
gguf_writer.add_float32("answer_in_float", 42.0) # Write a 32-bit float gguf_writer.add_float32("answer_in_float", 42.0) # Write a 32-bit float
gguf_writer.add_custom_alignment(64) gguf_writer.add_custom_alignment(64)
@ -493,7 +630,7 @@ if __name__ == "__main__":
gguf_writer.write_header_to_file() gguf_writer.write_header_to_file()
gguf_writer.write_kv_data_to_file() gguf_writer.write_kv_data_to_file()
gguf_writer.write_ti_data_to_file() gguf_writer.write_ti_data_to_file()
gguf_writer.write_tensor_to_file(tensor1) gguf_writer.write_tensor_data(tensor1)
gguf_writer.write_tensor_to_file(tensor2) gguf_writer.write_tensor_data(tensor2)
gguf_writer.close() gguf_writer.close()

View File

@ -676,22 +676,21 @@ static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_PER_CONTEXT()
// default hparams (LLaMA 7B) // default hparams (LLaMA 7B)
struct llama_hparams { struct llama_hparams {
uint32_t n_vocab = 32000; uint32_t n_vocab = 32000;
uint32_t n_ctx = 512; uint32_t n_ctx_train = 2048; // the context size used during training
uint32_t n_embd = 4096; uint32_t n_ctx = 512; // the context size used during inference
uint32_t n_head = 32; uint32_t n_embd = 4096;
uint32_t n_head_kv = 32; uint32_t n_head = 32;
uint32_t n_layer = 32; uint32_t n_head_kv = 32;
uint32_t n_rot = 64; uint32_t n_layer = 32;
uint32_t n_ff = 11008; uint32_t n_rot = 64;
uint32_t n_ff = 11008;
float f_norm_rms_eps = 1e-5; float f_norm_rms_eps = 1e-5;
float rope_freq_base = 10000.0f; float rope_freq_base = 10000.0f;
float rope_freq_scale = 1.0f; float rope_freq_scale = 1.0f;
enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16;
bool operator!=(const llama_hparams & other) const { bool operator!=(const llama_hparams & other) const {
return static_cast<bool>(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT return static_cast<bool>(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT
} }
@ -1023,7 +1022,8 @@ struct llama_model_loader {
int n_kv = 0; int n_kv = 0;
int n_tensors = 0; int n_tensors = 0;
int n_created = 0; int n_created = 0;
size_t n_tot_elements = 0;
int64_t n_elements = 0;
bool use_mmap = false; bool use_mmap = false;
@ -1051,9 +1051,9 @@ struct llama_model_loader {
for (int i = 0; i < n_tensors; i++) { for (int i = 0; i < n_tensors; i++) {
const char * name = gguf_get_tensor_name(ctx_gguf, i); const char * name = gguf_get_tensor_name(ctx_gguf, i);
struct ggml_tensor * t = ggml_get_tensor(ctx_meta, name); struct ggml_tensor * t = ggml_get_tensor(ctx_meta, name);
n_tot_elements += ggml_nelements(t); n_elements += ggml_nelements(t);
} }
// print meta data // print meta data
// TODO: make optional // TODO: make optional
{ {
@ -1123,6 +1123,10 @@ struct llama_model_loader {
struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector<uint32_t> & ne, ggml_backend backend) { struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector<uint32_t> & ne, ggml_backend backend) {
struct ggml_tensor * cur = ggml_get_tensor(ctx_meta, name.c_str()); struct ggml_tensor * cur = ggml_get_tensor(ctx_meta, name.c_str());
if (cur == NULL) {
throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str()));
}
{ {
bool is_ok = true; bool is_ok = true;
for (size_t i = 0; i < ne.size(); ++i) { for (size_t i = 0; i < ne.size(); ++i) {
@ -1332,7 +1336,7 @@ static void llama_model_load_internal(
} }
GGUF_GET(hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, "tokenizer.ggml.tokens"); GGUF_GET(hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, "tokenizer.ggml.tokens");
GGUF_GET(hparams.n_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.context_length"); GGUF_GET(hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.context_length");
GGUF_GET(hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.embedding_length"); GGUF_GET(hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.embedding_length");
GGUF_GET(hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.feed_forward_length"); GGUF_GET(hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.feed_forward_length");
GGUF_GET(hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.attention.head_count"); GGUF_GET(hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.attention.head_count");
@ -1406,22 +1410,24 @@ static void llama_model_load_internal(
} }
{ {
LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml->file_version)); LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml->file_version));
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab); LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab);
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx); LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx);
LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head); LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd);
LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head);
LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer);
LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa()); LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa());
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base); LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base);
LLAMA_LOG_INFO("%s: ftype = %u (%s)\n", __func__, hparams.ftype, llama_ftype_name(hparams.ftype)); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale);
LLAMA_LOG_INFO("%s: model size = %.2f B\n", __func__, ml->n_tot_elements*1e-9); LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
LLAMA_LOG_INFO("%s: model size = %.2fB\n", __func__, ml->n_elements*1e-9);
// TODO: print number of tensors for each quantization
} }
if (vocab_only) { if (vocab_only) {
@ -2310,6 +2316,18 @@ static uint8_t llama_byte_to_char(const llama_vocab & vocab, uint8_t byte) {
return false; return false;
} }
static uint8_t llama_char_to_byte(const llama_vocab & vocab, uint8_t ch) {
if (llama_vocab_type(vocab) == "spm") {
return ch + 3;
}
if (llama_vocab_type(vocab) == "bpe") {
return ch - 32;
}
return false;
}
static std::string llama_escape_whitespace(const std::string& text) { static std::string llama_escape_whitespace(const std::string& text) {
std::string result; std::string result;
bool escaping = false; bool escaping = false;
@ -2446,7 +2464,7 @@ private:
if (p == rev_merge.end()) { if (p == rev_merge.end()) {
// output any symbols that did not form tokens as bytes. // output any symbols that did not form tokens as bytes.
for (int j = 0; j < (int)symbol.n; ++j) { for (int j = 0; j < (int)symbol.n; ++j) {
llama_vocab::id token_id = llama_byte_to_char(vocab_, symbol.text[j]); llama_vocab::id token_id = llama_char_to_byte(vocab_, symbol.text[j]);
output.push_back(token_id); output.push_back(token_id);
} }
return; return;
@ -3373,7 +3391,6 @@ static void llama_convert_tensor_internal(struct ggml_tensor * tensor, std::vect
static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
ggml_type quantized_type; ggml_type quantized_type;
llama_ftype ftype = params->ftype; llama_ftype ftype = params->ftype;
int nthread = params->nthread;
switch (params->ftype) { switch (params->ftype) {
case LLAMA_FTYPE_MOSTLY_Q4_0: quantized_type = GGML_TYPE_Q4_0; break; case LLAMA_FTYPE_MOSTLY_Q4_0: quantized_type = GGML_TYPE_Q4_0; break;
@ -3399,6 +3416,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
} }
int nthread = params->nthread;
if (nthread <= 0) { if (nthread <= 0) {
nthread = std::thread::hardware_concurrency(); nthread = std::thread::hardware_concurrency();
} }
@ -3669,6 +3688,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
} }
} }
// TODO: after the GGUF PR, this likely won't work and needs to be updated
int llama_apply_lora_from_file_internal(const struct llama_model & model, const char * path_lora, const char * path_base_model, int n_threads) { int llama_apply_lora_from_file_internal(const struct llama_model & model, const char * path_lora, const char * path_base_model, int n_threads) {
LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora); LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora);
@ -4876,8 +4896,8 @@ int llama_token_to_str_with_model(const struct llama_model * model, llama_token
return 0; return 0;
} }
int llama_token_to_str(const struct llama_context * ctx, llama_token token, char * str, int length) { int llama_token_to_str(const struct llama_context * ctx, llama_token token, char * buf, int length) {
return llama_token_to_str_with_model(&ctx->model, token, str, length); return llama_token_to_str_with_model(&ctx->model, token, buf, length);
} }
std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) { std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) {
@ -4894,13 +4914,13 @@ std::string llama_token_to_str(const struct llama_context * ctx, llama_token tok
return std::string(result.data(), result.size()); return std::string(result.data(), result.size());
} }
int llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token, char * str, int length) { int llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token, char * buf, int length) {
if (0 <= token && token < llama_n_vocab_from_model(&ctx->model)) { if (0 <= token && token < llama_n_vocab_from_model(&ctx->model)) {
std::string result = ctx->model.vocab.id_to_token[token].tok; std::string result = ctx->model.vocab.id_to_token[token].tok;
if (length < (int) result.length()) { if (length < (int) result.length()) {
return -result.length(); return -result.length();
} }
memcpy(str, result.c_str(), result.length()); memcpy(buf, result.c_str(), result.length());
return result.length(); return result.length();
} }
return 0; return 0;

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -26,10 +26,10 @@ llama_build_and_test_executable(test-quantize-fns.cpp)
llama_build_and_test_executable(test-quantize-perf.cpp) llama_build_and_test_executable(test-quantize-perf.cpp)
llama_build_and_test_executable(test-sampling.cpp) llama_build_and_test_executable(test-sampling.cpp)
llama_build_executable(test-tokenizer-0.cpp) llama_build_executable(test-tokenizer-0.cpp)
llama_test_executable(test-tokenizer-0.llama test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.bin) llama_test_executable(test-tokenizer-0.llama test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
llama_build_executable(test-tokenizer-1.cpp) llama_build_executable(test-tokenizer-1.cpp)
llama_test_executable(test-tokenizer-1.llama test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.bin) llama_test_executable(test-tokenizer-1.llama test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
llama_test_executable(test-tokenizer-1.aquila test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.bin) #llama_test_executable(test-tokenizer-1.aquila test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
llama_build_and_test_executable(test-grammar-parser.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../examples/grammar-parser.cpp) llama_build_and_test_executable(test-grammar-parser.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../examples/grammar-parser.cpp)
llama_build_and_test_executable(test-grad0.cpp) # SLOW llama_build_and_test_executable(test-grad0.cpp) # SLOW
# llama_build_and_test_executable(test-opt.cpp) # SLOW # llama_build_and_test_executable(test-opt.cpp) # SLOW

View File

@ -89,6 +89,8 @@ int main(int argc, char **argv) {
return 2; return 2;
} }
bool success = true;
for (const auto & test_kv : k_tests()) { for (const auto & test_kv : k_tests()) {
std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, true); std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, true);
fprintf(stderr, "%s : '%s' tokenized to '%s'\n", fprintf(stderr, "%s : '%s' tokenized to '%s'\n",
@ -103,7 +105,8 @@ int main(int argc, char **argv) {
} }
if (!correct) { if (!correct) {
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str()); fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
fprintf(stderr, "%s : detokenized to: '%s'\n", __func__, unescape_whitespace(ctx, test_kv.second).c_str());
fprintf(stderr, "%s : expected tokens: ", __func__); fprintf(stderr, "%s : expected tokens: ", __func__);
for (const auto & t : test_kv.second) { for (const auto & t : test_kv.second) {
fprintf(stderr, "%6d, ", t); fprintf(stderr, "%6d, ", t);
@ -115,9 +118,7 @@ int main(int argc, char **argv) {
} }
fprintf(stderr, "\n"); fprintf(stderr, "\n");
llama_free_model(model); success = false;
llama_free(ctx);
return 3;
} }
} }
@ -126,5 +127,5 @@ int main(int argc, char **argv) {
llama_backend_free(); llama_backend_free();
return 0; return success ? 0 : 3;
} }