convert-*.py: GGUF Naming Convention Refactor and Metadata Override Refactor (#7499)

Main thing is that the default output filename will take this form

{name}{parameters}{finetune}{version}{encoding}{kind}

In addition this add and remove some entries in the KV store and adds a metadata class with automatic heuristics capability to derive some values based on model card content

* No Change:
  - Internal GGUF Spec
    - `general.architecture`
    - `general.quantization_version`
    - `general.alignment`
    - `general.file_type`
  - General Model Details
    - `general.name`
    - `general.author`
    - `general.version`
    - `general.description`
  - Licensing details
    - `general.license`
  - Typically represents the converted GGUF repo (Unless made from scratch)
    - `general.url`
  - Model Source during conversion
    - `general.source.url`

* Removed:
  - Model Source during conversion
    - `general.source.huggingface.repository`

* Added:
  - General Model Details
    - `general.organization`
    - `general.finetune`
    - `general.basename`
    - `general.quantized_by`
    - `general.size_label`
  - Licensing details
    - `general.license.name`
    - `general.license.link`
  - Typically represents the converted GGUF repo (Unless made from scratch)
    - `general.doi`
    - `general.uuid`
    - `general.repo_url`
  - Model Source during conversion
    - `general.source.doi`
    - `general.source.uuid`
    - `general.source.repo_url`
  - Base Model Source
    - `general.base_model.count`
    - `general.base_model.{id}.name`
    - `general.base_model.{id}.author`
    - `general.base_model.{id}.version`
    - `general.base_model.{id}.organization`
    - `general.base_model.{id}.url` (Model Website/Paper)
    - `general.base_model.{id}.doi`
    - `general.base_model.{id}.uuid`
    - `general.base_model.{id}.repo_url` (Model Source Repository (git/svn/etc...))
  - Array based KV stores
    - `general.tags`
    - `general.languages`
    - `general.datasets`

---------

Co-authored-by: compilade <git@compilade.net>
Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
This commit is contained in:
Brian 2024-07-18 20:40:15 +10:00 committed by GitHub
parent 3807c3de04
commit 672a6f1018
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1185 additions and 239 deletions

View File

@ -48,34 +48,38 @@ class Model:
dir_model: Path
ftype: gguf.LlamaFileType
fname_out: Path | None
is_big_endian: bool
endianess: gguf.GGUFEndian
use_temp_file: bool
lazy: bool
model_name: str | None
part_names: list[str]
is_safetensors: bool
hparams: dict[str, Any]
block_count: int
tensor_map: gguf.TensorNameMap
tensor_names: set[str] | None
fname_out: Path
gguf_writer: gguf.GGUFWriter
model_name: str | None
metadata_override: Path | None
# subclasses should define this!
model_arch: gguf.MODEL_ARCH
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool,
model_name: str | None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path | None, is_big_endian: bool = False,
use_temp_file: bool = False, eager: bool = False,
metadata_override: Path | None = None, model_name: str | None = None,
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
if type(self) is Model:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
self.dir_model = dir_model
self.ftype = ftype
self.fname_out = fname_out
self.is_big_endian = is_big_endian
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
self.use_temp_file = use_temp_file
self.lazy = not eager
self.model_name = model_name
self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors")
self.is_safetensors = len(self.part_names) > 0
if not self.is_safetensors:
@ -84,6 +88,10 @@ class Model:
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
self.tensor_names = None
self.metadata_override = metadata_override
self.model_name = model_name
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
if self.ftype == gguf.LlamaFileType.GUESSED:
# NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
_, first_tensor = next(self.get_tensors())
@ -93,10 +101,8 @@ class Model:
else:
logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
self.ftype = gguf.LlamaFileType.MOSTLY_BF16
ftype_up: str = self.ftype.name.partition("_")[2].upper()
ftype_lw: str = ftype_up.lower()
# allow templating the file name with the output ftype, useful with the "auto" ftype
self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up)
# Configure GGUF Writer
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
@ -193,7 +199,6 @@ class Model:
return new_name
def set_gguf_parameters(self):
self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
self.gguf_writer.add_block_count(self.block_count)
if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
@ -250,7 +255,7 @@ class Model:
return False
def write_tensors(self):
def prepare_tensors(self):
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
for name, data_torch in self.get_tensors():
@ -333,9 +338,67 @@ class Model:
self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype)
def set_type(self):
self.gguf_writer.add_type(gguf.GGUFType.MODEL)
def prepare_metadata(self, vocab_only: bool):
total_params, shared_params, expert_params, expert_count = self.gguf_writer.get_total_parameter_count()
self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model, self.model_name, total_params)
# Fallback to model directory name if metadata name is still missing
if self.metadata.name is None:
self.metadata.name = self.dir_model.name
# Generate parameter weight class (useful for leader boards) if not yet determined
if self.metadata.size_label is None and total_params > 0:
self.metadata.size_label = gguf.size_label(total_params, shared_params, expert_params, expert_count)
# Extract the encoding scheme from the file type name. e.g. 'gguf.LlamaFileType.MOSTLY_Q8_0' --> 'Q8_0'
output_type: str = self.ftype.name.partition("_")[2]
# Filename Output
# Note: `not is_dir()` is used because `.is_file()` will not detect
# file template strings as it doesn't actually exist as a file
if self.fname_out is not None and not self.fname_out.is_dir():
# Output path is a custom defined templated filename
# Process templated file name with the output ftype, useful with the "auto" ftype
self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type)
else:
# Generate default filename based on model specification and available metadata
if not vocab_only:
fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, self.metadata.size_label, output_type, model_type="LoRA" if total_params < 0 else None)
else:
fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, size_label=None, output_type=None, model_type="vocab")
# Check if preferred output directory path was provided
if self.fname_out is not None and self.fname_out.is_dir():
# output path is a directory
self.fname_out = self.fname_out / f"{fname_default}.gguf"
else:
# output in the same directory as the model by default
self.fname_out = self.dir_model / f"{fname_default}.gguf"
self.set_type()
logger.info("Set meta model")
self.metadata.set_gguf_meta_model(self.gguf_writer)
logger.info("Set model parameters")
self.set_gguf_parameters()
logger.info("Set model tokenizer")
self.set_vocab()
logger.info("Set model quantization version")
self.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
def write(self):
self.write_tensors()
self.gguf_writer.write_header_to_file(self.fname_out)
self.prepare_tensors()
self.prepare_metadata(vocab_only=False)
self.gguf_writer.write_header_to_file(path=self.fname_out)
self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.write_tensors_to_file(progress=True)
self.gguf_writer.close()
@ -343,7 +406,9 @@ class Model:
def write_vocab(self):
if len(self.gguf_writer.tensors) != 1:
raise ValueError('Splitting the vocabulary is not supported')
self.gguf_writer.write_header_to_file(self.fname_out)
self.prepare_metadata(vocab_only=True)
self.gguf_writer.write_header_to_file(path=self.fname_out)
self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.close()
@ -780,7 +845,6 @@ class GPTNeoXModel(Model):
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
@ -836,7 +900,6 @@ class BloomModel(Model):
model_arch = gguf.MODEL_ARCH.BLOOM
def set_gguf_parameters(self):
self.gguf_writer.add_name("Bloom")
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
@ -913,7 +976,6 @@ class MPTModel(Model):
def set_gguf_parameters(self):
block_count = self.hparams["n_layers"]
self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
self.gguf_writer.add_block_count(block_count)
@ -952,7 +1014,6 @@ class OrionModel(Model):
block_count = self.hparams["num_hidden_layers"]
head_count = self.hparams["num_attention_heads"]
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
hf_repo = self.hparams.get("_name_or_path", "")
ctx_length = 0
if "max_sequence_length" in self.hparams:
@ -965,8 +1026,6 @@ class OrionModel(Model):
raise ValueError("gguf: can not find ctx length parameter.")
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
self.gguf_writer.add_source_hf_repo(hf_repo)
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
self.gguf_writer.add_context_length(ctx_length)
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
@ -990,7 +1049,6 @@ class BaichuanModel(Model):
block_count = self.hparams["num_hidden_layers"]
head_count = self.hparams["num_attention_heads"]
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
hf_repo = self.hparams.get("_name_or_path", "")
ctx_length = 0
if "max_sequence_length" in self.hparams:
@ -1002,8 +1060,6 @@ class BaichuanModel(Model):
else:
raise ValueError("gguf: can not find ctx length parameter.")
self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
self.gguf_writer.add_source_hf_repo(hf_repo)
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
self.gguf_writer.add_context_length(ctx_length)
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
@ -1117,7 +1173,6 @@ class XverseModel(Model):
block_count = self.hparams["num_hidden_layers"]
head_count = self.hparams["num_attention_heads"]
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
hf_repo = self.hparams.get("_name_or_path", "")
ctx_length = 0
if "max_sequence_length" in self.hparams:
@ -1129,8 +1184,6 @@ class XverseModel(Model):
else:
raise ValueError("gguf: can not find ctx length parameter.")
self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
self.gguf_writer.add_source_hf_repo(hf_repo)
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
self.gguf_writer.add_context_length(ctx_length)
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
@ -1189,7 +1242,6 @@ class FalconModel(Model):
if n_head_kv is None:
n_head_kv = self.hparams.get("n_head_kv", 1) # old name
self.gguf_writer.add_name("Falcon")
self.gguf_writer.add_context_length(2048) # not in config.json
self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
@ -1234,7 +1286,6 @@ class StarCoderModel(Model):
def set_gguf_parameters(self):
block_count = self.hparams["n_layer"]
self.gguf_writer.add_name("StarCoder")
self.gguf_writer.add_context_length(self.hparams["n_positions"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
@ -1269,7 +1320,6 @@ class RefactModel(Model):
block_count = self.hparams["n_layer"]
self.gguf_writer.add_name("Refact")
# refact uses Alibi. So this is from config.json which might be used by training.
self.gguf_writer.add_context_length(self.hparams["n_positions"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
@ -1324,7 +1374,6 @@ class StableLMModel(Model):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]
self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
@ -1386,8 +1435,8 @@ class StableLMModel(Model):
return [(new_name, data_torch)]
def write_tensors(self):
super().write_tensors()
def prepare_tensors(self):
super().prepare_tensors()
if self._q_norms is not None or self._k_norms is not None:
# flatten two `list[dict[str, Tensor]]` into a single `list[str]`
@ -1503,8 +1552,8 @@ class LlamaModel(Model):
return [(self.map_tensor_name(name), data_torch)]
def write_tensors(self):
super().write_tensors()
def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
@ -1567,7 +1616,6 @@ class GrokModel(Model):
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_name("Grok")
_experts: list[dict[str, Tensor]] | None = None
@ -1616,7 +1664,6 @@ class DbrxModel(Model):
def set_gguf_parameters(self):
ffn_config = self.hparams["ffn_config"]
attn_config = self.hparams["attn_config"]
self.gguf_writer.add_name(self.hparams["model_type"])
self.gguf_writer.add_block_count(self.hparams["n_layers"])
self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
@ -1685,7 +1732,6 @@ class MiniCPMModel(Model):
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
self.gguf_writer.add_name("MiniCPM")
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
@ -1755,7 +1801,6 @@ class QwenModel(Model):
self._set_vocab_qwen()
def set_gguf_parameters(self):
self.gguf_writer.add_name("Qwen")
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
@ -1831,8 +1876,8 @@ class Qwen2MoeModel(Model):
return [(self.map_tensor_name(name), data_torch)]
def write_tensors(self):
super().write_tensors()
def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
@ -1846,7 +1891,6 @@ class GPT2Model(Model):
model_arch = gguf.MODEL_ARCH.GPT2
def set_gguf_parameters(self):
self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_context_length(self.hparams["n_ctx"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
@ -1889,7 +1933,6 @@ class Phi2Model(Model):
n_embd = self.find_hparam(["hidden_size", "n_embd"])
n_head = self.find_hparam(["num_attention_heads", "n_head"])
self.gguf_writer.add_name("Phi2")
self.gguf_writer.add_context_length(self.find_hparam(["n_positions", "max_position_embeddings"]))
self.gguf_writer.add_embedding_length(n_embd)
@ -2011,7 +2054,6 @@ class Phi3MiniModel(Model):
orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"])
rope_dims = n_embd // n_head
self.gguf_writer.add_name("Phi3")
self.gguf_writer.add_context_length(max_pos_embds)
self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds)
self.gguf_writer.add_embedding_length(n_embd)
@ -2068,7 +2110,6 @@ class PlamoModel(Model):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]
self.gguf_writer.add_name("PLaMo")
self.gguf_writer.add_context_length(4096) # not in config.json
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
@ -2113,7 +2154,6 @@ class CodeShellModel(Model):
def set_gguf_parameters(self):
block_count = self.hparams["n_layer"]
self.gguf_writer.add_name("CodeShell")
self.gguf_writer.add_context_length(self.hparams["n_positions"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
@ -2272,7 +2312,6 @@ class InternLM2Model(Model):
special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self):
self.gguf_writer.add_name("InternLM2")
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
@ -2440,7 +2479,6 @@ class GemmaModel(Model):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]
self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
@ -2481,7 +2519,6 @@ class Gemma2Model(Model):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]
self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
@ -2556,7 +2593,6 @@ class MambaModel(Model):
# Fail early for models which don't have a block expansion factor of 2
assert d_inner == 2 * d_model
self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
@ -2735,7 +2771,6 @@ class OpenELMModel(Model):
assert self.block_count == len(self._num_query_heads)
assert self.block_count == len(self._ffn_dims)
self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.hparams["max_context_length"])
self.gguf_writer.add_embedding_length(n_embd)
@ -2909,8 +2944,8 @@ class ArcticModel(Model):
return [(self.map_tensor_name(name), data_torch)]
def write_tensors(self):
super().write_tensors()
def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
@ -2988,8 +3023,8 @@ class DeepseekV2Model(Model):
return [(self.map_tensor_name(name), data_torch)]
def write_tensors(self):
super().write_tensors()
def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
@ -3107,7 +3142,6 @@ class T5Model(Model):
self.gguf_writer.add_add_eos_token(True)
def set_gguf_parameters(self):
self.gguf_writer.add_name("T5")
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
logger.warning("Couldn't find context length in config.json, assuming default value of 512")
n_ctx = 512
@ -3181,7 +3215,6 @@ class JaisModel(Model):
self._set_vocab_gpt2()
def set_gguf_parameters(self):
self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_context_length(self.hparams["n_positions"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
@ -3227,8 +3260,8 @@ class JaisModel(Model):
return tensors
def write_tensors(self):
super().write_tensors()
def prepare_tensors(self):
super().prepare_tensors()
self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias)
@ -3539,6 +3572,10 @@ def parse_args() -> argparse.Namespace:
"--no-tensor-first-split", action="store_true",
help="do not add tensors to the first split (disabled by default)"
)
parser.add_argument(
"--metadata", type=Path,
help="Specify the path for an authorship metadata override file"
)
return parser.parse_args()
@ -3564,7 +3601,10 @@ def split_str_to_n_bytes(split_str: str) -> int:
def main() -> None:
args = parse_args()
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
if args.verbose:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)
dir_model = args.model
@ -3585,37 +3625,33 @@ def main() -> None:
logger.error("Error: Cannot use temp file when splitting")
sys.exit(1)
fname_out = None
if args.outfile is not None:
fname_out = args.outfile
else:
# output in the same directory as the model by default
fname_out = dir_model / 'ggml-model-{ftype}.gguf'
logger.info(f"Loading model: {dir_model.name}")
hparams = Model.load_hparams(dir_model)
with torch.inference_mode():
output_type = ftype_map[args.outtype]
model_architecture = hparams["architectures"][0]
try:
model_class = Model.from_model_architecture(hparams["architectures"][0])
model_class = Model.from_model_architecture(model_architecture)
except NotImplementedError:
logger.error(f"Model {hparams['architectures'][0]} is not supported")
logger.error(f"Model {model_architecture} is not supported")
sys.exit(1)
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file,
args.no_lazy, args.model_name, split_max_tensors=args.split_max_tensors,
model_instance = model_class(dir_model=dir_model, ftype=output_type, fname_out=fname_out,
is_big_endian=args.bigendian, use_temp_file=args.use_temp_file,
eager=args.no_lazy,
metadata_override=args.metadata, model_name=args.model_name,
split_max_tensors=args.split_max_tensors,
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split)
logger.info("Set model parameters")
model_instance.gguf_writer.add_type(gguf.GGUFType.MODEL)
model_instance.set_gguf_parameters()
logger.info("Set model tokenizer")
model_instance.set_vocab()
model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
if args.vocab_only:
logger.info("Exporting model vocab...")
model_instance.write_vocab()
@ -3623,6 +3659,7 @@ def main() -> None:
else:
logger.info("Exporting model...")
model_instance.write()
assert model_instance.fname_out is not None
out_path = f"{model_instance.fname_out.parent}{os.sep}" if is_split else model_instance.fname_out
logger.info(f"Model successfully exported to {out_path}")

