tool-call: fix pyright type errors

This commit is contained in:
ochafik 2024-09-26 03:59:38 +01:00
parent 059babdd9b
commit 4cd82d61dd
2 changed files with 11 additions and 10 deletions

View File

@ -1146,8 +1146,8 @@ async def oai_chat_completions(user_prompt,
max_tokens=n_predict,
stream=enable_streaming,
response_format=payload.get('response_format') or openai.NOT_GIVEN,
tools=payload.get('tools'),
tool_choice=payload.get('tool_choice'),
tools=payload.get('tools') or openai.NOT_GIVEN,
tool_choice=payload.get('tool_choice') or openai.NOT_GIVEN,
seed=seed,
temperature=payload['temperature']
)

View File

@ -15,6 +15,7 @@
https://github.com/huggingface/transformers/blob/main/src/transformers/utils/chat_template_utils.py
'''
import logging
import datetime
import glob
import os
@ -25,6 +26,8 @@ import jinja2.ext
import re
# import requests
logger = logging.getLogger(__name__)
model_ids = [
"NousResearch/Hermes-3-Llama-3.1-70B",
"NousResearch/Hermes-2-Pro-Llama-3-8B",
@ -76,19 +79,19 @@ def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False)
def strftime_now(format):
return datetime.now().strftime(format)
return datetime.datetime.now().strftime(format)
def handle_chat_template(model_id, variant, template_src):
print(f"# {model_id} @ {variant}", flush=True)
logger.info(f"# {model_id} @ {variant}")
model_name = model_id.replace("/", "-")
base_name = f'{model_name}-{variant}' if variant else model_name
template_file = f'tests/chat/templates/{base_name}.jinja'
print(f'template_file: {template_file}')
logger.info(f'template_file: {template_file}')
with open(template_file, 'w') as f:
f.write(template_src)
print(f"- {template_file}", flush=True)
logger.info(f"- {template_file}")
env = jinja2.Environment(
trim_blocks=True,
@ -119,7 +122,7 @@ def handle_chat_template(model_id, variant, template_src):
continue
output_file = f'tests/chat/goldens/{base_name}-{context_name}.txt'
print(f"- {output_file}", flush=True)
logger.info(f"- {output_file}")
try:
output = template.render(**context)
except Exception as e1:
@ -131,14 +134,12 @@ def handle_chat_template(model_id, variant, template_src):
try:
output = template.render(**context)
except Exception as e2:
print(f" ERROR: {e2} (after first error: {e1})", flush=True)
logger.info(f" ERROR: {e2} (after first error: {e1})")
output = f"ERROR: {e2}"
with open(output_file, 'w') as f:
f.write(output)
print()
def main():
for dir in ['tests/chat/templates', 'tests/chat/goldens']: