mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 04:00:16 +00:00
tool-call
: make agent async
This commit is contained in:
parent
05bbba9f8a
commit
ef2a020276
@ -1,29 +1,30 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.11"
|
||||
# dependencies = [
|
||||
# "aiohttp",
|
||||
# "fastapi",
|
||||
# "openai",
|
||||
# "pydantic",
|
||||
# "requests",
|
||||
# "uvicorn",
|
||||
# "typer",
|
||||
# "uvicorn",
|
||||
# ]
|
||||
# ///
|
||||
import json
|
||||
import openai
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from functools import wraps
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam, ChatCompletionUserMessageParam
|
||||
from pydantic import BaseModel
|
||||
import requests
|
||||
import sys
|
||||
import typer
|
||||
from typing import Annotated, Optional
|
||||
import urllib.parse
|
||||
|
||||
|
||||
class OpenAPIMethod:
|
||||
def __init__(self, url, name, descriptor, catalog):
|
||||
'''
|
||||
Wraps a remote OpenAPI method as a Python function.
|
||||
Wraps a remote OpenAPI method as an async Python function.
|
||||
'''
|
||||
self.url = url
|
||||
self.__name__ = name
|
||||
@ -69,7 +70,7 @@ class OpenAPIMethod:
|
||||
required=[name for name, param in self.parameters.items() if param['required']] + ([self.body['name']] if self.body and self.body['required'] else [])
|
||||
)
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
async def __call__(self, session: aiohttp.ClientSession, **kwargs):
|
||||
if self.body:
|
||||
body = kwargs.pop(self.body['name'], None)
|
||||
if self.body['required']:
|
||||
@ -86,38 +87,26 @@ class OpenAPIMethod:
|
||||
assert param['in'] == 'query', 'Only query parameters are supported'
|
||||
query_params[name] = value
|
||||
|
||||
params = "&".join(f"{name}={urllib.parse.quote(value)}" for name, value in query_params.items())
|
||||
params = "&".join(f"{name}={urllib.parse.quote(str(value))}" for name, value in query_params.items() if value is not None)
|
||||
url = f'{self.url}?{params}'
|
||||
response = requests.post(url, json=body)
|
||||
async with session.post(url, json=body) as response:
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
response_json = await response.json()
|
||||
|
||||
return response_json
|
||||
|
||||
|
||||
def main(
|
||||
goal: Annotated[str, typer.Option()],
|
||||
api_key: str = '<unset>',
|
||||
tool_endpoint: Optional[list[str]] = None,
|
||||
max_iterations: Optional[int] = 10,
|
||||
verbose: bool = False,
|
||||
endpoint: str = "http://localhost:8080/v1/",
|
||||
):
|
||||
|
||||
openai.api_key = api_key
|
||||
openai.base_url = endpoint
|
||||
|
||||
async def discover_tools(tool_endpoints: list[str], verbose: bool = False) -> tuple[dict, list]:
|
||||
tool_map = {}
|
||||
tools = []
|
||||
|
||||
# Discover tools using OpenAPI catalogs at the provided endpoints.
|
||||
for url in (tool_endpoint or []):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for url in tool_endpoints:
|
||||
assert url.startswith('http://') or url.startswith('https://'), f'Tools must be URLs, not local files: {url}'
|
||||
|
||||
catalog_url = f'{url}/openapi.json'
|
||||
catalog_response = requests.get(catalog_url)
|
||||
catalog_response.raise_for_status()
|
||||
catalog = catalog_response.json()
|
||||
async with session.get(catalog_url) as response:
|
||||
response.raise_for_status()
|
||||
catalog = await response.json()
|
||||
|
||||
for path, descriptor in catalog['paths'].items():
|
||||
fn = OpenAPIMethod(url=f'{url}{path}', name=path.replace('/', ' ').strip().replace(' ', '_'), descriptor=descriptor, catalog=catalog)
|
||||
@ -134,6 +123,30 @@ def main(
|
||||
)
|
||||
)
|
||||
|
||||
return tool_map, tools
|
||||
|
||||
def typer_async_workaround():
|
||||
'Adapted from https://github.com/fastapi/typer/issues/950#issuecomment-2351076467'
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
return asyncio.run(f(*args, **kwargs))
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
@typer_async_workaround()
|
||||
async def main(
|
||||
goal: Annotated[str, typer.Option()],
|
||||
api_key: str = '<unset>',
|
||||
tool_endpoint: Optional[list[str]] = None,
|
||||
max_iterations: Optional[int] = 10,
|
||||
verbose: bool = False,
|
||||
endpoint: str = "http://localhost:8080/v1/",
|
||||
):
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=endpoint)
|
||||
|
||||
tool_map, tools = await discover_tools(tool_endpoint or [], verbose)
|
||||
|
||||
sys.stdout.write(f'🛠️ {", ".join(tool_map.keys())}\n')
|
||||
|
||||
messages: list[ChatCompletionMessageParam] = [
|
||||
@ -143,10 +156,9 @@ def main(
|
||||
)
|
||||
]
|
||||
|
||||
i = 0
|
||||
while (max_iterations is None or i < max_iterations):
|
||||
|
||||
response = openai.chat.completions.create(
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for i in range(max_iterations or sys.maxsize):
|
||||
response = await client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
@ -170,22 +182,18 @@ def main(
|
||||
pretty_call = f'{tool_call.function.name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
|
||||
sys.stdout.write(f'⚙️ {pretty_call}')
|
||||
sys.stdout.flush()
|
||||
tool_result = tool_map[tool_call.function.name](**args)
|
||||
tool_result = await tool_map[tool_call.function.name](session, **args)
|
||||
sys.stdout.write(f" → {tool_result}\n")
|
||||
messages.append(ChatCompletionToolMessageParam(
|
||||
tool_call_id=tool_call.id,
|
||||
role="tool",
|
||||
# name=tool_call.function.name,
|
||||
content=json.dumps(tool_result),
|
||||
# content=f'{pretty_call} = {tool_result}',
|
||||
))
|
||||
else:
|
||||
assert content
|
||||
print(content)
|
||||
return
|
||||
|
||||
i += 1
|
||||
|
||||
if max_iterations is not None:
|
||||
raise Exception(f"Failed to get a valid response after {max_iterations} tool calls")
|
||||
|
||||
|
@ -89,7 +89,7 @@ def python(code: str) -> str:
|
||||
Returns:
|
||||
str: The output of the executed code.
|
||||
"""
|
||||
from IPython import InteractiveShell
|
||||
from IPython.core.interactiveshell import InteractiveShell
|
||||
from io import StringIO
|
||||
import sys
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
aiohttp
|
||||
fastapi
|
||||
ipython
|
||||
openai
|
||||
pydantic
|
||||
requests
|
||||
typer
|
||||
uvicorn
|
||||
|
Loading…
Reference in New Issue
Block a user