mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 04:00:16 +00:00
tool-call
: remove duplicate script to fetch templates
This commit is contained in:
parent
ec547e4137
commit
b51c71c734
@ -1,155 +0,0 @@
|
||||
#!/usr/bin/env uv run
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "jinja2",
|
||||
# "huggingface_hub",
|
||||
# ]
|
||||
# ///
|
||||
'''
|
||||
Fetches the Jinja2 templates of specified models and generates prompt goldens for predefined chat contexts.
|
||||
Outputs lines of arguments for a C++ test binary.
|
||||
All files are written to the specified output folder.
|
||||
|
||||
Usage:
|
||||
python ./update_jinja_goldens.py output_folder context1.json context2.json ... model_id1 model_id2 ...
|
||||
|
||||
Example:
|
||||
python ./update_jinja_goldens.py ./test_files "microsoft/Phi-3-medium-4k-instruct" "Qwen/Qwen2-7B-Instruct"
|
||||
'''
|
||||
|
||||
import logging
|
||||
import datetime
|
||||
import glob
|
||||
import os
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
import jinja2
|
||||
import jinja2.ext
|
||||
import re
|
||||
import argparse
|
||||
import shutil
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def raise_exception(message: str):
|
||||
raise ValueError(message)
|
||||
|
||||
|
||||
def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
|
||||
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
|
||||
|
||||
|
||||
TEST_DATE = os.environ.get('TEST_DATE', '2024-07-26')
|
||||
|
||||
|
||||
def strftime_now(format):
|
||||
now = datetime.datetime.strptime(TEST_DATE, "%Y-%m-%d")
|
||||
return now.strftime(format)
|
||||
|
||||
|
||||
def handle_chat_template(output_folder, model_id, variant, template_src):
|
||||
model_name = model_id.replace("/", "-")
|
||||
base_name = f'{model_name}-{variant}' if variant else model_name
|
||||
template_file = os.path.join(output_folder, f'{base_name}.jinja')
|
||||
|
||||
with open(template_file, 'w') as f:
|
||||
f.write(template_src)
|
||||
|
||||
env = jinja2.Environment(
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
extensions=[jinja2.ext.loopcontrols]
|
||||
)
|
||||
env.filters['safe'] = lambda x: x
|
||||
env.filters['tojson'] = tojson
|
||||
env.globals['raise_exception'] = raise_exception
|
||||
env.globals['strftime_now'] = strftime_now
|
||||
|
||||
template_handles_tools = 'tools' in template_src
|
||||
template_hates_the_system = 'System role not supported' in template_src
|
||||
|
||||
template = env.from_string(template_src)
|
||||
|
||||
context_files = glob.glob(os.path.join(output_folder, '*.json'))
|
||||
for context_file in context_files:
|
||||
context_name = os.path.basename(context_file).replace(".json", "")
|
||||
with open(context_file, 'r') as f:
|
||||
context = json.load(f)
|
||||
|
||||
if not template_handles_tools and 'tools' in context:
|
||||
continue
|
||||
|
||||
if template_hates_the_system and any(m['role'] == 'system' for m in context['messages']):
|
||||
continue
|
||||
|
||||
output_file = os.path.join(output_folder, f'{base_name}-{context_name}.txt')
|
||||
|
||||
render_context = json.loads(json.dumps(context))
|
||||
|
||||
if 'tool_call.arguments | items' in template_src or 'tool_call.arguments | tojson' in template_src:
|
||||
for message in render_context['messages']:
|
||||
if 'tool_calls' in message:
|
||||
for tool_call in message['tool_calls']:
|
||||
if tool_call.get('type') == 'function':
|
||||
arguments = tool_call['function']['arguments']
|
||||
tool_call['function']['arguments'] = json.loads(arguments)
|
||||
|
||||
try:
|
||||
output = template.render(**render_context)
|
||||
except Exception as e1:
|
||||
for message in context["messages"]:
|
||||
if message.get("content") is None:
|
||||
message["content"] = ""
|
||||
|
||||
try:
|
||||
output = template.render(**render_context)
|
||||
except Exception as e2:
|
||||
logger.info(f" ERROR: {e2} (after first error: {e1})")
|
||||
output = f"ERROR: {e2}"
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
f.write(output)
|
||||
|
||||
# Output the line of arguments for the C++ test binary
|
||||
print(f"{template_file} {context_file} {output_file}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Generate chat templates and output test arguments.")
|
||||
parser.add_argument("output_folder", help="Folder to store all output files")
|
||||
parser.add_argument("model_ids", nargs="+", help="List of model IDs to process")
|
||||
args = parser.parse_args()
|
||||
|
||||
output_folder = args.output_folder
|
||||
if not os.path.isdir(output_folder):
|
||||
os.makedirs(output_folder)
|
||||
|
||||
# Copy context files to the output folder
|
||||
for context_file in glob.glob('tests/chat/contexts/*.json'):
|
||||
shutil.copy(context_file, output_folder)
|
||||
|
||||
for model_id in args.model_ids:
|
||||
try:
|
||||
with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f:
|
||||
config_str = f.read()
|
||||
|
||||
try:
|
||||
config = json.loads(config_str)
|
||||
except json.JSONDecodeError:
|
||||
config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str))
|
||||
|
||||
chat_template = config['chat_template']
|
||||
if isinstance(chat_template, str):
|
||||
handle_chat_template(output_folder, model_id, None, chat_template)
|
||||
else:
|
||||
for ct in chat_template:
|
||||
handle_chat_template(output_folder, model_id, ct['name'], ct['template'])
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing model {model_id}: {e}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -73,7 +73,7 @@ static void test_jinja_templates() {
|
||||
return "tests/chat/goldens/" + golden_name + ".txt";
|
||||
};
|
||||
auto fail_with_golden_instructions = [&]() {
|
||||
throw std::runtime_error("To fetch templates and generate golden files, run `python update_templates_and_goldens.py`");
|
||||
throw std::runtime_error("To fetch templates and generate golden files, run `python scripts/update_jinja_goldens.py`");
|
||||
};
|
||||
if (jinja_template_files.empty()) {
|
||||
std::cerr << "No Jinja templates found in tests/chat/templates" << std::endl;
|
||||
|
Loading…
Reference in New Issue
Block a user