fix flake8 lints

This commit is contained in:
ochafik 2024-09-26 02:30:17 +01:00
parent 1b6280102b
commit 76d2938ef8

View File

@ -66,15 +66,19 @@ model_ids = [
"mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mixtral-8x7B-Instruct-v0.1",
] ]
def raise_exception(message: str): def raise_exception(message: str):
raise ValueError(message) raise ValueError(message)
def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False): def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys) return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
def strftime_now(format): def strftime_now(format):
return datetime.now().strftime(format) return datetime.now().strftime(format)
def handle_chat_template(model_id, variant, template_src): def handle_chat_template(model_id, variant, template_src):
print(f"# {model_id} @ {variant}", flush=True) print(f"# {model_id} @ {variant}", flush=True)
model_name = model_id.replace("/", "-") model_name = model_id.replace("/", "-")
@ -87,12 +91,12 @@ def handle_chat_template(model_id, variant, template_src):
print(f"- {template_file}", flush=True) print(f"- {template_file}", flush=True)
env = jinja2.Environment( env = jinja2.Environment(
trim_blocks=True, trim_blocks=True,
lstrip_blocks=True, lstrip_blocks=True,
# keep_trailing_newline=False, # keep_trailing_newline=False,
extensions=[ extensions=[
jinja2.ext.loopcontrols jinja2.ext.loopcontrols
]) ])
env.filters['tojson'] = tojson env.filters['tojson'] = tojson
env.globals['raise_exception'] = raise_exception env.globals['raise_exception'] = raise_exception
env.globals['strftime_now'] = strftime_now env.globals['strftime_now'] = strftime_now
@ -118,7 +122,7 @@ def handle_chat_template(model_id, variant, template_src):
print(f"- {output_file}", flush=True) print(f"- {output_file}", flush=True)
try: try:
output = template.render(**context) output = template.render(**context)
except: except Exception as e1:
# Some templates (e.g. Phi-3-medium-128k's) expect a non-null "content" key in each message. # Some templates (e.g. Phi-3-medium-128k's) expect a non-null "content" key in each message.
for message in context["messages"]: for message in context["messages"]:
if message.get("content") is None: if message.get("content") is None:
@ -126,15 +130,16 @@ def handle_chat_template(model_id, variant, template_src):
try: try:
output = template.render(**context) output = template.render(**context)
except Exception as e: except Exception as e2:
print(f" ERROR: {e}", flush=True) print(f" ERROR: {e2} (after first error: {e1})", flush=True)
output = f"ERROR: {e}" output = f"ERROR: {e2}"
with open(output_file, 'w') as f: with open(output_file, 'w') as f:
f.write(output) f.write(output)
print() print()
def main(): def main():
for dir in ['tests/chat/templates', 'tests/chat/goldens']: for dir in ['tests/chat/templates', 'tests/chat/goldens']:
if not os.path.isdir(dir): if not os.path.isdir(dir):
@ -149,7 +154,7 @@ def main():
try: try:
config = json.loads(config_str) config = json.loads(config_str)
except json.JSONDecodeError as e: except json.JSONDecodeError:
# Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json # Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json
# (Remove extra '}' near the end of the file) # (Remove extra '}' near the end of the file)
config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str)) config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str))
@ -161,5 +166,6 @@ def main():
for ct in chat_template: for ct in chat_template:
handle_chat_template(model_id, ct['name'], ct['template']) handle_chat_template(model_id, ct['name'], ct['template'])
if __name__ == '__main__': if __name__ == '__main__':
main() main()