tool-call: fix agent type lints

This commit is contained in:
ochafik 2024-09-27 03:53:56 +01:00
parent 1e5c0e747e
commit 9295ca95db
4 changed files with 24 additions and 25 deletions

View File

@ -30,4 +30,4 @@
uv run examples/tool-call/agent.py \
--tool-endpoint http://localhost:8088 \
--goal "What is the sum of 2535 squared and 32222000403 then multiplied by one and a half. What's a third of the result?"
```
```

View File

@ -11,12 +11,13 @@
# ///
import json
import openai
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam, ChatCompletionUserMessageParam
from pydantic import BaseModel
import requests
import sys
import typer
from typing import Annotated, List, Optional
import urllib
from typing import Annotated, Optional
import urllib.parse
class OpenAPIMethod:
@ -94,7 +95,7 @@ class OpenAPIMethod:
def main(
goal: Annotated[str, typer.Option()],
api_key: Optional[str] = None,
tool_endpoint: Optional[List[str]] = None,
tool_endpoint: Optional[list[str]] = None,
format: Annotated[Optional[str], typer.Option(help="The output format: either a Python type (e.g. 'float' or a Pydantic model defined in one of the tool files), or a JSON schema, e.g. '{\"format\": \"date\"}'")] = None,
max_iterations: Optional[int] = 10,
parallel_calls: Optional[bool] = False,
@ -102,16 +103,16 @@ def main(
# endpoint: Optional[str] = None,
endpoint: str = "http://localhost:8080/v1/",
):
openai.api_key = api_key
openai.base_url = endpoint
tool_map = {}
tools = []
for url in (tool_endpoint or []):
assert url.startswith('http://') or url.startswith('https://'), f'Tools must be URLs, not local files: {url}'
catalog_url = f'{url}/openapi.json'
catalog_response = requests.get(catalog_url)
catalog_response.raise_for_status()
@ -131,11 +132,11 @@ def main(
)
)
)
sys.stdout.write(f'🛠️ {", ".join(tool_map.keys())}\n')
messages = [
dict(
messages: list[ChatCompletionMessageParam] = [
ChatCompletionUserMessageParam(
role="user",
content=goal,
)
@ -143,7 +144,7 @@ def main(
i = 0
while (max_iterations is None or i < max_iterations):
response = openai.chat.completions.create(
model="gpt-4o",
messages=messages,
@ -152,13 +153,14 @@ def main(
if verbose:
sys.stderr.write(f'# RESPONSE: {response}\n')
assert len(response.choices) == 1
choice = response.choices[0]
content = choice.message.content
if choice.finish_reason == "tool_calls":
messages.append(choice.message)
messages.append(choice.message) # type: ignore
assert choice.message.tool_calls
for tool_call in choice.message.tool_calls:
if content:
print(f'💭 {content}')
@ -169,11 +171,11 @@ def main(
sys.stdout.flush()
tool_result = tool_map[tool_call.function.name](**args)
sys.stdout.write(f"{tool_result}\n")
messages.append(dict(
messages.append(ChatCompletionToolMessageParam(
tool_call_id=tool_call.id,
role="tool",
name=tool_call.function.name,
content=f'{tool_result}',
# name=tool_call.function.name,
content=json.dumps(tool_result),
# content=f'{pretty_call} = {tool_result}',
))
else:

View File

@ -59,8 +59,8 @@ def main(files: List[str], host: str = '0.0.0.0', port: int = 8000):
continue
vt = type(v)
if vt.__module__ == 'langchain_core.tools' and vt.__name__.endswith('Tool') and hasattr(v, 'func') and callable(v.func):
v = v.func
if vt.__module__ == 'langchain_core.tools' and vt.__name__.endswith('Tool') and hasattr(v, 'func') and callable(func := getattr(v, 'func')):
v = func
print(f'INFO: Binding /{k}')
try:
@ -73,4 +73,4 @@ def main(files: List[str], host: str = '0.0.0.0', port: int = 8000):
if __name__ == '__main__':
typer.run(main)
typer.run(main)

View File

@ -1,13 +1,10 @@
from datetime import date
import datetime
import json
from pydantic import BaseModel
import subprocess
import sys
import time
import typer
from typing import Union, Optional, Dict
import types
from typing import Union, Optional, Dict
class Duration(BaseModel):
@ -58,7 +55,7 @@ def wait_for_duration(duration: Duration) -> None:
time.sleep(duration.get_total_seconds)
@staticmethod
def wait_for_date(target_date: date) -> None:
def wait_for_date(target_date: datetime.date) -> None:
f'''
Wait until a specific date is reached before continuing.
Today's date is {datetime.date.today()}