2023-11-11 05:04:50 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2024-05-03 19:36:41 +00:00
|
|
|
import logging
|
2023-11-11 05:04:50 +00:00
|
|
|
import json
|
|
|
|
import os
|
|
|
|
from pathlib import Path
|
2024-05-08 22:16:38 +00:00
|
|
|
from typing import Any, Callable, Sequence, Mapping, Iterable
|
2023-11-11 05:04:50 +00:00
|
|
|
|
|
|
|
from .gguf_writer import GGUFWriter
|
|
|
|
|
2024-05-03 19:36:41 +00:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2023-11-11 05:04:50 +00:00
|
|
|
|
|
|
|
class SpecialVocab:
|
|
|
|
merges: list[str]
|
|
|
|
add_special_token: dict[str, bool]
|
|
|
|
special_token_ids: dict[str, int]
|
2024-05-08 22:16:38 +00:00
|
|
|
chat_template: str | Sequence[Mapping[str, str]] | None
|
2023-11-11 05:04:50 +00:00
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self, path: str | os.PathLike[str], load_merges: bool = False,
|
2024-05-08 22:16:38 +00:00
|
|
|
special_token_types: Iterable[str] | None = None,
|
2023-11-11 05:04:50 +00:00
|
|
|
n_vocab: int | None = None,
|
|
|
|
):
|
|
|
|
self.special_token_ids = {}
|
|
|
|
self.add_special_token = {}
|
|
|
|
self.n_vocab = n_vocab
|
|
|
|
self.load_merges = load_merges
|
|
|
|
self.merges = []
|
2023-11-19 10:10:52 +00:00
|
|
|
self.chat_template = None
|
2023-11-11 05:04:50 +00:00
|
|
|
if special_token_types is not None:
|
|
|
|
self.special_token_types = special_token_types
|
|
|
|
else:
|
2024-02-15 13:14:37 +00:00
|
|
|
self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad', 'cls', 'mask')
|
2023-11-11 05:04:50 +00:00
|
|
|
self._load(Path(path))
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
return '<SpecialVocab with {} merges, special tokens {}, add special tokens {}>'.format(
|
|
|
|
len(self.merges), self.special_token_ids or "unset", self.add_special_token or "unset",
|
|
|
|
)
|
|
|
|
|
|
|
|
def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
|
|
|
|
if self.merges:
|
|
|
|
if not quiet:
|
2024-05-03 19:36:41 +00:00
|
|
|
logger.info(f'Adding {len(self.merges)} merge(s).')
|
2023-11-11 05:04:50 +00:00
|
|
|
gw.add_token_merges(self.merges)
|
|
|
|
elif self.load_merges:
|
2024-05-03 19:36:41 +00:00
|
|
|
logger.warning('Adding merges requested but no merges found, output may be non-functional.')
|
2023-11-11 05:04:50 +00:00
|
|
|
for typ, tokid in self.special_token_ids.items():
|
|
|
|
id_handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
|
|
|
|
if id_handler is None:
|
2024-05-03 19:36:41 +00:00
|
|
|
logger.warning(f'No handler for special token type {typ} with id {tokid} - skipping')
|
2023-11-11 05:04:50 +00:00
|
|
|
continue
|
|
|
|
if not quiet:
|
2024-05-03 19:36:41 +00:00
|
|
|
logger.info(f'Setting special token type {typ} to {tokid}')
|
2023-11-11 05:04:50 +00:00
|
|
|
id_handler(tokid)
|
|
|
|
for typ, value in self.add_special_token.items():
|
|
|
|
add_handler: Callable[[bool], None] | None = getattr(gw, f'add_add_{typ}_token', None)
|
|
|
|
if add_handler is None:
|
2024-05-03 19:36:41 +00:00
|
|
|
logger.warning(f'No handler for add_{typ}_token with value {value} - skipping')
|
2023-11-11 05:04:50 +00:00
|
|
|
continue
|
|
|
|
if not quiet:
|
2024-05-03 19:36:41 +00:00
|
|
|
logger.info(f'Setting add_{typ}_token to {value}')
|
2023-11-11 05:04:50 +00:00
|
|
|
add_handler(value)
|
2023-11-19 10:10:52 +00:00
|
|
|
if self.chat_template is not None:
|
|
|
|
if not quiet:
|
2024-05-03 19:36:41 +00:00
|
|
|
logger.info(f'Setting chat_template to {self.chat_template}')
|
2023-11-19 10:10:52 +00:00
|
|
|
gw.add_chat_template(self.chat_template)
|
2023-11-11 05:04:50 +00:00
|
|
|
|
|
|
|
def _load(self, path: Path) -> None:
|
|
|
|
self._try_load_from_tokenizer_json(path)
|
|
|
|
self._try_load_from_config_json(path)
|
|
|
|
if self.load_merges and not self.merges:
|
|
|
|
self._try_load_merges_txt(path)
|
|
|
|
|
|
|
|
def _try_load_merges_txt(self, path: Path) -> bool:
|
|
|
|
merges_file = path / 'merges.txt'
|
|
|
|
if not merges_file.is_file():
|
|
|
|
return False
|
py : open merges file as 'utf-8' (#4566)
Otherwise, on Windows converting bling-phi-2-v0 (<https://huggingface.co/llmware/bling-phi-2-v0>) via convert-hf-to-gguf.py will fail with the following error:
```
Traceback (most recent call last):
File "C:\Users\User\git\gguf\convert-hf-to-gguf.py", line 1061, in <module>
model_instance.set_vocab()
File "C:\Users\User\git\gguf\convert-hf-to-gguf.py", line 52, in set_vocab
self._set_vocab_gpt2()
File "C:\Users\User\git\gguf\convert-hf-to-gguf.py", line 264, in _set_vocab_gpt2
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
File "C:\Users\User\git\gguf\gguf\vocab.py", line 33, in __init__
self._load(Path(path))
File "C:\Users\User\git\gguf\gguf\vocab.py", line 81, in _load
self._try_load_merges_txt(path)
File "C:\Users\User\git\gguf\gguf\vocab.py", line 95, in _try_load_merges_txt
for line in fp:
File "C:\Users\User\miniconda3\envs\gguf\lib\encodings\cp1252.py", line 23, in decode
return codecs.charmap_decode(input,self.errors,decoding_table)[0]
UnicodeDecodeError: 'charmap' codec can't decode byte 0x81 in position 1415: character maps to <undefined>
```
2023-12-21 17:07:34 +00:00
|
|
|
with open(merges_file, 'r', encoding = 'utf-8') as fp:
|
2023-11-11 05:04:50 +00:00
|
|
|
first_line = next(fp, '').strip()
|
|
|
|
if not first_line.startswith('#'):
|
|
|
|
fp.seek(0)
|
|
|
|
line_num = 0
|
|
|
|
else:
|
|
|
|
line_num = 1
|
|
|
|
merges = []
|
|
|
|
for line in fp:
|
|
|
|
line_num += 1
|
|
|
|
line = line.strip()
|
|
|
|
if not line:
|
|
|
|
continue
|
|
|
|
parts = line.split(None, 3)
|
|
|
|
if len(parts) != 2:
|
2024-05-03 19:36:41 +00:00
|
|
|
logger.warning(f'{merges_file.name}: Line {line_num}: Entry malformed, ignoring')
|
2023-11-11 05:04:50 +00:00
|
|
|
continue
|
|
|
|
merges.append(f'{parts[0]} {parts[1]}')
|
|
|
|
self.merges = merges
|
|
|
|
return True
|
|
|
|
|
|
|
|
def _set_special_token(self, typ: str, tid: Any) -> None:
|
2023-12-17 15:45:46 +00:00
|
|
|
if not isinstance(tid, int):
|
2023-11-11 05:04:50 +00:00
|
|
|
return
|
2023-12-17 15:45:46 +00:00
|
|
|
if tid < 0:
|
|
|
|
raise ValueError(f'invalid value for special token type {typ}: {tid}')
|
2023-11-11 05:04:50 +00:00
|
|
|
if self.n_vocab is None or tid < self.n_vocab:
|
|
|
|
if typ in self.special_token_ids:
|
|
|
|
return
|
|
|
|
self.special_token_ids[typ] = tid
|
|
|
|
return
|
2024-05-03 19:36:41 +00:00
|
|
|
logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping')
|
2023-11-11 05:04:50 +00:00
|
|
|
|
|
|
|
def _try_load_from_tokenizer_json(self, path: Path) -> bool:
|
|
|
|
tokenizer_file = path / 'tokenizer.json'
|
2023-11-17 02:14:37 +00:00
|
|
|
if tokenizer_file.is_file():
|
|
|
|
with open(tokenizer_file, encoding = 'utf-8') as f:
|
|
|
|
tokenizer = json.load(f)
|
|
|
|
if self.load_merges:
|
|
|
|
merges = tokenizer.get('model', {}).get('merges')
|
|
|
|
if isinstance(merges, list) and merges and isinstance(merges[0], str):
|
|
|
|
self.merges = merges
|
|
|
|
added_tokens = tokenizer.get('added_tokens', {})
|
|
|
|
else:
|
|
|
|
added_tokens = {}
|
2023-11-11 05:04:50 +00:00
|
|
|
tokenizer_config_file = path / 'tokenizer_config.json'
|
2023-11-17 02:14:37 +00:00
|
|
|
if not tokenizer_config_file.is_file():
|
2023-11-11 05:04:50 +00:00
|
|
|
return True
|
|
|
|
with open(tokenizer_config_file, encoding = 'utf-8') as f:
|
|
|
|
tokenizer_config = json.load(f)
|
2023-11-19 10:10:52 +00:00
|
|
|
chat_template = tokenizer_config.get('chat_template')
|
2024-04-18 11:49:01 +00:00
|
|
|
if chat_template is None or isinstance(chat_template, (str, list)):
|
2023-11-19 10:10:52 +00:00
|
|
|
self.chat_template = chat_template
|
|
|
|
else:
|
2024-05-03 19:36:41 +00:00
|
|
|
logger.warning(f'Bad type for chat_template field in {tokenizer_config_file!r} - ignoring')
|
2023-11-11 05:04:50 +00:00
|
|
|
for typ in self.special_token_types:
|
|
|
|
add_entry = tokenizer_config.get(f'add_{typ}_token')
|
|
|
|
if isinstance(add_entry, bool):
|
|
|
|
self.add_special_token[typ] = add_entry
|
|
|
|
entry = tokenizer_config.get(f'{typ}_token')
|
|
|
|
if isinstance(entry, str):
|
|
|
|
tc_content = entry
|
|
|
|
elif isinstance(entry, dict):
|
|
|
|
entry_content = entry.get('content')
|
|
|
|
if not isinstance(entry_content, str):
|
|
|
|
continue
|
|
|
|
tc_content = entry_content
|
|
|
|
else:
|
|
|
|
continue
|
|
|
|
# We only need the first match here.
|
|
|
|
maybe_token_id = next(
|
|
|
|
(atok.get('id') for atok in added_tokens if atok.get('content') == tc_content),
|
|
|
|
None,
|
|
|
|
)
|
|
|
|
self._set_special_token(typ, maybe_token_id)
|
|
|
|
return True
|
|
|
|
|
|
|
|
def _try_load_from_config_json(self, path: Path) -> bool:
|
|
|
|
config_file = path / 'config.json'
|
|
|
|
if not config_file.is_file():
|
|
|
|
return False
|
|
|
|
with open(config_file, encoding = 'utf-8') as f:
|
|
|
|
config = json.load(f)
|
|
|
|
for typ in self.special_token_types:
|
|
|
|
self._set_special_token(typ, config.get(f'{typ}_token_id'))
|
|
|
|
return True
|