mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 12:10:18 +00:00
fix flake8 lints
This commit is contained in:
parent
1b6280102b
commit
76d2938ef8
@ -66,15 +66,19 @@ model_ids = [
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
]
|
||||
|
||||
|
||||
def raise_exception(message: str):
|
||||
raise ValueError(message)
|
||||
|
||||
|
||||
def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
|
||||
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
|
||||
|
||||
|
||||
def strftime_now(format):
|
||||
return datetime.now().strftime(format)
|
||||
|
||||
|
||||
def handle_chat_template(model_id, variant, template_src):
|
||||
print(f"# {model_id} @ {variant}", flush=True)
|
||||
model_name = model_id.replace("/", "-")
|
||||
@ -87,12 +91,12 @@ def handle_chat_template(model_id, variant, template_src):
|
||||
print(f"- {template_file}", flush=True)
|
||||
|
||||
env = jinja2.Environment(
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
# keep_trailing_newline=False,
|
||||
extensions=[
|
||||
jinja2.ext.loopcontrols
|
||||
])
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
# keep_trailing_newline=False,
|
||||
extensions=[
|
||||
jinja2.ext.loopcontrols
|
||||
])
|
||||
env.filters['tojson'] = tojson
|
||||
env.globals['raise_exception'] = raise_exception
|
||||
env.globals['strftime_now'] = strftime_now
|
||||
@ -118,7 +122,7 @@ def handle_chat_template(model_id, variant, template_src):
|
||||
print(f"- {output_file}", flush=True)
|
||||
try:
|
||||
output = template.render(**context)
|
||||
except:
|
||||
except Exception as e1:
|
||||
# Some templates (e.g. Phi-3-medium-128k's) expect a non-null "content" key in each message.
|
||||
for message in context["messages"]:
|
||||
if message.get("content") is None:
|
||||
@ -126,15 +130,16 @@ def handle_chat_template(model_id, variant, template_src):
|
||||
|
||||
try:
|
||||
output = template.render(**context)
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}", flush=True)
|
||||
output = f"ERROR: {e}"
|
||||
except Exception as e2:
|
||||
print(f" ERROR: {e2} (after first error: {e1})", flush=True)
|
||||
output = f"ERROR: {e2}"
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
f.write(output)
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def main():
|
||||
for dir in ['tests/chat/templates', 'tests/chat/goldens']:
|
||||
if not os.path.isdir(dir):
|
||||
@ -149,7 +154,7 @@ def main():
|
||||
|
||||
try:
|
||||
config = json.loads(config_str)
|
||||
except json.JSONDecodeError as e:
|
||||
except json.JSONDecodeError:
|
||||
# Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json
|
||||
# (Remove extra '}' near the end of the file)
|
||||
config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str))
|
||||
@ -161,5 +166,6 @@ def main():
|
||||
for ct in chat_template:
|
||||
handle_chat_template(model_id, ct['name'], ct['template'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user