mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 04:00:16 +00:00
tool-call
: fix pyright type errors
This commit is contained in:
parent
059babdd9b
commit
4cd82d61dd
@ -1146,8 +1146,8 @@ async def oai_chat_completions(user_prompt,
|
||||
max_tokens=n_predict,
|
||||
stream=enable_streaming,
|
||||
response_format=payload.get('response_format') or openai.NOT_GIVEN,
|
||||
tools=payload.get('tools'),
|
||||
tool_choice=payload.get('tool_choice'),
|
||||
tools=payload.get('tools') or openai.NOT_GIVEN,
|
||||
tool_choice=payload.get('tool_choice') or openai.NOT_GIVEN,
|
||||
seed=seed,
|
||||
temperature=payload['temperature']
|
||||
)
|
||||
|
@ -15,6 +15,7 @@
|
||||
https://github.com/huggingface/transformers/blob/main/src/transformers/utils/chat_template_utils.py
|
||||
'''
|
||||
|
||||
import logging
|
||||
import datetime
|
||||
import glob
|
||||
import os
|
||||
@ -25,6 +26,8 @@ import jinja2.ext
|
||||
import re
|
||||
# import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
model_ids = [
|
||||
"NousResearch/Hermes-3-Llama-3.1-70B",
|
||||
"NousResearch/Hermes-2-Pro-Llama-3-8B",
|
||||
@ -76,19 +79,19 @@ def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False)
|
||||
|
||||
|
||||
def strftime_now(format):
|
||||
return datetime.now().strftime(format)
|
||||
return datetime.datetime.now().strftime(format)
|
||||
|
||||
|
||||
def handle_chat_template(model_id, variant, template_src):
|
||||
print(f"# {model_id} @ {variant}", flush=True)
|
||||
logger.info(f"# {model_id} @ {variant}")
|
||||
model_name = model_id.replace("/", "-")
|
||||
base_name = f'{model_name}-{variant}' if variant else model_name
|
||||
template_file = f'tests/chat/templates/{base_name}.jinja'
|
||||
print(f'template_file: {template_file}')
|
||||
logger.info(f'template_file: {template_file}')
|
||||
with open(template_file, 'w') as f:
|
||||
f.write(template_src)
|
||||
|
||||
print(f"- {template_file}", flush=True)
|
||||
logger.info(f"- {template_file}")
|
||||
|
||||
env = jinja2.Environment(
|
||||
trim_blocks=True,
|
||||
@ -119,7 +122,7 @@ def handle_chat_template(model_id, variant, template_src):
|
||||
continue
|
||||
|
||||
output_file = f'tests/chat/goldens/{base_name}-{context_name}.txt'
|
||||
print(f"- {output_file}", flush=True)
|
||||
logger.info(f"- {output_file}")
|
||||
try:
|
||||
output = template.render(**context)
|
||||
except Exception as e1:
|
||||
@ -131,14 +134,12 @@ def handle_chat_template(model_id, variant, template_src):
|
||||
try:
|
||||
output = template.render(**context)
|
||||
except Exception as e2:
|
||||
print(f" ERROR: {e2} (after first error: {e1})", flush=True)
|
||||
logger.info(f" ERROR: {e2} (after first error: {e1})")
|
||||
output = f"ERROR: {e2}"
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
f.write(output)
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def main():
|
||||
for dir in ['tests/chat/templates', 'tests/chat/goldens']:
|
||||
|
Loading…
Reference in New Issue
Block a user