mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 12:10:18 +00:00
tool-call
: make agent async
This commit is contained in:
parent
05bbba9f8a
commit
ef2a020276
@ -1,29 +1,30 @@
|
|||||||
# /// script
|
# /// script
|
||||||
# requires-python = ">=3.11"
|
# requires-python = ">=3.11"
|
||||||
# dependencies = [
|
# dependencies = [
|
||||||
|
# "aiohttp",
|
||||||
# "fastapi",
|
# "fastapi",
|
||||||
# "openai",
|
# "openai",
|
||||||
# "pydantic",
|
# "pydantic",
|
||||||
# "requests",
|
|
||||||
# "uvicorn",
|
|
||||||
# "typer",
|
# "typer",
|
||||||
|
# "uvicorn",
|
||||||
# ]
|
# ]
|
||||||
# ///
|
# ///
|
||||||
import json
|
import json
|
||||||
import openai
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
from functools import wraps
|
||||||
|
from openai import AsyncOpenAI
|
||||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam, ChatCompletionUserMessageParam
|
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam, ChatCompletionUserMessageParam
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import requests
|
|
||||||
import sys
|
import sys
|
||||||
import typer
|
import typer
|
||||||
from typing import Annotated, Optional
|
from typing import Annotated, Optional
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
|
||||||
|
|
||||||
class OpenAPIMethod:
|
class OpenAPIMethod:
|
||||||
def __init__(self, url, name, descriptor, catalog):
|
def __init__(self, url, name, descriptor, catalog):
|
||||||
'''
|
'''
|
||||||
Wraps a remote OpenAPI method as a Python function.
|
Wraps a remote OpenAPI method as an async Python function.
|
||||||
'''
|
'''
|
||||||
self.url = url
|
self.url = url
|
||||||
self.__name__ = name
|
self.__name__ = name
|
||||||
@ -69,7 +70,7 @@ class OpenAPIMethod:
|
|||||||
required=[name for name, param in self.parameters.items() if param['required']] + ([self.body['name']] if self.body and self.body['required'] else [])
|
required=[name for name, param in self.parameters.items() if param['required']] + ([self.body['name']] if self.body and self.body['required'] else [])
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, **kwargs):
|
async def __call__(self, session: aiohttp.ClientSession, **kwargs):
|
||||||
if self.body:
|
if self.body:
|
||||||
body = kwargs.pop(self.body['name'], None)
|
body = kwargs.pop(self.body['name'], None)
|
||||||
if self.body['required']:
|
if self.body['required']:
|
||||||
@ -86,38 +87,26 @@ class OpenAPIMethod:
|
|||||||
assert param['in'] == 'query', 'Only query parameters are supported'
|
assert param['in'] == 'query', 'Only query parameters are supported'
|
||||||
query_params[name] = value
|
query_params[name] = value
|
||||||
|
|
||||||
params = "&".join(f"{name}={urllib.parse.quote(value)}" for name, value in query_params.items())
|
params = "&".join(f"{name}={urllib.parse.quote(str(value))}" for name, value in query_params.items() if value is not None)
|
||||||
url = f'{self.url}?{params}'
|
url = f'{self.url}?{params}'
|
||||||
response = requests.post(url, json=body)
|
async with session.post(url, json=body) as response:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
response_json = response.json()
|
response_json = await response.json()
|
||||||
|
|
||||||
return response_json
|
return response_json
|
||||||
|
|
||||||
|
async def discover_tools(tool_endpoints: list[str], verbose: bool = False) -> tuple[dict, list]:
|
||||||
def main(
|
|
||||||
goal: Annotated[str, typer.Option()],
|
|
||||||
api_key: str = '<unset>',
|
|
||||||
tool_endpoint: Optional[list[str]] = None,
|
|
||||||
max_iterations: Optional[int] = 10,
|
|
||||||
verbose: bool = False,
|
|
||||||
endpoint: str = "http://localhost:8080/v1/",
|
|
||||||
):
|
|
||||||
|
|
||||||
openai.api_key = api_key
|
|
||||||
openai.base_url = endpoint
|
|
||||||
|
|
||||||
tool_map = {}
|
tool_map = {}
|
||||||
tools = []
|
tools = []
|
||||||
|
|
||||||
# Discover tools using OpenAPI catalogs at the provided endpoints.
|
async with aiohttp.ClientSession() as session:
|
||||||
for url in (tool_endpoint or []):
|
for url in tool_endpoints:
|
||||||
assert url.startswith('http://') or url.startswith('https://'), f'Tools must be URLs, not local files: {url}'
|
assert url.startswith('http://') or url.startswith('https://'), f'Tools must be URLs, not local files: {url}'
|
||||||
|
|
||||||
catalog_url = f'{url}/openapi.json'
|
catalog_url = f'{url}/openapi.json'
|
||||||
catalog_response = requests.get(catalog_url)
|
async with session.get(catalog_url) as response:
|
||||||
catalog_response.raise_for_status()
|
response.raise_for_status()
|
||||||
catalog = catalog_response.json()
|
catalog = await response.json()
|
||||||
|
|
||||||
for path, descriptor in catalog['paths'].items():
|
for path, descriptor in catalog['paths'].items():
|
||||||
fn = OpenAPIMethod(url=f'{url}{path}', name=path.replace('/', ' ').strip().replace(' ', '_'), descriptor=descriptor, catalog=catalog)
|
fn = OpenAPIMethod(url=f'{url}{path}', name=path.replace('/', ' ').strip().replace(' ', '_'), descriptor=descriptor, catalog=catalog)
|
||||||
@ -134,6 +123,30 @@ def main(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return tool_map, tools
|
||||||
|
|
||||||
|
def typer_async_workaround():
|
||||||
|
'Adapted from https://github.com/fastapi/typer/issues/950#issuecomment-2351076467'
|
||||||
|
def decorator(f):
|
||||||
|
@wraps(f)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
return asyncio.run(f(*args, **kwargs))
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
@typer_async_workaround()
|
||||||
|
async def main(
|
||||||
|
goal: Annotated[str, typer.Option()],
|
||||||
|
api_key: str = '<unset>',
|
||||||
|
tool_endpoint: Optional[list[str]] = None,
|
||||||
|
max_iterations: Optional[int] = 10,
|
||||||
|
verbose: bool = False,
|
||||||
|
endpoint: str = "http://localhost:8080/v1/",
|
||||||
|
):
|
||||||
|
client = AsyncOpenAI(api_key=api_key, base_url=endpoint)
|
||||||
|
|
||||||
|
tool_map, tools = await discover_tools(tool_endpoint or [], verbose)
|
||||||
|
|
||||||
sys.stdout.write(f'🛠️ {", ".join(tool_map.keys())}\n')
|
sys.stdout.write(f'🛠️ {", ".join(tool_map.keys())}\n')
|
||||||
|
|
||||||
messages: list[ChatCompletionMessageParam] = [
|
messages: list[ChatCompletionMessageParam] = [
|
||||||
@ -143,10 +156,9 @@ def main(
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
i = 0
|
async with aiohttp.ClientSession() as session:
|
||||||
while (max_iterations is None or i < max_iterations):
|
for i in range(max_iterations or sys.maxsize):
|
||||||
|
response = await client.chat.completions.create(
|
||||||
response = openai.chat.completions.create(
|
|
||||||
model="gpt-4o",
|
model="gpt-4o",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
@ -170,22 +182,18 @@ def main(
|
|||||||
pretty_call = f'{tool_call.function.name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
|
pretty_call = f'{tool_call.function.name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
|
||||||
sys.stdout.write(f'⚙️ {pretty_call}')
|
sys.stdout.write(f'⚙️ {pretty_call}')
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
tool_result = tool_map[tool_call.function.name](**args)
|
tool_result = await tool_map[tool_call.function.name](session, **args)
|
||||||
sys.stdout.write(f" → {tool_result}\n")
|
sys.stdout.write(f" → {tool_result}\n")
|
||||||
messages.append(ChatCompletionToolMessageParam(
|
messages.append(ChatCompletionToolMessageParam(
|
||||||
tool_call_id=tool_call.id,
|
tool_call_id=tool_call.id,
|
||||||
role="tool",
|
role="tool",
|
||||||
# name=tool_call.function.name,
|
|
||||||
content=json.dumps(tool_result),
|
content=json.dumps(tool_result),
|
||||||
# content=f'{pretty_call} = {tool_result}',
|
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
assert content
|
assert content
|
||||||
print(content)
|
print(content)
|
||||||
return
|
return
|
||||||
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
if max_iterations is not None:
|
if max_iterations is not None:
|
||||||
raise Exception(f"Failed to get a valid response after {max_iterations} tool calls")
|
raise Exception(f"Failed to get a valid response after {max_iterations} tool calls")
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ def python(code: str) -> str:
|
|||||||
Returns:
|
Returns:
|
||||||
str: The output of the executed code.
|
str: The output of the executed code.
|
||||||
"""
|
"""
|
||||||
from IPython import InteractiveShell
|
from IPython.core.interactiveshell import InteractiveShell
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
|
aiohttp
|
||||||
fastapi
|
fastapi
|
||||||
|
ipython
|
||||||
openai
|
openai
|
||||||
pydantic
|
pydantic
|
||||||
requests
|
|
||||||
typer
|
typer
|
||||||
uvicorn
|
uvicorn
|
||||||
|
Loading…
Reference in New Issue
Block a user