tool-call: make agent async

This commit is contained in:
ochafik 2024-09-28 19:11:09 +01:00
parent 05bbba9f8a
commit ef2a020276
3 changed files with 92 additions and 83 deletions

View File

@ -1,29 +1,30 @@
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "aiohttp",
# "fastapi",
# "openai",
# "pydantic",
# "requests",
# "uvicorn",
# "typer",
# "uvicorn",
# ]
# ///
import json
import openai
import asyncio
import aiohttp
from functools import wraps
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam, ChatCompletionUserMessageParam
from pydantic import BaseModel
import requests
import sys
import typer
from typing import Annotated, Optional
import urllib.parse
class OpenAPIMethod:
def __init__(self, url, name, descriptor, catalog):
'''
Wraps a remote OpenAPI method as a Python function.
Wraps a remote OpenAPI method as an async Python function.
'''
self.url = url
self.__name__ = name
@ -69,7 +70,7 @@ class OpenAPIMethod:
required=[name for name, param in self.parameters.items() if param['required']] + ([self.body['name']] if self.body and self.body['required'] else [])
)
def __call__(self, **kwargs):
async def __call__(self, session: aiohttp.ClientSession, **kwargs):
if self.body:
body = kwargs.pop(self.body['name'], None)
if self.body['required']:
@ -86,38 +87,26 @@ class OpenAPIMethod:
assert param['in'] == 'query', 'Only query parameters are supported'
query_params[name] = value
params = "&".join(f"{name}={urllib.parse.quote(value)}" for name, value in query_params.items())
params = "&".join(f"{name}={urllib.parse.quote(str(value))}" for name, value in query_params.items() if value is not None)
url = f'{self.url}?{params}'
response = requests.post(url, json=body)
async with session.post(url, json=body) as response:
response.raise_for_status()
response_json = response.json()
response_json = await response.json()
return response_json
def main(
goal: Annotated[str, typer.Option()],
api_key: str = '<unset>',
tool_endpoint: Optional[list[str]] = None,
max_iterations: Optional[int] = 10,
verbose: bool = False,
endpoint: str = "http://localhost:8080/v1/",
):
openai.api_key = api_key
openai.base_url = endpoint
async def discover_tools(tool_endpoints: list[str], verbose: bool = False) -> tuple[dict, list]:
tool_map = {}
tools = []
# Discover tools using OpenAPI catalogs at the provided endpoints.
for url in (tool_endpoint or []):
async with aiohttp.ClientSession() as session:
for url in tool_endpoints:
assert url.startswith('http://') or url.startswith('https://'), f'Tools must be URLs, not local files: {url}'
catalog_url = f'{url}/openapi.json'
catalog_response = requests.get(catalog_url)
catalog_response.raise_for_status()
catalog = catalog_response.json()
async with session.get(catalog_url) as response:
response.raise_for_status()
catalog = await response.json()
for path, descriptor in catalog['paths'].items():
fn = OpenAPIMethod(url=f'{url}{path}', name=path.replace('/', ' ').strip().replace(' ', '_'), descriptor=descriptor, catalog=catalog)
@ -134,6 +123,30 @@ def main(
)
)
return tool_map, tools
def typer_async_workaround():
'Adapted from https://github.com/fastapi/typer/issues/950#issuecomment-2351076467'
def decorator(f):
@wraps(f)
def wrapper(*args, **kwargs):
return asyncio.run(f(*args, **kwargs))
return wrapper
return decorator
@typer_async_workaround()
async def main(
goal: Annotated[str, typer.Option()],
api_key: str = '<unset>',
tool_endpoint: Optional[list[str]] = None,
max_iterations: Optional[int] = 10,
verbose: bool = False,
endpoint: str = "http://localhost:8080/v1/",
):
client = AsyncOpenAI(api_key=api_key, base_url=endpoint)
tool_map, tools = await discover_tools(tool_endpoint or [], verbose)
sys.stdout.write(f'🛠️ {", ".join(tool_map.keys())}\n')
messages: list[ChatCompletionMessageParam] = [
@ -143,10 +156,9 @@ def main(
)
]
i = 0
while (max_iterations is None or i < max_iterations):
response = openai.chat.completions.create(
async with aiohttp.ClientSession() as session:
for i in range(max_iterations or sys.maxsize):
response = await client.chat.completions.create(
model="gpt-4o",
messages=messages,
tools=tools,
@ -170,22 +182,18 @@ def main(
pretty_call = f'{tool_call.function.name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
sys.stdout.write(f'⚙️ {pretty_call}')
sys.stdout.flush()
tool_result = tool_map[tool_call.function.name](**args)
tool_result = await tool_map[tool_call.function.name](session, **args)
sys.stdout.write(f"{tool_result}\n")
messages.append(ChatCompletionToolMessageParam(
tool_call_id=tool_call.id,
role="tool",
# name=tool_call.function.name,
content=json.dumps(tool_result),
# content=f'{pretty_call} = {tool_result}',
))
else:
assert content
print(content)
return
i += 1
if max_iterations is not None:
raise Exception(f"Failed to get a valid response after {max_iterations} tool calls")

View File

@ -89,7 +89,7 @@ def python(code: str) -> str:
Returns:
str: The output of the executed code.
"""
from IPython import InteractiveShell
from IPython.core.interactiveshell import InteractiveShell
from io import StringIO
import sys

View File

@ -1,6 +1,7 @@
aiohttp
fastapi
ipython
openai
pydantic
requests
typer
uvicorn