View File

@ -251,6 +251,10 @@ def parse_args() -> argparse.Namespace:
"--verbose", action="store_true",
help="increase output verbosity",
)
parser.add_argument(
"--dry-run", action="store_true",
help="only print out what will be done, without writing any new files",
)
parser.add_argument(
"--base", type=Path, required=True,
help="directory containing base model file",
@ -300,6 +304,12 @@ if __name__ == '__main__':
# load base model
logger.info(f"Loading base model: {dir_base_model.name}")
hparams = Model.load_hparams(dir_base_model)
with open(lora_config, "r") as f:
lparams: dict[str, Any] = json.load(f)
alpha: float = lparams["lora_alpha"]
with torch.inference_mode():
try:
model_class = Model.from_model_architecture(hparams["architectures"][0])
@ -310,6 +320,14 @@ if __name__ == '__main__':
class LoraModel(model_class):
model_arch = model_class.model_arch
def set_type(self):
self.gguf_writer.add_type(gguf.GGUFType.ADAPTER)
self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
def set_gguf_parameters(self):
self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, float(alpha))
super().set_gguf_parameters()
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
tensor_map: dict[str, PartialLoraTensor] = {}
@ -357,18 +375,9 @@ if __name__ == '__main__':
is_big_endian=args.bigendian,
use_temp_file=False,
eager=args.no_lazy,
model_name=None,
dry_run=args.dry_run,
)
with open(lora_config, "r") as f:
lparams: dict[str, Any] = json.load(f)
alpha = lparams["lora_alpha"]
model_instance.gguf_writer.add_string(gguf.Keys.General.TYPE, gguf.GGUFType.ADAPTER)
model_instance.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
model_instance.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, float(alpha))
model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
logger.info("Exporting model...")
model_instance.write()
logger.info(f"Model successfully exported to {model_instance.fname_out}")

