From c8ee87f1411dfc48aa9c0d0d094b38b15b2a7827 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 16 Aug 2023 19:55:49 +0300 Subject: [PATCH] gguf.py : merge all files in gguf.py --- constants.py | 50 ------- convert-gptneox-h5-to-gguf.py | 10 +- convert-llama-7b-pth-to-gguf.py | 5 +- convert-llama-h5-to-gguf.py | 4 +- gguf.py | 256 +++++++++++++++++++++++++------- gguf_namemap.py | 95 ------------ 6 files changed, 213 insertions(+), 207 deletions(-) delete mode 100644 constants.py delete mode 100644 gguf_namemap.py diff --git a/constants.py b/constants.py deleted file mode 100644 index 2f96fb960..000000000 --- a/constants.py +++ /dev/null @@ -1,50 +0,0 @@ -GGUF_MAGIC = 0x47475546 -GGUF_VERSION = 1 -GGUF_DEFAULT_ALIGNMENT = 32 - -# general -KEY_GENERAL_ARCHITECTURE = "general.architecture" -KEY_GENERAL_QUANTIZATION_VERSION = "general.quantization_version" -KEY_GENERAL_ALIGNMENT = "general.alignment" -KEY_GENERAL_NAME = "general.name" -KEY_GENERAL_AUTHOR = "general.author" -KEY_GENERAL_URL = "general.url" -KEY_GENERAL_DESCRIPTION = "general.description" -KEY_GENERAL_FILE_TYPE = "general.file_type" -KEY_GENERAL_LICENSE = "general.license" -KEY_GENERAL_SOURCE_URL = "general.source.url" -KEY_GENERAL_SOURCE_HF_REPO = "general.source.hugginface.repository" - -# LLM -KEY_LLM_CONTEXT_LENGTH = "{llm}.context_length" -KEY_LLM_EMBEDDING_LENGTH = "{llm}.embedding_length" -KEY_LLM_BLOCK_COUNT = "{llm}.block_count" -KEY_LLM_FEED_FORWARD_LENGTH = "{llm}.feed_forward_length" -KEY_LLM_USE_PARALLEL_RESIDUAL = "{llm}.use_parallel_residual" -KEY_LLM_TENSOR_DATA_LAYOUT = "{llm}.tensor_data_layout" - -# attention -KEY_ATTENTION_HEAD_COUNT = "{llm}.attention.head_count" -KEY_ATTENTION_HEAD_COUNT_KV = "{llm}.attention.head_count_kv" -KEY_ATTENTION_MAX_ALIBI_BIAS = "{llm}.attention.max_alibi_bias" -KEY_ATTENTION_CLAMP_KQV = "{llm}.attention.clamp_kqv" -KEY_ATTENTION_LAYERNORM_EPS = "{llm}.attention.layer_norm_epsilon" -KEY_ATTENTION_LAYERNORM_RMS_EPS = "{llm}.attention.layer_norm_rms_epsilon" - -# RoPE -KEY_ROPE_DIMENSION_COUNT = "{llm}.rope.dimension_count" -KEY_ROPE_SCALE = "{llm}.rope.scale" - -# tokenization -KEY_TOKENIZER_MODEL = "tokenizer.ggml.model" -KEY_TOKENIZER_LIST = "tokenizer.ggml.tokens" -KEY_TOKENIZER_TOKEN_TYPE = "tokenizer.ggml.token_type" -KEY_TOKENIZER_SCORES = "tokenizer.ggml.scores" -KEY_TOKENIZER_MERGES = "tokenizer.ggml.merges" -KEY_TOKENIZER_BOS_ID = "tokenizer.ggml.bos_token_id" -KEY_TOKENIZER_EOS_ID = "tokenizer.ggml.eos_token_id" -KEY_TOKENIZER_UNK_ID = "tokenizer.ggml.unknown_token_id" -KEY_TOKENIZER_SEP_ID = "tokenizer.ggml.seperator_token_id" -KEY_TOKENIZER_PAD_ID = "tokenizer.ggml.padding_token_id" -KEY_TOKENIZER_HF_JSON = "tokenizer.huggingface.json" -KEY_TOKENIZER_RWKV = "tokenizer.rwkv.world" diff --git a/convert-gptneox-h5-to-gguf.py b/convert-gptneox-h5-to-gguf.py index a3d78c9ef..79876eee3 100644 --- a/convert-gptneox-h5-to-gguf.py +++ b/convert-gptneox-h5-to-gguf.py @@ -1,15 +1,15 @@ # HF gptneox--> gguf conversion import gguf -import gguf_namemap as tmap import os import sys import struct import json import numpy as np +import torch + from typing import Any, List from pathlib import Path -import torch from transformers import AutoTokenizer # ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py @@ -188,7 +188,7 @@ if Path(dir_model + "/tokenizer.json").is_file(): # TENSORS -tensor_map = tmap.get_tensor_namemap(block_count) +tensor_map = gguf.get_tensor_name_map(block_count) # tensor info print("gguf: get tensor metadata") @@ -227,7 +227,7 @@ for part_name in part_names: sys.exit() n_dims = len(data.shape) - data_dtype = data.dtype + data_dtype = data.dtype # if f32 desired, convert any float16 to float32 if ftype == 0 and data.dtype == np.float16: @@ -292,7 +292,7 @@ for part_name in part_names: sys.exit() n_dims = len(data.shape) - data_dtype = data.dtype + data_dtype = data.dtype # if f32 desired, convert any float16 to float32 if ftype == 0 and data.dtype == np.float16: diff --git a/convert-llama-7b-pth-to-gguf.py b/convert-llama-7b-pth-to-gguf.py index 27841939d..c4e425ee3 100644 --- a/convert-llama-7b-pth-to-gguf.py +++ b/convert-llama-7b-pth-to-gguf.py @@ -3,18 +3,17 @@ # HF files required in the model dir: config.json tokenizer_config.json tokenizer.json tokenizer.model import gguf -import gguf_namemap as tmap import os import sys import struct import json import numpy as np import torch + from typing import Any, List from pathlib import Path from sentencepiece import SentencePieceProcessor - #NDArray = np.ndarray[Any, Any] # compatible with python < 3.9 NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]' @@ -189,7 +188,7 @@ if Path(dir_model + "/tokenizer.json").is_file(): # TENSORS -tensor_map = tmap.get_tensor_namemap(block_count) +tensor_map = gguf.get_tensor_name_map(block_count) # tensor info print("gguf: get tensor metadata") diff --git a/convert-llama-h5-to-gguf.py b/convert-llama-h5-to-gguf.py index 35f85bb97..a2b3f9a30 100644 --- a/convert-llama-h5-to-gguf.py +++ b/convert-llama-h5-to-gguf.py @@ -1,8 +1,6 @@ # HF llama --> gguf conversion import gguf -import gguf_namemap as tmap - import os import sys import struct @@ -201,7 +199,7 @@ if Path(dir_model + "/tokenizer.json").is_file(): # TENSORS -tensor_map = tmap.get_tensor_namemap(block_count) +tensor_map = gguf.get_tensor_name_map(block_count) # tensor info print("gguf: get tensor metadata") diff --git a/gguf.py b/gguf.py index 1b15554f3..e7f6f0ac8 100644 --- a/gguf.py +++ b/gguf.py @@ -4,14 +4,169 @@ 3. After development is done, Convert it to a proper pip-installable Python package, and possibly move it to its own repo under ggml-org. """ +import sys import struct -import constants +import numpy as np + from enum import IntEnum from typing import Any, IO, List -import numpy as np -import sys +# +# constants +# +GGUF_MAGIC = 0x47475546 +GGUF_VERSION = 1 +GGUF_DEFAULT_ALIGNMENT = 32 + +# general +KEY_GENERAL_ARCHITECTURE = "general.architecture" +KEY_GENERAL_QUANTIZATION_VERSION = "general.quantization_version" +KEY_GENERAL_ALIGNMENT = "general.alignment" +KEY_GENERAL_NAME = "general.name" +KEY_GENERAL_AUTHOR = "general.author" +KEY_GENERAL_URL = "general.url" +KEY_GENERAL_DESCRIPTION = "general.description" +KEY_GENERAL_FILE_TYPE = "general.file_type" +KEY_GENERAL_LICENSE = "general.license" +KEY_GENERAL_SOURCE_URL = "general.source.url" +KEY_GENERAL_SOURCE_HF_REPO = "general.source.hugginface.repository" + +# LLM +KEY_LLM_CONTEXT_LENGTH = "{llm}.context_length" +KEY_LLM_EMBEDDING_LENGTH = "{llm}.embedding_length" +KEY_LLM_BLOCK_COUNT = "{llm}.block_count" +KEY_LLM_FEED_FORWARD_LENGTH = "{llm}.feed_forward_length" +KEY_LLM_USE_PARALLEL_RESIDUAL = "{llm}.use_parallel_residual" +KEY_LLM_TENSOR_DATA_LAYOUT = "{llm}.tensor_data_layout" + +# attention +KEY_ATTENTION_HEAD_COUNT = "{llm}.attention.head_count" +KEY_ATTENTION_HEAD_COUNT_KV = "{llm}.attention.head_count_kv" +KEY_ATTENTION_MAX_ALIBI_BIAS = "{llm}.attention.max_alibi_bias" +KEY_ATTENTION_CLAMP_KQV = "{llm}.attention.clamp_kqv" +KEY_ATTENTION_LAYERNORM_EPS = "{llm}.attention.layer_norm_epsilon" +KEY_ATTENTION_LAYERNORM_RMS_EPS = "{llm}.attention.layer_norm_rms_epsilon" + +# RoPE +KEY_ROPE_DIMENSION_COUNT = "{llm}.rope.dimension_count" +KEY_ROPE_SCALE = "{llm}.rope.scale" + +# tokenization +KEY_TOKENIZER_MODEL = "tokenizer.ggml.model" +KEY_TOKENIZER_LIST = "tokenizer.ggml.tokens" +KEY_TOKENIZER_TOKEN_TYPE = "tokenizer.ggml.token_type" +KEY_TOKENIZER_SCORES = "tokenizer.ggml.scores" +KEY_TOKENIZER_MERGES = "tokenizer.ggml.merges" +KEY_TOKENIZER_BOS_ID = "tokenizer.ggml.bos_token_id" +KEY_TOKENIZER_EOS_ID = "tokenizer.ggml.eos_token_id" +KEY_TOKENIZER_UNK_ID = "tokenizer.ggml.unknown_token_id" +KEY_TOKENIZER_SEP_ID = "tokenizer.ggml.seperator_token_id" +KEY_TOKENIZER_PAD_ID = "tokenizer.ggml.padding_token_id" +KEY_TOKENIZER_HF_JSON = "tokenizer.huggingface.json" +KEY_TOKENIZER_RWKV = "tokenizer.rwkv.world" + +# +# recommended mapping of model tensor names for storage in gguf +# + +def get_tensor_name_map(n_blocks : int): + tensor_map = {} + # Token embeddings + mapped_to = "token_embd" + tensor_map["gpt_neox.embed_in"] = mapped_to # gptneox + tensor_map["transformer.wte"] = mapped_to # gpt2 mpt + tensor_map["transformer.word_embeddings"] = mapped_to # falcon + tensor_map["model.embed_tokens"] = mapped_to # llama-hf + tensor_map["tok_embeddings"] = mapped_to # llama-pth + # Position embeddings + mapped_to = "pos_embd" + tensor_map["transformer.wpe"] = mapped_to # gpt2 + # Output norm + mapped_to = "output_norm" + tensor_map["gpt_neox.final_layer_norm"] = mapped_to # gptneox + tensor_map["transformer.ln_f"] = mapped_to # gpt2 falcon + tensor_map["transformer.norm_f"] = mapped_to # mpt + tensor_map["model.norm"] = mapped_to # llama-hf + tensor_map["norm"] = mapped_to # llama-pth + # Output + mapped_to = "output" + tensor_map["embed_out"] = mapped_to # gptneox + tensor_map["lm_head"] = mapped_to # gpt2 mpt falcon llama-hf + tensor_map["output"] = mapped_to # llama-pth + # Attention and fee-forward layer blocks + for i in range(0,n_blocks): + # Attention norm + mapped_to = "blk."+str(i)+".attn_norm" + tensor_map["gpt_neox.layers."+str(i)+".input_layernorm"] = mapped_to # gptneox + tensor_map["transformer.h."+str(i)+".ln_1"] = mapped_to # gpt2 + tensor_map["transformer.blocks."+str(i)+".norm_1"] = mapped_to # mpt + tensor_map["transformer.h."+str(i)+".input_layernorm"] = mapped_to # falcon7b + tensor_map["transformer.h."+str(i)+".ln_attn"] = mapped_to # falcon40b + tensor_map["model.layers."+str(i)+".input_layernorm"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".attention_norm"] = mapped_to # llama-pth + # Attention norm 2 + mapped_to = "blk."+str(i)+".attn_norm_2" + tensor_map["transformer.h."+str(i)+".ln_mlp"] = mapped_to # falcon40b + # Attention query-key-value + mapped_to = "blk."+str(i)+".attn_qkv" + tensor_map["gpt_neox.layers."+str(i)+".attention.query_key_value"] = mapped_to # gptneox + tensor_map["transformer.h."+str(i)+".attn.c_attn"] = mapped_to # gpt2 + tensor_map["transformer.blocks."+str(i)+".attn.Wqkv"] = mapped_to # mpt + tensor_map["transformer.h."+str(i)+".self_attention.query_key_value"] = mapped_to # falcon + # Attention query + mapped_to = "blk."+str(i)+".attn_q" + tensor_map["model.layers."+str(i)+".self_attn.q_proj"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".attention.wq"] = mapped_to # llama-pth + # Attention key + mapped_to = "blk."+str(i)+".attn_k" + tensor_map["model.layers."+str(i)+".self_attn.k_proj"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".attention.wk"] = mapped_to # llama-pth + # Attention value + mapped_to = "blk."+str(i)+".attn_v" + tensor_map["model.layers."+str(i)+".self_attn.v_proj"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".attention.wv"] = mapped_to # llama-pth + # Attention output + mapped_to = "blk."+str(i)+".attn_output" + tensor_map["gpt_neox.layers."+str(i)+".attention.dense"] = mapped_to # gptneox + tensor_map["transformer.h."+str(i)+".attn.c_proj"] = mapped_to # gpt2 + tensor_map["transformer.blocks."+str(i)+".attn.out_proj"] = mapped_to # mpt + tensor_map["transformer.h."+str(i)+".self_attention.dense"] = mapped_to # falcon + tensor_map["model.layers."+str(i)+".self_attn.o_proj"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".attention.wo"] = mapped_to # llama-pth + # Feed-forward norm + mapped_to = "blk."+str(i)+".ffn_norm" + tensor_map["gpt_neox.layers."+str(i)+".post_attention_layernorm"] = mapped_to # gptneox + tensor_map["transformer.h."+str(i)+".ln_2"] = mapped_to # gpt2 + tensor_map["transformer.blocks."+str(i)+".norm_2"] = mapped_to # mpt + tensor_map["model.layers."+str(i)+".post_attention_layernorm"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".ffn_norm"] = mapped_to # llama-pth + # Feed-forward up + mapped_to = "blk."+str(i)+".ffn_up" + tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # gptneox + tensor_map["transformer.h."+str(i)+".mlp.c_fc"] = mapped_to # gpt2 + tensor_map["transformer.blocks."+str(i)+".ffn.up_proj"] = mapped_to # mpt + tensor_map["transformer.h."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # falcon + tensor_map["model.layers."+str(i)+".mlp.up_proj"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".feed_forward.w3"] = mapped_to # llama-pth + # Feed-forward gate + mapped_to = "blk."+str(i)+".ffn_gate" + tensor_map["model.layers."+str(i)+".mlp.gate_proj"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".feed_forward.w1"] = mapped_to # llama-pth + # Feed-forward down + mapped_to = "blk."+str(i)+".ffn_down" + tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # gptneox + tensor_map["transformer.h."+str(i)+".mlp.c_proj"] = mapped_to # gpt2 + tensor_map["transformer.blocks."+str(i)+".ffn.down_proj"] = mapped_to # mpt + tensor_map["transformer.h."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # falcon + tensor_map["model.layers."+str(i)+".mlp.down_proj"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".feed_forward.w2"] = mapped_to # llama-pth + + return tensor_map + +# +# implementation +# class GGMLQuantizationType(IntEnum): F32 = 0 @@ -19,16 +174,16 @@ class GGMLQuantizationType(IntEnum): class GGUFValueType(IntEnum): - UINT8 = 0 - INT8 = 1 - UINT16 = 2 - INT16 = 3 - UINT32 = 4 - INT32 = 5 + UINT8 = 0 + INT8 = 1 + UINT16 = 2 + INT16 = 3 + UINT32 = 4 + INT32 = 5 FLOAT32 = 6 - BOOL = 7 - STRING = 8 - ARRAY = 9 + BOOL = 7 + STRING = 8 + ARRAY = 9 @staticmethod def get_type(val): @@ -51,15 +206,15 @@ class GGUFWriter: def __init__(self, fout: IO): self.fout = fout self.offset_tensor = 0 - self.data_alignment = constants.GGUF_DEFAULT_ALIGNMENT + self.data_alignment = GGUF_DEFAULT_ALIGNMENT self.kv_data = b"" self.kv_data_count = 0 self.ti_data = b"" self.ti_data_count = 0 def write_header_to_file(self): - self.fout.write(struct.pack("