agent: ditch openai dependency, use cache_prompt and expose seed

This commit is contained in:
Olivier Chafik 2024-10-02 16:26:45 +01:00
parent b559d64ecc
commit 2428b73853
2 changed files with 31 additions and 26 deletions

View File

@ -1,7 +1,6 @@
aiohttp aiohttp
fastapi fastapi
ipython ipython
openai
pydantic pydantic
typer typer
uvicorn uvicorn

View File

@ -3,7 +3,6 @@
# dependencies = [ # dependencies = [
# "aiohttp", # "aiohttp",
# "fastapi", # "fastapi",
# "openai",
# "pydantic", # "pydantic",
# "typer", # "typer",
# "uvicorn", # "uvicorn",
@ -13,8 +12,6 @@ import json
import asyncio import asyncio
import aiohttp import aiohttp
from functools import wraps from functools import wraps
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam, ChatCompletionUserMessageParam
from pydantic import BaseModel from pydantic import BaseModel
import sys import sys
import typer import typer
@ -141,51 +138,60 @@ async def main(
tools: Optional[list[str]] = None, tools: Optional[list[str]] = None,
max_iterations: Optional[int] = 10, max_iterations: Optional[int] = 10,
verbose: bool = False, verbose: bool = False,
cache_prompt: bool = True,
seed: Optional[int] = None,
endpoint: str = "http://localhost:8080/v1/", endpoint: str = "http://localhost:8080/v1/",
): ):
client = AsyncOpenAI(api_key=api_key, base_url=endpoint)
tool_map, tools = await discover_tools(tools or [], verbose) tool_map, tools = await discover_tools(tools or [], verbose)
sys.stdout.write(f'🛠️ {", ".join(tool_map.keys())}\n') sys.stdout.write(f'🛠️ {", ".join(tool_map.keys())}\n')
messages: list[ChatCompletionMessageParam] = [ messages = [
ChatCompletionUserMessageParam( dict(
role="user", role="user",
content=goal, content=goal,
) )
] ]
async with aiohttp.ClientSession() as session: headers = {
'Authorization': f'Bearer {api_key}'
}
async with aiohttp.ClientSession(headers=headers) as session:
for i in range(max_iterations or sys.maxsize): for i in range(max_iterations or sys.maxsize):
response = await client.chat.completions.create( url = f'{endpoint}chat/completions'
model="gpt-4o", payload = dict(
messages=messages, messages=messages,
model="gpt-4o",
tools=tools, tools=tools,
seed=seed,
cache_prompt=cache_prompt,
) )
async with session.post(url, json=payload) as response:
if verbose:
sys.stderr.write(f'# RESPONSE: {response}\n')
response.raise_for_status()
response = await response.json()
if verbose: assert len(response["choices"]) == 1
sys.stderr.write(f'# RESPONSE: {response}\n') choice = response["choices"][0]
assert len(response.choices) == 1 content = choice['message']['content']
choice = response.choices[0] if choice['finish_reason'] == "tool_calls":
messages.append(choice['message'])
content = choice.message.content assert choice['message']['tool_calls']
if choice.finish_reason == "tool_calls": for tool_call in choice['message']['tool_calls']:
messages.append(choice.message) # type: ignore
assert choice.message.tool_calls
for tool_call in choice.message.tool_calls:
if content: if content:
print(f'💭 {content}') print(f'💭 {content}')
args = json.loads(tool_call.function.arguments) name = tool_call['function']['name']
pretty_call = f'{tool_call.function.name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})' args = json.loads(tool_call['function']['arguments'])
pretty_call = f'{name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
sys.stdout.write(f'⚙️ {pretty_call}') sys.stdout.write(f'⚙️ {pretty_call}')
sys.stdout.flush() sys.stdout.flush()
tool_result = await tool_map[tool_call.function.name](session, **args) tool_result = await tool_map[name](session, **args)
sys.stdout.write(f"{tool_result}\n") sys.stdout.write(f"{tool_result}\n")
messages.append(ChatCompletionToolMessageParam( messages.append(dict(
tool_call_id=tool_call.id, tool_call_id=tool_call.get('id'),
role="tool", role="tool",
content=json.dumps(tool_result), content=json.dumps(tool_result),
)) ))