agent: support more providers (+ extract serve_tools_inside_docker.sh)

update readme
This commit is contained in:
Olivier Chafik 2024-10-03 19:18:47 +01:00
parent b4fc1e8ba7
commit da02397f7f
3 changed files with 64 additions and 25 deletions

View File

@ -39,6 +39,7 @@
- Run the tools in [examples/agent/tools](./examples/agent/tools) inside a docker container (check http://localhost:8088/docs once running):
```bash
# Shorthand: ./examples/agent/serve_tools_inside_docker.sh
docker run -p 8088:8088 -w /src -v $PWD/examples/agent:/src \
--env BRAVE_SEARCH_API_KEY=$BRAVE_SEARCH_API_KEY \
--rm -it ghcr.io/astral-sh/uv:python3.12-alpine \
@ -99,13 +100,15 @@
</details>
- To compare the above results w/ OpenAI's tool usage behaviour, just add `--openai` to the agent invocation (other providers can easily be added, just use the `--endpoint`, `--api-key`, and `--model` flags)
- To compare the above results w/ a cloud provider's tool usage behaviour, just set the `--provider` flag (accepts `openai`, `together`, `groq`) and/or use `--endpoint`, `--api-key`, and `--model`
```bash
export OPENAI_API_KEY=...
export OPENAI_API_KEY=... # for --provider=openai
# export TOGETHER_API_KEY=... # for --provider=together
# export GROQ_API_KEY=... # for --provider=groq
uv run examples/agent/run.py --tools http://localhost:8088 \
"Search for, fetch and summarize the homepage of llama.cpp" \
--openai
--provider=openai
```
## TODO

View File

@ -12,12 +12,11 @@ import aiohttp
import asyncio
from functools import wraps
import json
import logging
import os
from pydantic import BaseModel
import sys
import typer
from typing import Optional
from typing import Annotated, Literal, Optional
import urllib.parse
class OpenAPIMethod:
@ -103,7 +102,7 @@ class OpenAPIMethod:
return response_json
async def discover_tools(tool_endpoints: list[str], logger) -> tuple[dict, list]:
async def discover_tools(tool_endpoints: list[str], verbose) -> tuple[dict, list]:
tool_map = {}
tools = []
@ -119,7 +118,8 @@ async def discover_tools(tool_endpoints: list[str], logger) -> tuple[dict, list]
for path, descriptor in catalog['paths'].items():
fn = OpenAPIMethod(url=f'{url}{path}', name=path.replace('/', ' ').strip().replace(' ', '_'), descriptor=descriptor, catalog=catalog)
tool_map[fn.__name__] = fn
logger.debug('Function %s: params schema: %s', fn.__name__, fn.parameters_schema)
if verbose:
print(f'Function {fn.__name__}: params schema: {fn.parameters_schema}', file=sys.stderr)
tools.append(dict(
type='function',
function=dict(
@ -142,6 +142,30 @@ def typer_async_workaround():
return wrapper
return decorator
_PROVIDERS = {
'llama.cpp': {
'endpoint': 'http://localhost:8080/v1/',
'api_key_env': 'LLAMACPP_API_KEY',
},
'openai': {
'endpoint': 'https://api.openai.com/v1/',
'default_model': 'gpt-4o',
'api_key_env': 'OPENAI_API_KEY',
},
'together': {
'endpoint': 'https://api.together.xyz',
'default_model': 'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo',
'api_key_env': 'TOGETHER_API_KEY',
},
'groq': {
'endpoint': 'https://api.groq.com/openai',
'default_model': 'llama-3.1-70b-versatile',
'api_key_env': 'GROQ_API_KEY',
},
}
@typer_async_workaround()
async def main(
goal: str,
@ -152,23 +176,17 @@ async def main(
cache_prompt: bool = True,
seed: Optional[int] = None,
interactive: bool = True,
openai: bool = False,
provider: Annotated[str, Literal['llama.cpp', 'openai', 'together', 'groq']] = 'llama.cpp',
endpoint: Optional[str] = None,
api_key: Optional[str] = None,
):
logging.basicConfig(level=logging.DEBUG if verbose else logging.INFO, format='%(message)s')
logger = logging.getLogger(__name__)
provider_info = _PROVIDERS[provider]
if endpoint is None:
if openai:
endpoint = 'https://api.openai.com/v1/'
else:
endpoint = 'http://localhost:8080/v1/'
endpoint = provider_info['endpoint']
if api_key is None:
if openai:
api_key = os.environ.get('OPENAI_API_KEY')
api_key = os.environ.get(provider_info['api_key_env'])
tool_map, tools = await discover_tools(tools or [], logger=logger)
tool_map, tools = await discover_tools(tools or [], verbose)
sys.stdout.write(f'🛠️ Tools: {", ".join(tool_map.keys()) if tool_map else "<none>"}\n')
@ -191,16 +209,18 @@ async def main(
model=model,
tools=tools,
)
if not openai:
if provider == 'llama.cpp':
payload.update(dict(
seed=seed,
cache_prompt=cache_prompt,
)) # type: ignore
logger.debug('Calling %s with %s', url, json.dumps(payload, indent=2))
if verbose:
print(f'Calling {url} with {json.dumps(payload, indent=2)}', file=sys.stderr)
async with aiohttp.ClientSession(headers=headers) as session:
async with session.post(url, json=payload) as response:
logger.debug('Response: %s', response)
if verbose:
print(f'Response: {response}', file=sys.stderr)
response.raise_for_status()
response = await response.json()
@ -213,17 +233,22 @@ async def main(
assert choice['message']['tool_calls']
for tool_call in choice['message']['tool_calls']:
if content:
print(f'💭 {content}')
print(f'💭 {content}', file=sys.stderr)
name = tool_call['function']['name']
args = json.loads(tool_call['function']['arguments'])
pretty_call = f'{name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
logger.info(f'⚙️ {pretty_call}')
print(f'⚙️ {pretty_call}', file=sys.stderr, end=None)
sys.stdout.flush()
tool_result = await tool_map[name](**args)
tool_result_str = json.dumps(tool_result)
logger.info('%d chars', len(tool_result_str))
logger.debug('%s', tool_result_str)
def describe(res, res_str):
if isinstance(res, list):
return f'{len(res)} items'
return f'{len(res_str)} chars'
print(f'{describe(tool_result, tool_result_str)}', file=sys.stderr)
if verbose:
print(tool_result_str, file=sys.stderr)
messages.append(dict(
tool_call_id=tool_call.get('id'),
role='tool',

View File

@ -0,0 +1,11 @@
#!/bin/bash
set -euo pipefail
PORT=${PORT:-8088}
docker run -p $PORT:$PORT \
-w /src \
-v $PWD/examples/agent:/src \
--env BRAVE_SEARCH_API_KEY=$BRAVE_SEARCH_API_KEY \
--rm -it ghcr.io/astral-sh/uv:python3.12-alpine \
uv run serve_tools.py --port $PORT