mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 13:30:35 +00:00
convert.py : add python logging instead of print() (#6511)
* convert.py: add python logging instead of print() * convert.py: verbose flag takes priority over dump flag log suppression * convert.py: named instance logging * convert.py: use explicit logger id string * convert.py: convert extra print() to named logger * convert.py: sys.stderr.write --> logger.error * *.py: Convert all python scripts to use logging module * requirements.txt: remove extra line * flake8: update flake8 ignore and exclude to match ci settings * gh-actions: add flake8-no-print to flake8 lint step * pre-commit: add flake8-no-print to flake8 and also update pre-commit version * convert-hf-to-gguf.py: print() to logger conversion * *.py: logging basiconfig refactor to use conditional expression * *.py: removed commented out logging * fixup! *.py: logging basiconfig refactor to use conditional expression * constant.py: logger.error then exit should be a raise exception instead * *.py: Convert logger error and sys.exit() into a raise exception (for atypical error) * gguf-convert-endian.py: refactor convert_byteorder() to use tqdm progressbar * verify-checksum-model.py: This is the result of the program, it should be printed to stdout. * compare-llama-bench.py: add blank line for readability during missing repo response * reader.py: read_gguf_file() use print() over logging * convert.py: warning goes to stderr and won't hurt the dump output * gguf-dump.py: dump_metadata() should print to stdout * convert-hf-to-gguf.py: print --> logger.debug or ValueError() * verify-checksum-models.py: use print() for printing table * *.py: refactor logging.basicConfig() * gguf-py/gguf/*.py: use __name__ as logger name Since they will be imported and not run directly. * python-lint.yml: use .flake8 file instead * constants.py: logger no longer required * convert-hf-to-gguf.py: add additional logging * convert-hf-to-gguf.py: print() --> logger * *.py: fix flake8 warnings * revert changes to convert-hf-to-gguf.py for get_name() * convert-hf-to-gguf-update.py: use triple quoted f-string instead * *.py: accidentally corrected the wrong line * *.py: add compilade warning suggestions and style fixes
This commit is contained in:
parent
433def286e
commit
a2ac89d6ef
3
.flake8
3
.flake8
@ -1,3 +1,4 @@
|
|||||||
[flake8]
|
[flake8]
|
||||||
max-line-length = 125
|
max-line-length = 125
|
||||||
ignore = W503
|
ignore = E203,E211,E221,E225,E231,E241,E251,E261,E266,E501,E701,E704,W503
|
||||||
|
exclude = examples/*,examples/*/**,*/**/__init__.py
|
||||||
|
3
.github/workflows/python-lint.yml
vendored
3
.github/workflows/python-lint.yml
vendored
@ -20,5 +20,4 @@ jobs:
|
|||||||
- name: flake8 Lint
|
- name: flake8 Lint
|
||||||
uses: py-actions/flake8@v2
|
uses: py-actions/flake8@v2
|
||||||
with:
|
with:
|
||||||
ignore: "E203,E211,E221,E225,E231,E241,E251,E261,E266,E501,E701,E704,W503"
|
plugins: "flake8-no-print"
|
||||||
exclude: "examples/*,examples/*/**,*/**/__init__.py,convert-hf-to-gguf-update.py"
|
|
||||||
|
@ -3,13 +3,14 @@
|
|||||||
exclude: prompts/.*.txt
|
exclude: prompts/.*.txt
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v3.2.0
|
rev: v4.6.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: check-added-large-files
|
- id: check-added-large-files
|
||||||
- repo: https://github.com/PyCQA/flake8
|
- repo: https://github.com/PyCQA/flake8
|
||||||
rev: 6.0.0
|
rev: 7.0.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
|
additional_dependencies: [flake8-no-print]
|
||||||
|
@ -21,6 +21,7 @@
|
|||||||
# TODO: automate the update of convert-hf-to-gguf.py
|
# TODO: automate the update of convert-hf-to-gguf.py
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import requests
|
import requests
|
||||||
import sys
|
import sys
|
||||||
@ -28,12 +29,17 @@ import json
|
|||||||
|
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
logger = logging.getLogger("convert-hf-to-gguf-update")
|
||||||
|
|
||||||
|
|
||||||
class TOKENIZER_TYPE(IntEnum):
|
class TOKENIZER_TYPE(IntEnum):
|
||||||
SPM = auto()
|
SPM = auto()
|
||||||
BPE = auto()
|
BPE = auto()
|
||||||
WPM = auto()
|
WPM = auto()
|
||||||
|
|
||||||
|
|
||||||
# TODO: this string has to exercise as much pre-tokenizer functionality as possible
|
# TODO: this string has to exercise as much pre-tokenizer functionality as possible
|
||||||
# will be updated with time - contributions welcome
|
# will be updated with time - contributions welcome
|
||||||
chktxt = '\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български \'\'\'\'\'\'```````\"\"\"\"......!!!!!!?????? I\'ve been \'told he\'s there, \'RE you sure? \'M not sure I\'ll make it, \'D you like some tea? We\'Ve a\'lL'
|
chktxt = '\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български \'\'\'\'\'\'```````\"\"\"\"......!!!!!!?????? I\'ve been \'told he\'s there, \'RE you sure? \'M not sure I\'ll make it, \'D you like some tea? We\'Ve a\'lL'
|
||||||
@ -41,36 +47,38 @@ chktxt = '\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶
|
|||||||
if len(sys.argv) == 2:
|
if len(sys.argv) == 2:
|
||||||
token = sys.argv[1]
|
token = sys.argv[1]
|
||||||
else:
|
else:
|
||||||
print("Usage: python convert-hf-to-gguf-update.py <huggingface_token>")
|
logger.info("Usage: python convert-hf-to-gguf-update.py <huggingface_token>")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# TODO: add models here, base models preferred
|
# TODO: add models here, base models preferred
|
||||||
models = [
|
models = [
|
||||||
{ "name": "llama-spm", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/meta-llama/Llama-2-7b-hf", },
|
{"name": "llama-spm", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/meta-llama/Llama-2-7b-hf", },
|
||||||
{ "name": "llama-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Meta-Llama-3-8B", },
|
{"name": "llama-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Meta-Llama-3-8B", },
|
||||||
{ "name": "phi-3", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", },
|
{"name": "phi-3", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", },
|
||||||
{ "name": "deepseek-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-llm-7b-base", },
|
{"name": "deepseek-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-llm-7b-base", },
|
||||||
{ "name": "deepseek-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base", },
|
{"name": "deepseek-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base", },
|
||||||
{ "name": "falcon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/falcon-7b", },
|
{"name": "falcon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/falcon-7b", },
|
||||||
{ "name": "bert-bge", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/BAAI/bge-small-en-v1.5", },
|
{"name": "bert-bge", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/BAAI/bge-small-en-v1.5", },
|
||||||
{ "name": "mpt", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mosaicml/mpt-7b", },
|
{"name": "mpt", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mosaicml/mpt-7b", },
|
||||||
{ "name": "starcoder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigcode/starcoder2-3b", },
|
{"name": "starcoder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigcode/starcoder2-3b", },
|
||||||
{ "name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", },
|
{"name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", },
|
||||||
]
|
]
|
||||||
|
|
||||||
# make directory "models/tokenizers" if it doesn't exist
|
# make directory "models/tokenizers" if it doesn't exist
|
||||||
if not os.path.exists("models/tokenizers"):
|
if not os.path.exists("models/tokenizers"):
|
||||||
os.makedirs("models/tokenizers")
|
os.makedirs("models/tokenizers")
|
||||||
|
|
||||||
|
|
||||||
def download_file_with_auth(url, token, save_path):
|
def download_file_with_auth(url, token, save_path):
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
response = requests.get(url, headers=headers)
|
response = requests.get(url, headers=headers)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
with open(save_path, 'wb') as f:
|
with open(save_path, 'wb') as f:
|
||||||
f.write(response.content)
|
f.write(response.content)
|
||||||
print(f"File {save_path} downloaded successfully")
|
logger.info(f"File {save_path} downloaded successfully")
|
||||||
else:
|
else:
|
||||||
print(f"Failed to download file. Status code: {response.status_code}")
|
logger.info(f"Failed to download file. Status code: {response.status_code}")
|
||||||
|
|
||||||
|
|
||||||
# download the tokenizer models
|
# download the tokenizer models
|
||||||
for model in models:
|
for model in models:
|
||||||
@ -81,10 +89,10 @@ for model in models:
|
|||||||
if not os.path.exists(f"models/tokenizers/{name}"):
|
if not os.path.exists(f"models/tokenizers/{name}"):
|
||||||
os.makedirs(f"models/tokenizers/{name}")
|
os.makedirs(f"models/tokenizers/{name}")
|
||||||
else:
|
else:
|
||||||
print(f"Directory models/tokenizers/{name} already exists - skipping")
|
logger.info(f"Directory models/tokenizers/{name} already exists - skipping")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print(f"Downloading {name} to models/tokenizers/{name}")
|
logger.info(f"Downloading {name} to models/tokenizers/{name}")
|
||||||
|
|
||||||
url = f"{repo}/raw/main/config.json"
|
url = f"{repo}/raw/main/config.json"
|
||||||
save_path = f"models/tokenizers/{name}/config.json"
|
save_path = f"models/tokenizers/{name}/config.json"
|
||||||
@ -115,76 +123,76 @@ for model in models:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# create the tokenizer
|
# create the tokenizer
|
||||||
from transformers import AutoTokenizer
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
|
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
|
||||||
|
|
||||||
chktok = tokenizer.encode(chktxt)
|
chktok = tokenizer.encode(chktxt)
|
||||||
chkhsh = sha256(str(chktok).encode()).hexdigest()
|
chkhsh = sha256(str(chktok).encode()).hexdigest()
|
||||||
|
|
||||||
print(f"model: {name}")
|
logger.info(f"model: {name}")
|
||||||
print(f"tokt: {tokt}")
|
logger.info(f"tokt: {tokt}")
|
||||||
print(f"repo: {model['repo']}")
|
logger.info(f"repo: {model['repo']}")
|
||||||
print(f"chktok: {chktok}")
|
logger.info(f"chktok: {chktok}")
|
||||||
print(f"chkhsh: {chkhsh}")
|
logger.info(f"chkhsh: {chkhsh}")
|
||||||
|
|
||||||
# print the "pre_tokenizer" content from the tokenizer.json
|
# print the "pre_tokenizer" content from the tokenizer.json
|
||||||
with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f:
|
with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f:
|
||||||
cfg = json.load(f)
|
cfg = json.load(f)
|
||||||
pre_tokenizer = cfg["pre_tokenizer"]
|
pre_tokenizer = cfg["pre_tokenizer"]
|
||||||
print("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4))
|
logger.info("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4))
|
||||||
|
|
||||||
print(f"\n")
|
logger.info("")
|
||||||
|
|
||||||
src_ifs += f" if chkhsh == \"{chkhsh}\":\n"
|
src_ifs += f" if chkhsh == \"{chkhsh}\":\n"
|
||||||
src_ifs += f" # ref: {model['repo']}\n"
|
src_ifs += f" # ref: {model['repo']}\n"
|
||||||
src_ifs += f" res = \"{name}\"\n"
|
src_ifs += f" res = \"{name}\"\n"
|
||||||
|
|
||||||
src_func = ""
|
src_func = f"""
|
||||||
src_func += " def get_vocab_base_pre(self, tokenizer) -> str:\n"
|
def get_vocab_base_pre(self, tokenizer) -> str:
|
||||||
src_func += " # encoding this string and hashing the resulting tokens would (hopefully) give us a unique identifier that\n"
|
# encoding this string and hashing the resulting tokens would (hopefully) give us a unique identifier that
|
||||||
src_func += " # is specific for the BPE pre-tokenizer used by the model\n"
|
# is specific for the BPE pre-tokenizer used by the model
|
||||||
src_func += " # we will use this unique identifier to write a \"tokenizer.ggml.pre\" entry in the GGUF file which we can\n"
|
# we will use this unique identifier to write a "tokenizer.ggml.pre" entry in the GGUF file which we can
|
||||||
src_func += " # use in llama.cpp to implement the same pre-tokenizer\n"
|
# use in llama.cpp to implement the same pre-tokenizer
|
||||||
src_func += "\n"
|
|
||||||
src_func += f" chktxt = {repr(chktxt)}\n"
|
|
||||||
src_func += "\n"
|
|
||||||
src_func += " chktok = tokenizer.encode(chktxt)\n"
|
|
||||||
src_func += " chkhsh = sha256(str(chktok).encode()).hexdigest()\n"
|
|
||||||
src_func += "\n"
|
|
||||||
src_func += " print(f\"chktok: {chktok}\")\n"
|
|
||||||
src_func += " print(f\"chkhsh: {chkhsh}\")\n"
|
|
||||||
src_func += "\n"
|
|
||||||
src_func += " res = None\n"
|
|
||||||
src_func += "\n"
|
|
||||||
src_func += " # NOTE: if you get an error here, you need to update the convert-hf-to-gguf-update.py script\n"
|
|
||||||
src_func += " # or pull the latest version of the model from Huggingface\n"
|
|
||||||
src_func += " # don't edit the hashes manually!\n"
|
|
||||||
src_func += f"{src_ifs}\n"
|
|
||||||
src_func += " if res is None:\n"
|
|
||||||
src_func += " print(\"\\n\")\n"
|
|
||||||
src_func += " print(\"**************************************************************************************\")\n"
|
|
||||||
src_func += " print(\"** WARNING: The BPE pre-tokenizer was not recognized!\")\n"
|
|
||||||
src_func += " print(\"** There are 2 possible reasons for this:\")\n"
|
|
||||||
src_func += " print(\"** - the model has not been added to convert-hf-to-gguf-update.py yet\")\n"
|
|
||||||
src_func += " print(\"** - the pre-tokenization config has changed upstream\")\n"
|
|
||||||
src_func += " print(\"** Check your model files and convert-hf-to-gguf-update.py and update them accordingly.\")\n"
|
|
||||||
src_func += " print(\"** ref: https://github.com/ggerganov/llama.cpp/pull/6920\")\n"
|
|
||||||
src_func += " print(\"**\")\n"
|
|
||||||
src_func += " print(f\"** chkhsh: {chkhsh}\")\n"
|
|
||||||
src_func += " print(\"**************************************************************************************\")\n"
|
|
||||||
src_func += " print(\"\\n\")\n"
|
|
||||||
src_func += " raise NotImplementedError(\"BPE pre-tokenizer was not recognized - update get_vocab_base_pre()\")\n"
|
|
||||||
src_func += "\n"
|
|
||||||
src_func += " print(f\"tokenizer.ggml.pre: {res}\")\n"
|
|
||||||
src_func += " print(f\"chkhsh: {chkhsh}\")\n"
|
|
||||||
src_func += "\n"
|
|
||||||
src_func += " return res\n"
|
|
||||||
|
|
||||||
print(src_func)
|
chktxt = {repr(chktxt)}
|
||||||
|
|
||||||
print("\n")
|
chktok = tokenizer.encode(chktxt)
|
||||||
print("!!! Copy-paste the function above into convert-hf-to-gguf.py !!!")
|
chkhsh = sha256(str(chktok).encode()).hexdigest()
|
||||||
print("\n")
|
|
||||||
|
print(f"chktok: {{chktok}}")
|
||||||
|
print(f"chkhsh: {{chkhsh}}")
|
||||||
|
|
||||||
|
res = None
|
||||||
|
|
||||||
|
# NOTE: if you get an error here, you need to update the convert-hf-to-gguf-update.py script
|
||||||
|
# or pull the latest version of the model from Huggingface
|
||||||
|
# don't edit the hashes manually!
|
||||||
|
{src_ifs}
|
||||||
|
if res is None:
|
||||||
|
print("\\n")
|
||||||
|
print("**************************************************************************************")
|
||||||
|
print("** WARNING: The BPE pre-tokenizer was not recognized!")
|
||||||
|
print("** There are 2 possible reasons for this:")
|
||||||
|
print("** - the model has not been added to convert-hf-to-gguf-update.py yet")
|
||||||
|
print("** - the pre-tokenization config has changed upstream")
|
||||||
|
print("** Check your model files and convert-hf-to-gguf-update.py and update them accordingly.")
|
||||||
|
print("** ref: https://github.com/ggerganov/llama.cpp/pull/6920")
|
||||||
|
print("**")
|
||||||
|
print(f"** chkhsh: {{chkhsh}}")
|
||||||
|
print("**************************************************************************************")
|
||||||
|
print("\\n")
|
||||||
|
raise NotImplementedError("BPE pre-tokenizer was not recognized - update get_vocab_base_pre()")
|
||||||
|
|
||||||
|
print(f"tokenizer.ggml.pre: {{repr(res)}}")
|
||||||
|
print(f"chkhsh: {{chkhsh}}")
|
||||||
|
|
||||||
|
return res
|
||||||
|
"""
|
||||||
|
|
||||||
|
print(src_func) # noqa: NP100
|
||||||
|
|
||||||
|
logger.info("\n")
|
||||||
|
logger.info("!!! Copy-paste the function above into convert-hf-to-gguf.py !!!")
|
||||||
|
logger.info("\n")
|
||||||
|
|
||||||
# generate tests for each tokenizer model
|
# generate tests for each tokenizer model
|
||||||
|
|
||||||
@ -250,7 +258,6 @@ for model in models:
|
|||||||
tokt = model["tokt"]
|
tokt = model["tokt"]
|
||||||
|
|
||||||
# create the tokenizer
|
# create the tokenizer
|
||||||
from transformers import AutoTokenizer
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
|
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
|
||||||
|
|
||||||
with open(f"models/ggml-vocab-{name}.gguf.inp", "w", encoding="utf-8") as f:
|
with open(f"models/ggml-vocab-{name}.gguf.inp", "w", encoding="utf-8") as f:
|
||||||
@ -265,15 +272,15 @@ for model in models:
|
|||||||
f.write(f" {r}")
|
f.write(f" {r}")
|
||||||
f.write("\n")
|
f.write("\n")
|
||||||
|
|
||||||
print(f"Tests for {name} written in ./models/ggml-vocab-{name}.gguf.*")
|
logger.info(f"Tests for {name} written in ./models/ggml-vocab-{name}.gguf.*")
|
||||||
|
|
||||||
# generate commands for creating vocab files
|
# generate commands for creating vocab files
|
||||||
|
|
||||||
print("\nRun the following commands to generate the vocab files for testing:\n")
|
logger.info("\nRun the following commands to generate the vocab files for testing:\n")
|
||||||
|
|
||||||
for model in models:
|
for model in models:
|
||||||
name = model["name"]
|
name = model["name"]
|
||||||
|
|
||||||
print(f"python3 convert-hf-to-gguf.py models/tokenizers/{name}/ --outfile models/ggml-vocab-{name}.gguf --vocab-only")
|
logger.info(f"python3 convert-hf-to-gguf.py models/tokenizers/{name}/ --outfile models/ggml-vocab-{name}.gguf --vocab-only")
|
||||||
|
|
||||||
print("\n")
|
logger.info("\n")
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
import contextlib
|
import contextlib
|
||||||
import json
|
import json
|
||||||
@ -26,6 +27,8 @@ import gguf
|
|||||||
|
|
||||||
from convert import LlamaHfVocab, permute
|
from convert import LlamaHfVocab, permute
|
||||||
|
|
||||||
|
logger = logging.getLogger("hf-to-gguf")
|
||||||
|
|
||||||
|
|
||||||
###### MODEL DEFINITIONS ######
|
###### MODEL DEFINITIONS ######
|
||||||
|
|
||||||
@ -76,7 +79,7 @@ class Model(ABC):
|
|||||||
|
|
||||||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||||
for part_name in self.part_names:
|
for part_name in self.part_names:
|
||||||
print(f"gguf: loading model part '{part_name}'")
|
logger.info(f"gguf: loading model part '{part_name}'")
|
||||||
ctx: ContextManager[Any]
|
ctx: ContextManager[Any]
|
||||||
if self.is_safetensors:
|
if self.is_safetensors:
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
@ -95,42 +98,42 @@ class Model(ABC):
|
|||||||
|
|
||||||
if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
|
if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
|
||||||
self.gguf_writer.add_context_length(n_ctx)
|
self.gguf_writer.add_context_length(n_ctx)
|
||||||
print(f"gguf: context length = {n_ctx}")
|
logger.info(f"gguf: context length = {n_ctx}")
|
||||||
|
|
||||||
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
||||||
self.gguf_writer.add_embedding_length(n_embd)
|
self.gguf_writer.add_embedding_length(n_embd)
|
||||||
print(f"gguf: embedding length = {n_embd}")
|
logger.info(f"gguf: embedding length = {n_embd}")
|
||||||
|
|
||||||
if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None:
|
if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None:
|
||||||
self.gguf_writer.add_feed_forward_length(n_ff)
|
self.gguf_writer.add_feed_forward_length(n_ff)
|
||||||
print(f"gguf: feed forward length = {n_ff}")
|
logger.info(f"gguf: feed forward length = {n_ff}")
|
||||||
|
|
||||||
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
||||||
self.gguf_writer.add_head_count(n_head)
|
self.gguf_writer.add_head_count(n_head)
|
||||||
print(f"gguf: head count = {n_head}")
|
logger.info(f"gguf: head count = {n_head}")
|
||||||
|
|
||||||
if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
|
if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
|
||||||
self.gguf_writer.add_head_count_kv(n_head_kv)
|
self.gguf_writer.add_head_count_kv(n_head_kv)
|
||||||
print(f"gguf: key-value head count = {n_head_kv}")
|
logger.info(f"gguf: key-value head count = {n_head_kv}")
|
||||||
|
|
||||||
if (rope_theta := self.hparams.get("rope_theta")) is not None:
|
if (rope_theta := self.hparams.get("rope_theta")) is not None:
|
||||||
self.gguf_writer.add_rope_freq_base(rope_theta)
|
self.gguf_writer.add_rope_freq_base(rope_theta)
|
||||||
print(f"gguf: rope theta = {rope_theta}")
|
logger.info(f"gguf: rope theta = {rope_theta}")
|
||||||
if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
|
if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
|
||||||
self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
|
self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
|
||||||
print(f"gguf: rms norm epsilon = {f_rms_eps}")
|
logger.info(f"gguf: rms norm epsilon = {f_rms_eps}")
|
||||||
if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
|
if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
|
||||||
self.gguf_writer.add_layer_norm_eps(f_norm_eps)
|
self.gguf_writer.add_layer_norm_eps(f_norm_eps)
|
||||||
print(f"gguf: layer norm epsilon = {f_norm_eps}")
|
logger.info(f"gguf: layer norm epsilon = {f_norm_eps}")
|
||||||
if (n_experts := self.hparams.get("num_local_experts")) is not None:
|
if (n_experts := self.hparams.get("num_local_experts")) is not None:
|
||||||
self.gguf_writer.add_expert_count(n_experts)
|
self.gguf_writer.add_expert_count(n_experts)
|
||||||
print(f"gguf: expert count = {n_experts}")
|
logger.info(f"gguf: expert count = {n_experts}")
|
||||||
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
|
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
|
||||||
self.gguf_writer.add_expert_used_count(n_experts_used)
|
self.gguf_writer.add_expert_used_count(n_experts_used)
|
||||||
print(f"gguf: experts used count = {n_experts_used}")
|
logger.info(f"gguf: experts used count = {n_experts_used}")
|
||||||
|
|
||||||
self.gguf_writer.add_file_type(self.ftype)
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
print(f"gguf: file type = {self.ftype}")
|
logger.info(f"gguf: file type = {self.ftype}")
|
||||||
|
|
||||||
def write_tensors(self):
|
def write_tensors(self):
|
||||||
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
|
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
|
||||||
@ -151,8 +154,7 @@ class Model(ABC):
|
|||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
data_dtype = data.dtype
|
data_dtype = data.dtype
|
||||||
@ -169,7 +171,7 @@ class Model(ABC):
|
|||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
logger.info(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
@ -274,8 +276,8 @@ class Model(ABC):
|
|||||||
chktok = tokenizer.encode(chktxt)
|
chktok = tokenizer.encode(chktxt)
|
||||||
chkhsh = sha256(str(chktok).encode()).hexdigest()
|
chkhsh = sha256(str(chktok).encode()).hexdigest()
|
||||||
|
|
||||||
print(f"chktok: {chktok}")
|
logger.debug(f"chktok: {chktok}")
|
||||||
print(f"chkhsh: {chkhsh}")
|
logger.debug(f"chkhsh: {chkhsh}")
|
||||||
|
|
||||||
res = None
|
res = None
|
||||||
|
|
||||||
@ -308,22 +310,22 @@ class Model(ABC):
|
|||||||
res = "gpt-2"
|
res = "gpt-2"
|
||||||
|
|
||||||
if res is None:
|
if res is None:
|
||||||
print("\n")
|
logger.warning("\n")
|
||||||
print("**************************************************************************************")
|
logger.warning("**************************************************************************************")
|
||||||
print("** WARNING: The BPE pre-tokenizer was not recognized!")
|
logger.warning("** WARNING: The BPE pre-tokenizer was not recognized!")
|
||||||
print("** There are 2 possible reasons for this:")
|
logger.warning("** There are 2 possible reasons for this:")
|
||||||
print("** - the model has not been added to convert-hf-to-gguf-update.py yet")
|
logger.warning("** - the model has not been added to convert-hf-to-gguf-update.py yet")
|
||||||
print("** - the pre-tokenization config has changed upstream")
|
logger.warning("** - the pre-tokenization config has changed upstream")
|
||||||
print("** Check your model files and convert-hf-to-gguf-update.py and update them accordingly.")
|
logger.warning("** Check your model files and convert-hf-to-gguf-update.py and update them accordingly.")
|
||||||
print("** ref: https://github.com/ggerganov/llama.cpp/pull/6920")
|
logger.warning("** ref: https://github.com/ggerganov/llama.cpp/pull/6920")
|
||||||
print("**")
|
logger.warning("**")
|
||||||
print(f"** chkhsh: {chkhsh}")
|
logger.warning(f"** chkhsh: {chkhsh}")
|
||||||
print("**************************************************************************************")
|
logger.warning("**************************************************************************************")
|
||||||
print("\n")
|
logger.warning("\n")
|
||||||
raise NotImplementedError("BPE pre-tokenizer was not recognized - update get_vocab_base_pre()")
|
raise NotImplementedError("BPE pre-tokenizer was not recognized - update get_vocab_base_pre()")
|
||||||
|
|
||||||
print(f"tokenizer.ggml.pre: {res}")
|
logger.debug(f"tokenizer.ggml.pre: {res}")
|
||||||
print(f"chkhsh: {chkhsh}")
|
logger.debug(f"chkhsh: {chkhsh}")
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@ -439,9 +441,7 @@ class Model(ABC):
|
|||||||
|
|
||||||
if vocab_size > len(tokens):
|
if vocab_size > len(tokens):
|
||||||
pad_count = vocab_size - len(tokens)
|
pad_count = vocab_size - len(tokens)
|
||||||
print(
|
logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
|
||||||
f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]"
|
|
||||||
)
|
|
||||||
for i in range(1, pad_count + 1):
|
for i in range(1, pad_count + 1):
|
||||||
tokens.append(f"[PAD{i}]")
|
tokens.append(f"[PAD{i}]")
|
||||||
scores.append(-1000.0)
|
scores.append(-1000.0)
|
||||||
@ -553,7 +553,7 @@ class BloomModel(Model):
|
|||||||
),
|
),
|
||||||
axis=0,
|
axis=0,
|
||||||
)
|
)
|
||||||
print("re-format attention.linear_qkv.weight")
|
logger.info("re-format attention.linear_qkv.weight")
|
||||||
elif re.match(r"h\.\d+\.self_attention\.query_key_value\.bias", name):
|
elif re.match(r"h\.\d+\.self_attention\.query_key_value\.bias", name):
|
||||||
qkv_bias = data.reshape((n_head, 3, n_embed // n_head))
|
qkv_bias = data.reshape((n_head, 3, n_embed // n_head))
|
||||||
data = np.concatenate(
|
data = np.concatenate(
|
||||||
@ -564,13 +564,12 @@ class BloomModel(Model):
|
|||||||
),
|
),
|
||||||
axis=0,
|
axis=0,
|
||||||
)
|
)
|
||||||
print("re-format attention.linear_qkv.bias")
|
logger.info("re-format attention.linear_qkv.bias")
|
||||||
|
|
||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
data_dtype = data.dtype
|
data_dtype = data.dtype
|
||||||
@ -587,13 +586,13 @@ class BloomModel(Model):
|
|||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(f"=> {new_name}, shape = {data.shape}, {old_dtype} --> {data.dtype}")
|
logger.info(f"=> {new_name}, shape = {data.shape}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
if not has_lm_head and name == "word_embeddings.weight":
|
if not has_lm_head and name == "word_embeddings.weight":
|
||||||
self.gguf_writer.add_tensor("output.weight", data)
|
self.gguf_writer.add_tensor("output.weight", data)
|
||||||
print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
|
logger.info(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
|
|
||||||
@Model.register("MPTForCausalLM")
|
@Model.register("MPTForCausalLM")
|
||||||
@ -653,8 +652,7 @@ class MPTModel(Model):
|
|||||||
else:
|
else:
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
data_dtype = data.dtype
|
data_dtype = data.dtype
|
||||||
@ -671,7 +669,7 @@ class MPTModel(Model):
|
|||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
logger.info(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
@ -697,8 +695,7 @@ class OrionModel(Model):
|
|||||||
elif "model_max_length" in self.hparams:
|
elif "model_max_length" in self.hparams:
|
||||||
ctx_length = self.hparams["model_max_length"]
|
ctx_length = self.hparams["model_max_length"]
|
||||||
else:
|
else:
|
||||||
print("gguf: can not find ctx length parameter.")
|
raise ValueError("gguf: can not find ctx length parameter.")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
self.gguf_writer.add_file_type(self.ftype)
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
self.gguf_writer.add_name(self.dir_model.name)
|
self.gguf_writer.add_name(self.dir_model.name)
|
||||||
@ -736,8 +733,7 @@ class OrionModel(Model):
|
|||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
data_dtype = data.dtype
|
data_dtype = data.dtype
|
||||||
@ -754,7 +750,7 @@ class OrionModel(Model):
|
|||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
logger.info(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
|
||||||
@ -779,8 +775,7 @@ class BaichuanModel(Model):
|
|||||||
elif "model_max_length" in self.hparams:
|
elif "model_max_length" in self.hparams:
|
||||||
ctx_length = self.hparams["model_max_length"]
|
ctx_length = self.hparams["model_max_length"]
|
||||||
else:
|
else:
|
||||||
print("gguf: can not find ctx length parameter.")
|
raise ValueError("gguf: can not find ctx length parameter.")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
self.gguf_writer.add_name(self.dir_model.name)
|
self.gguf_writer.add_name(self.dir_model.name)
|
||||||
self.gguf_writer.add_source_hf_repo(hf_repo)
|
self.gguf_writer.add_source_hf_repo(hf_repo)
|
||||||
@ -809,7 +804,7 @@ class BaichuanModel(Model):
|
|||||||
|
|
||||||
for i in range(block_count):
|
for i in range(block_count):
|
||||||
if (w := model_kv.get(f"model.layers.{i}.self_attn.W_pack.weight")) is not None:
|
if (w := model_kv.get(f"model.layers.{i}.self_attn.W_pack.weight")) is not None:
|
||||||
print(f"Unpacking and permuting layer {i}")
|
logger.info(f"Unpacking and permuting layer {i}")
|
||||||
model_kv[f"model.layers.{i}.self_attn.q_proj.weight"] = \
|
model_kv[f"model.layers.{i}.self_attn.q_proj.weight"] = \
|
||||||
self._reverse_hf_permute_part(w, 0, head_count, head_count)
|
self._reverse_hf_permute_part(w, 0, head_count, head_count)
|
||||||
model_kv[f"model.layers.{i}.self_attn.k_proj.weight"] = \
|
model_kv[f"model.layers.{i}.self_attn.k_proj.weight"] = \
|
||||||
@ -834,8 +829,7 @@ class BaichuanModel(Model):
|
|||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
data_dtype = data.dtype
|
data_dtype = data.dtype
|
||||||
@ -852,7 +846,7 @@ class BaichuanModel(Model):
|
|||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
logger.info(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
|
def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
|
||||||
@ -937,8 +931,7 @@ class XverseModel(Model):
|
|||||||
elif "model_max_length" in self.hparams:
|
elif "model_max_length" in self.hparams:
|
||||||
ctx_length = self.hparams["model_max_length"]
|
ctx_length = self.hparams["model_max_length"]
|
||||||
else:
|
else:
|
||||||
print("gguf: can not find ctx length parameter.")
|
raise ValueError("gguf: can not find ctx length parameter.")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
self.gguf_writer.add_name(self.dir_model.name)
|
self.gguf_writer.add_name(self.dir_model.name)
|
||||||
self.gguf_writer.add_source_hf_repo(hf_repo)
|
self.gguf_writer.add_source_hf_repo(hf_repo)
|
||||||
@ -987,8 +980,7 @@ class XverseModel(Model):
|
|||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
data_dtype = data.dtype
|
data_dtype = data.dtype
|
||||||
@ -1005,7 +997,7 @@ class XverseModel(Model):
|
|||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
logger.info(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
|
def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
|
||||||
@ -1092,8 +1084,7 @@ class FalconModel(Model):
|
|||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
data_dtype = data.dtype
|
data_dtype = data.dtype
|
||||||
@ -1110,7 +1101,7 @@ class FalconModel(Model):
|
|||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
logger.info(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
@ -1197,8 +1188,7 @@ class RefactModel(Model):
|
|||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight",))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight",))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
data_dtype = data.dtype
|
data_dtype = data.dtype
|
||||||
@ -1215,7 +1205,7 @@ class RefactModel(Model):
|
|||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
logger.info(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
@ -1264,10 +1254,9 @@ class PersimmonModel(Model):
|
|||||||
data = data_torch.to(torch.float32).squeeze().numpy()
|
data = data_torch.to(torch.float32).squeeze().numpy()
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
logger.info(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
|
||||||
@ -1332,8 +1321,7 @@ class StableLMModel(Model):
|
|||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
data_dtype = data.dtype
|
data_dtype = data.dtype
|
||||||
@ -1350,7 +1338,7 @@ class StableLMModel(Model):
|
|||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and not new_name.endswith("_norm.weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and not new_name.endswith("_norm.weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
logger.debug(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
@ -1366,8 +1354,7 @@ class StableLMModel(Model):
|
|||||||
merged_name = f"model.layers.{bid}.self_attn.{layer_name}.weight"
|
merged_name = f"model.layers.{bid}.self_attn.{layer_name}.weight"
|
||||||
new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
if self.ftype == 1 and data_dtype == np.float16 and (n_dims == 1 or new_name.endswith("_norm.weight")):
|
if self.ftype == 1 and data_dtype == np.float16 and (n_dims == 1 or new_name.endswith("_norm.weight")):
|
||||||
data = data.astype(np.float32)
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
@ -1375,7 +1362,7 @@ class StableLMModel(Model):
|
|||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and not new_name.endswith("_norm.weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and not new_name.endswith("_norm.weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}")
|
logger.debug(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
@ -1480,10 +1467,9 @@ class LlamaModel(Model):
|
|||||||
|
|
||||||
new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}")
|
logger.info(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
continue
|
continue
|
||||||
@ -1491,8 +1477,7 @@ class LlamaModel(Model):
|
|||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
data_dtype = data.dtype
|
data_dtype = data.dtype
|
||||||
@ -1509,7 +1494,7 @@ class LlamaModel(Model):
|
|||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
logger.info(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
@ -1584,10 +1569,9 @@ class GrokModel(Model):
|
|||||||
|
|
||||||
new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}")
|
logger.info(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
continue
|
continue
|
||||||
@ -1595,8 +1579,7 @@ class GrokModel(Model):
|
|||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
data_dtype = data.dtype
|
data_dtype = data.dtype
|
||||||
@ -1613,7 +1596,7 @@ class GrokModel(Model):
|
|||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
logger.info(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
@ -1646,7 +1629,7 @@ class DbrxModel(Model):
|
|||||||
self.gguf_writer.add_layer_norm_eps(1e-5)
|
self.gguf_writer.add_layer_norm_eps(1e-5)
|
||||||
|
|
||||||
self.gguf_writer.add_file_type(self.ftype)
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
print(f"gguf: file type = {self.ftype}")
|
logger.info(f"gguf: file type = {self.ftype}")
|
||||||
|
|
||||||
def write_tensors(self):
|
def write_tensors(self):
|
||||||
block_count = self.hparams.get("n_layers")
|
block_count = self.hparams.get("n_layers")
|
||||||
@ -1689,8 +1672,7 @@ class DbrxModel(Model):
|
|||||||
# https://huggingface.co/databricks/dbrx-instruct/blob/main/model.safetensors.index.json#L15
|
# https://huggingface.co/databricks/dbrx-instruct/blob/main/model.safetensors.index.json#L15
|
||||||
new_name = tensor_map.get_name(name if not experts else name + ".weight", try_suffixes=(".weight",))
|
new_name = tensor_map.get_name(name if not experts else name + ".weight", try_suffixes=(".weight",))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
data_dtype = data.dtype
|
data_dtype = data.dtype
|
||||||
@ -1698,8 +1680,7 @@ class DbrxModel(Model):
|
|||||||
# Most of the codebase that takes in 1D tensors only handles F32 tensors
|
# Most of the codebase that takes in 1D tensors only handles F32 tensors
|
||||||
# and most of the outputs tensors are F32.
|
# and most of the outputs tensors are F32.
|
||||||
if data_dtype != np.float32 and n_dims == 1:
|
if data_dtype != np.float32 and n_dims == 1:
|
||||||
print(f"Can not map tensor {name!r}: all 1D tensors must be F32")
|
raise ValueError(f"Can not map tensor {name!r}: all 1D tensors must be F32")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
# if f32 desired, convert any float16 to float32
|
# if f32 desired, convert any float16 to float32
|
||||||
if self.ftype == 0 and data_dtype == np.float16:
|
if self.ftype == 0 and data_dtype == np.float16:
|
||||||
@ -1709,7 +1690,7 @@ class DbrxModel(Model):
|
|||||||
if self.ftype == 1 and data_dtype == np.float32 and n_dims > 1:
|
if self.ftype == 1 and data_dtype == np.float32 and n_dims > 1:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {n_dims}, shape = {data.shape}, {old_dtype} --> {data.dtype}")
|
logger.debug(f"{new_name}, n_dims = {n_dims}, shape = {data.shape}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
@ -1771,8 +1752,7 @@ class MiniCPMModel(Model):
|
|||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
data_dtype = data.dtype
|
data_dtype = data.dtype
|
||||||
@ -1789,7 +1769,7 @@ class MiniCPMModel(Model):
|
|||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
logger.info(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
@ -1855,8 +1835,7 @@ class QwenModel(Model):
|
|||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
data_dtype = data.dtype
|
data_dtype = data.dtype
|
||||||
@ -1873,7 +1852,7 @@ class QwenModel(Model):
|
|||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
logger.info(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
|
||||||
@ -1950,10 +1929,9 @@ class Qwen2MoeModel(Model):
|
|||||||
|
|
||||||
new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}")
|
logger.debug(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
continue
|
continue
|
||||||
@ -1961,8 +1939,7 @@ class Qwen2MoeModel(Model):
|
|||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
data_dtype = data.dtype
|
data_dtype = data.dtype
|
||||||
@ -1979,7 +1956,7 @@ class Qwen2MoeModel(Model):
|
|||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {n_dims}, shape = {data.shape}, {old_dtype} --> {data.dtype}")
|
logger.debug(f"{new_name}, n_dims = {n_dims}, shape = {data.shape}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
@ -2024,8 +2001,7 @@ class GPT2Model(Model):
|
|||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
data_dtype = data.dtype
|
data_dtype = data.dtype
|
||||||
@ -2042,13 +2018,13 @@ class GPT2Model(Model):
|
|||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
logger.info(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
# note: GPT2 output is tied to (same as) wte in original model
|
# note: GPT2 output is tied to (same as) wte in original model
|
||||||
if new_name == "token_embd.weight":
|
if new_name == "token_embd.weight":
|
||||||
print(f"output.weight, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
logger.info(f"output.weight, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
self.gguf_writer.add_tensor("output.weight", data)
|
self.gguf_writer.add_tensor("output.weight", data)
|
||||||
|
|
||||||
|
|
||||||
@ -2087,8 +2063,7 @@ class Phi3MiniModel(Model):
|
|||||||
tokenizer_path = self.dir_model / 'tokenizer.model'
|
tokenizer_path = self.dir_model / 'tokenizer.model'
|
||||||
|
|
||||||
if not tokenizer_path.is_file():
|
if not tokenizer_path.is_file():
|
||||||
print(f'Error: Missing {tokenizer_path}', file=sys.stderr)
|
raise ValueError(f'Error: Missing {tokenizer_path}')
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
tokenizer = SentencePieceProcessor(str(tokenizer_path))
|
tokenizer = SentencePieceProcessor(str(tokenizer_path))
|
||||||
|
|
||||||
@ -2126,7 +2101,7 @@ class Phi3MiniModel(Model):
|
|||||||
for key in added_tokens_json:
|
for key in added_tokens_json:
|
||||||
token_id = added_tokens_json[key]
|
token_id = added_tokens_json[key]
|
||||||
if (token_id >= vocab_size):
|
if (token_id >= vocab_size):
|
||||||
print(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
|
logger.debug(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tokens[token_id] = key.encode("utf-8")
|
tokens[token_id] = key.encode("utf-8")
|
||||||
@ -2208,8 +2183,7 @@ class PlamoModel(Model):
|
|||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
# shuffle for broadcasting of gqa in ggml_mul_mat
|
# shuffle for broadcasting of gqa in ggml_mul_mat
|
||||||
if new_name.endswith("attn_q.weight"):
|
if new_name.endswith("attn_q.weight"):
|
||||||
@ -2240,7 +2214,7 @@ class PlamoModel(Model):
|
|||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
logger.info(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
@ -2286,8 +2260,7 @@ class CodeShellModel(Model):
|
|||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
data_dtype = data.dtype
|
data_dtype = data.dtype
|
||||||
@ -2304,13 +2277,13 @@ class CodeShellModel(Model):
|
|||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
logger.info(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
if not has_lm_head and name == "transformer.wte.weight":
|
if not has_lm_head and name == "transformer.wte.weight":
|
||||||
self.gguf_writer.add_tensor("output.weight", data)
|
self.gguf_writer.add_tensor("output.weight", data)
|
||||||
print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
|
logger.info(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
|
|
||||||
@Model.register("InternLM2ForCausalLM")
|
@Model.register("InternLM2ForCausalLM")
|
||||||
@ -2332,7 +2305,7 @@ class InternLM2Model(Model):
|
|||||||
toktypes: list[int] = []
|
toktypes: list[int] = []
|
||||||
|
|
||||||
if not tokenizer_path.is_file():
|
if not tokenizer_path.is_file():
|
||||||
print(f'Error: Missing {tokenizer_path}', file=sys.stderr)
|
logger.error(f'Error: Missing {tokenizer_path}')
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
sentencepiece_model = model.ModelProto()
|
sentencepiece_model = model.ModelProto()
|
||||||
@ -2349,7 +2322,7 @@ class InternLM2Model(Model):
|
|||||||
if text == b"\x00":
|
if text == b"\x00":
|
||||||
# (TODO): fixme
|
# (TODO): fixme
|
||||||
# Hack here and replace the \x00 characters.
|
# Hack here and replace the \x00 characters.
|
||||||
print(f"InternLM2 convert token '{text}' to '🐉'!")
|
logger.debug(f"InternLM2 convert token '{text}' to '🐉'!")
|
||||||
text = "🐉"
|
text = "🐉"
|
||||||
|
|
||||||
toktype = SentencePieceTokenTypes.NORMAL
|
toktype = SentencePieceTokenTypes.NORMAL
|
||||||
@ -2390,7 +2363,7 @@ class InternLM2Model(Model):
|
|||||||
# TODO: this is a hack, should be fixed
|
# TODO: this is a hack, should be fixed
|
||||||
# https://github.com/ggerganov/llama.cpp/pull/6745#issuecomment-2067687048
|
# https://github.com/ggerganov/llama.cpp/pull/6745#issuecomment-2067687048
|
||||||
special_vocab.special_token_ids["eos"] = self._try_get_sft_eos(tokenizer)
|
special_vocab.special_token_ids["eos"] = self._try_get_sft_eos(tokenizer)
|
||||||
print(f"Replace eos:{old_eos} with a special token:{special_vocab.special_token_ids['eos']} \
|
logger.warning(f"Replace eos:{old_eos} with a special token:{special_vocab.special_token_ids['eos']} \
|
||||||
in chat mode so that the conversation can end normally.")
|
in chat mode so that the conversation can end normally.")
|
||||||
|
|
||||||
special_vocab.add_to_gguf(self.gguf_writer)
|
special_vocab.add_to_gguf(self.gguf_writer)
|
||||||
@ -2435,8 +2408,7 @@ in chat mode so that the conversation can end normally.")
|
|||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
data_dtype = data.dtype
|
data_dtype = data.dtype
|
||||||
@ -2453,7 +2425,7 @@ in chat mode so that the conversation can end normally.")
|
|||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
logger.info(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
def write_tensors(self):
|
def write_tensors(self):
|
||||||
@ -2564,8 +2536,7 @@ class BertModel(Model):
|
|||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
# convert any unsupported data types to float32
|
# convert any unsupported data types to float32
|
||||||
if data_torch.dtype not in (torch.float16, torch.float32):
|
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||||
@ -2585,7 +2556,7 @@ class BertModel(Model):
|
|||||||
# if f32 desired, convert any float16 to float32
|
# if f32 desired, convert any float16 to float32
|
||||||
new_dtype = np.float32
|
new_dtype = np.float32
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {n_dims}, {data_torch.dtype} --> {new_dtype}")
|
logger.info(f"{new_name}, n_dims = {n_dims}, {data_torch.dtype} --> {new_dtype}")
|
||||||
|
|
||||||
if data.dtype != new_dtype:
|
if data.dtype != new_dtype:
|
||||||
data = data.astype(new_dtype)
|
data = data.astype(new_dtype)
|
||||||
@ -2664,7 +2635,7 @@ class GemmaModel(Model):
|
|||||||
# lm_head is not used in llama.cpp, while autoawq will include this tensor in model
|
# lm_head is not used in llama.cpp, while autoawq will include this tensor in model
|
||||||
# To prevent errors, skip loading lm_head.weight.
|
# To prevent errors, skip loading lm_head.weight.
|
||||||
if name == "lm_head.weight":
|
if name == "lm_head.weight":
|
||||||
print(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.")
|
logger.debug(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
old_dtype = data_torch.dtype
|
old_dtype = data_torch.dtype
|
||||||
@ -2681,8 +2652,7 @@ class GemmaModel(Model):
|
|||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
data_dtype = data.dtype
|
data_dtype = data.dtype
|
||||||
@ -2693,7 +2663,7 @@ class GemmaModel(Model):
|
|||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
logger.info(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
@ -2721,7 +2691,7 @@ class MambaModel(Model):
|
|||||||
else:
|
else:
|
||||||
# Use the GPT-NeoX tokenizer when no tokenizer files are present
|
# Use the GPT-NeoX tokenizer when no tokenizer files are present
|
||||||
tokenizer_path = Path(sys.path[0]) / "models" / "ggml-vocab-gpt-neox.gguf"
|
tokenizer_path = Path(sys.path[0]) / "models" / "ggml-vocab-gpt-neox.gguf"
|
||||||
print(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
|
logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
|
||||||
neox_reader = gguf.GGUFReader(tokenizer_path, "r")
|
neox_reader = gguf.GGUFReader(tokenizer_path, "r")
|
||||||
|
|
||||||
field = neox_reader.get_field(gguf.Keys.Tokenizer.MODEL)
|
field = neox_reader.get_field(gguf.Keys.Tokenizer.MODEL)
|
||||||
@ -2793,17 +2763,16 @@ class MambaModel(Model):
|
|||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
if name.endswith(".A_log"):
|
if name.endswith(".A_log"):
|
||||||
print("A_log --> A ==> " + new_name)
|
logger.debug("A_log --> A ==> " + new_name)
|
||||||
data_torch = -torch.exp(data_torch)
|
data_torch = -torch.exp(data_torch)
|
||||||
|
|
||||||
# assuming token_embd.weight is seen before output.weight
|
# assuming token_embd.weight is seen before output.weight
|
||||||
if tok_embd is not None and new_name == output_name:
|
if tok_embd is not None and new_name == output_name:
|
||||||
if torch.equal(tok_embd, data_torch):
|
if torch.equal(tok_embd, data_torch):
|
||||||
print(f"{output_name} is equivalent to {tok_embd_name}, omitting")
|
logger.debug(f"{output_name} is equivalent to {tok_embd_name}, omitting")
|
||||||
continue
|
continue
|
||||||
if new_name == tok_embd_name:
|
if new_name == tok_embd_name:
|
||||||
tok_embd = data_torch
|
tok_embd = data_torch
|
||||||
@ -2826,7 +2795,7 @@ class MambaModel(Model):
|
|||||||
if self.ftype == 1 and data_dtype == np.float32 and new_weight_name.endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and new_weight_name.endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
logger.info(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
@ -2885,8 +2854,7 @@ class OlmoModel(Model):
|
|||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
raise ValueError(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
data_dtype = data.dtype
|
data_dtype = data.dtype
|
||||||
@ -2903,7 +2871,7 @@ class OlmoModel(Model):
|
|||||||
if self.ftype == 1 and data_dtype == np.float32 and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
logger.info(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
@ -2936,6 +2904,7 @@ def parse_args() -> argparse.Namespace:
|
|||||||
)
|
)
|
||||||
parser.add_argument("--use-temp-file", action="store_true", help="use the tempfile library while processing (helpful when running out of memory, process killed)")
|
parser.add_argument("--use-temp-file", action="store_true", help="use the tempfile library while processing (helpful when running out of memory, process killed)")
|
||||||
parser.add_argument("--model-name", type=str, default=None, help="name of the model")
|
parser.add_argument("--model-name", type=str, default=None, help="name of the model")
|
||||||
|
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@ -2943,6 +2912,8 @@ def parse_args() -> argparse.Namespace:
|
|||||||
def main() -> None:
|
def main() -> None:
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||||
|
|
||||||
dir_model = args.model
|
dir_model = args.model
|
||||||
|
|
||||||
if args.awq_path:
|
if args.awq_path:
|
||||||
@ -2951,15 +2922,15 @@ def main() -> None:
|
|||||||
tmp_model_path = args.model / "weighted_model"
|
tmp_model_path = args.model / "weighted_model"
|
||||||
dir_model = tmp_model_path
|
dir_model = tmp_model_path
|
||||||
if tmp_model_path.is_dir():
|
if tmp_model_path.is_dir():
|
||||||
print(f"{tmp_model_path} exists as a weighted model.")
|
logger.info(f"{tmp_model_path} exists as a weighted model.")
|
||||||
else:
|
else:
|
||||||
tmp_model_path.mkdir(parents=True, exist_ok=True)
|
tmp_model_path.mkdir(parents=True, exist_ok=True)
|
||||||
print("Saving new weighted model ...")
|
logger.info("Saving new weighted model ...")
|
||||||
add_scale_weights(str(args.model), str(args.awq_path), str(tmp_model_path))
|
add_scale_weights(str(args.model), str(args.awq_path), str(tmp_model_path))
|
||||||
print(f"Saved weighted model at {tmp_model_path}.")
|
logger.info(f"Saved weighted model at {tmp_model_path}.")
|
||||||
|
|
||||||
if not dir_model.is_dir():
|
if not dir_model.is_dir():
|
||||||
print(f'Error: {args.model} is not a directory', file=sys.stderr)
|
logger.error(f'Error: {args.model} is not a directory')
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
ftype_map = {
|
ftype_map = {
|
||||||
@ -2973,7 +2944,7 @@ def main() -> None:
|
|||||||
# output in the same directory as the model by default
|
# output in the same directory as the model by default
|
||||||
fname_out = dir_model / f'ggml-model-{args.outtype}.gguf'
|
fname_out = dir_model / f'ggml-model-{args.outtype}.gguf'
|
||||||
|
|
||||||
print(f"Loading model: {dir_model.name}")
|
logger.info(f"Loading model: {dir_model.name}")
|
||||||
|
|
||||||
hparams = Model.load_hparams(dir_model)
|
hparams = Model.load_hparams(dir_model)
|
||||||
|
|
||||||
@ -2981,20 +2952,20 @@ def main() -> None:
|
|||||||
model_class = Model.from_model_architecture(hparams["architectures"][0])
|
model_class = Model.from_model_architecture(hparams["architectures"][0])
|
||||||
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file)
|
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file)
|
||||||
|
|
||||||
print("Set model parameters")
|
logger.info("Set model parameters")
|
||||||
model_instance.set_gguf_parameters()
|
model_instance.set_gguf_parameters()
|
||||||
|
|
||||||
print("Set model tokenizer")
|
logger.info("Set model tokenizer")
|
||||||
model_instance.set_vocab()
|
model_instance.set_vocab()
|
||||||
|
|
||||||
if args.vocab_only:
|
if args.vocab_only:
|
||||||
print(f"Exporting model vocab to '{fname_out}'")
|
logger.info(f"Exporting model vocab to '{fname_out}'")
|
||||||
model_instance.write_vocab()
|
model_instance.write_vocab()
|
||||||
else:
|
else:
|
||||||
print(f"Exporting model to '{fname_out}'")
|
logger.info(f"Exporting model to '{fname_out}'")
|
||||||
model_instance.write()
|
model_instance.write()
|
||||||
|
|
||||||
print(f"Model successfully exported to '{fname_out}'")
|
logger.info(f"Model successfully exported to '{fname_out}'")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
@ -14,6 +15,8 @@ if 'NO_LOCAL_GGUF' not in os.environ:
|
|||||||
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
|
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
|
||||||
import gguf
|
import gguf
|
||||||
|
|
||||||
|
logger = logging.getLogger("ggml-to-gguf")
|
||||||
|
|
||||||
|
|
||||||
class GGMLFormat(IntEnum):
|
class GGMLFormat(IntEnum):
|
||||||
GGML = 0
|
GGML = 0
|
||||||
@ -125,7 +128,6 @@ class Tensor:
|
|||||||
self.start_offset = offset
|
self.start_offset = offset
|
||||||
self.len_bytes = n_bytes
|
self.len_bytes = n_bytes
|
||||||
offset += n_bytes
|
offset += n_bytes
|
||||||
# print(n_dims, name_len, dtype, self.dims, self.name, pad)
|
|
||||||
return offset - orig_offset
|
return offset - orig_offset
|
||||||
|
|
||||||
|
|
||||||
@ -175,7 +177,7 @@ class GGMLModel:
|
|||||||
offset += self.validate_header(data, offset)
|
offset += self.validate_header(data, offset)
|
||||||
hp = Hyperparameters()
|
hp = Hyperparameters()
|
||||||
offset += hp.load(data, offset)
|
offset += hp.load(data, offset)
|
||||||
print(f'* File format: {self.file_format.name}v{self.format_version} with ftype {hp.ftype.name}')
|
logger.info(f'* File format: {self.file_format.name}v{self.format_version} with ftype {hp.ftype.name}')
|
||||||
self.validate_conversion(hp.ftype)
|
self.validate_conversion(hp.ftype)
|
||||||
vocab = Vocab(load_scores = self.file_format > GGMLFormat.GGML)
|
vocab = Vocab(load_scores = self.file_format > GGMLFormat.GGML)
|
||||||
offset += vocab.load(data, offset, hp.n_vocab)
|
offset += vocab.load(data, offset, hp.n_vocab)
|
||||||
@ -215,12 +217,12 @@ class GGMLToGGUF:
|
|||||||
if float(hp.n_head) / float(x) == gqa:
|
if float(hp.n_head) / float(x) == gqa:
|
||||||
n_kv_head = x
|
n_kv_head = x
|
||||||
assert n_kv_head is not None, "Couldn't determine n_kv_head from GQA param"
|
assert n_kv_head is not None, "Couldn't determine n_kv_head from GQA param"
|
||||||
print(f'- Guessed n_kv_head = {n_kv_head} based on GQA {cfg.gqa}')
|
logger.info(f'- Guessed n_kv_head = {n_kv_head} based on GQA {cfg.gqa}')
|
||||||
self.n_kv_head = n_kv_head
|
self.n_kv_head = n_kv_head
|
||||||
self.name_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.LLAMA, ggml_model.hyperparameters.n_layer)
|
self.name_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.LLAMA, ggml_model.hyperparameters.n_layer)
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
print('* Preparing to save GGUF file')
|
logger.info('* Preparing to save GGUF file')
|
||||||
gguf_writer = gguf.GGUFWriter(
|
gguf_writer = gguf.GGUFWriter(
|
||||||
self.cfg.output,
|
self.cfg.output,
|
||||||
gguf.MODEL_ARCH_NAMES[gguf.MODEL_ARCH.LLAMA],
|
gguf.MODEL_ARCH_NAMES[gguf.MODEL_ARCH.LLAMA],
|
||||||
@ -230,11 +232,11 @@ class GGMLToGGUF:
|
|||||||
if self.special_vocab is not None:
|
if self.special_vocab is not None:
|
||||||
self.special_vocab.add_to_gguf(gguf_writer)
|
self.special_vocab.add_to_gguf(gguf_writer)
|
||||||
self.add_tensors(gguf_writer)
|
self.add_tensors(gguf_writer)
|
||||||
print(" gguf: write header")
|
logger.info(" gguf: write header")
|
||||||
gguf_writer.write_header_to_file()
|
gguf_writer.write_header_to_file()
|
||||||
print(" gguf: write metadata")
|
logger.info(" gguf: write metadata")
|
||||||
gguf_writer.write_kv_data_to_file()
|
gguf_writer.write_kv_data_to_file()
|
||||||
print(" gguf: write tensors")
|
logger.info(" gguf: write tensors")
|
||||||
gguf_writer.write_tensors_to_file()
|
gguf_writer.write_tensors_to_file()
|
||||||
gguf_writer.close()
|
gguf_writer.close()
|
||||||
|
|
||||||
@ -250,7 +252,7 @@ class GGMLToGGUF:
|
|||||||
name = cfg.name if cfg.name is not None else cfg.input.name
|
name = cfg.name if cfg.name is not None else cfg.input.name
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
name = None
|
name = None
|
||||||
print('* Adding model parameters and KV items')
|
logger.info('* Adding model parameters and KV items')
|
||||||
if name is not None:
|
if name is not None:
|
||||||
gguf_writer.add_name(name)
|
gguf_writer.add_name(name)
|
||||||
gguf_writer.add_description(desc)
|
gguf_writer.add_description(desc)
|
||||||
@ -287,7 +289,7 @@ class GGMLToGGUF:
|
|||||||
toktypes = []
|
toktypes = []
|
||||||
if self.vocab_override is not None:
|
if self.vocab_override is not None:
|
||||||
vo = self.vocab_override
|
vo = self.vocab_override
|
||||||
print('* Adding vocab item(s)')
|
logger.info('* Adding vocab item(s)')
|
||||||
for (idx, (vbytes, score, ttype)) in enumerate(vo.all_tokens()):
|
for (idx, (vbytes, score, ttype)) in enumerate(vo.all_tokens()):
|
||||||
tokens.append(vbytes)
|
tokens.append(vbytes)
|
||||||
scores.append(score)
|
scores.append(score)
|
||||||
@ -299,7 +301,7 @@ class GGMLToGGUF:
|
|||||||
if len(toktypes) > 0:
|
if len(toktypes) > 0:
|
||||||
gguf_writer.add_token_types(toktypes)
|
gguf_writer.add_token_types(toktypes)
|
||||||
return
|
return
|
||||||
print(f'* Adding {hp.n_vocab} vocab item(s)')
|
logger.info(f'* Adding {hp.n_vocab} vocab item(s)')
|
||||||
assert len(self.model.vocab.items) >= 3, 'Cannot handle unexpectedly short model vocab'
|
assert len(self.model.vocab.items) >= 3, 'Cannot handle unexpectedly short model vocab'
|
||||||
for (tokid, (vbytes, vscore)) in enumerate(self.model.vocab.items):
|
for (tokid, (vbytes, vscore)) in enumerate(self.model.vocab.items):
|
||||||
tt = 1 # Normal
|
tt = 1 # Normal
|
||||||
@ -334,7 +336,7 @@ class GGMLToGGUF:
|
|||||||
def add_tensors(self, gguf_writer):
|
def add_tensors(self, gguf_writer):
|
||||||
tensor_map = self.name_map
|
tensor_map = self.name_map
|
||||||
data = self.data
|
data = self.data
|
||||||
print(f'* Adding {len(self.model.tensors)} tensor(s)')
|
logger.info(f'* Adding {len(self.model.tensors)} tensor(s)')
|
||||||
for tensor in self.model.tensors:
|
for tensor in self.model.tensors:
|
||||||
name = str(tensor.name, 'UTF-8')
|
name = str(tensor.name, 'UTF-8')
|
||||||
mapped_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
|
mapped_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
|
||||||
@ -344,7 +346,6 @@ class GGMLToGGUF:
|
|||||||
temp = tempdims[1]
|
temp = tempdims[1]
|
||||||
tempdims[1] = tempdims[0]
|
tempdims[1] = tempdims[0]
|
||||||
tempdims[0] = temp
|
tempdims[0] = temp
|
||||||
# print(f'+ {tensor.name} | {mapped_name} {tensor.dims} :: {tempdims}')
|
|
||||||
gguf_writer.add_tensor(
|
gguf_writer.add_tensor(
|
||||||
mapped_name,
|
mapped_name,
|
||||||
data[tensor.start_offset:tensor.start_offset + tensor.len_bytes],
|
data[tensor.start_offset:tensor.start_offset + tensor.len_bytes],
|
||||||
@ -401,33 +402,35 @@ def handle_args():
|
|||||||
help="directory containing tokenizer.model, if separate from model file - only meaningful with --model-metadata-dir")
|
help="directory containing tokenizer.model, if separate from model file - only meaningful with --model-metadata-dir")
|
||||||
parser.add_argument("--vocabtype", default="spm,hfft",
|
parser.add_argument("--vocabtype", default="spm,hfft",
|
||||||
help="vocab format - only meaningful with --model-metadata-dir and/or --vocab-dir (default: spm,hfft)")
|
help="vocab format - only meaningful with --model-metadata-dir and/or --vocab-dir (default: spm,hfft)")
|
||||||
|
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
cfg = handle_args()
|
cfg = handle_args()
|
||||||
print(f'* Using config: {cfg}')
|
logging.basicConfig(level=logging.DEBUG if cfg.verbose else logging.INFO)
|
||||||
print('\n=== WARNING === Be aware that this conversion script is best-effort. Use a native GGUF model if possible. === WARNING ===\n')
|
logger.info(f'* Using config: {cfg}')
|
||||||
|
logger.warning('=== WARNING === Be aware that this conversion script is best-effort. Use a native GGUF model if possible. === WARNING ===')
|
||||||
if cfg.model_metadata_dir is None and (cfg.gqa == 1 or cfg.eps == '5.0e-06'):
|
if cfg.model_metadata_dir is None and (cfg.gqa == 1 or cfg.eps == '5.0e-06'):
|
||||||
print('- Note: If converting LLaMA2, specifying "--eps 1e-5" is required. 70B models also need "--gqa 8".')
|
logger.info('- Note: If converting LLaMA2, specifying "--eps 1e-5" is required. 70B models also need "--gqa 8".')
|
||||||
data = np.memmap(cfg.input, mode = 'r')
|
data = np.memmap(cfg.input, mode = 'r')
|
||||||
model = GGMLModel()
|
model = GGMLModel()
|
||||||
print('* Scanning GGML input file')
|
logger.info('* Scanning GGML input file')
|
||||||
offset = model.load(data, 0) # noqa
|
offset = model.load(data, 0) # noqa
|
||||||
print(f'* GGML model hyperparameters: {model.hyperparameters}')
|
logger.info(f'* GGML model hyperparameters: {model.hyperparameters}')
|
||||||
vocab_override = None
|
vocab_override = None
|
||||||
params_override = None
|
params_override = None
|
||||||
special_vocab = None
|
special_vocab = None
|
||||||
if cfg.model_metadata_dir is not None:
|
if cfg.model_metadata_dir is not None:
|
||||||
(params_override, vocab_override, special_vocab) = handle_metadata(cfg, model.hyperparameters)
|
(params_override, vocab_override, special_vocab) = handle_metadata(cfg, model.hyperparameters)
|
||||||
print('!! Note: When overriding params the --gqa, --eps and --context-length options are ignored.')
|
logger.info('!! Note: When overriding params the --gqa, --eps and --context-length options are ignored.')
|
||||||
print(f'* Overriding params: {params_override}')
|
logger.info(f'* Overriding params: {params_override}')
|
||||||
print(f'* Overriding vocab: {vocab_override}')
|
logger.info(f'* Overriding vocab: {vocab_override}')
|
||||||
print(f'* Special vocab: {special_vocab}')
|
logger.info(f'* Special vocab: {special_vocab}')
|
||||||
else:
|
else:
|
||||||
print('\n=== WARNING === Special tokens may not be converted correctly. Use --model-metadata-dir if possible === WARNING ===\n')
|
logger.warning('\n=== WARNING === Special tokens may not be converted correctly. Use --model-metadata-dir if possible === WARNING ===\n')
|
||||||
if model.file_format == GGMLFormat.GGML:
|
if model.file_format == GGMLFormat.GGML:
|
||||||
print('! This is a very old GGML file that does not contain vocab scores. Strongly recommend using model metadata!')
|
logger.info('! This is a very old GGML file that does not contain vocab scores. Strongly recommend using model metadata!')
|
||||||
converter = GGMLToGGUF(
|
converter = GGMLToGGUF(
|
||||||
model, data, cfg,
|
model, data, cfg,
|
||||||
params_override = params_override,
|
params_override = params_override,
|
||||||
@ -435,7 +438,7 @@ def main():
|
|||||||
special_vocab = special_vocab
|
special_vocab = special_vocab
|
||||||
)
|
)
|
||||||
converter.save()
|
converter.save()
|
||||||
print(f'* Successful completion. Output saved to: {cfg.output}')
|
logger.info(f'* Successful completion. Output saved to: {cfg.output}')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
@ -15,6 +16,8 @@ if 'NO_LOCAL_GGUF' not in os.environ:
|
|||||||
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
|
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
|
||||||
import gguf
|
import gguf
|
||||||
|
|
||||||
|
logger = logging.getLogger("lora-to-gguf")
|
||||||
|
|
||||||
NUMPY_TYPE_TO_FTYPE: dict[str, int] = {"float32": 0, "float16": 1}
|
NUMPY_TYPE_TO_FTYPE: dict[str, int] = {"float32": 0, "float16": 1}
|
||||||
|
|
||||||
|
|
||||||
@ -48,11 +51,9 @@ def write_tensor_header(fout: BinaryIO, name: str, shape: Sequence[int], data_ty
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
if len(sys.argv) < 2:
|
if len(sys.argv) < 2:
|
||||||
print(f"Usage: python {sys.argv[0]} <path> [arch]")
|
logger.info(f"Usage: python {sys.argv[0]} <path> [arch]")
|
||||||
print(
|
logger.info("Path must contain HuggingFace PEFT LoRA files 'adapter_config.json' and 'adapter_model.bin'")
|
||||||
"Path must contain HuggingFace PEFT LoRA files 'adapter_config.json' and 'adapter_model.bin'"
|
logger.info(f"Arch must be one of {list(gguf.MODEL_ARCH_NAMES.values())} (default: llama)")
|
||||||
)
|
|
||||||
print(f"Arch must be one of {list(gguf.MODEL_ARCH_NAMES.values())} (default: llama)")
|
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
input_json = os.path.join(sys.argv[1], "adapter_config.json")
|
input_json = os.path.join(sys.argv[1], "adapter_config.json")
|
||||||
@ -70,7 +71,7 @@ if __name__ == '__main__':
|
|||||||
arch_name = sys.argv[2] if len(sys.argv) == 3 else "llama"
|
arch_name = sys.argv[2] if len(sys.argv) == 3 else "llama"
|
||||||
|
|
||||||
if arch_name not in gguf.MODEL_ARCH_NAMES.values():
|
if arch_name not in gguf.MODEL_ARCH_NAMES.values():
|
||||||
print(f"Error: unsupported architecture {arch_name}")
|
logger.error(f"Error: unsupported architecture {arch_name}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
arch = list(gguf.MODEL_ARCH_NAMES.keys())[list(gguf.MODEL_ARCH_NAMES.values()).index(arch_name)]
|
arch = list(gguf.MODEL_ARCH_NAMES.keys())[list(gguf.MODEL_ARCH_NAMES.values()).index(arch_name)]
|
||||||
@ -80,21 +81,21 @@ if __name__ == '__main__':
|
|||||||
params = json.load(f)
|
params = json.load(f)
|
||||||
|
|
||||||
if params["peft_type"] != "LORA":
|
if params["peft_type"] != "LORA":
|
||||||
print(f"Error: unsupported adapter type {params['peft_type']}, expected LORA")
|
logger.error(f"Error: unsupported adapter type {params['peft_type']}, expected LORA")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if params["fan_in_fan_out"] is True:
|
if params["fan_in_fan_out"] is True:
|
||||||
print("Error: param fan_in_fan_out is not supported")
|
logger.error("Error: param fan_in_fan_out is not supported")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if params["bias"] is not None and params["bias"] != "none":
|
if params["bias"] is not None and params["bias"] != "none":
|
||||||
print("Error: param bias is not supported")
|
logger.error("Error: param bias is not supported")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# TODO: these seem to be layers that have been trained but without lora.
|
# TODO: these seem to be layers that have been trained but without lora.
|
||||||
# doesn't seem widely used but eventually should be supported
|
# doesn't seem widely used but eventually should be supported
|
||||||
if params["modules_to_save"] is not None and len(params["modules_to_save"]) > 0:
|
if params["modules_to_save"] is not None and len(params["modules_to_save"]) > 0:
|
||||||
print("Error: param modules_to_save is not supported")
|
logger.error("Error: param modules_to_save is not supported")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
with open(output_path, "wb") as fout:
|
with open(output_path, "wb") as fout:
|
||||||
@ -125,13 +126,13 @@ if __name__ == '__main__':
|
|||||||
suffix = k[-len(lora_suffixes[0]):]
|
suffix = k[-len(lora_suffixes[0]):]
|
||||||
k = k[: -len(lora_suffixes[0])]
|
k = k[: -len(lora_suffixes[0])]
|
||||||
else:
|
else:
|
||||||
print(f"Error: unrecognized tensor name {orig_k}")
|
logger.error(f"Error: unrecognized tensor name {orig_k}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
tname = name_map.get_name(k)
|
tname = name_map.get_name(k)
|
||||||
if tname is None:
|
if tname is None:
|
||||||
print(f"Error: could not map tensor name {orig_k}")
|
logger.error(f"Error: could not map tensor name {orig_k}")
|
||||||
print(" Note: the arch parameter must be specified if the model is not llama")
|
logger.error(" Note: the arch parameter must be specified if the model is not llama")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if suffix == ".lora_A.weight":
|
if suffix == ".lora_A.weight":
|
||||||
@ -141,8 +142,8 @@ if __name__ == '__main__':
|
|||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
print(f"{k} => {tname} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB")
|
logger.info(f"{k} => {tname} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB")
|
||||||
write_tensor_header(fout, tname, t.shape, t.dtype)
|
write_tensor_header(fout, tname, t.shape, t.dtype)
|
||||||
t.tofile(fout)
|
t.tofile(fout)
|
||||||
|
|
||||||
print(f"Converted {input_json} and {input_model} to {output_path}")
|
logger.info(f"Converted {input_json} and {input_model} to {output_path}")
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@ -14,6 +15,8 @@ if 'NO_LOCAL_GGUF' not in os.environ:
|
|||||||
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
|
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
|
||||||
import gguf
|
import gguf
|
||||||
|
|
||||||
|
logger = logging.getLogger("persimmon-to-gguf")
|
||||||
|
|
||||||
|
|
||||||
def _flatten_dict(dct, tensors, prefix=None):
|
def _flatten_dict(dct, tensors, prefix=None):
|
||||||
assert isinstance(dct, dict)
|
assert isinstance(dct, dict)
|
||||||
@ -30,9 +33,9 @@ def _flatten_dict(dct, tensors, prefix=None):
|
|||||||
|
|
||||||
def _get_sentencepiece_tokenizer_info(dir_model: Path):
|
def _get_sentencepiece_tokenizer_info(dir_model: Path):
|
||||||
tokenizer_path = dir_model / 'adept_vocab.model'
|
tokenizer_path = dir_model / 'adept_vocab.model'
|
||||||
print('gguf: getting sentencepiece tokenizer from', tokenizer_path)
|
logger.info('getting sentencepiece tokenizer from', tokenizer_path)
|
||||||
tokenizer = SentencePieceProcessor(str(tokenizer_path))
|
tokenizer = SentencePieceProcessor(str(tokenizer_path))
|
||||||
print('gguf: adding tokens')
|
logger.info('adding tokens')
|
||||||
tokens: list[bytes] = []
|
tokens: list[bytes] = []
|
||||||
scores: list[float] = []
|
scores: list[float] = []
|
||||||
toktypes: list[int] = []
|
toktypes: list[int] = []
|
||||||
@ -67,8 +70,10 @@ def main():
|
|||||||
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
|
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
|
||||||
parser.add_argument("--ckpt-path", type=Path, help="path to persimmon checkpoint .pt file")
|
parser.add_argument("--ckpt-path", type=Path, help="path to persimmon checkpoint .pt file")
|
||||||
parser.add_argument("--model-dir", type=Path, help="directory containing model e.g. 8b_chat_model_release")
|
parser.add_argument("--model-dir", type=Path, help="directory containing model e.g. 8b_chat_model_release")
|
||||||
parser.add_argument("--adept-inference-dir", type=str, help="path to adept-inference code directory")
|
parser.add_argument("--adept-inference-dir", type=str, help="path to adept-inference code directory")
|
||||||
|
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||||
sys.path.append(str(args.adept_inference_dir))
|
sys.path.append(str(args.adept_inference_dir))
|
||||||
persimmon_model = torch.load(args.ckpt_path)
|
persimmon_model = torch.load(args.ckpt_path)
|
||||||
hparams = persimmon_model['args']
|
hparams = persimmon_model['args']
|
||||||
@ -107,7 +112,7 @@ def main():
|
|||||||
gguf_writer.add_eos_token_id(71013)
|
gguf_writer.add_eos_token_id(71013)
|
||||||
|
|
||||||
tensor_map = gguf.get_tensor_name_map(arch, block_count)
|
tensor_map = gguf.get_tensor_name_map(arch, block_count)
|
||||||
print(tensor_map)
|
logger.info(tensor_map)
|
||||||
for name in tensors.keys():
|
for name in tensors.keys():
|
||||||
data_torch = tensors[name]
|
data_torch = tensors[name]
|
||||||
if name.endswith(".self_attention.rotary_emb.inv_freq"):
|
if name.endswith(".self_attention.rotary_emb.inv_freq"):
|
||||||
@ -117,22 +122,21 @@ def main():
|
|||||||
data = data_torch.to(torch.float32).squeeze().numpy()
|
data = data_torch.to(torch.float32).squeeze().numpy()
|
||||||
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print("Can not map tensor '" + name + "'")
|
raise ValueError(f"Can not map tensor '{name}'")
|
||||||
sys.exit()
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
|
logger.debug(f"{new_name}, n_dims = {str(n_dims)}, {str(old_dtype)} --> {str(data.dtype)}")
|
||||||
gguf_writer.add_tensor(new_name, data)
|
gguf_writer.add_tensor(new_name, data)
|
||||||
print("gguf: write header")
|
logger.info("gguf: write header")
|
||||||
gguf_writer.write_header_to_file()
|
gguf_writer.write_header_to_file()
|
||||||
print("gguf: write metadata")
|
logger.info("gguf: write metadata")
|
||||||
gguf_writer.write_kv_data_to_file()
|
gguf_writer.write_kv_data_to_file()
|
||||||
print("gguf: write tensors")
|
logger.info("gguf: write tensors")
|
||||||
gguf_writer.write_tensors_to_file()
|
gguf_writer.write_tensors_to_file()
|
||||||
|
|
||||||
gguf_writer.close()
|
gguf_writer.close()
|
||||||
|
|
||||||
print(f"gguf: model successfully exported to '{args.outfile}'")
|
logger.info(f"gguf: model successfully exported to '{args.outfile}'")
|
||||||
print("")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
60
convert.py
60
convert.py
@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import enum
|
import enum
|
||||||
@ -35,6 +36,8 @@ import gguf
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing_extensions import Self, TypeAlias
|
from typing_extensions import Self, TypeAlias
|
||||||
|
|
||||||
|
logger = logging.getLogger("convert")
|
||||||
|
|
||||||
if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'):
|
if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'):
|
||||||
faulthandler.register(signal.SIGUSR1)
|
faulthandler.register(signal.SIGUSR1)
|
||||||
|
|
||||||
@ -643,7 +646,6 @@ class LlamaHfVocab(Vocab):
|
|||||||
|
|
||||||
|
|
||||||
def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:
|
def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:
|
||||||
# print( "permute debug " + str(weights.shape[0]) + " x " + str(weights.shape[1]) + " nhead " + str(n_head) + " nheadkv " + str(n_kv_head) )
|
|
||||||
if n_head_kv is not None and n_head != n_head_kv:
|
if n_head_kv is not None and n_head != n_head_kv:
|
||||||
n_head = n_head_kv
|
n_head = n_head_kv
|
||||||
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
||||||
@ -1033,12 +1035,12 @@ def check_vocab_size(params: Params, vocab: BaseVocab, pad_vocab: bool = False)
|
|||||||
|
|
||||||
# Check for a vocab size mismatch
|
# Check for a vocab size mismatch
|
||||||
if params.n_vocab == vocab.vocab_size:
|
if params.n_vocab == vocab.vocab_size:
|
||||||
print("Ignoring added_tokens.json since model matches vocab size without it.")
|
logger.warning("Ignoring added_tokens.json since model matches vocab size without it.")
|
||||||
return
|
return
|
||||||
|
|
||||||
if pad_vocab and params.n_vocab > vocab.vocab_size:
|
if pad_vocab and params.n_vocab > vocab.vocab_size:
|
||||||
pad_count = params.n_vocab - vocab.vocab_size
|
pad_count = params.n_vocab - vocab.vocab_size
|
||||||
print(
|
logger.debug(
|
||||||
f"Padding vocab with {pad_count} token(s) - <dummy00001> through <dummy{pad_count:05}>"
|
f"Padding vocab with {pad_count} token(s) - <dummy00001> through <dummy{pad_count:05}>"
|
||||||
)
|
)
|
||||||
for i in range(1, pad_count + 1):
|
for i in range(1, pad_count + 1):
|
||||||
@ -1166,7 +1168,7 @@ class OutputFile:
|
|||||||
elapsed = time.time() - start
|
elapsed = time.time() - start
|
||||||
size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape)
|
size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape)
|
||||||
padi = len(str(len(model)))
|
padi = len(str(len(model)))
|
||||||
print(
|
logger.info(
|
||||||
f"[{i + 1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}"
|
f"[{i + 1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}"
|
||||||
)
|
)
|
||||||
self.gguf.write_tensor_data(ndarray)
|
self.gguf.write_tensor_data(ndarray)
|
||||||
@ -1281,12 +1283,12 @@ def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) ->
|
|||||||
# HF models permut or pack some of the tensors, so we need to undo that
|
# HF models permut or pack some of the tensors, so we need to undo that
|
||||||
for i in itertools.count():
|
for i in itertools.count():
|
||||||
if f"model.layers.{i}.self_attn.q_proj.weight" in model:
|
if f"model.layers.{i}.self_attn.q_proj.weight" in model:
|
||||||
print(f"Permuting layer {i}")
|
logger.debug(f"Permuting layer {i}")
|
||||||
tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head, params.n_head)
|
tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head, params.n_head)
|
||||||
tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_head_kv)
|
tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_head_kv)
|
||||||
# tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
|
# tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
|
||||||
elif f"model.layers.{i}.self_attn.W_pack.weight" in model:
|
elif f"model.layers.{i}.self_attn.W_pack.weight" in model:
|
||||||
print(f"Unpacking and permuting layer {i}")
|
logger.debug(f"Unpacking and permuting layer {i}")
|
||||||
tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head, params.n_head)
|
tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head, params.n_head)
|
||||||
tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head, params.n_head_kv)
|
tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head, params.n_head_kv)
|
||||||
tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = part_lazy (model[f"model.layers.{i}.self_attn.W_pack.weight"], 2)
|
tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = part_lazy (model[f"model.layers.{i}.self_attn.W_pack.weight"], 2)
|
||||||
@ -1299,15 +1301,15 @@ def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) ->
|
|||||||
tensor_type, name_new = tmap.get_type_and_name(name, try_suffixes = (".weight", ".bias")) or (None, None)
|
tensor_type, name_new = tmap.get_type_and_name(name, try_suffixes = (".weight", ".bias")) or (None, None)
|
||||||
if name_new is None:
|
if name_new is None:
|
||||||
if skip_unknown:
|
if skip_unknown:
|
||||||
print(f"Unexpected tensor name: {name} - skipping")
|
logger.warning(f"Unexpected tensor name: {name} - skipping")
|
||||||
continue
|
continue
|
||||||
raise ValueError(f"Unexpected tensor name: {name}. Use --skip-unknown to ignore it (e.g. LLaVA)")
|
raise ValueError(f"Unexpected tensor name: {name}. Use --skip-unknown to ignore it (e.g. LLaVA)")
|
||||||
|
|
||||||
if tensor_type in should_skip:
|
if tensor_type in should_skip:
|
||||||
print(f"skipping tensor {name_new}")
|
logger.debug(f"skipping tensor {name_new}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}")
|
logger.debug(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}")
|
||||||
out[name_new] = lazy_tensor
|
out[name_new] = lazy_tensor
|
||||||
|
|
||||||
return out
|
return out
|
||||||
@ -1372,7 +1374,7 @@ def load_some_model(path: Path) -> ModelPlus:
|
|||||||
paths = find_multifile_paths(path)
|
paths = find_multifile_paths(path)
|
||||||
models_plus: list[ModelPlus] = []
|
models_plus: list[ModelPlus] = []
|
||||||
for path in paths:
|
for path in paths:
|
||||||
print(f"Loading model file {path}")
|
logger.info(f"Loading model file {path}")
|
||||||
models_plus.append(lazy_load_file(path))
|
models_plus.append(lazy_load_file(path))
|
||||||
|
|
||||||
model_plus = merge_multifile_models(models_plus)
|
model_plus = merge_multifile_models(models_plus)
|
||||||
@ -1413,7 +1415,7 @@ class VocabFactory:
|
|||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f"Could not find a tokenizer matching any of {vocab_types}")
|
raise FileNotFoundError(f"Could not find a tokenizer matching any of {vocab_types}")
|
||||||
|
|
||||||
print(f"Loaded vocab file {vocab.fname_tokenizer!r}, type {vocab.name!r}")
|
logger.info(f"Loaded vocab file {vocab.fname_tokenizer!r}, type {vocab.name!r}")
|
||||||
return vocab
|
return vocab
|
||||||
|
|
||||||
def load_vocab(self, vocab_types: list[str] | None, model_parent_path: Path) -> tuple[BaseVocab, gguf.SpecialVocab]:
|
def load_vocab(self, vocab_types: list[str] | None, model_parent_path: Path) -> tuple[BaseVocab, gguf.SpecialVocab]:
|
||||||
@ -1438,19 +1440,19 @@ def default_outfile(model_paths: list[Path], file_type: GGMLFileType) -> Path:
|
|||||||
}[file_type]
|
}[file_type]
|
||||||
ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf"
|
ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf"
|
||||||
if ret in model_paths:
|
if ret in model_paths:
|
||||||
sys.stderr.write(
|
logger.error(
|
||||||
f"Error: Default output path ({ret}) would overwrite the input. "
|
f"Error: Default output path ({ret}) would overwrite the input. "
|
||||||
"Please explicitly specify a path using --outfile.\n")
|
"Please explicitly specify a path using --outfile.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def do_dump_model(model_plus: ModelPlus) -> None:
|
def do_dump_model(model_plus: ModelPlus) -> None:
|
||||||
print(f"model_plus.paths = {model_plus.paths!r}")
|
print(f"model_plus.paths = {model_plus.paths!r}") # noqa: NP100
|
||||||
print(f"model_plus.format = {model_plus.format!r}")
|
print(f"model_plus.format = {model_plus.format!r}") # noqa: NP100
|
||||||
print(f"model_plus.vocab = {model_plus.vocab!r}")
|
print(f"model_plus.vocab = {model_plus.vocab!r}") # noqa: NP100
|
||||||
for name, lazy_tensor in model_plus.model.items():
|
for name, lazy_tensor in model_plus.model.items():
|
||||||
print(f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}")
|
print(f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}") # noqa: NP100
|
||||||
|
|
||||||
|
|
||||||
def main(args_in: list[str] | None = None) -> None:
|
def main(args_in: list[str] | None = None) -> None:
|
||||||
@ -1473,8 +1475,18 @@ def main(args_in: list[str] | None = None) -> None:
|
|||||||
parser.add_argument("--big-endian", action="store_true", help="model is executed on big endian machine")
|
parser.add_argument("--big-endian", action="store_true", help="model is executed on big endian machine")
|
||||||
parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
|
parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
|
||||||
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
|
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
|
||||||
|
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||||
|
|
||||||
args = parser.parse_args(args_in)
|
args = parser.parse_args(args_in)
|
||||||
|
|
||||||
|
if args.verbose:
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
elif args.dump_single or args.dump:
|
||||||
|
# Avoid printing anything besides the dump output
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
else:
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
if args.no_vocab and args.vocab_only:
|
if args.no_vocab and args.vocab_only:
|
||||||
raise ValueError("--vocab-only does not make sense with --no-vocab")
|
raise ValueError("--vocab-only does not make sense with --no-vocab")
|
||||||
|
|
||||||
@ -1491,6 +1503,7 @@ def main(args_in: list[str] | None = None) -> None:
|
|||||||
if args.dump:
|
if args.dump:
|
||||||
do_dump_model(model_plus)
|
do_dump_model(model_plus)
|
||||||
return
|
return
|
||||||
|
|
||||||
endianess = gguf.GGUFEndian.LITTLE
|
endianess = gguf.GGUFEndian.LITTLE
|
||||||
if args.big_endian:
|
if args.big_endian:
|
||||||
endianess = gguf.GGUFEndian.BIG
|
endianess = gguf.GGUFEndian.BIG
|
||||||
@ -1513,7 +1526,7 @@ def main(args_in: list[str] | None = None) -> None:
|
|||||||
"q8_0": GGMLFileType.MostlyQ8_0,
|
"q8_0": GGMLFileType.MostlyQ8_0,
|
||||||
}[args.outtype]
|
}[args.outtype]
|
||||||
|
|
||||||
print(f"params = {params}")
|
logger.info(f"params = {params}")
|
||||||
|
|
||||||
model_parent_path = model_plus.paths[0].parent
|
model_parent_path = model_plus.paths[0].parent
|
||||||
vocab_path = Path(args.vocab_dir or args.model or model_parent_path)
|
vocab_path = Path(args.vocab_dir or args.model or model_parent_path)
|
||||||
@ -1528,15 +1541,14 @@ def main(args_in: list[str] | None = None) -> None:
|
|||||||
outfile = args.outfile
|
outfile = args.outfile
|
||||||
OutputFile.write_vocab_only(outfile, params, vocab, special_vocab,
|
OutputFile.write_vocab_only(outfile, params, vocab, special_vocab,
|
||||||
endianess=endianess, pad_vocab=args.pad_vocab)
|
endianess=endianess, pad_vocab=args.pad_vocab)
|
||||||
print(f"Wrote {outfile}")
|
logger.info(f"Wrote {outfile}")
|
||||||
return
|
return
|
||||||
|
|
||||||
if model_plus.vocab is not None and args.vocab_dir is None and not args.no_vocab:
|
if model_plus.vocab is not None and args.vocab_dir is None and not args.no_vocab:
|
||||||
vocab = model_plus.vocab
|
vocab = model_plus.vocab
|
||||||
|
|
||||||
print(f"Vocab info: {vocab}")
|
logger.info(f"Vocab info: {vocab}")
|
||||||
print(f"Special vocab info: {special_vocab}")
|
logger.info(f"Special vocab info: {special_vocab}")
|
||||||
|
|
||||||
model = model_plus.model
|
model = model_plus.model
|
||||||
model = convert_model_names(model, params, args.skip_unknown)
|
model = convert_model_names(model, params, args.skip_unknown)
|
||||||
ftype = pick_output_type(model, args.outtype)
|
ftype = pick_output_type(model, args.outtype)
|
||||||
@ -1544,11 +1556,11 @@ def main(args_in: list[str] | None = None) -> None:
|
|||||||
outfile = args.outfile or default_outfile(model_plus.paths, ftype)
|
outfile = args.outfile or default_outfile(model_plus.paths, ftype)
|
||||||
|
|
||||||
params.ftype = ftype
|
params.ftype = ftype
|
||||||
print(f"Writing {outfile}, format {ftype}")
|
logger.info(f"Writing {outfile}, format {ftype}")
|
||||||
|
|
||||||
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
|
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
|
||||||
concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab)
|
concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab)
|
||||||
print(f"Wrote {outfile}")
|
logger.info(f"Wrote {outfile}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from tempfile import gettempdir, NamedTemporaryFile
|
from tempfile import gettempdir, NamedTemporaryFile
|
||||||
|
|
||||||
|
logger = logging.getLogger("ggml-vk-generate-shaders")
|
||||||
|
|
||||||
shader_f32 = """
|
shader_f32 = """
|
||||||
#define FLOAT_TYPE float
|
#define FLOAT_TYPE float
|
||||||
"""
|
"""
|
||||||
@ -2498,7 +2501,7 @@ async def string_to_spv(name, code, defines, fp16=True):
|
|||||||
|
|
||||||
stdout, stderr = await proc.communicate()
|
stdout, stderr = await proc.communicate()
|
||||||
|
|
||||||
print(" ".join(cmd))
|
logger.info(" ".join(cmd))
|
||||||
|
|
||||||
if proc.returncode:
|
if proc.returncode:
|
||||||
raise RuntimeError(f"{name=} {f.name=} {stdout=} {stderr=}")
|
raise RuntimeError(f"{name=} {f.name=} {stdout=} {stderr=}")
|
||||||
@ -2507,7 +2510,7 @@ async def string_to_spv(name, code, defines, fp16=True):
|
|||||||
|
|
||||||
cmd.extend([f"-D{key}={value}" for key, value in defines.items()])
|
cmd.extend([f"-D{key}={value}" for key, value in defines.items()])
|
||||||
code_with_lines = "\n".join([f"{i + 1}: {line}" for i, line in enumerate(preprocessed_code.splitlines())])
|
code_with_lines = "\n".join([f"{i + 1}: {line}" for i, line in enumerate(preprocessed_code.splitlines())])
|
||||||
print(f"ERROR compiling {name}\n\n{code_with_lines}\n\n{error}")
|
logger.error(f"cannot compile {name}\n\n{code_with_lines}\n\n{error}")
|
||||||
f.close()
|
f.close()
|
||||||
os.remove(f.name)
|
os.remove(f.name)
|
||||||
sys.exit(proc.returncode)
|
sys.exit(proc.returncode)
|
||||||
@ -2520,7 +2523,7 @@ async def string_to_spv(name, code, defines, fp16=True):
|
|||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
print("ggml_vulkan: Generating and compiling shaders to SPIR-V")
|
logger.info("ggml_vulkan: Generating and compiling shaders to SPIR-V")
|
||||||
|
|
||||||
tasks = []
|
tasks = []
|
||||||
|
|
||||||
@ -2768,9 +2771,12 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser(description="GGML Vulkan Shader Generator")
|
parser = argparse.ArgumentParser(description="GGML Vulkan Shader Generator")
|
||||||
|
|
||||||
parser.add_argument("--glslc", help="Path to glslc")
|
parser.add_argument("--glslc", help="Path to glslc")
|
||||||
|
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||||
|
|
||||||
if args.glslc:
|
if args.glslc:
|
||||||
GLSLC = args.glslc
|
GLSLC = args.glslc
|
||||||
|
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from gguf.gguf_reader import GGUFReader
|
from gguf.gguf_reader import GGUFReader
|
||||||
|
|
||||||
|
logger = logging.getLogger("reader")
|
||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
@ -18,28 +20,28 @@ def read_gguf_file(gguf_file_path):
|
|||||||
reader = GGUFReader(gguf_file_path)
|
reader = GGUFReader(gguf_file_path)
|
||||||
|
|
||||||
# List all key-value pairs in a columnized format
|
# List all key-value pairs in a columnized format
|
||||||
print("Key-Value Pairs:")
|
print("Key-Value Pairs:") # noqa: NP100
|
||||||
max_key_length = max(len(key) for key in reader.fields.keys())
|
max_key_length = max(len(key) for key in reader.fields.keys())
|
||||||
for key, field in reader.fields.items():
|
for key, field in reader.fields.items():
|
||||||
value = field.parts[field.data[0]]
|
value = field.parts[field.data[0]]
|
||||||
print(f"{key:{max_key_length}} : {value}")
|
print(f"{key:{max_key_length}} : {value}") # noqa: NP100
|
||||||
print("----")
|
print("----") # noqa: NP100
|
||||||
|
|
||||||
# List all tensors
|
# List all tensors
|
||||||
print("Tensors:")
|
print("Tensors:") # noqa: NP100
|
||||||
tensor_info_format = "{:<30} | Shape: {:<15} | Size: {:<12} | Quantization: {}"
|
tensor_info_format = "{:<30} | Shape: {:<15} | Size: {:<12} | Quantization: {}"
|
||||||
print(tensor_info_format.format("Tensor Name", "Shape", "Size", "Quantization"))
|
print(tensor_info_format.format("Tensor Name", "Shape", "Size", "Quantization")) # noqa: NP100
|
||||||
print("-" * 80)
|
print("-" * 80) # noqa: NP100
|
||||||
for tensor in reader.tensors:
|
for tensor in reader.tensors:
|
||||||
shape_str = "x".join(map(str, tensor.shape))
|
shape_str = "x".join(map(str, tensor.shape))
|
||||||
size_str = str(tensor.n_elements)
|
size_str = str(tensor.n_elements)
|
||||||
quantization_str = tensor.tensor_type.name
|
quantization_str = tensor.tensor_type.name
|
||||||
print(tensor_info_format.format(tensor.name, shape_str, size_str, quantization_str))
|
print(tensor_info_format.format(tensor.name, shape_str, size_str, quantization_str)) # noqa: NP100
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
if len(sys.argv) < 2:
|
if len(sys.argv) < 2:
|
||||||
print("Usage: reader.py <path_to_gguf_file>")
|
logger.info("Usage: reader.py <path_to_gguf_file>")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
gguf_file_path = sys.argv[1]
|
gguf_file_path = sys.argv[1]
|
||||||
read_gguf_file(gguf_file_path)
|
read_gguf_file(gguf_file_path)
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sys
|
|
||||||
from enum import Enum, IntEnum, auto
|
from enum import Enum, IntEnum, auto
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -854,8 +853,7 @@ class GGUFValueType(IntEnum):
|
|||||||
return GGUFValueType.INT32
|
return GGUFValueType.INT32
|
||||||
# TODO: need help with 64-bit types in Python
|
# TODO: need help with 64-bit types in Python
|
||||||
else:
|
else:
|
||||||
print("Unknown type:", type(val))
|
raise ValueError(f"Unknown type: {type(val)}")
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
|
|
||||||
# Note: Does not support GGML_QKK_64
|
# Note: Does not support GGML_QKK_64
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
#
|
#
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Literal, NamedTuple, TypeVar, Union
|
from typing import Any, Literal, NamedTuple, TypeVar, Union
|
||||||
@ -27,6 +28,7 @@ from gguf.constants import (
|
|||||||
GGUFValueType,
|
GGUFValueType,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
READER_SUPPORTED_VERSIONS = [2, GGUF_VERSION]
|
READER_SUPPORTED_VERSIONS = [2, GGUF_VERSION]
|
||||||
|
|
||||||
@ -142,7 +144,7 @@ class GGUFReader:
|
|||||||
# TODO: add option to generate error on duplicate keys
|
# TODO: add option to generate error on duplicate keys
|
||||||
# raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}')
|
# raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}')
|
||||||
|
|
||||||
print(f'Warning: Duplicate key {field.name} at offset {field.offset}')
|
logger.warning(f'Duplicate key {field.name} at offset {field.offset}')
|
||||||
self.fields[field.name + '_{}'.format(field.offset)] = field
|
self.fields[field.name + '_{}'.format(field.offset)] = field
|
||||||
else:
|
else:
|
||||||
self.fields[field.name] = field
|
self.fields[field.name] = field
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import struct
|
import struct
|
||||||
@ -24,6 +25,8 @@ from .constants import (
|
|||||||
TokenType,
|
TokenType,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WriterState(Enum):
|
class WriterState(Enum):
|
||||||
EMPTY = auto()
|
EMPTY = auto()
|
||||||
@ -67,7 +70,7 @@ class GGUFWriter:
|
|||||||
self.use_temp_file = use_temp_file
|
self.use_temp_file = use_temp_file
|
||||||
self.temp_file = None
|
self.temp_file = None
|
||||||
self.tensors = []
|
self.tensors = []
|
||||||
print("gguf: This GGUF file is for {0} Endian only".format(
|
logger.info("gguf: This GGUF file is for {0} Endian only".format(
|
||||||
"Big" if self.endianess == GGUFEndian.BIG else "Little",
|
"Big" if self.endianess == GGUFEndian.BIG else "Little",
|
||||||
))
|
))
|
||||||
self.state = WriterState.EMPTY
|
self.state = WriterState.EMPTY
|
||||||
|
@ -1,13 +1,15 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
from .gguf_writer import GGUFWriter
|
from .gguf_writer import GGUFWriter
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SpecialVocab:
|
class SpecialVocab:
|
||||||
merges: list[str]
|
merges: list[str]
|
||||||
@ -40,38 +42,29 @@ class SpecialVocab:
|
|||||||
def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
|
def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
|
||||||
if self.merges:
|
if self.merges:
|
||||||
if not quiet:
|
if not quiet:
|
||||||
print(f'gguf: Adding {len(self.merges)} merge(s).')
|
logger.info(f'Adding {len(self.merges)} merge(s).')
|
||||||
gw.add_token_merges(self.merges)
|
gw.add_token_merges(self.merges)
|
||||||
elif self.load_merges:
|
elif self.load_merges:
|
||||||
print(
|
logger.warning('Adding merges requested but no merges found, output may be non-functional.')
|
||||||
'gguf: WARNING: Adding merges requested but no merges found, output may be non-functional.',
|
|
||||||
file = sys.stderr,
|
|
||||||
)
|
|
||||||
for typ, tokid in self.special_token_ids.items():
|
for typ, tokid in self.special_token_ids.items():
|
||||||
id_handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
|
id_handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
|
||||||
if id_handler is None:
|
if id_handler is None:
|
||||||
print(
|
logger.warning(f'No handler for special token type {typ} with id {tokid} - skipping')
|
||||||
f'gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping',
|
|
||||||
file = sys.stderr,
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
if not quiet:
|
if not quiet:
|
||||||
print(f'gguf: Setting special token type {typ} to {tokid}')
|
logger.info(f'Setting special token type {typ} to {tokid}')
|
||||||
id_handler(tokid)
|
id_handler(tokid)
|
||||||
for typ, value in self.add_special_token.items():
|
for typ, value in self.add_special_token.items():
|
||||||
add_handler: Callable[[bool], None] | None = getattr(gw, f'add_add_{typ}_token', None)
|
add_handler: Callable[[bool], None] | None = getattr(gw, f'add_add_{typ}_token', None)
|
||||||
if add_handler is None:
|
if add_handler is None:
|
||||||
print(
|
logger.warning(f'No handler for add_{typ}_token with value {value} - skipping')
|
||||||
f'gguf: WARNING: No handler for add_{typ}_token with value {value} - skipping',
|
|
||||||
file = sys.stderr,
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
if not quiet:
|
if not quiet:
|
||||||
print(f'gguf: Setting add_{typ}_token to {value}')
|
logger.info(f'Setting add_{typ}_token to {value}')
|
||||||
add_handler(value)
|
add_handler(value)
|
||||||
if self.chat_template is not None:
|
if self.chat_template is not None:
|
||||||
if not quiet:
|
if not quiet:
|
||||||
print(f'gguf: Setting chat_template to {self.chat_template}')
|
logger.info(f'Setting chat_template to {self.chat_template}')
|
||||||
gw.add_chat_template(self.chat_template)
|
gw.add_chat_template(self.chat_template)
|
||||||
|
|
||||||
def _load(self, path: Path) -> None:
|
def _load(self, path: Path) -> None:
|
||||||
@ -99,10 +92,7 @@ class SpecialVocab:
|
|||||||
continue
|
continue
|
||||||
parts = line.split(None, 3)
|
parts = line.split(None, 3)
|
||||||
if len(parts) != 2:
|
if len(parts) != 2:
|
||||||
print(
|
logger.warning(f'{merges_file.name}: Line {line_num}: Entry malformed, ignoring')
|
||||||
f'gguf: WARNING: {merges_file.name}: Line {line_num}: Entry malformed, ignoring',
|
|
||||||
file = sys.stderr,
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
merges.append(f'{parts[0]} {parts[1]}')
|
merges.append(f'{parts[0]} {parts[1]}')
|
||||||
self.merges = merges
|
self.merges = merges
|
||||||
@ -118,10 +108,7 @@ class SpecialVocab:
|
|||||||
return
|
return
|
||||||
self.special_token_ids[typ] = tid
|
self.special_token_ids[typ] = tid
|
||||||
return
|
return
|
||||||
print(
|
logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping')
|
||||||
f'gguf: WARNING: Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping',
|
|
||||||
file = sys.stderr,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _try_load_from_tokenizer_json(self, path: Path) -> bool:
|
def _try_load_from_tokenizer_json(self, path: Path) -> bool:
|
||||||
tokenizer_file = path / 'tokenizer.json'
|
tokenizer_file = path / 'tokenizer.json'
|
||||||
@ -144,10 +131,7 @@ class SpecialVocab:
|
|||||||
if chat_template is None or isinstance(chat_template, (str, list)):
|
if chat_template is None or isinstance(chat_template, (str, list)):
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
else:
|
else:
|
||||||
print(
|
logger.warning(f'Bad type for chat_template field in {tokenizer_config_file!r} - ignoring')
|
||||||
f'gguf: WARNING: Bad type for chat_template field in {tokenizer_config_file!r} - ignoring',
|
|
||||||
file = sys.stderr
|
|
||||||
)
|
|
||||||
for typ in self.special_token_types:
|
for typ in self.special_token_types:
|
||||||
add_entry = tokenizer_config.get(f'add_{typ}_token')
|
add_entry = tokenizer_config.get(f'add_{typ}_token')
|
||||||
if isinstance(add_entry, bool):
|
if isinstance(add_entry, bool):
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from tqdm import tqdm
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -14,6 +16,8 @@ if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent /
|
|||||||
|
|
||||||
import gguf
|
import gguf
|
||||||
|
|
||||||
|
logger = logging.getLogger("gguf-convert-endian")
|
||||||
|
|
||||||
|
|
||||||
def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None:
|
def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None:
|
||||||
if np.uint32(1) == np.uint32(1).newbyteorder("<"):
|
if np.uint32(1) == np.uint32(1).newbyteorder("<"):
|
||||||
@ -29,11 +33,11 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None
|
|||||||
else:
|
else:
|
||||||
file_endian = host_endian
|
file_endian = host_endian
|
||||||
order = host_endian if args.order == "native" else args.order
|
order = host_endian if args.order == "native" else args.order
|
||||||
print(f"* Host is {host_endian.upper()} endian, GGUF file seems to be {file_endian.upper()} endian")
|
logger.info(f"* Host is {host_endian.upper()} endian, GGUF file seems to be {file_endian.upper()} endian")
|
||||||
if file_endian == order:
|
if file_endian == order:
|
||||||
print(f"* File is already {order.upper()} endian. Nothing to do.")
|
logger.info(f"* File is already {order.upper()} endian. Nothing to do.")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
print("* Checking tensors for conversion compatibility")
|
logger.info("* Checking tensors for conversion compatibility")
|
||||||
for tensor in reader.tensors:
|
for tensor in reader.tensors:
|
||||||
if tensor.tensor_type not in (
|
if tensor.tensor_type not in (
|
||||||
gguf.GGMLQuantizationType.F32,
|
gguf.GGMLQuantizationType.F32,
|
||||||
@ -41,51 +45,64 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None
|
|||||||
gguf.GGMLQuantizationType.Q8_0,
|
gguf.GGMLQuantizationType.Q8_0,
|
||||||
):
|
):
|
||||||
raise ValueError(f"Cannot handle type {tensor.tensor_type.name} for tensor {repr(tensor.name)}")
|
raise ValueError(f"Cannot handle type {tensor.tensor_type.name} for tensor {repr(tensor.name)}")
|
||||||
print(f"* Preparing to convert from {file_endian.upper()} to {order.upper()}")
|
logger.info(f"* Preparing to convert from {file_endian.upper()} to {order.upper()}")
|
||||||
if args.dry_run:
|
if args.dry_run:
|
||||||
return
|
return
|
||||||
print("\n*** Warning *** Warning *** Warning **")
|
logger.warning("*** Warning *** Warning *** Warning **")
|
||||||
print("* This conversion process may damage the file. Ensure you have a backup.")
|
logger.warning("* This conversion process may damage the file. Ensure you have a backup.")
|
||||||
if order != host_endian:
|
if order != host_endian:
|
||||||
print("* Requested endian differs from host, you will not be able to load the model on this machine.")
|
logger.warning("* Requested endian differs from host, you will not be able to load the model on this machine.")
|
||||||
print("* The file will be modified immediately, so if conversion fails or is interrupted")
|
logger.warning("* The file will be modified immediately, so if conversion fails or is interrupted")
|
||||||
print("* the file will be corrupted. Enter exactly YES if you are positive you want to proceed:")
|
logger.warning("* the file will be corrupted. Enter exactly YES if you are positive you want to proceed:")
|
||||||
response = input("YES, I am sure> ")
|
response = input("YES, I am sure> ")
|
||||||
if response != "YES":
|
if response != "YES":
|
||||||
print("You didn't enter YES. Okay then, see ya!")
|
logger.warning("You didn't enter YES. Okay then, see ya!")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
print(f"\n* Converting fields ({len(reader.fields)})")
|
logger.info(f"* Converting fields ({len(reader.fields)})")
|
||||||
for idx, field in enumerate(reader.fields.values()):
|
for idx, field in enumerate(reader.fields.values()):
|
||||||
print(f"- {idx:4}: Converting field {repr(field.name)}, part count: {len(field.parts)}")
|
logger.info(f"- {idx:4}: Converting field {repr(field.name)}, part count: {len(field.parts)}")
|
||||||
for part in field.parts:
|
for part in field.parts:
|
||||||
part.byteswap(inplace=True)
|
part.byteswap(inplace=True)
|
||||||
print(f"\n* Converting tensors ({len(reader.tensors)})")
|
logger.info(f"* Converting tensors ({len(reader.tensors)})")
|
||||||
for idx, tensor in enumerate(reader.tensors):
|
|
||||||
print(
|
for idx, tensor in enumerate(pbar := tqdm(reader.tensors, desc="Converting tensor")):
|
||||||
f" - {idx:4}: Converting tensor {repr(tensor.name)}, type={tensor.tensor_type.name}, "
|
log_message = (
|
||||||
f"elements={tensor.n_elements}... ",
|
f"Converting tensor {repr(tensor.name)}, "
|
||||||
end="",
|
f"type={tensor.tensor_type.name}, "
|
||||||
|
f"elements={tensor.n_elements} "
|
||||||
)
|
)
|
||||||
tensor_type = tensor.tensor_type
|
|
||||||
|
# Byte-swap each part of the tensor's field
|
||||||
for part in tensor.field.parts:
|
for part in tensor.field.parts:
|
||||||
part.byteswap(inplace=True)
|
part.byteswap(inplace=True)
|
||||||
if tensor_type != gguf.GGMLQuantizationType.Q8_0:
|
|
||||||
|
# Byte-swap tensor data if necessary
|
||||||
|
if tensor.tensor_type == gguf.GGMLQuantizationType.Q8_0:
|
||||||
|
# Handle Q8_0 tensor blocks (block_q8_0)
|
||||||
|
# Specific handling of block_q8_0 is required.
|
||||||
|
# Each block_q8_0 consists of an f16 delta (scaling factor) followed by 32 int8 quantizations.
|
||||||
|
|
||||||
|
block_size = 34 # 34 bytes = <f16 delta scaling factor> + 32 * <int8 quant>
|
||||||
|
|
||||||
|
n_blocks = len(tensor.data) // block_size
|
||||||
|
for block_num in (inner_pbar := tqdm(range(n_blocks), desc="Byte-swapping Blocks", leave=False)):
|
||||||
|
block_offs = block_num * block_size
|
||||||
|
|
||||||
|
# Byte-Swap f16 sized delta field
|
||||||
|
delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16)
|
||||||
|
delta.byteswap(inplace=True)
|
||||||
|
|
||||||
|
# Byte-Swap Q8 weights
|
||||||
|
if block_num % 100000 == 0:
|
||||||
|
inner_pbar.set_description(f"Byte-swapping Blocks [{(n_blocks - block_num) // n_blocks}]")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Handle other tensor types
|
||||||
tensor.data.byteswap(inplace=True)
|
tensor.data.byteswap(inplace=True)
|
||||||
print()
|
|
||||||
continue
|
pbar.set_description(log_message)
|
||||||
# A Q8_0 block consists of a f16 delta followed by 32 int8 quants, so 34 bytes
|
|
||||||
block_size = 34
|
logger.info("* Completion")
|
||||||
n_blocks = len(tensor.data) // block_size
|
|
||||||
for block_num in range(n_blocks):
|
|
||||||
block_offs = block_num * block_size
|
|
||||||
# I know I said f16, but it doesn't matter here - any simple 16 bit type works.
|
|
||||||
delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16)
|
|
||||||
delta.byteswap(inplace=True)
|
|
||||||
if block_num % 100000 == 0:
|
|
||||||
print(f"[{(n_blocks - block_num) // 1000}K]", end="")
|
|
||||||
sys.stdout.flush()
|
|
||||||
print()
|
|
||||||
print("* Completion")
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
@ -102,8 +119,13 @@ def main() -> None:
|
|||||||
"--dry-run", action="store_true",
|
"--dry-run", action="store_true",
|
||||||
help="Don't actually change anything",
|
help="Don't actually change anything",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||||
|
|
||||||
args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
|
args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
|
||||||
print(f'* Loading: {args.model}')
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||||
|
|
||||||
|
logger.info(f'* Loading: {args.model}')
|
||||||
reader = gguf.GGUFReader(args.model, 'r' if args.dry_run else 'r+')
|
reader = gguf.GGUFReader(args.model, 'r' if args.dry_run else 'r+')
|
||||||
convert_byteorder(reader, args)
|
convert_byteorder(reader, args)
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@ -15,6 +16,8 @@ if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent /
|
|||||||
|
|
||||||
from gguf import GGUFReader, GGUFValueType # noqa: E402
|
from gguf import GGUFReader, GGUFValueType # noqa: E402
|
||||||
|
|
||||||
|
logger = logging.getLogger("gguf-dump")
|
||||||
|
|
||||||
|
|
||||||
def get_file_host_endian(reader: GGUFReader) -> tuple[str, str]:
|
def get_file_host_endian(reader: GGUFReader) -> tuple[str, str]:
|
||||||
host_endian = 'LITTLE' if np.uint32(1) == np.uint32(1).newbyteorder("<") else 'BIG'
|
host_endian = 'LITTLE' if np.uint32(1) == np.uint32(1).newbyteorder("<") else 'BIG'
|
||||||
@ -29,8 +32,8 @@ def get_file_host_endian(reader: GGUFReader) -> tuple[str, str]:
|
|||||||
# please see the comments in the modify_gguf.py example.
|
# please see the comments in the modify_gguf.py example.
|
||||||
def dump_metadata(reader: GGUFReader, args: argparse.Namespace) -> None:
|
def dump_metadata(reader: GGUFReader, args: argparse.Namespace) -> None:
|
||||||
host_endian, file_endian = get_file_host_endian(reader)
|
host_endian, file_endian = get_file_host_endian(reader)
|
||||||
print(f'* File is {file_endian} endian, script is running on a {host_endian} endian host.')
|
print(f'* File is {file_endian} endian, script is running on a {host_endian} endian host.') # noqa: NP100
|
||||||
print(f'\n* Dumping {len(reader.fields)} key/value pair(s)')
|
print(f'* Dumping {len(reader.fields)} key/value pair(s)') # noqa: NP100
|
||||||
for n, field in enumerate(reader.fields.values(), 1):
|
for n, field in enumerate(reader.fields.values(), 1):
|
||||||
if not field.types:
|
if not field.types:
|
||||||
pretty_type = 'N/A'
|
pretty_type = 'N/A'
|
||||||
@ -39,20 +42,21 @@ def dump_metadata(reader: GGUFReader, args: argparse.Namespace) -> None:
|
|||||||
pretty_type = '[' * nest_count + str(field.types[-1].name) + ']' * nest_count
|
pretty_type = '[' * nest_count + str(field.types[-1].name) + ']' * nest_count
|
||||||
else:
|
else:
|
||||||
pretty_type = str(field.types[-1].name)
|
pretty_type = str(field.types[-1].name)
|
||||||
print(f' {n:5}: {pretty_type:10} | {len(field.data):8} | {field.name}', end = '')
|
|
||||||
|
log_message = f' {n:5}: {pretty_type:10} | {len(field.data):8} | {field.name}'
|
||||||
if len(field.types) == 1:
|
if len(field.types) == 1:
|
||||||
curr_type = field.types[0]
|
curr_type = field.types[0]
|
||||||
if curr_type == GGUFValueType.STRING:
|
if curr_type == GGUFValueType.STRING:
|
||||||
print(' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf8')[:60])), end = '')
|
log_message += ' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf8')[:60]))
|
||||||
elif field.types[0] in reader.gguf_scalar_to_np:
|
elif field.types[0] in reader.gguf_scalar_to_np:
|
||||||
print(' = {0}'.format(field.parts[-1][0]), end = '')
|
log_message += ' = {0}'.format(field.parts[-1][0])
|
||||||
print()
|
print(log_message) # noqa: NP100
|
||||||
if args.no_tensors:
|
if args.no_tensors:
|
||||||
return
|
return
|
||||||
print(f'\n* Dumping {len(reader.tensors)} tensor(s)')
|
print(f'* Dumping {len(reader.tensors)} tensor(s)') # noqa: NP100
|
||||||
for n, tensor in enumerate(reader.tensors, 1):
|
for n, tensor in enumerate(reader.tensors, 1):
|
||||||
prettydims = ', '.join('{0:5}'.format(d) for d in list(tensor.shape) + [1] * (4 - len(tensor.shape)))
|
prettydims = ', '.join('{0:5}'.format(d) for d in list(tensor.shape) + [1] * (4 - len(tensor.shape)))
|
||||||
print(f' {n:5}: {tensor.n_elements:10} | {prettydims} | {tensor.tensor_type.name:7} | {tensor.name}')
|
print(f' {n:5}: {tensor.n_elements:10} | {prettydims} | {tensor.tensor_type.name:7} | {tensor.name}') # noqa: NP100
|
||||||
|
|
||||||
|
|
||||||
def dump_metadata_json(reader: GGUFReader, args: argparse.Namespace) -> None:
|
def dump_metadata_json(reader: GGUFReader, args: argparse.Namespace) -> None:
|
||||||
@ -103,10 +107,17 @@ def main() -> None:
|
|||||||
parser.add_argument("--no-tensors", action="store_true", help="Don't dump tensor metadata")
|
parser.add_argument("--no-tensors", action="store_true", help="Don't dump tensor metadata")
|
||||||
parser.add_argument("--json", action="store_true", help="Produce JSON output")
|
parser.add_argument("--json", action="store_true", help="Produce JSON output")
|
||||||
parser.add_argument("--json-array", action="store_true", help="Include full array values in JSON output (long)")
|
parser.add_argument("--json-array", action="store_true", help="Include full array values in JSON output (long)")
|
||||||
|
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||||
|
|
||||||
args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
|
args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||||
|
|
||||||
if not args.json:
|
if not args.json:
|
||||||
print(f'* Loading: {args.model}')
|
logger.info(f'* Loading: {args.model}')
|
||||||
|
|
||||||
reader = GGUFReader(args.model, 'r')
|
reader = GGUFReader(args.model, 'r')
|
||||||
|
|
||||||
if args.json:
|
if args.json:
|
||||||
dump_metadata_json(reader, args)
|
dump_metadata_json(reader, args)
|
||||||
else:
|
else:
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@ -10,6 +11,8 @@ if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent /
|
|||||||
|
|
||||||
from gguf import GGUFReader # noqa: E402
|
from gguf import GGUFReader # noqa: E402
|
||||||
|
|
||||||
|
logger = logging.getLogger("gguf-set-metadata")
|
||||||
|
|
||||||
|
|
||||||
def minimal_example(filename: str) -> None:
|
def minimal_example(filename: str) -> None:
|
||||||
reader = GGUFReader(filename, 'r+')
|
reader = GGUFReader(filename, 'r+')
|
||||||
@ -41,36 +44,33 @@ def minimal_example(filename: str) -> None:
|
|||||||
def set_metadata(reader: GGUFReader, args: argparse.Namespace) -> None:
|
def set_metadata(reader: GGUFReader, args: argparse.Namespace) -> None:
|
||||||
field = reader.get_field(args.key)
|
field = reader.get_field(args.key)
|
||||||
if field is None:
|
if field is None:
|
||||||
print(f'! Field {repr(args.key)} not found', file = sys.stderr)
|
logger.error(f'! Field {repr(args.key)} not found')
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
# Note that field.types is a list of types. This is because the GGUF
|
# Note that field.types is a list of types. This is because the GGUF
|
||||||
# format supports arrays. For example, an array of UINT32 would
|
# format supports arrays. For example, an array of UINT32 would
|
||||||
# look like [GGUFValueType.ARRAY, GGUFValueType.UINT32]
|
# look like [GGUFValueType.ARRAY, GGUFValueType.UINT32]
|
||||||
handler = reader.gguf_scalar_to_np.get(field.types[0]) if field.types else None
|
handler = reader.gguf_scalar_to_np.get(field.types[0]) if field.types else None
|
||||||
if handler is None:
|
if handler is None:
|
||||||
print(
|
logger.error(f'! This tool only supports changing simple values, {repr(args.key)} has unsupported type {field.types}')
|
||||||
f'! This tool only supports changing simple values, {repr(args.key)} has unsupported type {field.types}',
|
|
||||||
file = sys.stderr,
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
current_value = field.parts[field.data[0]][0]
|
current_value = field.parts[field.data[0]][0]
|
||||||
new_value = handler(args.value)
|
new_value = handler(args.value)
|
||||||
print(f'* Preparing to change field {repr(args.key)} from {current_value} to {new_value}')
|
logger.info(f'* Preparing to change field {repr(args.key)} from {current_value} to {new_value}')
|
||||||
if current_value == new_value:
|
if current_value == new_value:
|
||||||
print(f'- Key {repr(args.key)} already set to requested value {current_value}')
|
logger.info(f'- Key {repr(args.key)} already set to requested value {current_value}')
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
if args.dry_run:
|
if args.dry_run:
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
if not args.force:
|
if not args.force:
|
||||||
print('*** Warning *** Warning *** Warning **')
|
logger.warning('*** Warning *** Warning *** Warning **')
|
||||||
print('* Changing fields in a GGUF file can make it unusable. Proceed at your own risk.')
|
logger.warning('* Changing fields in a GGUF file can make it unusable. Proceed at your own risk.')
|
||||||
print('* Enter exactly YES if you are positive you want to proceed:')
|
logger.warning('* Enter exactly YES if you are positive you want to proceed:')
|
||||||
response = input('YES, I am sure> ')
|
response = input('YES, I am sure> ')
|
||||||
if response != 'YES':
|
if response != 'YES':
|
||||||
print("You didn't enter YES. Okay then, see ya!")
|
logger.info("You didn't enter YES. Okay then, see ya!")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
field.parts[field.data[0]][0] = new_value
|
field.parts[field.data[0]][0] = new_value
|
||||||
print('* Field changed. Successful completion.')
|
logger.info('* Field changed. Successful completion.')
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
@ -80,8 +80,13 @@ def main() -> None:
|
|||||||
parser.add_argument("value", type=str, help="Metadata value to set")
|
parser.add_argument("value", type=str, help="Metadata value to set")
|
||||||
parser.add_argument("--dry-run", action="store_true", help="Don't actually change anything")
|
parser.add_argument("--dry-run", action="store_true", help="Don't actually change anything")
|
||||||
parser.add_argument("--force", action="store_true", help="Change the field without confirmation")
|
parser.add_argument("--force", action="store_true", help="Change the field without confirmation")
|
||||||
|
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||||
|
|
||||||
args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
|
args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
|
||||||
print(f'* Loading: {args.model}')
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||||
|
|
||||||
|
logger.info(f'* Loading: {args.model}')
|
||||||
reader = GGUFReader(args.model, 'r' if args.dry_run else 'r+')
|
reader = GGUFReader(args.model, 'r' if args.dry_run else 'r+')
|
||||||
set_metadata(reader, args)
|
set_metadata(reader, args)
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
import heapq
|
import heapq
|
||||||
import sys
|
import sys
|
||||||
@ -11,9 +12,11 @@ try:
|
|||||||
import git
|
import git
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print("ERROR: the following Python libraries are required: GitPython, tabulate.")
|
print("the following Python libraries are required: GitPython, tabulate.") # noqa: NP100
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
logger = logging.getLogger("compare-llama-bench")
|
||||||
|
|
||||||
# Properties by which to differentiate results per commit:
|
# Properties by which to differentiate results per commit:
|
||||||
KEY_PROPERTIES = [
|
KEY_PROPERTIES = [
|
||||||
"cpu_info", "gpu_info", "n_gpu_layers", "main_gpu", "cuda", "opencl", "metal", "gpu_blas",
|
"cpu_info", "gpu_info", "n_gpu_layers", "main_gpu", "cuda", "opencl", "metal", "gpu_blas",
|
||||||
@ -94,8 +97,7 @@ parser.add_argument("-s", "--show", help=help_s)
|
|||||||
known_args, unknown_args = parser.parse_known_args()
|
known_args, unknown_args = parser.parse_known_args()
|
||||||
|
|
||||||
if unknown_args:
|
if unknown_args:
|
||||||
print(f"ERROR: Received unknown args: {unknown_args}.")
|
logger.error(f"Received unknown args: {unknown_args}.")
|
||||||
print()
|
|
||||||
parser.print_help()
|
parser.print_help()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
@ -108,8 +110,7 @@ if input_file is None:
|
|||||||
input_file = sqlite_files[0]
|
input_file = sqlite_files[0]
|
||||||
|
|
||||||
if input_file is None:
|
if input_file is None:
|
||||||
print("ERROR: Cannot find a suitable input file, please provide one.")
|
logger.error("Cannot find a suitable input file, please provide one.")
|
||||||
print()
|
|
||||||
parser.print_help()
|
parser.print_help()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
@ -194,23 +195,19 @@ if known_args.baseline is not None:
|
|||||||
hexsha8_baseline = get_commit_hexsha8(known_args.baseline)
|
hexsha8_baseline = get_commit_hexsha8(known_args.baseline)
|
||||||
name_baseline = known_args.baseline
|
name_baseline = known_args.baseline
|
||||||
if hexsha8_baseline is None:
|
if hexsha8_baseline is None:
|
||||||
print(f"ERROR: cannot find data for baseline={known_args.baseline}.")
|
logger.error(f"cannot find data for baseline={known_args.baseline}.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
# Otherwise, search for the most recent parent of master for which there is data:
|
# Otherwise, search for the most recent parent of master for which there is data:
|
||||||
elif repo is not None:
|
elif repo is not None:
|
||||||
hexsha8_baseline = find_parent_in_data(repo.heads.master.commit)
|
hexsha8_baseline = find_parent_in_data(repo.heads.master.commit)
|
||||||
|
|
||||||
if hexsha8_baseline is None:
|
if hexsha8_baseline is None:
|
||||||
print("ERROR: No baseline was provided and did not find data for any master branch commits.")
|
logger.error("No baseline was provided and did not find data for any master branch commits.")
|
||||||
print()
|
|
||||||
parser.print_help()
|
parser.print_help()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
else:
|
else:
|
||||||
print(
|
logger.error("No baseline was provided and the current working directory "
|
||||||
"ERROR: No baseline was provided and the current working directory "
|
"is not part of a git repository from which a baseline could be inferred.")
|
||||||
"is not part of a git repository from which a baseline could be inferred."
|
|
||||||
)
|
|
||||||
print()
|
|
||||||
parser.print_help()
|
parser.print_help()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
@ -227,7 +224,7 @@ if known_args.compare is not None:
|
|||||||
hexsha8_compare = get_commit_hexsha8(known_args.compare)
|
hexsha8_compare = get_commit_hexsha8(known_args.compare)
|
||||||
name_compare = known_args.compare
|
name_compare = known_args.compare
|
||||||
if hexsha8_compare is None:
|
if hexsha8_compare is None:
|
||||||
print(f"ERROR: cannot find data for compare={known_args.compare}.")
|
logger.error(f"cannot find data for compare={known_args.compare}.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
# Otherwise, search for the commit for llama-bench was most recently run
|
# Otherwise, search for the commit for llama-bench was most recently run
|
||||||
# and that is not a parent of master:
|
# and that is not a parent of master:
|
||||||
@ -241,16 +238,12 @@ elif repo is not None:
|
|||||||
break
|
break
|
||||||
|
|
||||||
if hexsha8_compare is None:
|
if hexsha8_compare is None:
|
||||||
print("ERROR: No compare target was provided and did not find data for any non-master commits.")
|
logger.error("No compare target was provided and did not find data for any non-master commits.")
|
||||||
print()
|
|
||||||
parser.print_help()
|
parser.print_help()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
else:
|
else:
|
||||||
print(
|
logger.error("No compare target was provided and the current working directory "
|
||||||
"ERROR: No compare target was provided and the current working directory "
|
"is not part of a git repository from which a compare target could be inferred.\n")
|
||||||
"is not part of a git repository from which a compare target could be inferred."
|
|
||||||
)
|
|
||||||
print()
|
|
||||||
parser.print_help()
|
parser.print_help()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
@ -284,8 +277,7 @@ if known_args.show is not None:
|
|||||||
if prop not in KEY_PROPERTIES[:-2]: # Last two values are n_prompt, n_gen.
|
if prop not in KEY_PROPERTIES[:-2]: # Last two values are n_prompt, n_gen.
|
||||||
unknown_cols.append(prop)
|
unknown_cols.append(prop)
|
||||||
if unknown_cols:
|
if unknown_cols:
|
||||||
print(f"ERROR: Unknown values for --show: {', '.join(unknown_cols)}")
|
logger.error(f"Unknown values for --show: {', '.join(unknown_cols)}")
|
||||||
print()
|
|
||||||
parser.print_usage()
|
parser.print_usage()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
rows_show = get_rows(show)
|
rows_show = get_rows(show)
|
||||||
@ -369,7 +361,7 @@ if "gpu_info" in show:
|
|||||||
headers = [PRETTY_NAMES[p] for p in show]
|
headers = [PRETTY_NAMES[p] for p in show]
|
||||||
headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"]
|
headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"]
|
||||||
|
|
||||||
print(tabulate(
|
logger.info(tabulate(
|
||||||
table,
|
table,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
floatfmt=".2f",
|
floatfmt=".2f",
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
@ -7,6 +8,8 @@ import sys
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
logger = logging.getLogger("run-with-preset")
|
||||||
|
|
||||||
CLI_ARGS_MAIN_PERPLEXITY = [
|
CLI_ARGS_MAIN_PERPLEXITY = [
|
||||||
"batch-size", "cfg-negative-prompt", "cfg-scale", "chunks", "color", "ctx-size", "escape",
|
"batch-size", "cfg-negative-prompt", "cfg-scale", "chunks", "color", "ctx-size", "escape",
|
||||||
"export", "file", "frequency-penalty", "grammar", "grammar-file", "hellaswag",
|
"export", "file", "frequency-penalty", "grammar", "grammar-file", "hellaswag",
|
||||||
@ -56,6 +59,7 @@ parser.add_argument("-bin", "--binary", help="The binary to run.")
|
|||||||
parser.add_argument("yaml_files", nargs="*",
|
parser.add_argument("yaml_files", nargs="*",
|
||||||
help="Arbitrary number of YAML files from which to read preset values. "
|
help="Arbitrary number of YAML files from which to read preset values. "
|
||||||
"If two files specify the same values the later one will be used.")
|
"If two files specify the same values the later one will be used.")
|
||||||
|
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||||
|
|
||||||
known_args, unknown_args = parser.parse_known_args()
|
known_args, unknown_args = parser.parse_known_args()
|
||||||
|
|
||||||
@ -63,6 +67,8 @@ if not known_args.yaml_files and not unknown_args:
|
|||||||
parser.print_help()
|
parser.print_help()
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO)
|
||||||
|
|
||||||
props = dict()
|
props = dict()
|
||||||
|
|
||||||
for yaml_file in known_args.yaml_files:
|
for yaml_file in known_args.yaml_files:
|
||||||
@ -85,7 +91,7 @@ elif binary.lower().endswith("llama-bench"):
|
|||||||
elif binary.lower().endswith("server"):
|
elif binary.lower().endswith("server"):
|
||||||
cli_args = CLI_ARGS_SERVER
|
cli_args = CLI_ARGS_SERVER
|
||||||
else:
|
else:
|
||||||
print(f"Unknown binary: {binary}")
|
logger.error(f"Unknown binary: {binary}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
command_list = [binary]
|
command_list = [binary]
|
||||||
@ -121,11 +127,11 @@ for cli_arg in cli_args:
|
|||||||
|
|
||||||
num_unused = len(props)
|
num_unused = len(props)
|
||||||
if num_unused > 10:
|
if num_unused > 10:
|
||||||
print(f"The preset file contained a total of {num_unused} unused properties.")
|
logger.info(f"The preset file contained a total of {num_unused} unused properties.")
|
||||||
elif num_unused > 0:
|
elif num_unused > 0:
|
||||||
print("The preset file contained the following unused properties:")
|
logger.info("The preset file contained the following unused properties:")
|
||||||
for prop, value in props.items():
|
for prop, value in props.items():
|
||||||
print(f" {prop}: {value}")
|
logger.info(f" {prop}: {value}")
|
||||||
|
|
||||||
command_list += unknown_args
|
command_list += unknown_args
|
||||||
|
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
|
logger = logging.getLogger("verify-checksum-models")
|
||||||
|
|
||||||
|
|
||||||
def sha256sum(file):
|
def sha256sum(file):
|
||||||
block_size = 16 * 1024 * 1024 # 16 MB block size
|
block_size = 16 * 1024 * 1024 # 16 MB block size
|
||||||
@ -27,7 +30,7 @@ hash_list_file = os.path.join(llama_path, "SHA256SUMS")
|
|||||||
|
|
||||||
# Check if the hash list file exists
|
# Check if the hash list file exists
|
||||||
if not os.path.exists(hash_list_file):
|
if not os.path.exists(hash_list_file):
|
||||||
print(f"Hash list file not found: {hash_list_file}")
|
logger.error(f"Hash list file not found: {hash_list_file}")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
# Read the hash file content and split it into an array of lines
|
# Read the hash file content and split it into an array of lines
|
||||||
@ -46,7 +49,7 @@ for line in hash_list:
|
|||||||
file_path = os.path.join(llama_path, filename)
|
file_path = os.path.join(llama_path, filename)
|
||||||
|
|
||||||
# Informing user of the progress of the integrity check
|
# Informing user of the progress of the integrity check
|
||||||
print(f"Verifying the checksum of {file_path}")
|
logger.info(f"Verifying the checksum of {file_path}")
|
||||||
|
|
||||||
# Check if the file exists
|
# Check if the file exists
|
||||||
if os.path.exists(file_path):
|
if os.path.exists(file_path):
|
||||||
@ -73,9 +76,9 @@ for line in hash_list:
|
|||||||
|
|
||||||
|
|
||||||
# Print column headers for results table
|
# Print column headers for results table
|
||||||
print("\n" + "filename".ljust(40) + "valid checksum".center(20) + "file missing".center(20))
|
print("filename".ljust(40) + "valid checksum".center(20) + "file missing".center(20)) # noqa: NP100
|
||||||
print("-" * 80)
|
print("-" * 80) # noqa: NP100
|
||||||
|
|
||||||
# Output the results as a table
|
# Output the results as a table
|
||||||
for r in results:
|
for r in results:
|
||||||
print(f"{r['filename']:40} {r['valid checksum']:^20} {r['file missing']:^20}")
|
print(f"{r['filename']:40} {r['valid checksum']:^20} {r['file missing']:^20}") # noqa: NP100
|
||||||
|
@ -7,15 +7,20 @@
|
|||||||
# python3 tests/test-tokenizer-0-bpe.py ~/Data/huggingface/deepseek-coder-6.7b-instruct/
|
# python3 tests/test-tokenizer-0-bpe.py ~/Data/huggingface/deepseek-coder-6.7b-instruct/
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
logger = logging.getLogger("test-tokenizer-0-bpe")
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")
|
parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")
|
||||||
parser.add_argument("--fname-tok", help="path to a text file to tokenize")
|
parser.add_argument("--fname-tok", help="path to a text file to tokenize")
|
||||||
args = parser.parse_args()
|
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||||
dir_tokenizer = args.dir_tokenizer
|
dir_tokenizer = args.dir_tokenizer
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(dir_tokenizer)
|
tokenizer = AutoTokenizer.from_pretrained(dir_tokenizer)
|
||||||
@ -64,30 +69,34 @@ tests = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
for text in tests:
|
for text in tests:
|
||||||
print('text: ', text)
|
logger.info(f"text: {text}")
|
||||||
print(tokenizer.encode(text))
|
logger.info(tokenizer.encode(text))
|
||||||
print(tokenizer.decode(tokenizer.encode(text)))
|
logger.info(tokenizer.decode(tokenizer.encode(text)))
|
||||||
|
|
||||||
print("\n\ntests for C++:\n")
|
logger.info("tests for C++:")
|
||||||
for text in tests:
|
for text in tests:
|
||||||
res = tokenizer.encode(text)
|
res = tokenizer.encode(text)
|
||||||
|
|
||||||
|
# Modify text representation for logging
|
||||||
k = text.replace('\n', '\\n')
|
k = text.replace('\n', '\\n')
|
||||||
k = k.replace('\t', '\\t')
|
k = k.replace('\t', '\\t')
|
||||||
k = '"' + k + '"'
|
k = '"' + k + '"'
|
||||||
print("{ %-24s, { " % k, end='')
|
|
||||||
for x in res:
|
|
||||||
print("%7d," % x, end='')
|
|
||||||
print(" }, },")
|
|
||||||
|
|
||||||
print(tokenizer.encode('hello'))
|
# Log the modified text and its encoding
|
||||||
print(tokenizer.encode('world'))
|
log_message = "{ %-24s, { " % k
|
||||||
print(tokenizer.encode(' world'))
|
for x in res:
|
||||||
print(tokenizer.encode('hello world'))
|
log_message += "%7d," % x
|
||||||
|
log_message += " }, },"
|
||||||
|
logger.info(log_message)
|
||||||
|
|
||||||
|
logger.info(tokenizer.encode('hello'))
|
||||||
|
logger.info(tokenizer.encode('world'))
|
||||||
|
logger.info(tokenizer.encode(' world'))
|
||||||
|
logger.info(tokenizer.encode('hello world'))
|
||||||
|
|
||||||
fname_tok = args.fname_tok
|
fname_tok = args.fname_tok
|
||||||
if fname_tok:
|
if fname_tok:
|
||||||
print('tokenizing file: ', fname_tok)
|
logger.info(f"tokenizing file: {fname_tok}")
|
||||||
fname_out = fname_tok + '.tok'
|
fname_out = fname_tok + '.tok'
|
||||||
with open(fname_tok, 'r', encoding='utf-8') as f:
|
with open(fname_tok, 'r', encoding='utf-8') as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
@ -112,6 +121,6 @@ if fname_tok:
|
|||||||
# else:
|
# else:
|
||||||
# f.write(str(x) + ' \'' + tokenizer.decode(x) + '\'\n')
|
# f.write(str(x) + ' \'' + tokenizer.decode(x) + '\'\n')
|
||||||
f.write(str(x) + ' \'' + tokenizer.decode(x).strip() + '\'\n')
|
f.write(str(x) + ' \'' + tokenizer.decode(x).strip() + '\'\n')
|
||||||
print('len(res): ', len(res))
|
logger.info(f"len(res): {len(res)}")
|
||||||
print('len(lines): ', len(lines))
|
logger.info(f"len(lines): {len(lines)}")
|
||||||
print('results written to: ', fname_out)
|
logger.info(f"results written to: {fname_out}")
|
||||||
|
@ -7,15 +7,22 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from sentencepiece import SentencePieceProcessor
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
|
||||||
|
logger = logging.getLogger("test-tokenizer-0-spm")
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")
|
parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")
|
||||||
parser.add_argument("--fname-tok", help="path to a text file to tokenize")
|
parser.add_argument("--fname-tok", help="path to a text file to tokenize")
|
||||||
|
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||||
|
|
||||||
dir_tokenizer = args.dir_tokenizer
|
dir_tokenizer = args.dir_tokenizer
|
||||||
|
|
||||||
tokenizer = SentencePieceProcessor(dir_tokenizer + '/tokenizer.model')
|
tokenizer = SentencePieceProcessor(dir_tokenizer + '/tokenizer.model')
|
||||||
@ -65,41 +72,46 @@ tests = [
|
|||||||
|
|
||||||
|
|
||||||
for text in tests:
|
for text in tests:
|
||||||
print('text: ', text)
|
message_log = (f"text: {text}\n"
|
||||||
print('\nwith bos:')
|
"with bos:\n"
|
||||||
print(tokenizer.encode(text, add_bos=True))
|
f"{tokenizer.encode(text, add_bos=True)}\n"
|
||||||
print(tokenizer.decode(tokenizer.encode(text, add_bos=True)))
|
f"{tokenizer.decode(tokenizer.encode(text, add_bos=True))}\n"
|
||||||
print('\nwithout bos:')
|
"without bos:\n"
|
||||||
print(tokenizer.encode(text, add_bos=False))
|
f"{tokenizer.encode(text, add_bos=False)}\n"
|
||||||
print(tokenizer.decode(tokenizer.encode(text, add_bos=False)))
|
f"{tokenizer.decode(tokenizer.encode(text, add_bos=False))}\n")
|
||||||
|
logger.info(message_log)
|
||||||
|
|
||||||
print("'" + tokenizer.id_to_piece(15043) + "'") # '_Hello'
|
logger.info(f"'{tokenizer.id_to_piece(15043)}'") # '_Hello'
|
||||||
print("'" + tokenizer.id_to_piece(29871) + "'") # '_'
|
logger.info(f"'{tokenizer.id_to_piece(29871)}'") # '_'
|
||||||
print("'" + tokenizer.decode([15043]) + "'") # 'Hello'
|
logger.info(f"'{tokenizer.decode([15043])}'") # 'Hello'
|
||||||
print("'" + tokenizer.decode([15043, 15043]) + "'") # 'Hello Hello'
|
logger.info(f"'{tokenizer.decode([15043, 15043])}'") # 'Hello Hello'
|
||||||
print("'" + tokenizer.decode([29871, 15043]) + "'") # ' Hello'
|
logger.info(f"'{tokenizer.decode([29871, 15043])}'") # ' Hello'
|
||||||
print("'" + tokenizer.decode([29871, 15043, 29871, 15043]) + "'") # ' Hello Hello'
|
logger.info(f"'{tokenizer.decode([29871, 15043, 29871, 15043])}'") # ' Hello Hello'
|
||||||
|
|
||||||
print("\n\ntests for C++:\n")
|
logger.info("\n\ntests for C++:\n")
|
||||||
for text in tests:
|
for text in tests:
|
||||||
res = tokenizer.encode(text, add_bos=False)
|
res = tokenizer.encode(text, add_bos=False)
|
||||||
|
|
||||||
|
# Modify text representation for logging
|
||||||
k = text.replace('\n', '\\n')
|
k = text.replace('\n', '\\n')
|
||||||
k = k.replace('\t', '\\t')
|
k = k.replace('\t', '\\t')
|
||||||
k = '"' + k + '"'
|
k = '"' + k + '"'
|
||||||
print("{ %-24s, { " % k, end='')
|
|
||||||
for x in res:
|
|
||||||
print("%7d," % x, end='')
|
|
||||||
print(" }, },")
|
|
||||||
|
|
||||||
print(tokenizer.encode('hello'))
|
# Log the modified text and its encoding
|
||||||
print(tokenizer.encode('world'))
|
log_message = "{ %-24s, { " % k
|
||||||
print(tokenizer.encode(' world'))
|
for x in res:
|
||||||
print(tokenizer.encode('hello world'))
|
log_message += "%7d," % x
|
||||||
|
log_message += " }, },"
|
||||||
|
logger.info(log_message)
|
||||||
|
|
||||||
|
logger.info(tokenizer.encode('hello'))
|
||||||
|
logger.info(tokenizer.encode('world'))
|
||||||
|
logger.info(tokenizer.encode(' world'))
|
||||||
|
logger.info(tokenizer.encode('hello world'))
|
||||||
|
|
||||||
fname_tok = args.fname_tok
|
fname_tok = args.fname_tok
|
||||||
if fname_tok:
|
if fname_tok:
|
||||||
print('tokenizing file: ', fname_tok)
|
logger.info(f"tokenizing file: {fname_tok}")
|
||||||
fname_out = fname_tok + '.tok'
|
fname_out = fname_tok + '.tok'
|
||||||
with open(fname_tok, 'r', encoding='utf-8') as f:
|
with open(fname_tok, 'r', encoding='utf-8') as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
@ -109,6 +121,6 @@ if fname_tok:
|
|||||||
with open(fname_out, 'w', encoding='utf-8') as f:
|
with open(fname_out, 'w', encoding='utf-8') as f:
|
||||||
for x in res:
|
for x in res:
|
||||||
f.write(str(x) + ' \'' + tokenizer.decode(x) + '\'\n')
|
f.write(str(x) + ' \'' + tokenizer.decode(x) + '\'\n')
|
||||||
print('len(res): ', len(res))
|
logger.info(f"len(res): {len(res)}")
|
||||||
print('len(lines): ', len(lines))
|
logger.info(f"len(lines): {len(lines)}")
|
||||||
print('results written to: ', fname_out)
|
logger.info(f"results written to: {fname_out}")
|
||||||
|
Loading…
Reference in New Issue
Block a user