tool-call: fix pyright type errors

This commit is contained in:
ochafik 2024-09-26 03:59:38 +01:00
parent 059babdd9b
commit 4cd82d61dd
2 changed files with 11 additions and 10 deletions

View File

@ -1146,8 +1146,8 @@ async def oai_chat_completions(user_prompt,
max_tokens=n_predict, max_tokens=n_predict,
stream=enable_streaming, stream=enable_streaming,
response_format=payload.get('response_format') or openai.NOT_GIVEN, response_format=payload.get('response_format') or openai.NOT_GIVEN,
tools=payload.get('tools'), tools=payload.get('tools') or openai.NOT_GIVEN,
tool_choice=payload.get('tool_choice'), tool_choice=payload.get('tool_choice') or openai.NOT_GIVEN,
seed=seed, seed=seed,
temperature=payload['temperature'] temperature=payload['temperature']
) )

View File

@ -15,6 +15,7 @@
https://github.com/huggingface/transformers/blob/main/src/transformers/utils/chat_template_utils.py https://github.com/huggingface/transformers/blob/main/src/transformers/utils/chat_template_utils.py
''' '''
import logging
import datetime import datetime
import glob import glob
import os import os
@ -25,6 +26,8 @@ import jinja2.ext
import re import re
# import requests # import requests
logger = logging.getLogger(__name__)
model_ids = [ model_ids = [
"NousResearch/Hermes-3-Llama-3.1-70B", "NousResearch/Hermes-3-Llama-3.1-70B",
"NousResearch/Hermes-2-Pro-Llama-3-8B", "NousResearch/Hermes-2-Pro-Llama-3-8B",
@ -76,19 +79,19 @@ def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False)
def strftime_now(format): def strftime_now(format):
return datetime.now().strftime(format) return datetime.datetime.now().strftime(format)
def handle_chat_template(model_id, variant, template_src): def handle_chat_template(model_id, variant, template_src):
print(f"# {model_id} @ {variant}", flush=True) logger.info(f"# {model_id} @ {variant}")
model_name = model_id.replace("/", "-") model_name = model_id.replace("/", "-")
base_name = f'{model_name}-{variant}' if variant else model_name base_name = f'{model_name}-{variant}' if variant else model_name
template_file = f'tests/chat/templates/{base_name}.jinja' template_file = f'tests/chat/templates/{base_name}.jinja'
print(f'template_file: {template_file}') logger.info(f'template_file: {template_file}')
with open(template_file, 'w') as f: with open(template_file, 'w') as f:
f.write(template_src) f.write(template_src)
print(f"- {template_file}", flush=True) logger.info(f"- {template_file}")
env = jinja2.Environment( env = jinja2.Environment(
trim_blocks=True, trim_blocks=True,
@ -119,7 +122,7 @@ def handle_chat_template(model_id, variant, template_src):
continue continue
output_file = f'tests/chat/goldens/{base_name}-{context_name}.txt' output_file = f'tests/chat/goldens/{base_name}-{context_name}.txt'
print(f"- {output_file}", flush=True) logger.info(f"- {output_file}")
try: try:
output = template.render(**context) output = template.render(**context)
except Exception as e1: except Exception as e1:
@ -131,14 +134,12 @@ def handle_chat_template(model_id, variant, template_src):
try: try:
output = template.render(**context) output = template.render(**context)
except Exception as e2: except Exception as e2:
print(f" ERROR: {e2} (after first error: {e1})", flush=True) logger.info(f" ERROR: {e2} (after first error: {e1})")
output = f"ERROR: {e2}" output = f"ERROR: {e2}"
with open(output_file, 'w') as f: with open(output_file, 'w') as f:
f.write(output) f.write(output)
print()
def main(): def main():
for dir in ['tests/chat/templates', 'tests/chat/goldens']: for dir in ['tests/chat/templates', 'tests/chat/goldens']: