agent: disable brave_search when BRAVE_SEARCH_API_KEY unset

This commit is contained in:
ochafik 2024-10-07 02:22:52 +01:00
parent a151ddcd5a
commit 241acc2488
4 changed files with 33 additions and 15 deletions

View File

@ -80,7 +80,7 @@ class OpenAPIMethod:
if self.body:
body = kwargs.pop(self.body['name'], None)
if self.body['required']:
assert body is not None, f'Missing required body parameter: {self.body['name']}'
assert body is not None, f'Missing required body parameter: {self.body["name"]}'
else:
body = None
@ -174,6 +174,7 @@ async def main(
model: str = 'gpt-4o',
tools: Optional[list[str]] = None,
max_iterations: Optional[int] = 10,
system: Optional[str] = None,
verbose: bool = False,
cache_prompt: bool = True,
seed: Optional[int] = None,
@ -192,12 +193,18 @@ async def main(
sys.stdout.write(f'🛠️ Tools: {", ".join(tool_map.keys()) if tool_map else "<none>"}\n')
messages = [
messages = []
if system:
messages.append(dict(
role='system',
content=system,
))
messages.append(
dict(
role='user',
content=goal,
)
]
)
headers = {
'Content-Type': 'application/json',
@ -221,10 +228,10 @@ async def main(
print(f'Calling {url} with {json.dumps(payload, indent=2)}', file=sys.stderr)
async with aiohttp.ClientSession(headers=headers) as session:
async with session.post(url, json=payload) as response:
if verbose:
print(f'Response: {response}', file=sys.stderr)
response.raise_for_status()
response = await response.json()
if verbose:
print(f'Response: {json.dumps(response, indent=2)}', file=sys.stderr)
assert len(response['choices']) == 1
choice = response['choices'][0]

View File

@ -63,15 +63,12 @@ def main(host: str = '0.0.0.0', port: int = 8000, verbose: bool = False, include
return True
app = fastapi.FastAPI()
for name, fn in python_tools.items():
for name, fn in ALL_TOOLS.items():
if accept_tool(name):
app.post(f'/{name}')(fn)
if name != 'python':
python_tools[name] = fn
for name, fn in ALL_TOOLS.items():
app.post(f'/{name}')(fn)
uvicorn.run(app, host=host, port=port)

View File

@ -7,16 +7,27 @@
#
set -euo pipefail
if [[ -z "${BRAVE_SEARCH_API_KEY:-}" ]]; then
echo "Please set BRAVE_SEARCH_API_KEY environment variable in order to enable the brave_search tool" >&2
fi
PORT=${PORT:-8088}
BRAVE_SEARCH_API_KEY=${BRAVE_SEARCH_API_KEY:-}
docker run -p $PORT:$PORT \
excludes=()
if [[ -z "${BRAVE_SEARCH_API_KEY:-}" ]]; then
echo "#" >&2
echo "# Please set BRAVE_SEARCH_API_KEY environment variable in order to enable the brave_search tool" >&2
echo "#" >&2
excludes+=( "brave_search" )
fi
args=( --port $PORT "$@" )
if [[ "${#excludes[@]}" -gt 0 ]]; then
args+=( --exclude="$(IFS=\|; echo "${excludes[*]}")" )
fi
echo "# Running inside docker: serve_tools.py ${args[*]}"
docker run \
-p $PORT:$PORT \
-w /src \
-v $PWD/examples/agent:/src \
--env BRAVE_SEARCH_API_KEY=$BRAVE_SEARCH_API_KEY \
--rm -it ghcr.io/astral-sh/uv:python3.12-alpine \
uv run serve_tools.py --port $PORT "$@"
uv run serve_tools.py "${args[@]}"

View File

@ -1,3 +1,4 @@
import sys
from pydantic import Field
import aiohttp
import itertools
@ -67,6 +68,8 @@ async def brave_search(*, query: str) -> List[Dict]:
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as res:
if not res.ok:
raise Exception(await res.text())
res.raise_for_status()
response = await res.json()