llama.cpp/scripts/get_hf_chat_template.py
2024-10-02 19:13:28 +01:00

70 lines
2.7 KiB
Python

'''
Fetches the Jinja chat template of a HuggingFace model.
If a model
Syntax:
get_hf_chat_template.py model_id [variant]
Examples:
python ./scripts/get_hf_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct
python ./scripts/get_hf_chat_template.py NousResearch/Hermes-3-Llama-3.1-70B tool_use
python ./scripts/get_hf_chat_template.py meta-llama/Llama-3.2-3B-Instruct
'''
import json
import re
import sys
def main(args):
if len(args) < 1:
raise ValueError("Please provide a model ID and an optional variant name")
model_id = args[0]
variant = None if len(args) < 2 else args[1]
try:
# Use huggingface_hub library if available.
# Allows access to gated models if the user has access and ran `huggingface-cli login`.
from huggingface_hub import hf_hub_download
with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f:
config_str = f.read()
except ImportError:
import requests
assert re.match(r"^[\w.-]+/[\w.-]+$", model_id), f"Invalid model ID: {model_id}"
response = requests.get(f"https://huggingface.co/{model_id}/resolve/main/tokenizer_config.json")
if response.status_code == 401:
raise Exception('Access to this model is gated, please request access, authenticate with `huggingface-cli login` and make sure to run `pip install huggingface_hub`')
response.raise_for_status()
config_str = response.text
try:
config = json.loads(config_str)
except json.JSONDecodeError:
# Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json
# (Remove extra '}' near the end of the file)
config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str))
chat_template = config['chat_template']
if isinstance(chat_template, str):
print(chat_template, end=None)
else:
variants = {
ct['name']: ct['template']
for ct in chat_template
}
format_variants = lambda: ', '.join(f'"{v}"' for v in variants.keys())
if variant is None:
if 'default' not in variants:
raise Exception(f'Please specify a chat template variant (one of {format_variants()})')
variant = 'default'
print(f'Note: picked "default" chat template variant (out of {format_variants()})', file=sys.stderr)
elif variant not in variants:
raise Exception(f"Variant {variant} not found in chat template (found {format_variants()})")
print(variants[variant], end=None)
if __name__ == '__main__':
main(sys.argv[1:])