gguf-py : add special token modification capability (#7166)

* Add special token modification capability

To be able to fix/amend special tokens in a GGUF let's add two new arguments:
* `--special-token <name> <value>` where `<name>` can be bos, eos, prefix, middle, etc. while `<value>` is the token value, f.ex. `"<|fim▁begin|>"`
* `--special-token-by-id <name> <id>` where `<id>` is the ID of the token, f.ex. 32006

So, in order to f.ex. add fill-in-middle tokens to a GGUF you would do the following:
```bash
python3 gguf-new-metadata.py input.gguf output.gguf --special-token prefix "<|fim▁begin|>" --special-token middle "<|fim▁hole|>" --special-token suffix "<|fim▁end|>"
```

* improve help text

* flake--

* fix multiple tokens warning

* make script executable

* switch to namedtuple, no need to dataclass

* typing++

* add progress bar

* Add special token modification capability

To be able to fix/amend special tokens in a GGUF let's add two new arguments:
* `--special-token <name> <value>` where `<name>` can be bos, eos, prefix, middle, etc. while `<value>` is the token value, f.ex. `"<|fim▁begin|>"`
* `--special-token-by-id <name> <id>` where `<id>` is the ID of the token, f.ex. 32006

So, in order to f.ex. add fill-in-middle tokens to a GGUF you would do the following:
```bash
gguf-new-metadata.py input.gguf output.gguf --special-token prefix "<|fim▁begin|>" --special-token middle "<|fim▁end|>" --special-token suffix "<|fim▁hole|>"
```
(yes, fim_end is the `middle` token, because completion is a `prefix`/`suffix`/`middle` sequence (where `middle` is unfilled))
or
```bash
gguf-new-metadata.py input.gguf output.gguf --special-token prefix "<fim_prefix>" --special-token middle "<fim_middle>" --special-token suffix "<fim_suffix>"
```
etc...

NB: The tokens have to exist already, trying to add non-existent token name/IDs will be ignored (with a warning), while non-existent values will fail (with an error).

* improve help text

* flake--

* fix multiple tokens warning

* make script executable

* switch to namedtuple, no need to dataclass

* typing++

* add progress bar

* fail on invalid token id
This commit is contained in:
Sigbjørn Skjæret 2024-05-09 12:56:00 +02:00 committed by GitHub
parent 4734524882
commit 22842164bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

92
gguf-py/scripts/gguf-new-metadata.py Normal file → Executable file
View File

@ -7,7 +7,8 @@ import json
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
from typing import Any, Sequence from tqdm import tqdm
from typing import Any, Sequence, NamedTuple
# Necessary to load the local gguf package # Necessary to load the local gguf package
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists(): if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
@ -18,6 +19,12 @@ import gguf
logger = logging.getLogger("gguf-new-metadata") logger = logging.getLogger("gguf-new-metadata")
class MetadataDetails(NamedTuple):
type: gguf.GGUFValueType
value: Any
description: str = ''
def get_byteorder(reader: gguf.GGUFReader) -> gguf.GGUFEndian: def get_byteorder(reader: gguf.GGUFReader) -> gguf.GGUFEndian:
if np.uint32(1) == np.uint32(1).newbyteorder("<"): if np.uint32(1) == np.uint32(1).newbyteorder("<"):
# Host is little endian # Host is little endian
@ -59,7 +66,16 @@ def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
return decode_field(field) return decode_field(field)
def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: dict[str, str], remove_metadata: Sequence[str]) -> None: def find_token(token_list: Sequence[int], token: str) -> Sequence[int]:
token_ids = [index for index, value in enumerate(token_list) if value == token]
if len(token_ids) == 0:
raise LookupError(f'Unable to find "{token}" in token list!')
return token_ids
def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: dict[str, MetadataDetails], remove_metadata: Sequence[str]) -> None:
for field in reader.fields.values(): for field in reader.fields.values():
# Suppress virtual fields and fields written by GGUFWriter # Suppress virtual fields and fields written by GGUFWriter
if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'): if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
@ -75,54 +91,64 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
logger.debug(f'Removing {field.name}') logger.debug(f'Removing {field.name}')
continue continue
old_val = decode_field(field) old_val = MetadataDetails(field.types[0], decode_field(field))
val = new_metadata.get(field.name, old_val) val = new_metadata.get(field.name, old_val)
if field.name in new_metadata: if field.name in new_metadata:
logger.debug(f'Modifying {field.name}: "{old_val}" -> "{val}"') logger.debug(f'Modifying {field.name}: "{old_val.value}" -> "{val.value}" {val.description}')
del new_metadata[field.name] del new_metadata[field.name]
elif val is not None: elif val.value is not None:
logger.debug(f'Copying {field.name}') logger.debug(f'Copying {field.name}')
if val is not None: if val.value is not None:
writer.add_key(field.name) writer.add_key(field.name)
writer.add_val(val, field.types[0]) writer.add_val(val.value, val.type)
if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata: if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
logger.debug('Adding chat template(s)') logger.debug('Adding chat template(s)')
writer.add_chat_template(new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE]) writer.add_chat_template(new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE].value)
del new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] del new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE]
# TODO: Support other types than string?
for key, val in new_metadata.items(): for key, val in new_metadata.items():
logger.debug(f'Adding {key}: {val}') logger.debug(f'Adding {key}: "{val.value}" {val.description}')
writer.add_key(key) writer.add_key(key)
writer.add_val(val, gguf.GGUFValueType.STRING) writer.add_val(val.value, val.type)
total_bytes = 0
for tensor in reader.tensors: for tensor in reader.tensors:
total_bytes += tensor.n_bytes
# Dimensions are written in reverse order, so flip them first # Dimensions are written in reverse order, so flip them first
shape = np.flipud(tensor.shape).tolist() shape = np.flipud(tensor.shape).tolist()
writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type) writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
writer.write_header_to_file() writer.write_header_to_file()
writer.write_kv_data_to_file() writer.write_kv_data_to_file()
writer.write_ti_data_to_file() writer.write_ti_data_to_file()
for tensor in reader.tensors: for tensor in reader.tensors:
writer.write_tensor_data(tensor.data) writer.write_tensor_data(tensor.data)
bar.update(tensor.n_bytes)
writer.close() writer.close()
def main() -> None: def main() -> None:
tokenizer_metadata = (getattr(gguf.Keys.Tokenizer, n) for n in gguf.Keys.Tokenizer.__dict__.keys() if not n.startswith('_'))
token_names = dict((n.split('.')[-1][:-len('_token_id')], n) for n in tokenizer_metadata if n.endswith('_token_id'))
parser = argparse.ArgumentParser(description="Make a copy of a GGUF file with new metadata") parser = argparse.ArgumentParser(description="Make a copy of a GGUF file with new metadata")
parser.add_argument("input", type=Path, help="GGUF format model input filename") parser.add_argument("input", type=Path, help="GGUF format model input filename")
parser.add_argument("output", type=Path, help="GGUF format model output filename") parser.add_argument("output", type=Path, help="GGUF format model output filename")
parser.add_argument("--general-name", type=str, help="The models general.name") parser.add_argument("--general-name", type=str, help="The models general.name", metavar='"name"')
parser.add_argument("--general-description", type=str, help="The models general.description") parser.add_argument("--general-description", type=str, help="The models general.description", metavar='"Description ..."')
parser.add_argument("--chat-template", type=str, help="Chat template string (or JSON string containing templates)") parser.add_argument("--chat-template", type=str, help="Chat template string (or JSON string containing templates)", metavar='"{% ... %} ..."')
parser.add_argument("--chat-template-config", type=Path, help="Config file (tokenizer_config.json) containing chat template(s)") parser.add_argument("--chat-template-config", type=Path, help="Config file containing chat template(s)", metavar='tokenizer_config.json')
parser.add_argument("--remove-metadata", action="append", type=str, help="Remove metadata (by key name) from output model") parser.add_argument("--remove-metadata", action="append", type=str, help="Remove metadata (by key name) from output model", metavar='general.url')
parser.add_argument("--special-token", action="append", type=str, help="Special token by value", nargs=2, metavar=(' | '.join(token_names.keys()), '"<token>"'))
parser.add_argument("--special-token-by-id", action="append", type=str, help="Special token by id", nargs=2, metavar=(' | '.join(token_names.keys()), '0'))
parser.add_argument("--force", action="store_true", help="Bypass warnings without confirmation") parser.add_argument("--force", action="store_true", help="Bypass warnings without confirmation")
parser.add_argument("--verbose", action="store_true", help="Increase output verbosity") parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
args = parser.parse_args(None if len(sys.argv) > 2 else ["--help"]) args = parser.parse_args(None if len(sys.argv) > 2 else ["--help"])
@ -133,20 +159,20 @@ def main() -> None:
remove_metadata = args.remove_metadata or [] remove_metadata = args.remove_metadata or []
if args.general_name: if args.general_name:
new_metadata[gguf.Keys.General.NAME] = args.general_name new_metadata[gguf.Keys.General.NAME] = MetadataDetails(gguf.GGUFValueType.STRING, args.general_name)
if args.general_description: if args.general_description:
new_metadata[gguf.Keys.General.DESCRIPTION] = args.general_description new_metadata[gguf.Keys.General.DESCRIPTION] = MetadataDetails(gguf.GGUFValueType.STRING, args.general_description)
if args.chat_template: if args.chat_template:
new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = json.loads(args.chat_template) if args.chat_template.startswith('[') else args.chat_template new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, json.loads(args.chat_template) if args.chat_template.startswith('[') else args.chat_template)
if args.chat_template_config: if args.chat_template_config:
with open(args.chat_template_config, 'r') as fp: with open(args.chat_template_config, 'r') as fp:
config = json.load(fp) config = json.load(fp)
template = config.get('chat_template') template = config.get('chat_template')
if template: if template:
new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = template new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, template)
if remove_metadata: if remove_metadata:
logger.warning('*** Warning *** Warning *** Warning **') logger.warning('*** Warning *** Warning *** Warning **')
@ -166,6 +192,32 @@ def main() -> None:
arch = get_field_data(reader, gguf.Keys.General.ARCHITECTURE) arch = get_field_data(reader, gguf.Keys.General.ARCHITECTURE)
endianess = get_byteorder(reader) endianess = get_byteorder(reader)
token_list = get_field_data(reader, gguf.Keys.Tokenizer.LIST) or []
for name, token in args.special_token or []:
if name not in token_names:
logger.warning(f'Unknown special token "{name}", ignoring...')
else:
ids = find_token(token_list, token)
new_metadata[token_names[name]] = MetadataDetails(gguf.GGUFValueType.UINT32, ids[0], f'= {token}')
if len(ids) > 1:
logger.warning(f'Multiple "{token}" tokens found, choosing ID {ids[0]}, use --special-token-by-id if you want another:')
logger.warning(', '.join(str(i) for i in ids))
for name, id_string in args.special_token_by_id or []:
if name not in token_names:
logger.warning(f'Unknown special token "{name}", ignoring...')
elif not id_string.isdecimal():
raise LookupError(f'Token ID "{id_string}" is not a valid ID!')
else:
id_int = int(id_string)
if id_int >= 0 and id_int < len(token_list):
new_metadata[token_names[name]] = MetadataDetails(gguf.GGUFValueType.UINT32, id_int, f'= {token_list[id_int]}')
else:
raise LookupError(f'Token ID {id_int} is not within token list!')
if os.path.isfile(args.output) and not args.force: if os.path.isfile(args.output) and not args.force:
logger.warning('*** Warning *** Warning *** Warning **') logger.warning('*** Warning *** Warning *** Warning **')
logger.warning(f'* The "{args.output}" GGUF file already exists, it will be overwritten!') logger.warning(f'* The "{args.output}" GGUF file already exists, it will be overwritten!')