View File

@ -24,7 +24,7 @@ from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, IO, Iterable, Literal, TypeVar, Optional
from typing import TYPE_CHECKING, Any, Callable, IO, Iterable, Literal, TypeVar
import numpy as np
@ -346,42 +346,6 @@ class Params:
return params
@dataclass
class Metadata:
name: Optional[str] = None
author: Optional[str] = None
version: Optional[str] = None
url: Optional[str] = None
description: Optional[str] = None
license: Optional[str] = None
source_url: Optional[str] = None
source_hf_repo: Optional[str] = None
@staticmethod
def load(metadata_path: Path) -> Metadata:
if metadata_path is None or not metadata_path.exists():
return Metadata()
with open(metadata_path, 'r') as file:
data = json.load(file)
# Create a new Metadata instance
metadata = Metadata()
# Assigning values to Metadata attributes if they exist in the JSON file
# This is based on LLM_KV_NAMES mapping in llama.cpp
metadata.name = data.get("general.name")
metadata.author = data.get("general.author")
metadata.version = data.get("general.version")
metadata.url = data.get("general.url")
metadata.description = data.get("general.description")
metadata.license = data.get("general.license")
metadata.source_url = data.get("general.source.url")
metadata.source_hf_repo = data.get("general.source.huggingface.repository")
return metadata
#
# data loading
# TODO: reuse (probably move to gguf.py?)
@ -806,7 +770,7 @@ class OutputFile:
def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
def add_meta_model(self, params: Params, metadata: Metadata | None) -> None:
def add_meta_model(self, params: Params, metadata: gguf.Metadata | None) -> None:
# Metadata About The Model And Its Provenence
name = "LLaMA"
if metadata is not None and metadata.name is not None:
@ -824,16 +788,73 @@ class OutputFile:
self.gguf.add_author(metadata.author)
if metadata.version is not None:
self.gguf.add_version(metadata.version)
if metadata.url is not None:
self.gguf.add_url(metadata.url)
if metadata.organization is not None:
self.gguf.add_organization(metadata.organization)
if metadata.finetune is not None:
self.gguf.add_finetune(metadata.finetune)
if metadata.basename is not None:
self.gguf.add_basename(metadata.basename)
if metadata.description is not None:
self.gguf.add_description(metadata.description)
if metadata.quantized_by is not None:
self.gguf.add_quantized_by(metadata.quantized_by)
if metadata.size_label is not None:
self.gguf.add_size_label(metadata.size_label)
if metadata.license is not None:
self.gguf.add_licence(metadata.license)
self.gguf.add_license(metadata.license)
if metadata.license_name is not None:
self.gguf.add_license_name(metadata.license_name)
if metadata.license_link is not None:
self.gguf.add_license_link(metadata.license_link)
if metadata.url is not None:
self.gguf.add_url(metadata.url)
if metadata.doi is not None:
self.gguf.add_doi(metadata.doi)
if metadata.uuid is not None:
self.gguf.add_uuid(metadata.uuid)
if metadata.repo_url is not None:
self.gguf.add_repo_url(metadata.repo_url)
if metadata.source_url is not None:
self.gguf.add_source_url(metadata.source_url)
if metadata.source_hf_repo is not None:
self.gguf.add_source_hf_repo(metadata.source_hf_repo)
if metadata.source_doi is not None:
self.gguf.add_source_doi(metadata.source_doi)
if metadata.source_uuid is not None:
self.gguf.add_source_uuid(metadata.source_uuid)
if metadata.source_repo_url is not None:
self.gguf.add_source_repo_url(metadata.source_repo_url)
if metadata.base_models is not None:
self.gguf.add_base_model_count(len(metadata.base_models))
for key, base_model_entry in enumerate(metadata.base_models):
if "name" in base_model_entry:
self.gguf.add_base_model_name(key, base_model_entry["name"])
if "author" in base_model_entry:
self.gguf.add_base_model_author(key, base_model_entry["author"])
if "version" in base_model_entry:
self.gguf.add_base_model_version(key, base_model_entry["version"])
if "organization" in base_model_entry:
self.gguf.add_base_model_organization(key, base_model_entry["organization"])
if "url" in base_model_entry:
self.gguf.add_base_model_url(key, base_model_entry["url"])
if "doi" in base_model_entry:
self.gguf.add_base_model_doi(key, base_model_entry["doi"])
if "uuid" in base_model_entry:
self.gguf.add_base_model_uuid(key, base_model_entry["uuid"])
if "repo_url" in base_model_entry:
self.gguf.add_base_model_repo_url(key, base_model_entry["repo_url"])
if metadata.tags is not None:
self.gguf.add_tags(metadata.tags)
if metadata.languages is not None:
self.gguf.add_languages(metadata.languages)
if metadata.datasets is not None:
self.gguf.add_datasets(metadata.datasets)
def add_meta_arch(self, params: Params) -> None:
# Metadata About The Neural Architecture Itself
@ -944,7 +965,7 @@ class OutputFile:
@staticmethod
def write_vocab_only(
fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: Metadata | None = None,
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: gguf.Metadata | None = None,
) -> None:
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
@ -978,7 +999,7 @@ class OutputFile:
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
pad_vocab: bool = False,
metadata: Metadata | None = None,
metadata: gguf.Metadata | None = None,
) -> None:
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
@ -1021,35 +1042,32 @@ def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileT
raise ValueError(f"Unexpected combination of types: {name_to_type}")
def model_parameter_count(model: LazyModel) -> int:
total_model_parameters = 0
for i, (name, lazy_tensor) in enumerate(model.items()):
sum_weights_in_tensor = 1
def per_model_weight_count_estimation(tensors: Iterable[tuple[str, LazyTensor]]) -> tuple[int, int, int]:
total_params = 0
shared_params = 0
expert_params = 0
for name, lazy_tensor in tensors:
# We don't need these
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
continue
# Got A Tensor
sum_weights_in_tensor: int = 1
# Tensor Volume
for dim in lazy_tensor.shape:
sum_weights_in_tensor *= dim
total_model_parameters += sum_weights_in_tensor
return total_model_parameters
if ".experts." in name:
if ".experts.0." in name:
expert_params += sum_weights_in_tensor
else:
shared_params += sum_weights_in_tensor
def model_parameter_count_rounded_notation(model_params_count: int) -> str:
if model_params_count > 1e12 :
# Trillions Of Parameters
scaled_model_params = model_params_count * 1e-12
scale_suffix = "T"
elif model_params_count > 1e9 :
# Billions Of Parameters
scaled_model_params = model_params_count * 1e-9
scale_suffix = "B"
elif model_params_count > 1e6 :
# Millions Of Parameters
scaled_model_params = model_params_count * 1e-6
scale_suffix = "M"
else:
# Thousands Of Parameters
scaled_model_params = model_params_count * 1e-3
scale_suffix = "K"
total_params += sum_weights_in_tensor
return f"{round(scaled_model_params)}{scale_suffix}"
return total_params, shared_params, expert_params
def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel:
@ -1231,34 +1249,24 @@ class VocabFactory:
return vocab, special_vocab
def default_convention_outfile(file_type: GGMLFileType, params: Params, model_params_count: int, metadata: Metadata) -> str:
quantization = {
def default_convention_outfile(file_type: GGMLFileType, expert_count: int | None, model_params_count: tuple[int, int, int], metadata: gguf.Metadata) -> str:
name = metadata.name if metadata.name is not None else None
basename = metadata.basename if metadata.basename is not None else None
finetune = metadata.finetune if metadata.finetune is not None else None
version = metadata.version if metadata.version is not None else None
size_label = metadata.size_label if metadata.size_label is not None else gguf.size_label(*model_params_count, expert_count=expert_count or 0)
output_type = {
GGMLFileType.AllF32: "F32",
GGMLFileType.MostlyF16: "F16",
GGMLFileType.MostlyQ8_0: "Q8_0",
}[file_type]
parameters = model_parameter_count_rounded_notation(model_params_count)
expert_count = ""
if params.n_experts is not None:
expert_count = f"{params.n_experts}x"
version = ""
if metadata is not None and metadata.version is not None:
version = f"-{metadata.version}"
name = "ggml-model"
if metadata is not None and metadata.name is not None:
name = metadata.name
elif params.path_model is not None:
name = params.path_model.name
return f"{name}{version}-{expert_count}{parameters}-{quantization}"
return gguf.naming_convention(name, basename, finetune, version, size_label, output_type)
def default_outfile(model_paths: list[Path], file_type: GGMLFileType, params: Params, model_params_count: int, metadata: Metadata) -> Path:
default_filename = default_convention_outfile(file_type, params, model_params_count, metadata)
def default_outfile(model_paths: list[Path], file_type: GGMLFileType, expert_count: int | None, model_params_count: tuple[int, int, int], metadata: gguf.Metadata) -> Path:
default_filename = default_convention_outfile(file_type, expert_count, model_params_count, metadata)
ret = model_paths[0].parent / f"{default_filename}.gguf"
if ret in model_paths:
logger.error(
@ -1297,8 +1305,9 @@ def main(args_in: list[str] | None = None) -> None:
parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
parser.add_argument("--metadata", type=Path, help="Specify the path for a metadata file")
parser.add_argument("--metadata", type=Path, help="Specify the path for an authorship metadata override file")
parser.add_argument("--get-outfile", action="store_true", help="get calculated default outfile name")
parser.add_argument("--model-name", type=str, default=None, help="name of the model")
args = parser.parse_args(args_in)
@ -1310,32 +1319,36 @@ def main(args_in: list[str] | None = None) -> None:
else:
logging.basicConfig(level=logging.INFO)
metadata = Metadata.load(args.metadata)
model_name = args.model_name
dir_model = args.model
metadata = gguf.Metadata.load(args.metadata, dir_model, model_name)
if args.get_outfile:
model_plus = load_some_model(args.model)
model_plus = load_some_model(dir_model)
params = Params.load(model_plus)
model = convert_model_names(model_plus.model, params, args.skip_unknown)
model_params_count = model_parameter_count(model_plus.model)
ftype = pick_output_type(model, args.outtype)
print(f"{default_convention_outfile(ftype, params, model_params_count, metadata)}") # noqa: NP100
model = convert_model_names(model_plus.model, params, args.skip_unknown)
model_params_count = per_model_weight_count_estimation(model_plus.model.items())
ftype = pick_output_type(model, args.outtype)
if (metadata is None or metadata.name is None) and params.path_model is not None:
metadata.name = params.path_model.name
print(f"{default_convention_outfile(ftype, params.n_experts, model_params_count, metadata)}") # noqa: NP100
return
if args.no_vocab and args.vocab_only:
raise ValueError("--vocab-only does not make sense with --no-vocab")
if args.dump_single:
model_plus = lazy_load_file(args.model)
model_plus = lazy_load_file(dir_model)
do_dump_model(model_plus)
return
if not args.vocab_only:
model_plus = load_some_model(args.model)
model_plus = load_some_model(dir_model)
else:
model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)
model_params_count = model_parameter_count(model_plus.model)
logger.info(f"model parameters count : {model_params_count} ({model_parameter_count_rounded_notation(model_params_count)})")
model_plus = ModelPlus(model = {}, paths = [dir_model / 'dummy'], format = 'none', vocab = None)
if args.dump:
do_dump_model(model_plus)
@ -1368,7 +1381,7 @@ def main(args_in: list[str] | None = None) -> None:
logger.info(f"params = {params}")
model_parent_path = model_plus.paths[0].parent
vocab_path = Path(args.vocab_dir or args.model or model_parent_path)
vocab_path = Path(args.vocab_dir or dir_model or model_parent_path)
vocab_factory = VocabFactory(vocab_path)
vocab_types = None if args.no_vocab else args.vocab_type.split(",")
vocab, special_vocab = vocab_factory.load_vocab(vocab_types, model_parent_path)
@ -1399,13 +1412,21 @@ def main(args_in: list[str] | None = None) -> None:
assert params is not None
if metadata.name is None and params.path_model is not None:
metadata.name = params.path_model.name
model_params_count = per_model_weight_count_estimation(model_plus.model.items())
logger.info(f"model parameters count : {model_params_count} ({gguf.model_weight_count_rounded_notation(model_params_count[0])})")
logger.info(f"Vocab info: {vocab}")
logger.info(f"Special vocab info: {special_vocab}")
model = model_plus.model
model = convert_model_names(model, params, args.skip_unknown)
ftype = pick_output_type(model, args.outtype)
model = convert_to_output_type(model, ftype)
outfile = args.outfile or default_outfile(model_plus.paths, ftype, params, model_params_count, metadata)
outfile = args.outfile or default_outfile(model_plus.paths, ftype, params.n_experts, model_params_count, metadata=metadata)
metadata.size_label = gguf.size_label(*model_params_count, expert_count=params.n_experts or 0)
params.ftype = ftype
logger.info(f"Writing {outfile}, format {ftype}")

View File

@ -78,5 +78,13 @@ python -m build
python -m twine upload dist/*
```
## Run Unit Tests
From root of this repository you can run this command to run all the unit tests
```bash
python -m unittest discover ./gguf-py -v
```
## TODO
- [ ] Include conversion scripts as command line entry points in this package.

View File

@ -5,3 +5,5 @@ from .gguf_writer import *
from .quants import *
from .tensor_mapping import *
from .vocab import *
from .utility import *
from .metadata import *

View File

@ -19,19 +19,60 @@ GGML_QUANT_VERSION = 2 # GGML_QNT_VERSION from ggml.h
class Keys:
class General:
TYPE = "general.type"
ARCHITECTURE = "general.architecture"
QUANTIZATION_VERSION = "general.quantization_version"
ALIGNMENT = "general.alignment"
NAME = "general.name"
AUTHOR = "general.author"
VERSION = "general.version"
URL = "general.url"
DESCRIPTION = "general.description"
LICENSE = "general.license"
SOURCE_URL = "general.source.url"
SOURCE_HF_REPO = "general.source.huggingface.repository"
FILE_TYPE = "general.file_type"
TYPE = "general.type"
ARCHITECTURE = "general.architecture"
QUANTIZATION_VERSION = "general.quantization_version"
ALIGNMENT = "general.alignment"
FILE_TYPE = "general.file_type"
# Authorship Metadata
NAME = "general.name"
AUTHOR = "general.author"
VERSION = "general.version"
ORGANIZATION = "general.organization"
FINETUNE = "general.finetune"
BASENAME = "general.basename"
DESCRIPTION = "general.description"
QUANTIZED_BY = "general.quantized_by"
SIZE_LABEL = "general.size_label"
# Licensing details
LICENSE = "general.license"
LICENSE_NAME = "general.license.name"
LICENSE_LINK = "general.license.link"
# Typically represents the converted GGUF repo (Unless native)
URL = "general.url" # Model Website/Paper
DOI = "general.doi"
UUID = "general.uuid"
REPO_URL = "general.repo_url" # Model Source Repository (git/svn/etc...)
# Model Source during conversion
SOURCE_URL = "general.source.url" # Model Website/Paper
SOURCE_DOI = "general.source.doi"
SOURCE_UUID = "general.source.uuid"
SOURCE_REPO_URL = "general.source.repo_url" # Model Source Repository (git/svn/etc...)
# Base Model Source. There can be more than one source if it's a merged
# model like with 'Mistral-7B-Merge-14-v0.1'. This will assist in
# tracing linage of models as it is finetuned or merged over time.
BASE_MODEL_COUNT = "general.base_model.count"
BASE_MODEL_NAME = "general.base_model.{id}.name"
BASE_MODEL_AUTHOR = "general.base_model.{id}.author"
BASE_MODEL_VERSION = "general.base_model.{id}.version"
BASE_MODEL_ORGANIZATION = "general.base_model.{id}.organization"
BASE_MODEL_URL = "general.base_model.{id}.url" # Model Website/Paper
BASE_MODEL_DOI = "general.base_model.{id}.doi"
BASE_MODEL_UUID = "general.base_model.{id}.uuid"
BASE_MODEL_REPO_URL = "general.base_model.{id}.repo_url" # Model Source Repository (git/svn/etc...)
# Array based KV stores
TAGS = "general.tags"
LANGUAGES = "general.languages"
DATASETS = "general.datasets"
class LLM:
VOCAB_SIZE = "{arch}.vocab_size"
@ -1233,7 +1274,6 @@ KEY_GENERAL_URL = Keys.General.URL
KEY_GENERAL_DESCRIPTION = Keys.General.DESCRIPTION
KEY_GENERAL_LICENSE = Keys.General.LICENSE
KEY_GENERAL_SOURCE_URL = Keys.General.SOURCE_URL
KEY_GENERAL_SOURCE_HF_REPO = Keys.General.SOURCE_HF_REPO
KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE
# LLM

View File

@ -7,6 +7,7 @@ import struct
import tempfile
from dataclasses import dataclass
from enum import Enum, auto
from math import prod
from pathlib import Path
from io import BufferedWriter
from typing import IO, Any, Sequence, Mapping
@ -106,6 +107,53 @@ class GGUFWriter:
self.add_architecture()
def get_total_parameter_count(self) -> tuple[int, int, int, int]:
total_params = 0
shared_params = 0
expert_params = 0
expert_sum = 0
n_expert_tensors = 0
last_lora_a: tuple[str, TensorInfo] | None = None
for tensors in self.tensors:
for name, info in tensors.items():
shape = info.shape
if name.endswith(".lora_a"):
last_lora_a = (name, info)
continue
elif name.endswith(".lora_b"):
if last_lora_a is None or last_lora_a[0] != name[:-1] + "a":
# Bail when the LoRA pair can't be found trivially
logger.warning("can't measure LoRA size correctly, tensor order is unusual")
return 0, 0, 0, 0
else:
shape = (*shape[:-1], last_lora_a[1].shape[-1])
size = prod(shape)
if "_exps." in name:
expert_params += (size // shape[-3])
expert_sum += shape[-3]
n_expert_tensors += 1
else:
shared_params += size
total_params += size
# Hopefully this should work even for variable-expert-count models
expert_count = (expert_sum // n_expert_tensors) if n_expert_tensors > 0 else 0
# Negate the total to signal it's likely not exact
if last_lora_a is not None:
total_params = -total_params
# NOTE: keep the output in the same order as accepted by 'size_label' in gguf-py/gguf/utility.py
return total_params, shared_params, expert_params, expert_count
def format_shard_names(self, path: Path) -> list[Path]:
if len(self.tensors) == 1:
return [path]
@ -115,6 +163,7 @@ class GGUFWriter:
if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path):
# allow calling this multiple times as long as the path is the same
return
if self.state is not WriterState.NO_FILE:
raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
@ -136,6 +185,8 @@ class GGUFWriter:
if self.dry_run:
logger.info("Dry run, not writing files")
for name in filenames:
print(name) # noqa: NP100
exit()
return filenames
@ -430,29 +481,12 @@ class GGUFWriter:
def add_architecture(self) -> None:
self.add_string(Keys.General.ARCHITECTURE, self.arch)
def add_author(self, author: str) -> None:
self.add_string(Keys.General.AUTHOR, author)
def add_quantization_version(self, quantization_version: int) -> None:
self.add_uint32(Keys.General.QUANTIZATION_VERSION, quantization_version)
def add_version(self, version: str) -> None:
self.add_string(Keys.General.VERSION, version)
def add_tensor_data_layout(self, layout: str) -> None:
self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
def add_url(self, url: str) -> None:
self.add_string(Keys.General.URL, url)
def add_description(self, description: str) -> None:
self.add_string(Keys.General.DESCRIPTION, description)
def add_licence(self, licence: str) -> None:
self.add_string(Keys.General.LICENSE, licence)
def add_source_url(self, url: str) -> None:
self.add_string(Keys.General.SOURCE_URL, url)
def add_source_hf_repo(self, repo: str) -> None:
self.add_string(Keys.General.SOURCE_HF_REPO, repo)
def add_custom_alignment(self, alignment: int) -> None:
self.data_alignment = alignment
self.add_uint32(Keys.General.ALIGNMENT, alignment)
def add_file_type(self, ftype: int) -> None:
self.add_uint32(Keys.General.FILE_TYPE, ftype)
@ -460,13 +494,101 @@ class GGUFWriter:
def add_name(self, name: str) -> None:
self.add_string(Keys.General.NAME, name)
def add_quantization_version(self, quantization_version: int) -> None:
self.add_uint32(
Keys.General.QUANTIZATION_VERSION, quantization_version)
def add_author(self, author: str) -> None:
self.add_string(Keys.General.AUTHOR, author)
def add_custom_alignment(self, alignment: int) -> None:
self.data_alignment = alignment
self.add_uint32(Keys.General.ALIGNMENT, alignment)
def add_version(self, version: str) -> None:
self.add_string(Keys.General.VERSION, version)
def add_organization(self, organization: str) -> None:
self.add_string(Keys.General.ORGANIZATION, organization)
def add_finetune(self, finetune: str) -> None:
self.add_string(Keys.General.FINETUNE, finetune)
def add_basename(self, basename: str) -> None:
self.add_string(Keys.General.BASENAME, basename)
def add_description(self, description: str) -> None:
self.add_string(Keys.General.DESCRIPTION, description)
def add_quantized_by(self, quantized: str) -> None:
self.add_string(Keys.General.QUANTIZED_BY, quantized)
def add_size_label(self, size_label: str) -> None:
self.add_string(Keys.General.SIZE_LABEL, size_label)
def add_license(self, license: str) -> None:
self.add_string(Keys.General.LICENSE, license)
def add_license_name(self, license: str) -> None:
self.add_string(Keys.General.LICENSE_NAME, license)
def add_license_link(self, license: str) -> None:
self.add_string(Keys.General.LICENSE_LINK, license)
def add_url(self, url: str) -> None:
self.add_string(Keys.General.URL, url)
def add_doi(self, doi: str) -> None:
self.add_string(Keys.General.DOI, doi)
def add_uuid(self, uuid: str) -> None:
self.add_string(Keys.General.UUID, uuid)
def add_repo_url(self, repo_url: str) -> None:
self.add_string(Keys.General.REPO_URL, repo_url)
def add_source_url(self, url: str) -> None:
self.add_string(Keys.General.SOURCE_URL, url)
def add_source_doi(self, doi: str) -> None:
self.add_string(Keys.General.SOURCE_DOI, doi)
def add_source_uuid(self, uuid: str) -> None:
self.add_string(Keys.General.SOURCE_UUID, uuid)
def add_source_repo_url(self, repo_url: str) -> None:
self.add_string(Keys.General.SOURCE_REPO_URL, repo_url)
def add_base_model_count(self, source_count: int) -> None:
self.add_uint32(Keys.General.BASE_MODEL_COUNT, source_count)
def add_base_model_name(self, source_id: int, name: str) -> None:
self.add_string(Keys.General.BASE_MODEL_NAME.format(id=source_id), name)
def add_base_model_author(self, source_id: int, author: str) -> None:
self.add_string(Keys.General.BASE_MODEL_AUTHOR.format(id=source_id), author)
def add_base_model_version(self, source_id: int, version: str) -> None:
self.add_string(Keys.General.BASE_MODEL_VERSION.format(id=source_id), version)
def add_base_model_organization(self, source_id: int, organization: str) -> None:
self.add_string(Keys.General.BASE_MODEL_ORGANIZATION.format(id=source_id), organization)
def add_base_model_url(self, source_id: int, url: str) -> None:
self.add_string(Keys.General.BASE_MODEL_URL.format(id=source_id), url)
def add_base_model_doi(self, source_id: int, doi: str) -> None:
self.add_string(Keys.General.BASE_MODEL_DOI.format(id=source_id), doi)
def add_base_model_uuid(self, source_id: int, uuid: str) -> None:
self.add_string(Keys.General.BASE_MODEL_UUID.format(id=source_id), uuid)
def add_base_model_repo_url(self, source_id: int, repo_url: str) -> None:
self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=source_id), repo_url)
def add_tags(self, tags: Sequence[str]) -> None:
self.add_array(Keys.General.TAGS, tags)
def add_languages(self, languages: Sequence[str]) -> None:
self.add_array(Keys.General.LANGUAGES, languages)
def add_datasets(self, datasets: Sequence[str]) -> None:
self.add_array(Keys.General.DATASETS, datasets)
def add_tensor_data_layout(self, layout: str) -> None:
self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
def add_vocab_size(self, size: int) -> None:
self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size)

485
gguf-py/gguf/metadata.py Normal file
View File

@ -0,0 +1,485 @@
from __future__ import annotations
import re
import json
import yaml
import logging
from pathlib import Path
from typing import Any, Literal, Optional
from dataclasses import dataclass
from .constants import Keys
import gguf
logger = logging.getLogger("metadata")
@dataclass
class Metadata:
# Authorship Metadata to be written to GGUF KV Store
name: Optional[str] = None
author: Optional[str] = None
version: Optional[str] = None
organization: Optional[str] = None
finetune: Optional[str] = None
basename: Optional[str] = None
description: Optional[str] = None
quantized_by: Optional[str] = None
size_label: Optional[str] = None
url: Optional[str] = None
doi: Optional[str] = None
uuid: Optional[str] = None
repo_url: Optional[str] = None
source_url: Optional[str] = None
source_doi: Optional[str] = None
source_uuid: Optional[str] = None
source_repo_url: Optional[str] = None
license: Optional[str] = None
license_name: Optional[str] = None
license_link: Optional[str] = None
base_models: Optional[list[dict]] = None
tags: Optional[list[str]] = None
languages: Optional[list[str]] = None
datasets: Optional[list[str]] = None
@staticmethod
def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata:
# This grabs as many contextual authorship metadata as possible from the model repository
# making any conversion as required to match the gguf kv store metadata format
# as well as giving users the ability to override any authorship metadata that may be incorrect
# Create a new Metadata instance
metadata = Metadata()
model_card = Metadata.load_model_card(model_path)
hf_params = Metadata.load_hf_parameters(model_path)
# heuristics
metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
# Metadata Override File Provided
# This is based on LLM_KV_NAMES mapping in llama.cpp
metadata_override = Metadata.load_metadata_override(metadata_override_path)
metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author)
metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version)
metadata.organization = metadata_override.get(Keys.General.ORGANIZATION, metadata.organization)
metadata.finetune = metadata_override.get(Keys.General.FINETUNE, metadata.finetune)
metadata.basename = metadata_override.get(Keys.General.BASENAME, metadata.basename)
metadata.description = metadata_override.get(Keys.General.DESCRIPTION, metadata.description)
metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY, metadata.quantized_by)
metadata.size_label = metadata_override.get(Keys.General.SIZE_LABEL, metadata.size_label)
metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME, metadata.license_name)
metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK, metadata.license_link)
metadata.url = metadata_override.get(Keys.General.URL, metadata.url)
metadata.doi = metadata_override.get(Keys.General.DOI, metadata.doi)
metadata.uuid = metadata_override.get(Keys.General.UUID, metadata.uuid)
metadata.repo_url = metadata_override.get(Keys.General.REPO_URL, metadata.repo_url)
metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL, metadata.source_url)
metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI, metadata.source_doi)
metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID, metadata.source_uuid)
metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL, metadata.source_repo_url)
# Base Models is received here as an array of models
metadata.base_models = metadata_override.get("general.base_models", metadata.base_models)
metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags)
metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages)
metadata.datasets = metadata_override.get(Keys.General.DATASETS, metadata.datasets)
# Direct Metadata Override (via direct cli argument)
if model_name is not None:
metadata.name = model_name
return metadata
@staticmethod
def load_metadata_override(metadata_override_path: Optional[Path] = None) -> dict[str, Any]:
if metadata_override_path is None or not metadata_override_path.is_file():
return {}
with open(metadata_override_path, "r", encoding="utf-8") as f:
return json.load(f)
@staticmethod
def load_model_card(model_path: Optional[Path] = None) -> dict[str, Any]:
if model_path is None or not model_path.is_dir():
return {}
model_card_path = model_path / "README.md"
if not model_card_path.is_file():
return {}
# The model card metadata is assumed to always be in YAML
# ref: https://github.com/huggingface/transformers/blob/a5c642fe7a1f25d3bdcd76991443ba6ff7ee34b2/src/transformers/modelcard.py#L468-L473
with open(model_card_path, "r", encoding="utf-8") as f:
if f.readline() == "---\n":
raw = f.read().partition("---\n")[0]
data = yaml.safe_load(raw)
if isinstance(data, dict):
return data
else:
logger.error(f"while reading YAML model card frontmatter, data is {type(data)} instead of dict")
return {}
else:
return {}
@staticmethod
def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, Any]:
if model_path is None or not model_path.is_dir():
return {}
config_path = model_path / "config.json"
if not config_path.is_file():
return {}
with open(config_path, "r", encoding="utf-8") as f:
return json.load(f)
@staticmethod
def id_to_title(string):
# Convert capitalization into title form unless acronym or version number
return ' '.join([w.title() if w.islower() and not re.match(r'^(v\d+(?:\.\d+)*|\d.*)$', w) else w for w in string.strip().replace('-', ' ').split()])
@staticmethod
def get_model_id_components(model_id: Optional[str] = None, total_params: int = 0) -> tuple[str | None, str | None, str | None, str | None, str | None, str | None]:
# Huggingface often store model id as '<org>/<model name>'
# so let's parse it and apply some heuristics if possible for model name components
if model_id is None:
# model ID missing
return None, None, None, None, None, None
if ' ' in model_id:
# model ID is actually a normal human sentence
# which means its most likely a normal model name only
# not part of the hugging face naming standard, but whatever
return model_id, None, None, None, None, None
if '/' in model_id:
# model ID (huggingface style)
org_component, model_full_name_component = model_id.split('/', 1)
else:
# model ID but missing org components
org_component, model_full_name_component = None, model_id
# Check if we erroneously matched against './' or '../' etc...
if org_component is not None and org_component[0] == '.':
org_component = None
name_parts: list[str] = model_full_name_component.split('-')
name_types: list[
set[Literal["basename", "size_label", "finetune", "version", "type"]]
] = [set() for _ in name_parts]
# Annotate the name
for i, part in enumerate(name_parts):
# Version
if re.fullmatch(r'(v|iter)?\d+([.]\d+)*', part, re.IGNORECASE):
name_types[i].add("version")
# Quant type (should not be there for base models, but still annotated)
elif re.fullmatch(r'i?q\d(_\w)*|b?fp?(16|32)', part, re.IGNORECASE):
name_types[i].add("type")
name_parts[i] = part.upper()
# Model size
elif i > 0 and re.fullmatch(r'(([A]|\d+[x])?\d+([._]\d+)?[KMBT][\d]?|small|mini|medium|large|x?xl)', part, re.IGNORECASE):
part = part.replace("_", ".")
# Handle weird bloom-7b1 notation
if part[-1].isdecimal():
part = part[:-2] + "." + part[-1] + part[-2]
# Normalize the size suffixes
if len(part) > 1 and part[-2].isdecimal():
if part[-1] in "kmbt":
part = part[:-1] + part[-1].upper()
if total_params != 0:
try:
label_params = float(part[:-1]) * pow(1000, " KMBT".find(part[-1]))
# Only use it as a size label if it's close or bigger than the model size
# Note that LoRA adapters don't necessarily include all layers,
# so this is why bigger label sizes are accepted.
# Do not use the size label when it's smaller than 1/8 of the model size
if (total_params < 0 and label_params < abs(total_params) // 8) or (
# Check both directions when the current model isn't a LoRA adapter
total_params > 0 and abs(label_params - total_params) > 7 * total_params // 8
):
# Likely a context length
name_types[i].add("finetune")
# Lowercase the size when it's a context length
part = part[:-1] + part[-1].lower()
except ValueError:
# Failed to convert the size label to float, use it anyway
pass
if len(name_types[i]) == 0:
name_types[i].add("size_label")
name_parts[i] = part
# Some easy to recognize finetune names
elif i > 0 and re.fullmatch(r'chat|instruct|vision|lora', part, re.IGNORECASE):
name_types[i].add("finetune")
if part.lower() == "lora":
name_parts[i] = "LoRA"
at_start = True
# Find the basename through the annotated name
for part, t in zip(name_parts, name_types):
if at_start and ((len(t) == 0 and part[0].isalpha()) or "version" in t):
t.add("basename")
else:
if at_start:
at_start = False
if len(t) == 0:
t.add("finetune")
# Remove the basename annotation from trailing version
for part, t in zip(reversed(name_parts), reversed(name_types)):
if "basename" in t:
if len(t) > 1:
t.remove("basename")
else:
break
basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None
size_label = "-".join(s for s, t in zip(name_parts, name_types) if "size_label" in t) or None
finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None
# TODO: should the basename version always be excluded?
# TODO: should multiple versions be joined together?
version = ([v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t] or [None])[-1]
if size_label is None and finetune is None and version is None:
# Too ambiguous, output nothing
basename = None
return model_full_name_component, org_component, basename, finetune, version, size_label
@staticmethod
def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None, total_params: int = 0) -> Metadata:
# Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
# Model Card Heuristics
########################
if model_card is not None:
if "model_name" in model_card and metadata.name is None:
# Not part of huggingface model card standard but notice some model creator using it
# such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
metadata.name = model_card.get("model_name")
if "model_creator" in model_card and metadata.author is None:
# Not part of huggingface model card standard but notice some model creator using it
# such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
metadata.author = model_card.get("model_creator")
if "model_type" in model_card and metadata.basename is None:
# Not part of huggingface model card standard but notice some model creator using it
# such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
metadata.basename = model_card.get("model_type")
if "base_model" in model_card:
# This represents the parent models that this is based on
# Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
# Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md
metadata_base_models = []
base_model_value = model_card.get("base_model", None)
if base_model_value is not None:
if isinstance(base_model_value, str):
metadata_base_models.append(base_model_value)
elif isinstance(base_model_value, list):
metadata_base_models.extend(base_model_value)
if metadata.base_models is None:
metadata.base_models = []
for model_id in metadata_base_models:
# NOTE: model size of base model is assumed to be similar to the size of the current model
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
base_model = {}
if model_full_name_component is not None:
base_model["name"] = Metadata.id_to_title(model_full_name_component)
if org_component is not None:
base_model["organization"] = Metadata.id_to_title(org_component)
if version is not None:
base_model["version"] = version
if org_component is not None and model_full_name_component is not None:
base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
metadata.base_models.append(base_model)
if "license" in model_card and metadata.license is None:
metadata.license = model_card.get("license")
if "license_name" in model_card and metadata.license_name is None:
metadata.license_name = model_card.get("license_name")
if "license_link" in model_card and metadata.license_link is None:
metadata.license_link = model_card.get("license_link")
tags_value = model_card.get("tags", None)
if tags_value is not None:
if metadata.tags is None:
metadata.tags = []
if isinstance(tags_value, str):
metadata.tags.append(tags_value)
elif isinstance(tags_value, list):
metadata.tags.extend(tags_value)
pipeline_tags_value = model_card.get("pipeline_tag", None)
if pipeline_tags_value is not None:
if metadata.tags is None:
metadata.tags = []
if isinstance(pipeline_tags_value, str):
metadata.tags.append(pipeline_tags_value)
elif isinstance(pipeline_tags_value, list):
metadata.tags.extend(pipeline_tags_value)
language_value = model_card.get("languages", model_card.get("language", None))
if language_value is not None:
if metadata.languages is None:
metadata.languages = []
if isinstance(language_value, str):
metadata.languages.append(language_value)
elif isinstance(language_value, list):
metadata.languages.extend(language_value)
dataset_value = model_card.get("datasets", model_card.get("dataset", None))
if dataset_value is not None:
if metadata.datasets is None:
metadata.datasets = []
if isinstance(dataset_value, str):
metadata.datasets.append(dataset_value)
elif isinstance(dataset_value, list):
metadata.datasets.extend(dataset_value)
# Hugging Face Parameter Heuristics
####################################
if hf_params is not None:
hf_name_or_path = hf_params.get("_name_or_path")
if hf_name_or_path is not None and hf_name_or_path.count('/') <= 1:
# Use _name_or_path only if its actually a model name and not some computer path
# e.g. 'meta-llama/Llama-2-7b-hf'
model_id = hf_name_or_path
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
if metadata.name is None and model_full_name_component is not None:
metadata.name = Metadata.id_to_title(model_full_name_component)
if metadata.organization is None and org_component is not None:
metadata.organization = Metadata.id_to_title(org_component)
if metadata.basename is None and basename is not None:
metadata.basename = basename
if metadata.finetune is None and finetune is not None:
metadata.finetune = finetune
if metadata.version is None and version is not None:
metadata.version = version
if metadata.size_label is None and size_label is not None:
metadata.size_label = size_label
# Directory Folder Name Fallback Heuristics
############################################
if model_path is not None:
model_id = model_path.name
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
if metadata.name is None and model_full_name_component is not None:
metadata.name = Metadata.id_to_title(model_full_name_component)
if metadata.organization is None and org_component is not None:
metadata.organization = Metadata.id_to_title(org_component)
if metadata.basename is None and basename is not None:
metadata.basename = basename
if metadata.finetune is None and finetune is not None:
metadata.finetune = finetune
if metadata.version is None and version is not None:
metadata.version = version
if metadata.size_label is None and size_label is not None:
metadata.size_label = size_label
return metadata
def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
assert self.name is not None
gguf_writer.add_name(self.name)
if self.author is not None:
gguf_writer.add_author(self.author)
if self.version is not None:
gguf_writer.add_version(self.version)
if self.organization is not None:
gguf_writer.add_organization(self.organization)
if self.finetune is not None:
gguf_writer.add_finetune(self.finetune)
if self.basename is not None:
gguf_writer.add_basename(self.basename)
if self.description is not None:
gguf_writer.add_description(self.description)
if self.quantized_by is not None:
gguf_writer.add_quantized_by(self.quantized_by)
if self.size_label is not None:
gguf_writer.add_size_label(self.size_label)
if self.license is not None:
gguf_writer.add_license(self.license)
if self.license_name is not None:
gguf_writer.add_license_name(self.license_name)
if self.license_link is not None:
gguf_writer.add_license_link(self.license_link)
if self.url is not None:
gguf_writer.add_url(self.url)
if self.doi is not None:
gguf_writer.add_doi(self.doi)
if self.uuid is not None:
gguf_writer.add_uuid(self.uuid)
if self.repo_url is not None:
gguf_writer.add_repo_url(self.repo_url)
if self.source_url is not None:
gguf_writer.add_source_url(self.source_url)
if self.source_doi is not None:
gguf_writer.add_source_doi(self.source_doi)
if self.source_uuid is not None:
gguf_writer.add_source_uuid(self.source_uuid)
if self.source_repo_url is not None:
gguf_writer.add_source_repo_url(self.source_repo_url)
if self.base_models is not None:
gguf_writer.add_base_model_count(len(self.base_models))
for key, base_model_entry in enumerate(self.base_models):
if "name" in base_model_entry:
gguf_writer.add_base_model_name(key, base_model_entry["name"])
if "author" in base_model_entry:
gguf_writer.add_base_model_author(key, base_model_entry["author"])
if "version" in base_model_entry:
gguf_writer.add_base_model_version(key, base_model_entry["version"])
if "organization" in base_model_entry:
gguf_writer.add_base_model_organization(key, base_model_entry["organization"])
if "url" in base_model_entry:
gguf_writer.add_base_model_url(key, base_model_entry["url"])
if "doi" in base_model_entry:
gguf_writer.add_base_model_doi(key, base_model_entry["doi"])
if "uuid" in base_model_entry:
gguf_writer.add_base_model_uuid(key, base_model_entry["uuid"])
if "repo_url" in base_model_entry:
gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"])
if self.tags is not None:
gguf_writer.add_tags(self.tags)
if self.languages is not None:
gguf_writer.add_languages(self.languages)
if self.datasets is not None:
gguf_writer.add_datasets(self.datasets)

69
gguf-py/gguf/utility.py Normal file
View File

@ -0,0 +1,69 @@
from __future__ import annotations
from typing import Literal
def fill_templated_filename(filename: str, output_type: str | None) -> str:
# Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
ftype_lowercase: str = output_type.lower() if output_type is not None else ""
ftype_uppercase: str = output_type.upper() if output_type is not None else ""
return filename.format(ftype_lowercase,
outtype=ftype_lowercase, ftype=ftype_lowercase,
OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase)
def model_weight_count_rounded_notation(model_params_count: int, min_digits: int = 2) -> str:
if model_params_count > 1e12 :
# Trillions Of Parameters
scaled_model_params = model_params_count * 1e-12
scale_suffix = "T"
elif model_params_count > 1e9 :
# Billions Of Parameters
scaled_model_params = model_params_count * 1e-9
scale_suffix = "B"
elif model_params_count > 1e6 :
# Millions Of Parameters
scaled_model_params = model_params_count * 1e-6
scale_suffix = "M"
else:
# Thousands Of Parameters
scaled_model_params = model_params_count * 1e-3
scale_suffix = "K"
fix = max(min_digits - len(str(round(scaled_model_params)).lstrip('0')), 0)
return f"{scaled_model_params:.{fix}f}{scale_suffix}"
def size_label(total_params: int, shared_params: int, expert_params: int, expert_count: int) -> str:
if expert_count > 0:
pretty_size = model_weight_count_rounded_notation(abs(shared_params) + abs(expert_params), min_digits=2)
size_class = f"{expert_count}x{pretty_size}"
else:
size_class = model_weight_count_rounded_notation(abs(total_params), min_digits=2)
return size_class
def naming_convention(model_name: str | None, base_name: str | None, finetune_string: str | None, version_string: str | None, size_label: str | None, output_type: str | None, model_type: Literal['vocab', 'LoRA'] | None = None) -> str:
# Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gguf-naming-convention
if base_name is not None:
name = base_name.strip().title().replace(' ', '-').replace('/', '-')
elif model_name is not None:
name = model_name.strip().title().replace(' ', '-').replace('/', '-')
else:
name = "ggml-model"
parameters = f"-{size_label}" if size_label is not None else ""
finetune = f"-{finetune_string.strip().title().replace(' ', '-')}" if finetune_string is not None else ""
version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else ""
encoding = f"-{output_type.strip().replace(' ', '-').upper()}" if output_type is not None else ""
kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else ""
return f"{name}{parameters}{finetune}{version}{encoding}{kind}"

View File

@ -22,6 +22,7 @@ classifiers = [
python = ">=3.8"
numpy = ">=1.17"
tqdm = ">=4.27"
pyyaml = ">=5.1"
[tool.poetry.dev-dependencies]
pytest = "^5.2"

View File

@ -0,0 +1 @@
from .test_metadata import *

View File

@ -1,7 +0,0 @@
import gguf # noqa: F401 # pyright: ignore[reportUnusedImport]
# TODO: add tests
def test_write_gguf() -> None:
pass

158
gguf-py/tests/test_metadata.py Executable file
View File

@ -0,0 +1,158 @@
#!/usr/bin/env python3
import unittest
from pathlib import Path
import os
import sys
# Necessary to load the local gguf package
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
sys.path.insert(0, str(Path(__file__).parent.parent))
import gguf
class TestMetadataMethod(unittest.TestCase):
def test_id_to_title(self):
self.assertEqual(gguf.Metadata.id_to_title("Mixtral-8x7B-Instruct-v0.1"), "Mixtral 8x7B Instruct v0.1")
self.assertEqual(gguf.Metadata.id_to_title("Meta-Llama-3-8B"), "Meta Llama 3 8B")
self.assertEqual(gguf.Metadata.id_to_title("hermes-2-pro-llama-3-8b-DPO"), "Hermes 2 Pro Llama 3 8b DPO")
def test_get_model_id_components(self):
# This is the basic standard form with organization marker
self.assertEqual(gguf.Metadata.get_model_id_components("Mistral/Mixtral-8x7B-Instruct-v0.1"),
('Mixtral-8x7B-Instruct-v0.1', "Mistral", 'Mixtral', 'Instruct', 'v0.1', '8x7B'))
# Similar to basic standard form but without organization marker
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct-v0.1"),
('Mixtral-8x7B-Instruct-v0.1', None, 'Mixtral', 'Instruct', 'v0.1', '8x7B'))
# Missing version
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct"),
('Mixtral-8x7B-Instruct', None, 'Mixtral', 'Instruct', None, '8x7B'))
# Missing finetune
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-v0.1"),
('Mixtral-8x7B-v0.1', None, 'Mixtral', None, 'v0.1', '8x7B'))
# Base name and size label only
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B"),
('Mixtral-8x7B', None, 'Mixtral', None, None, '8x7B'))
# Base name and version only
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-v0.1"),
('Mixtral-v0.1', None, 'Mixtral', None, 'v0.1', None))
## Edge Cases ##
# This is too ambiguous... best to err on caution and output nothing
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral"),
('Mixtral', None, None, None, None, None))
# Basename has numbers mixed in and also size label provided. Must avoid capturing number in basename
self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"),
('Meta-Llama-3-8B', "NousResearch", 'Meta-Llama-3', None, None, '8B'))
# Can't detect all non standard form in a heuristically safe way... best to err in caution and output nothing...
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen1.5-MoE-A2.7B-Chat"),
('Qwen1.5-MoE-A2.7B-Chat', None, 'Qwen1.5-MoE', 'Chat', None, 'A2.7B'))
# Capture 'sub size labels' e.g. A14B in '57B-A14B' usually refers to activated params/weight count
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-57B-A14B-Instruct"),
('Qwen2-57B-A14B-Instruct', None, 'Qwen2', 'Instruct', None, '57B-A14B'))
# Check that it can handle a real model id with no version code
# Note that 4k in this string is non standard and microsoft were referring to context length rather than weight count
self.assertEqual(gguf.Metadata.get_model_id_components("microsoft/Phi-3-mini-4k-instruct", 4 * 10**9),
('Phi-3-mini-4k-instruct', 'microsoft', 'Phi-3', '4k-instruct', None, 'mini'))
# There is some legitimate models with only thousands of parameters
self.assertEqual(gguf.Metadata.get_model_id_components("delphi-suite/stories-llama2-50k", 50 * 10**3),
('stories-llama2-50k', 'delphi-suite', 'stories-llama2', None, None, '50K'))
# None standard and not easy to disambiguate
self.assertEqual(gguf.Metadata.get_model_id_components("DeepSeek-Coder-V2-Lite-Instruct"),
('DeepSeek-Coder-V2-Lite-Instruct', None, 'DeepSeek-Coder-V2-Lite', 'Instruct', None, None))
# This is a real model_id where they append 2DPO to refer to Direct Preference Optimization
self.assertEqual(gguf.Metadata.get_model_id_components("crestf411/daybreak-kunoichi-2dpo-7b"),
('daybreak-kunoichi-2dpo-7b', 'crestf411', 'daybreak-kunoichi', '2dpo', None, '7B'))
# This is a real model id where the weight size has a decimal point
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-0.5B-Instruct"),
('Qwen2-0.5B-Instruct', None, 'Qwen2', 'Instruct', None, '0.5B'))
# Uses an underscore in the size label
self.assertEqual(gguf.Metadata.get_model_id_components("smallcloudai/Refact-1_6B-fim"),
('Refact-1_6B-fim', 'smallcloudai', 'Refact', 'fim', None, '1.6B'))
# Uses Iter3 for the version
self.assertEqual(gguf.Metadata.get_model_id_components("UCLA-AGI/Gemma-2-9B-It-SPPO-Iter3"),
('Gemma-2-9B-It-SPPO-Iter3', 'UCLA-AGI', 'Gemma-2', 'It-SPPO', 'Iter3', '9B'))
# Has two potential versions in the basename
self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Hermes-2-Theta-Llama-3-8B"),
('Hermes-2-Theta-Llama-3-8B', 'NousResearch', 'Hermes-2-Theta-Llama-3', None, None, '8B'))
# Potential version in the basename
self.assertEqual(gguf.Metadata.get_model_id_components("SeaLLMs/SeaLLMs-v3-7B-Chat"),
('SeaLLMs-v3-7B-Chat', 'SeaLLMs', 'SeaLLMs-v3', 'Chat', None, '7B'))
# Underscore in the basename, and 1m for the context size
self.assertEqual(gguf.Metadata.get_model_id_components("internlm/internlm2_5-7b-chat-1m", 7 * 10**9),
('internlm2_5-7b-chat-1m', 'internlm', 'internlm2_5', 'chat-1m', None, '7B'))
# Version before the finetune name
self.assertEqual(gguf.Metadata.get_model_id_components("pszemraj/jamba-900M-v0.13-KIx2"),
('jamba-900M-v0.13-KIx2', 'pszemraj', 'jamba', 'KIx2', 'v0.13', '900M'))
# TODO: hf suffix which could be ignored but isn't
self.assertEqual(gguf.Metadata.get_model_id_components("state-spaces/mamba-2.8b-hf"),
('mamba-2.8b-hf', 'state-spaces', 'mamba', 'hf', None, '2.8B'))
# Two sizes, don't merge them, the other is the number of tokens on which it was trained
self.assertEqual(gguf.Metadata.get_model_id_components("abacaj/llama-161M-100B", 161 * 10**6),
('llama-161M-100B', 'abacaj', 'llama', '100b', None, '161M'))
# It's a trap, there is no size label
self.assertEqual(gguf.Metadata.get_model_id_components("SparseLLM/relu-100B", 1340 * 10**6),
('relu-100B', 'SparseLLM', 'relu', '100b', None, None))
# Weird size notation
self.assertEqual(gguf.Metadata.get_model_id_components("bigscience/bloom-7b1-petals"),
('bloom-7b1-petals', 'bigscience', 'bloom', 'petals', None, '7.1B'))
def test_apply_metadata_heuristic_from_model_card(self):
model_card = {
'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'],
'model-index': [{'name': 'Mixtral-8x7B-Instruct-v0.1', 'results': []}],
'language': ['en'],
'datasets': ['teknium/OpenHermes-2.5'],
'widget': [{'example_title': 'Hermes 2 Pro', 'messages': [{'role': 'system', 'content': 'You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.'}, {'role': 'user', 'content': 'Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.'}]}],
'base_model': ["EmbeddedLLM/Mistral-7B-Merge-14-v0", "janai-hq/trinity-v1"]
}
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
expect = gguf.Metadata()
expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'version': 'v0', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'version': 'v1', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}]
expect.tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl']
expect.languages=['en']
expect.datasets=['teknium/OpenHermes-2.5']
self.assertEqual(got, expect)
def test_apply_metadata_heuristic_from_hf_parameters(self):
hf_params = {"_name_or_path": "./hermes-2-pro-llama-3-8b-DPO"}
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=hf_params, model_path=None)
expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', finetune='DPO', basename='hermes-2-pro-llama-3', size_label='8B')
self.assertEqual(got, expect)
def test_apply_metadata_heuristic_from_model_dir(self):
model_dir_path = Path("./hermes-2-pro-llama-3-8b-DPO")
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=None, model_path=model_dir_path)
expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', finetune='DPO', basename='hermes-2-pro-llama-3', size_label='8B')
self.assertEqual(got, expect)
if __name__ == "__main__":
unittest.main()