From ef2a0202765e0f466bf937a8d946a661e443699b Mon Sep 17 00:00:00 2001 From: ochafik Date: Sat, 28 Sep 2024 19:11:09 +0100 Subject: [PATCH] `tool-call`: make agent async --- examples/agent/run.py | 170 +++++++++++++++------------- examples/agent/tools.py | 2 +- requirements/requirements-agent.txt | 3 +- 3 files changed, 92 insertions(+), 83 deletions(-) diff --git a/examples/agent/run.py b/examples/agent/run.py index 912e3e9ef..c092a6d45 100644 --- a/examples/agent/run.py +++ b/examples/agent/run.py @@ -1,29 +1,30 @@ # /// script # requires-python = ">=3.11" # dependencies = [ +# "aiohttp", # "fastapi", # "openai", # "pydantic", -# "requests", -# "uvicorn", # "typer", +# "uvicorn", # ] # /// import json -import openai +import asyncio +import aiohttp +from functools import wraps +from openai import AsyncOpenAI from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam, ChatCompletionUserMessageParam from pydantic import BaseModel -import requests import sys import typer from typing import Annotated, Optional import urllib.parse - class OpenAPIMethod: def __init__(self, url, name, descriptor, catalog): ''' - Wraps a remote OpenAPI method as a Python function. + Wraps a remote OpenAPI method as an async Python function. ''' self.url = url self.__name__ = name @@ -69,7 +70,7 @@ class OpenAPIMethod: required=[name for name, param in self.parameters.items() if param['required']] + ([self.body['name']] if self.body and self.body['required'] else []) ) - def __call__(self, **kwargs): + async def __call__(self, session: aiohttp.ClientSession, **kwargs): if self.body: body = kwargs.pop(self.body['name'], None) if self.body['required']: @@ -86,16 +87,55 @@ class OpenAPIMethod: assert param['in'] == 'query', 'Only query parameters are supported' query_params[name] = value - params = "&".join(f"{name}={urllib.parse.quote(value)}" for name, value in query_params.items()) + params = "&".join(f"{name}={urllib.parse.quote(str(value))}" for name, value in query_params.items() if value is not None) url = f'{self.url}?{params}' - response = requests.post(url, json=body) - response.raise_for_status() - response_json = response.json() + async with session.post(url, json=body) as response: + response.raise_for_status() + response_json = await response.json() return response_json +async def discover_tools(tool_endpoints: list[str], verbose: bool = False) -> tuple[dict, list]: + tool_map = {} + tools = [] -def main( + async with aiohttp.ClientSession() as session: + for url in tool_endpoints: + assert url.startswith('http://') or url.startswith('https://'), f'Tools must be URLs, not local files: {url}' + + catalog_url = f'{url}/openapi.json' + async with session.get(catalog_url) as response: + response.raise_for_status() + catalog = await response.json() + + for path, descriptor in catalog['paths'].items(): + fn = OpenAPIMethod(url=f'{url}{path}', name=path.replace('/', ' ').strip().replace(' ', '_'), descriptor=descriptor, catalog=catalog) + tool_map[fn.__name__] = fn + if verbose: + sys.stderr.write(f'# PARAMS SCHEMA ({fn.__name__}): {json.dumps(fn.parameters_schema, indent=2)}\n') + tools.append(dict( + type="function", + function=dict( + name=fn.__name__, + description=fn.__doc__ or '', + parameters=fn.parameters_schema, + ) + ) + ) + + return tool_map, tools + +def typer_async_workaround(): + 'Adapted from https://github.com/fastapi/typer/issues/950#issuecomment-2351076467' + def decorator(f): + @wraps(f) + def wrapper(*args, **kwargs): + return asyncio.run(f(*args, **kwargs)) + return wrapper + return decorator + +@typer_async_workaround() +async def main( goal: Annotated[str, typer.Option()], api_key: str = '', tool_endpoint: Optional[list[str]] = None, @@ -103,36 +143,9 @@ def main( verbose: bool = False, endpoint: str = "http://localhost:8080/v1/", ): + client = AsyncOpenAI(api_key=api_key, base_url=endpoint) - openai.api_key = api_key - openai.base_url = endpoint - - tool_map = {} - tools = [] - - # Discover tools using OpenAPI catalogs at the provided endpoints. - for url in (tool_endpoint or []): - assert url.startswith('http://') or url.startswith('https://'), f'Tools must be URLs, not local files: {url}' - - catalog_url = f'{url}/openapi.json' - catalog_response = requests.get(catalog_url) - catalog_response.raise_for_status() - catalog = catalog_response.json() - - for path, descriptor in catalog['paths'].items(): - fn = OpenAPIMethod(url=f'{url}{path}', name=path.replace('/', ' ').strip().replace(' ', '_'), descriptor=descriptor, catalog=catalog) - tool_map[fn.__name__] = fn - if verbose: - sys.stderr.write(f'# PARAMS SCHEMA ({fn.__name__}): {json.dumps(fn.parameters_schema, indent=2)}\n') - tools.append(dict( - type="function", - function=dict( - name=fn.__name__, - description=fn.__doc__ or '', - parameters=fn.parameters_schema, - ) - ) - ) + tool_map, tools = await discover_tools(tool_endpoint or [], verbose) sys.stdout.write(f'🛠️ {", ".join(tool_map.keys())}\n') @@ -143,51 +156,46 @@ def main( ) ] - i = 0 - while (max_iterations is None or i < max_iterations): + async with aiohttp.ClientSession() as session: + for i in range(max_iterations or sys.maxsize): + response = await client.chat.completions.create( + model="gpt-4o", + messages=messages, + tools=tools, + ) - response = openai.chat.completions.create( - model="gpt-4o", - messages=messages, - tools=tools, - ) + if verbose: + sys.stderr.write(f'# RESPONSE: {response}\n') - if verbose: - sys.stderr.write(f'# RESPONSE: {response}\n') + assert len(response.choices) == 1 + choice = response.choices[0] - assert len(response.choices) == 1 - choice = response.choices[0] + content = choice.message.content + if choice.finish_reason == "tool_calls": + messages.append(choice.message) # type: ignore + assert choice.message.tool_calls + for tool_call in choice.message.tool_calls: + if content: + print(f'💭 {content}') - content = choice.message.content - if choice.finish_reason == "tool_calls": - messages.append(choice.message) # type: ignore - assert choice.message.tool_calls - for tool_call in choice.message.tool_calls: - if content: - print(f'💭 {content}') + args = json.loads(tool_call.function.arguments) + pretty_call = f'{tool_call.function.name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})' + sys.stdout.write(f'⚙️ {pretty_call}') + sys.stdout.flush() + tool_result = await tool_map[tool_call.function.name](session, **args) + sys.stdout.write(f" → {tool_result}\n") + messages.append(ChatCompletionToolMessageParam( + tool_call_id=tool_call.id, + role="tool", + content=json.dumps(tool_result), + )) + else: + assert content + print(content) + return - args = json.loads(tool_call.function.arguments) - pretty_call = f'{tool_call.function.name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})' - sys.stdout.write(f'⚙️ {pretty_call}') - sys.stdout.flush() - tool_result = tool_map[tool_call.function.name](**args) - sys.stdout.write(f" → {tool_result}\n") - messages.append(ChatCompletionToolMessageParam( - tool_call_id=tool_call.id, - role="tool", - # name=tool_call.function.name, - content=json.dumps(tool_result), - # content=f'{pretty_call} = {tool_result}', - )) - else: - assert content - print(content) - return - - i += 1 - - if max_iterations is not None: - raise Exception(f"Failed to get a valid response after {max_iterations} tool calls") + if max_iterations is not None: + raise Exception(f"Failed to get a valid response after {max_iterations} tool calls") if __name__ == '__main__': typer.run(main) diff --git a/examples/agent/tools.py b/examples/agent/tools.py index ff48464cf..b91595778 100644 --- a/examples/agent/tools.py +++ b/examples/agent/tools.py @@ -89,7 +89,7 @@ def python(code: str) -> str: Returns: str: The output of the executed code. """ - from IPython import InteractiveShell + from IPython.core.interactiveshell import InteractiveShell from io import StringIO import sys diff --git a/requirements/requirements-agent.txt b/requirements/requirements-agent.txt index 639f0111f..e9de760fb 100644 --- a/requirements/requirements-agent.txt +++ b/requirements/requirements-agent.txt @@ -1,6 +1,7 @@ +aiohttp fastapi +ipython openai pydantic -requests typer uvicorn