llama.cpp/scripts/update_jinja_goldens.py

183 lines
6.3 KiB
Python
Raw Normal View History

#!/usr/bin/env uv run
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "jinja2",
# "huggingface_hub",
# ]
# ///
'''
Fetches the Jinja2 templates of a few known models and use them to generate prompt goldens for a few predefined chat contexts.
2024-09-26 01:27:46 +00:00
Examples:
python ./scripts/update_jinja_goldens.py
2024-09-26 01:27:46 +00:00
https://github.com/huggingface/transformers/blob/main/src/transformers/utils/chat_template_utils.py
'''
2024-09-26 02:59:38 +00:00
import logging
import datetime
import glob
import os
from huggingface_hub import hf_hub_download
import json
import jinja2
import jinja2.ext
import re
# import requests
logging.basicConfig(level=logging.INFO, format='%(message)s')
2024-09-26 02:59:38 +00:00
logger = logging.getLogger(__name__)
model_ids = [
"abacusai/Fewshot-Metamath-OrcaVicuna-Mistral",
"bofenghuang/vigogne-2-70b-chat",
"deepseek-ai/deepseek-coder-33b-instruct",
"deepseek-ai/DeepSeek-Coder-V2-Instruct",
"deepseek-ai/DeepSeek-V2.5",
"indischepartij/MiniCPM-3B-OpenHermes-2.5-v2",
"meetkai/functionary-medium-v3.1",
"meetkai/functionary-medium-v3.2",
"microsoft/Phi-3-medium-4k-instruct",
"microsoft/Phi-3-mini-4k-instruct",
"microsoft/Phi-3-small-8k-instruct",
"microsoft/Phi-3.5-mini-instruct",
"microsoft/Phi-3.5-vision-instruct",
"mlabonne/AlphaMonarch-7B",
"CohereForAI/c4ai-command-r-plus",
"NousResearch/Hermes-2-Pro-Llama-3-8B",
"NousResearch/Hermes-2-Pro-Mistral-7B",
"NousResearch/Hermes-3-Llama-3.1-8B",
"openchat/openchat-3.5-0106",
"OrionStarAI/Orion-14B-Chat",
"Qwen/Qwen2-7B-Instruct",
"Qwen/Qwen2-VL-7B-Instruct",
"Qwen/Qwen2.5-7B-Instruct",
2024-09-26 01:27:46 +00:00
"Qwen/Qwen2.5-Math-7B-Instruct",
"teknium/OpenHermes-2.5-Mistral-7B",
"TheBloke/FusionNet_34Bx2_MoE-AWQ",
# Gated models:
"meta-llama/Llama-3.2-3B-Instruct",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"mistralai/Mistral-Nemo-Instruct-2407",
"google/gemma-7b-it",
"google/gemma-2-2b-it",
"mistralai/Mistral-7B-Instruct-v0.2",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
]
2024-09-26 01:30:17 +00:00
def raise_exception(message: str):
raise ValueError(message)
2024-09-26 01:30:17 +00:00
def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
2024-09-26 01:30:17 +00:00
TEST_DATE = os.environ.get('TEST_DATE', '2024-07-26')
2024-09-28 17:31:51 +00:00
def strftime_now(format):
now = datetime.datetime.strptime(TEST_DATE, "%Y-%m-%d")
# now = datetime.datetime.now()
return now.strftime(format)
2024-09-26 01:30:17 +00:00
def handle_chat_template(model_id, variant, template_src):
logger.info(f"# {model_id}{' @ ' + variant if variant else ''}")
model_name = model_id.replace("/", "-")
base_name = f'{model_name}-{variant}' if variant else model_name
template_file = f'tests/chat/templates/{base_name}.jinja'
logger.info(f'- template_file: {template_file}')
with open(template_file, 'w') as f:
f.write(template_src)
2024-09-26 01:27:46 +00:00
2024-09-26 02:59:38 +00:00
logger.info(f"- {template_file}")
2024-09-26 01:27:46 +00:00
env = jinja2.Environment(
2024-09-26 01:30:17 +00:00
trim_blocks=True,
lstrip_blocks=True,
# keep_trailing_newline=False,
extensions=[
jinja2.ext.loopcontrols
])
env.filters['safe'] = lambda x: x
env.filters['tojson'] = tojson
env.globals['raise_exception'] = raise_exception
env.globals['strftime_now'] = strftime_now
template = env.from_string(template_src)
2024-09-26 01:27:46 +00:00
context_files = glob.glob('tests/chat/contexts/*.json')
for context_file in context_files:
context_name = context_file.split("/")[-1].replace(".json", "")
with open(context_file, 'r') as f:
context = json.load(f)
2024-09-26 01:27:46 +00:00
output_file = f'tests/chat/goldens/{base_name}-{context_name}.txt'
2024-09-26 02:59:38 +00:00
logger.info(f"- {output_file}")
# The template (and workarounds) may modify the context in place, so we need to make a copy of it.
render_context = json.loads(json.dumps(context))
# Work around Llama-3.1 template quirk: it expects tool_call.function.arguments to be an object rather than its JSON string representation.
if 'tool_call.arguments | items' in template_src or 'tool_call.arguments | tojson' in template_src:
for message in render_context['messages']:
if 'tool_calls' in message:
for tool_call in message['tool_calls']:
if tool_call.get('type') == 'function':
arguments = tool_call['function']['arguments']
tool_call['function']['arguments'] = json.loads(arguments)
try:
output = template.render(**render_context)
2024-09-26 01:30:17 +00:00
except Exception as e1:
# Some templates (e.g. Phi-3-medium-128k's) expect a non-null "content" key in each message.
for message in context["messages"]:
if message.get("content") is None:
message["content"] = ""
try:
output = template.render(**render_context)
2024-09-26 01:30:17 +00:00
except Exception as e2:
2024-09-26 02:59:38 +00:00
logger.info(f" ERROR: {e2} (after first error: {e1})")
2024-09-26 01:30:17 +00:00
output = f"ERROR: {e2}"
with open(output_file, 'w') as f:
f.write(output)
2024-09-26 01:27:46 +00:00
logger.info('')
2024-09-26 01:30:17 +00:00
2024-09-26 18:06:29 +00:00
def main():
for dir in ['tests/chat/templates', 'tests/chat/goldens']:
if not os.path.isdir(dir):
os.mkdir(dir)
2024-09-26 01:27:46 +00:00
for model_id in model_ids:
# response = requests.get(f"https://huggingface.co/{model_id}/resolve/main/tokenizer_config.json")
# response.raise_for_status()
# config_str = response.text
with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f:
config_str = f.read()
2024-09-26 01:27:46 +00:00
try:
config = json.loads(config_str)
2024-09-26 01:30:17 +00:00
except json.JSONDecodeError:
# Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json
# (Remove extra '}' near the end of the file)
2024-09-26 01:27:46 +00:00
config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str))
chat_template = config['chat_template']
if isinstance(chat_template, str):
handle_chat_template(model_id, None, chat_template)
else:
for ct in chat_template:
handle_chat_template(model_id, ct['name'], ct['template'])
2024-09-26 01:30:17 +00:00
if __name__ == '__main__':
2024-09-26 01:27:46 +00:00
main()