diff --git a/convert-llama-ggmlv3-to-gguf.py b/convert-llama-ggmlv3-to-gguf.py new file mode 100644 index 000000000..30038072f --- /dev/null +++ b/convert-llama-ggmlv3-to-gguf.py @@ -0,0 +1,334 @@ +import sys, struct, math, argparse +from pathlib import Path + +import numpy as np + +import gguf + +# Note: Does not support GGML_QKK_64 +QK_K = 256 +# Items here are (block size, type size) +GGML_QUANT_SIZES = { + gguf.GGMLQuantizationType.F32 : (1, 4), + gguf.GGMLQuantizationType.F16 : (1, 2), + gguf.GGMLQuantizationType.Q4_0 : (32, 2 + 16), + gguf.GGMLQuantizationType.Q4_1 : (32, 2 + 2 + 16), + gguf.GGMLQuantizationType.Q5_0 : (32, 2 + 4 + 16), + gguf.GGMLQuantizationType.Q5_1 : (32, 2 + 2 + 4 + 16), + gguf.GGMLQuantizationType.Q8_0 : (32, 2 + 32), + gguf.GGMLQuantizationType.Q8_1 : (32, 4 + 4 + 32), + gguf.GGMLQuantizationType.Q2_K : (256, 2 + 2 + QK_K // 16 + QK_K // 4), + gguf.GGMLQuantizationType.Q3_K : (256, 2 + QK_K // 4 + QK_K // 8 + 12), + gguf.GGMLQuantizationType.Q4_K : (256, 2 + 2 + QK_K // 2 + 12), + gguf.GGMLQuantizationType.Q5_K : (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12), + gguf.GGMLQuantizationType.Q6_K : (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16), + gguf.GGMLQuantizationType.Q8_K : (256, 4 + QK_K + QK_K // 8), +} + +class Hyperparameters: + def __init__(self): + self.n_vocab = self.n_embd = self.n_mult = self.n_head = self.n_layer = self.n_rot = self.ftype = 0 + self.n_ff = 0 + + def set_n_ff(self, model): + ff_tensor_idx = model.tensor_map.get(b'layers.0.feed_forward.w1.weight') + assert ff_tensor_idx is not None, 'Missing layer 0 FF tensor' + ff_tensor = model.tensors[ff_tensor_idx] + self.n_ff = ff_tensor.dims[1] + + def load(self, data, offset): + ( + self.n_vocab, + self.n_embd, + self.n_mult, + self.n_head, + self.n_layer, + self.n_rot, + self.ftype, + ) = struct.unpack('<7I', data[offset:offset + (4 * 7)]) + return 4 * 7 + + def __str__(self): + return f'' + +class Vocab: + def __init__(self): + self.items = [] + + def load(self, data, offset, n_vocab): + orig_offset = offset + for _ in range(n_vocab): + itemlen = struct.unpack('= 0 and n_dims <= 4, f'Invalid tensor dimensions {n_dims}' + assert name_len < 4096, 'Absurd tensor name length' + quant = GGML_QUANT_SIZES.get(dtype) + assert quant is not None, 'Unknown tensor type' + (blksize, tysize) = quant + offset += 12 + self.dtype= dtype + self.dims = struct.unpack(f'<{n_dims}I', data[offset:offset + (4 * n_dims)]) + offset += 4 * n_dims + self.name = bytes(data[offset:offset + name_len]) + offset += name_len + pad = ((offset + 31) & ~31) - offset + offset += pad + n_elems = np.prod(self.dims) + n_bytes = (n_elems * tysize) // blksize + self.start_offset = offset + self.len_bytes = n_bytes + offset += n_bytes + # print(n_dims, name_len, dtype, self.dims, self.name, pad) + return offset - orig_offset + +class GGMLV3Model: + def __init__(self): + self.hyperparameters = None + self.vocab = None + self.tensor_map = {} + self.tensors = [] + + def validate_header(self, data, offset): + if bytes(data[offset:offset + 4]) != b'tjgg' or struct.unpack(' 0: + gguf_writer.add_token_types(toktypes) + return + print(f'* Adding {hp.n_vocab} vocab item(s)') + for (tokid, (vbytes, vscore)) in enumerate(self.model.vocab.items): + tt = 1 # Normal + if len(vbytes) == 0: + tt = 3 # Control + elif tokid >= 3 and tokid <= 258 and len(vbytes) == 1: + hv = hex(vbytes[0])[2:].upper() + vbytes = bytes(f'<0x{hv}>', encoding = 'UTF-8') + tt = 6 # Byte + else: + vbytes = vbytes.replace(b' ', b'\xe2\x96\x81') + toktypes.append(tt) + tokens.append(vbytes) + scores.append(vscore) + gguf_writer.add_token_list(tokens) + gguf_writer.add_token_scores(scores) + gguf_writer.add_token_types(toktypes) + + def add_tensors(self, gguf_writer): + nm = self.name_map + data = self.data + print(f'* Adding {len(self.model.tensors)} tensor(s)') + for tensor in self.model.tensors: + name = str(tensor.name, 'UTF-8') + if name.endswith('.weight'): + name = name[:-7] + suffix = '.weight' + elif name.endswith('.bias'): + name = name[:-5] + suffix = '.bias' + mapped_name = nm.get(name) + assert mapped_name is not None, f'Bad name {name}' + mapped_name += suffix + tempdims = list(tensor.dims[:]) + if len(tempdims) > 1: + temp = tempdims[1] + tempdims[1] = tempdims[0] + tempdims[0] = temp + # print(f'+ {tensor.name} | {mapped_name} {tensor.dims} :: {tempdims}') + gguf_writer.add_tensor(mapped_name, data[tensor.start_offset:tensor.start_offset + tensor.len_bytes], raw_shape = tempdims, raw_dtype = tensor.dtype) + +def handle_metadata(cfg, hp): + import convert + assert cfg.model_metadata_dir.is_dir(), 'Metadata dir is not a directory' + hf_config_path = cfg.model_metadata_dir / "config.json" + orig_config_path = cfg.model_metadata_dir / "params.json" + # We pass a fake model here. "original" mode will check the shapes of some + # tensors if information is missing in the .json file: other than that, the + # model data isn't used so this should be safe (at least for now). + fakemodel = { + 'tok_embeddings.weight': convert.LazyTensor.__new__(convert.LazyTensor), + 'layers.0.feed_forward.w1.weight': convert.LazyTensor.__new__(convert.LazyTensor), + } + fakemodel['tok_embeddings.weight'].shape = [hp.n_vocab] + fakemodel['layers.0.feed_forward.w1.weight'].shape = [hp.n_ff] + if hf_config_path.exists(): + params = convert.Params.loadHFTransformerJson(fakemodel, hf_config_path) + elif orig_config_path.exists(): + params = convert.Params.loadOriginalParamsJson(fakemodel, orig_config_path) + else: + raise ValueError('Unable to load metadata') + vocab = convert.load_vocab(cfg.vocab_dir if cfg.vocab_dir is not None else cfg.model_metadata_dir, cfg.vocabtype) + convert.check_vocab_size(params, vocab) + return (params, vocab) + +def handle_args(): + parser = argparse.ArgumentParser(description = 'Convert GGMLv3 models to GGUF') + parser.add_argument('--input', '-i', type = Path, help = 'Input GGMLv3 filename') + parser.add_argument('--output', '-o', type = Path, help ='Output GGUF filename') + parser.add_argument('--name', help = 'Set model name') + parser.add_argument('--desc', help = 'Set model description') + parser.add_argument('--gqa', type = int, default = 1, help = 'grouped-query attention factor (use 8 for LLaMA2 70B)') + parser.add_argument('--eps', default = '5.0e-06', help = 'RMS norm eps: Use 1e-6 for LLaMA1 and OpenLLaMA, use 1e-5 for LLaMA2') + parser.add_argument('--context-length', '-c', type=int, default = 2048, help = 'Default max context length: LLaMA1 is typically 2048, LLaMA2 is typically 4096') + parser.add_argument('--model-metadata-dir', '-m', type = Path, help ='Load HuggingFace/.pth vocab and metadata from the specified directory') + parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file - only meaningful with --model-metadata-dir") + parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format - only meaningful with --model-metadata-dir and/or --vocab-dir (default: spm)", default="spm") + return parser.parse_args() + +def main(): + cfg = handle_args() + print(f'* Using config: {cfg}') + print('\n=== WARNING === Be aware that this conversion script is best-effort. Use a native GGUF model if possible. === WARNING ===\n') + data = np.memmap(cfg.input, mode = 'r') + model = GGMLV3Model() + print('* Scanning GGML input file') + offset = model.load(data, 0) + print(f'* GGML model hyperparameters: {model.hyperparameters}') + vocab_override = None + params_override = None + if cfg.model_metadata_dir is not None: + (params_override, vocab_override) = handle_metadata(cfg, model.hyperparameters) + print('!! Note: When overriding params the --gqa, --eps and --context-length options are ignored.') + print(f'* Overriding params: {params_override}') + print(f'* Overriding vocab: {vocab_override}') + else: + print('\n=== WARNING === Special tokens may not be converted correctly. Use --model-metadata-dir if possible === WARNING ===\n') + converter = GGMLToGGUF(model, data, cfg, params_override = params_override, vocab_override = vocab_override) + converter.save() + print(f'* Successful completion. Output saved to: {cfg.output}') + +main() diff --git a/gguf.py b/gguf.py index d461b8d40..60ee52f09 100644 --- a/gguf.py +++ b/gguf.py @@ -5,7 +5,7 @@ import tempfile import numpy as np from enum import IntEnum, auto -from typing import Any, IO, List +from typing import Any, IO, List, Optional # # constants @@ -325,8 +325,20 @@ def get_tensor_name_map(arch: MODEL_ARCH, n_blocks: int) -> dict: class GGMLQuantizationType(IntEnum): - F32 = 0 - F16 = 1 + F32 = 0 + F16 = 1 + Q4_0 = 2 + Q4_1 = 3 + Q5_0 = 6 + Q5_1 = 7 + Q8_0 = 8 + Q8_1 = 9 + Q2_K = 10 + Q3_K = 11 + Q4_K = 12 + Q5_K = 13 + Q6_K = 14 + Q8_K = 15 class GGUFValueType(IntEnum): @@ -359,7 +371,7 @@ class GGUFValueType(IntEnum): class GGUFWriter: - def __init__(self, path: str, arch: str): + def __init__(self, path: str, arch: str, use_temp_file = True): self.fout = open(path, "wb") self.arch = arch self.offset_tensor = 0 @@ -369,6 +381,8 @@ class GGUFWriter: self.ti_data = b"" self.ti_data_count = 0 self.add_architecture() + self.use_temp_file = use_temp_file + self.tensors = [] def write_header_to_file(self): self.fout.write(struct.pack(" int: return ((x + n - 1) // n) * n - def add_tensor_info(self, name: str, tensor_shape: np.ndarray, tensor_dtype: np.dtype, tensor_nbytes: int): - assert tensor_dtype in (np.float32, np.float16), "Only F32 and F16 tensors are supported for now" + def add_tensor_info(self, name: str, tensor_shape: np.ndarray, tensor_dtype: np.dtype, tensor_nbytes: int, raw_dtype: Optional[GGMLQuantizationType] = None): + assert raw_dtype is not None or tensor_dtype in (np.float32, np.float16), "Only F32 and F16 tensors are supported for now" encoded_name = name.encode("utf8") self.ti_data += struct.pack("