diff --git a/fetch_templates_and_goldens.py b/fetch_templates_and_goldens.py deleted file mode 100644 index a6a1ed209..000000000 --- a/fetch_templates_and_goldens.py +++ /dev/null @@ -1,155 +0,0 @@ -#!/usr/bin/env uv run -# /// script -# requires-python = ">=3.10" -# dependencies = [ -# "jinja2", -# "huggingface_hub", -# ] -# /// -''' - Fetches the Jinja2 templates of specified models and generates prompt goldens for predefined chat contexts. - Outputs lines of arguments for a C++ test binary. - All files are written to the specified output folder. - - Usage: - python ./update_jinja_goldens.py output_folder context1.json context2.json ... model_id1 model_id2 ... - - Example: - python ./update_jinja_goldens.py ./test_files "microsoft/Phi-3-medium-4k-instruct" "Qwen/Qwen2-7B-Instruct" -''' - -import logging -import datetime -import glob -import os -from huggingface_hub import hf_hub_download -import json -import jinja2 -import jinja2.ext -import re -import argparse -import shutil - -logging.basicConfig(level=logging.INFO, format='%(message)s') -logger = logging.getLogger(__name__) - - -def raise_exception(message: str): - raise ValueError(message) - - -def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False): - return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys) - - -TEST_DATE = os.environ.get('TEST_DATE', '2024-07-26') - - -def strftime_now(format): - now = datetime.datetime.strptime(TEST_DATE, "%Y-%m-%d") - return now.strftime(format) - - -def handle_chat_template(output_folder, model_id, variant, template_src): - model_name = model_id.replace("/", "-") - base_name = f'{model_name}-{variant}' if variant else model_name - template_file = os.path.join(output_folder, f'{base_name}.jinja') - - with open(template_file, 'w') as f: - f.write(template_src) - - env = jinja2.Environment( - trim_blocks=True, - lstrip_blocks=True, - extensions=[jinja2.ext.loopcontrols] - ) - env.filters['safe'] = lambda x: x - env.filters['tojson'] = tojson - env.globals['raise_exception'] = raise_exception - env.globals['strftime_now'] = strftime_now - - template_handles_tools = 'tools' in template_src - template_hates_the_system = 'System role not supported' in template_src - - template = env.from_string(template_src) - - context_files = glob.glob(os.path.join(output_folder, '*.json')) - for context_file in context_files: - context_name = os.path.basename(context_file).replace(".json", "") - with open(context_file, 'r') as f: - context = json.load(f) - - if not template_handles_tools and 'tools' in context: - continue - - if template_hates_the_system and any(m['role'] == 'system' for m in context['messages']): - continue - - output_file = os.path.join(output_folder, f'{base_name}-{context_name}.txt') - - render_context = json.loads(json.dumps(context)) - - if 'tool_call.arguments | items' in template_src or 'tool_call.arguments | tojson' in template_src: - for message in render_context['messages']: - if 'tool_calls' in message: - for tool_call in message['tool_calls']: - if tool_call.get('type') == 'function': - arguments = tool_call['function']['arguments'] - tool_call['function']['arguments'] = json.loads(arguments) - - try: - output = template.render(**render_context) - except Exception as e1: - for message in context["messages"]: - if message.get("content") is None: - message["content"] = "" - - try: - output = template.render(**render_context) - except Exception as e2: - logger.info(f" ERROR: {e2} (after first error: {e1})") - output = f"ERROR: {e2}" - - with open(output_file, 'w') as f: - f.write(output) - - # Output the line of arguments for the C++ test binary - print(f"{template_file} {context_file} {output_file}") - - -def main(): - parser = argparse.ArgumentParser(description="Generate chat templates and output test arguments.") - parser.add_argument("output_folder", help="Folder to store all output files") - parser.add_argument("model_ids", nargs="+", help="List of model IDs to process") - args = parser.parse_args() - - output_folder = args.output_folder - if not os.path.isdir(output_folder): - os.makedirs(output_folder) - - # Copy context files to the output folder - for context_file in glob.glob('tests/chat/contexts/*.json'): - shutil.copy(context_file, output_folder) - - for model_id in args.model_ids: - try: - with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f: - config_str = f.read() - - try: - config = json.loads(config_str) - except json.JSONDecodeError: - config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str)) - - chat_template = config['chat_template'] - if isinstance(chat_template, str): - handle_chat_template(output_folder, model_id, None, chat_template) - else: - for ct in chat_template: - handle_chat_template(output_folder, model_id, ct['name'], ct['template']) - except Exception as e: - logger.error(f"Error processing model {model_id}: {e}") - - -if __name__ == '__main__': - main() diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 5e0abc0ca..ab7746248 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -73,7 +73,7 @@ static void test_jinja_templates() { return "tests/chat/goldens/" + golden_name + ".txt"; }; auto fail_with_golden_instructions = [&]() { - throw std::runtime_error("To fetch templates and generate golden files, run `python update_templates_and_goldens.py`"); + throw std::runtime_error("To fetch templates and generate golden files, run `python scripts/update_jinja_goldens.py`"); }; if (jinja_template_files.empty()) { std::cerr << "No Jinja templates found in tests/chat/templates" << std::endl;