tool-call: make agent async

This commit is contained in:
ochafik 2024-09-28 19:11:09 +01:00
parent 05bbba9f8a
commit ef2a020276
3 changed files with 92 additions and 83 deletions

View File

@ -1,29 +1,30 @@
# /// script # /// script
# requires-python = ">=3.11" # requires-python = ">=3.11"
# dependencies = [ # dependencies = [
# "aiohttp",
# "fastapi", # "fastapi",
# "openai", # "openai",
# "pydantic", # "pydantic",
# "requests",
# "uvicorn",
# "typer", # "typer",
# "uvicorn",
# ] # ]
# /// # ///
import json import json
import openai import asyncio
import aiohttp
from functools import wraps
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam, ChatCompletionUserMessageParam from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam, ChatCompletionUserMessageParam
from pydantic import BaseModel from pydantic import BaseModel
import requests
import sys import sys
import typer import typer
from typing import Annotated, Optional from typing import Annotated, Optional
import urllib.parse import urllib.parse
class OpenAPIMethod: class OpenAPIMethod:
def __init__(self, url, name, descriptor, catalog): def __init__(self, url, name, descriptor, catalog):
''' '''
Wraps a remote OpenAPI method as a Python function. Wraps a remote OpenAPI method as an async Python function.
''' '''
self.url = url self.url = url
self.__name__ = name self.__name__ = name
@ -69,7 +70,7 @@ class OpenAPIMethod:
required=[name for name, param in self.parameters.items() if param['required']] + ([self.body['name']] if self.body and self.body['required'] else []) required=[name for name, param in self.parameters.items() if param['required']] + ([self.body['name']] if self.body and self.body['required'] else [])
) )
def __call__(self, **kwargs): async def __call__(self, session: aiohttp.ClientSession, **kwargs):
if self.body: if self.body:
body = kwargs.pop(self.body['name'], None) body = kwargs.pop(self.body['name'], None)
if self.body['required']: if self.body['required']:
@ -86,38 +87,26 @@ class OpenAPIMethod:
assert param['in'] == 'query', 'Only query parameters are supported' assert param['in'] == 'query', 'Only query parameters are supported'
query_params[name] = value query_params[name] = value
params = "&".join(f"{name}={urllib.parse.quote(value)}" for name, value in query_params.items()) params = "&".join(f"{name}={urllib.parse.quote(str(value))}" for name, value in query_params.items() if value is not None)
url = f'{self.url}?{params}' url = f'{self.url}?{params}'
response = requests.post(url, json=body) async with session.post(url, json=body) as response:
response.raise_for_status() response.raise_for_status()
response_json = response.json() response_json = await response.json()
return response_json return response_json
async def discover_tools(tool_endpoints: list[str], verbose: bool = False) -> tuple[dict, list]:
def main(
goal: Annotated[str, typer.Option()],
api_key: str = '<unset>',
tool_endpoint: Optional[list[str]] = None,
max_iterations: Optional[int] = 10,
verbose: bool = False,
endpoint: str = "http://localhost:8080/v1/",
):
openai.api_key = api_key
openai.base_url = endpoint
tool_map = {} tool_map = {}
tools = [] tools = []
# Discover tools using OpenAPI catalogs at the provided endpoints. async with aiohttp.ClientSession() as session:
for url in (tool_endpoint or []): for url in tool_endpoints:
assert url.startswith('http://') or url.startswith('https://'), f'Tools must be URLs, not local files: {url}' assert url.startswith('http://') or url.startswith('https://'), f'Tools must be URLs, not local files: {url}'
catalog_url = f'{url}/openapi.json' catalog_url = f'{url}/openapi.json'
catalog_response = requests.get(catalog_url) async with session.get(catalog_url) as response:
catalog_response.raise_for_status() response.raise_for_status()
catalog = catalog_response.json() catalog = await response.json()
for path, descriptor in catalog['paths'].items(): for path, descriptor in catalog['paths'].items():
fn = OpenAPIMethod(url=f'{url}{path}', name=path.replace('/', ' ').strip().replace(' ', '_'), descriptor=descriptor, catalog=catalog) fn = OpenAPIMethod(url=f'{url}{path}', name=path.replace('/', ' ').strip().replace(' ', '_'), descriptor=descriptor, catalog=catalog)
@ -134,6 +123,30 @@ def main(
) )
) )
return tool_map, tools
def typer_async_workaround():
'Adapted from https://github.com/fastapi/typer/issues/950#issuecomment-2351076467'
def decorator(f):
@wraps(f)
def wrapper(*args, **kwargs):
return asyncio.run(f(*args, **kwargs))
return wrapper
return decorator
@typer_async_workaround()
async def main(
goal: Annotated[str, typer.Option()],
api_key: str = '<unset>',
tool_endpoint: Optional[list[str]] = None,
max_iterations: Optional[int] = 10,
verbose: bool = False,
endpoint: str = "http://localhost:8080/v1/",
):
client = AsyncOpenAI(api_key=api_key, base_url=endpoint)
tool_map, tools = await discover_tools(tool_endpoint or [], verbose)
sys.stdout.write(f'🛠️ {", ".join(tool_map.keys())}\n') sys.stdout.write(f'🛠️ {", ".join(tool_map.keys())}\n')
messages: list[ChatCompletionMessageParam] = [ messages: list[ChatCompletionMessageParam] = [
@ -143,10 +156,9 @@ def main(
) )
] ]
i = 0 async with aiohttp.ClientSession() as session:
while (max_iterations is None or i < max_iterations): for i in range(max_iterations or sys.maxsize):
response = await client.chat.completions.create(
response = openai.chat.completions.create(
model="gpt-4o", model="gpt-4o",
messages=messages, messages=messages,
tools=tools, tools=tools,
@ -170,22 +182,18 @@ def main(
pretty_call = f'{tool_call.function.name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})' pretty_call = f'{tool_call.function.name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
sys.stdout.write(f'⚙️ {pretty_call}') sys.stdout.write(f'⚙️ {pretty_call}')
sys.stdout.flush() sys.stdout.flush()
tool_result = tool_map[tool_call.function.name](**args) tool_result = await tool_map[tool_call.function.name](session, **args)
sys.stdout.write(f"{tool_result}\n") sys.stdout.write(f"{tool_result}\n")
messages.append(ChatCompletionToolMessageParam( messages.append(ChatCompletionToolMessageParam(
tool_call_id=tool_call.id, tool_call_id=tool_call.id,
role="tool", role="tool",
# name=tool_call.function.name,
content=json.dumps(tool_result), content=json.dumps(tool_result),
# content=f'{pretty_call} = {tool_result}',
)) ))
else: else:
assert content assert content
print(content) print(content)
return return
i += 1
if max_iterations is not None: if max_iterations is not None:
raise Exception(f"Failed to get a valid response after {max_iterations} tool calls") raise Exception(f"Failed to get a valid response after {max_iterations} tool calls")

View File

@ -89,7 +89,7 @@ def python(code: str) -> str:
Returns: Returns:
str: The output of the executed code. str: The output of the executed code.
""" """
from IPython import InteractiveShell from IPython.core.interactiveshell import InteractiveShell
from io import StringIO from io import StringIO
import sys import sys

View File

@ -1,6 +1,7 @@
aiohttp
fastapi fastapi
ipython
openai openai
pydantic pydantic
requests
typer typer
uvicorn uvicorn