From 241acc24880b2a86494300a67becae53561e53ac Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 7 Oct 2024 02:22:52 +0100 Subject: [PATCH] `agent`: disable brave_search when BRAVE_SEARCH_API_KEY unset --- examples/agent/run.py | 17 ++++++++++----- examples/agent/serve_tools.py | 5 +---- examples/agent/serve_tools_inside_docker.sh | 23 +++++++++++++++------ examples/agent/tools/search.py | 3 +++ 4 files changed, 33 insertions(+), 15 deletions(-) diff --git a/examples/agent/run.py b/examples/agent/run.py index 287262035..bc2322bc4 100644 --- a/examples/agent/run.py +++ b/examples/agent/run.py @@ -80,7 +80,7 @@ class OpenAPIMethod: if self.body: body = kwargs.pop(self.body['name'], None) if self.body['required']: - assert body is not None, f'Missing required body parameter: {self.body['name']}' + assert body is not None, f'Missing required body parameter: {self.body["name"]}' else: body = None @@ -174,6 +174,7 @@ async def main( model: str = 'gpt-4o', tools: Optional[list[str]] = None, max_iterations: Optional[int] = 10, + system: Optional[str] = None, verbose: bool = False, cache_prompt: bool = True, seed: Optional[int] = None, @@ -192,12 +193,18 @@ async def main( sys.stdout.write(f'🛠️ Tools: {", ".join(tool_map.keys()) if tool_map else ""}\n') - messages = [ + messages = [] + if system: + messages.append(dict( + role='system', + content=system, + )) + messages.append( dict( role='user', content=goal, ) - ] + ) headers = { 'Content-Type': 'application/json', @@ -221,10 +228,10 @@ async def main( print(f'Calling {url} with {json.dumps(payload, indent=2)}', file=sys.stderr) async with aiohttp.ClientSession(headers=headers) as session: async with session.post(url, json=payload) as response: - if verbose: - print(f'Response: {response}', file=sys.stderr) response.raise_for_status() response = await response.json() + if verbose: + print(f'Response: {json.dumps(response, indent=2)}', file=sys.stderr) assert len(response['choices']) == 1 choice = response['choices'][0] diff --git a/examples/agent/serve_tools.py b/examples/agent/serve_tools.py index 64f15a580..197944073 100644 --- a/examples/agent/serve_tools.py +++ b/examples/agent/serve_tools.py @@ -63,15 +63,12 @@ def main(host: str = '0.0.0.0', port: int = 8000, verbose: bool = False, include return True app = fastapi.FastAPI() - for name, fn in python_tools.items(): + for name, fn in ALL_TOOLS.items(): if accept_tool(name): app.post(f'/{name}')(fn) if name != 'python': python_tools[name] = fn - for name, fn in ALL_TOOLS.items(): - app.post(f'/{name}')(fn) - uvicorn.run(app, host=host, port=port) diff --git a/examples/agent/serve_tools_inside_docker.sh b/examples/agent/serve_tools_inside_docker.sh index aad700f6c..898241c79 100755 --- a/examples/agent/serve_tools_inside_docker.sh +++ b/examples/agent/serve_tools_inside_docker.sh @@ -7,16 +7,27 @@ # set -euo pipefail -if [[ -z "${BRAVE_SEARCH_API_KEY:-}" ]]; then - echo "Please set BRAVE_SEARCH_API_KEY environment variable in order to enable the brave_search tool" >&2 -fi - PORT=${PORT:-8088} BRAVE_SEARCH_API_KEY=${BRAVE_SEARCH_API_KEY:-} -docker run -p $PORT:$PORT \ +excludes=() +if [[ -z "${BRAVE_SEARCH_API_KEY:-}" ]]; then + echo "#" >&2 + echo "# Please set BRAVE_SEARCH_API_KEY environment variable in order to enable the brave_search tool" >&2 + echo "#" >&2 + excludes+=( "brave_search" ) +fi + +args=( --port $PORT "$@" ) +if [[ "${#excludes[@]}" -gt 0 ]]; then + args+=( --exclude="$(IFS=\|; echo "${excludes[*]}")" ) +fi + +echo "# Running inside docker: serve_tools.py ${args[*]}" +docker run \ + -p $PORT:$PORT \ -w /src \ -v $PWD/examples/agent:/src \ --env BRAVE_SEARCH_API_KEY=$BRAVE_SEARCH_API_KEY \ --rm -it ghcr.io/astral-sh/uv:python3.12-alpine \ - uv run serve_tools.py --port $PORT "$@" + uv run serve_tools.py "${args[@]}" diff --git a/examples/agent/tools/search.py b/examples/agent/tools/search.py index 5bcddc438..63c92d8a1 100644 --- a/examples/agent/tools/search.py +++ b/examples/agent/tools/search.py @@ -1,3 +1,4 @@ +import sys from pydantic import Field import aiohttp import itertools @@ -67,6 +68,8 @@ async def brave_search(*, query: str) -> List[Dict]: async with aiohttp.ClientSession() as session: async with session.get(url, headers=headers) as res: + if not res.ok: + raise Exception(await res.text()) res.raise_for_status() response = await res.json()