diff --git a/tests/update_jinja_goldens.py b/tests/update_jinja_goldens.py index fafa6dee0..faefc92e3 100644 --- a/tests/update_jinja_goldens.py +++ b/tests/update_jinja_goldens.py @@ -66,15 +66,19 @@ model_ids = [ "mistralai/Mixtral-8x7B-Instruct-v0.1", ] + def raise_exception(message: str): raise ValueError(message) + def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False): return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys) + def strftime_now(format): return datetime.now().strftime(format) + def handle_chat_template(model_id, variant, template_src): print(f"# {model_id} @ {variant}", flush=True) model_name = model_id.replace("/", "-") @@ -87,12 +91,12 @@ def handle_chat_template(model_id, variant, template_src): print(f"- {template_file}", flush=True) env = jinja2.Environment( - trim_blocks=True, - lstrip_blocks=True, - # keep_trailing_newline=False, - extensions=[ - jinja2.ext.loopcontrols - ]) + trim_blocks=True, + lstrip_blocks=True, + # keep_trailing_newline=False, + extensions=[ + jinja2.ext.loopcontrols + ]) env.filters['tojson'] = tojson env.globals['raise_exception'] = raise_exception env.globals['strftime_now'] = strftime_now @@ -118,7 +122,7 @@ def handle_chat_template(model_id, variant, template_src): print(f"- {output_file}", flush=True) try: output = template.render(**context) - except: + except Exception as e1: # Some templates (e.g. Phi-3-medium-128k's) expect a non-null "content" key in each message. for message in context["messages"]: if message.get("content") is None: @@ -126,15 +130,16 @@ def handle_chat_template(model_id, variant, template_src): try: output = template.render(**context) - except Exception as e: - print(f" ERROR: {e}", flush=True) - output = f"ERROR: {e}" + except Exception as e2: + print(f" ERROR: {e2} (after first error: {e1})", flush=True) + output = f"ERROR: {e2}" with open(output_file, 'w') as f: f.write(output) print() + def main(): for dir in ['tests/chat/templates', 'tests/chat/goldens']: if not os.path.isdir(dir): @@ -149,7 +154,7 @@ def main(): try: config = json.loads(config_str) - except json.JSONDecodeError as e: + except json.JSONDecodeError: # Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json # (Remove extra '}' near the end of the file) config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str)) @@ -161,5 +166,6 @@ def main(): for ct in chat_template: handle_chat_template(model_id, ct['name'], ct['template']) + if __name__ == '__main__': main()