agent: ditch openai dependency, use cache_prompt and expose seed

This commit is contained in:
Olivier Chafik 2024-10-02 16:26:45 +01:00
parent b559d64ecc
commit 2428b73853
2 changed files with 31 additions and 26 deletions

View File

@ -1,7 +1,6 @@
aiohttp
fastapi
ipython
openai
pydantic
typer
uvicorn

View File

@ -3,7 +3,6 @@
# dependencies = [
# "aiohttp",
# "fastapi",
# "openai",
# "pydantic",
# "typer",
# "uvicorn",
@ -13,8 +12,6 @@ import json
import asyncio
import aiohttp
from functools import wraps
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam, ChatCompletionUserMessageParam
from pydantic import BaseModel
import sys
import typer
@ -141,51 +138,60 @@ async def main(
tools: Optional[list[str]] = None,
max_iterations: Optional[int] = 10,
verbose: bool = False,
cache_prompt: bool = True,
seed: Optional[int] = None,
endpoint: str = "http://localhost:8080/v1/",
):
client = AsyncOpenAI(api_key=api_key, base_url=endpoint)
tool_map, tools = await discover_tools(tools or [], verbose)
sys.stdout.write(f'🛠️ {", ".join(tool_map.keys())}\n')
messages: list[ChatCompletionMessageParam] = [
ChatCompletionUserMessageParam(
messages = [
dict(
role="user",
content=goal,
)
]
async with aiohttp.ClientSession() as session:
headers = {
'Authorization': f'Bearer {api_key}'
}
async with aiohttp.ClientSession(headers=headers) as session:
for i in range(max_iterations or sys.maxsize):
response = await client.chat.completions.create(
model="gpt-4o",
url = f'{endpoint}chat/completions'
payload = dict(
messages=messages,
model="gpt-4o",
tools=tools,
seed=seed,
cache_prompt=cache_prompt,
)
async with session.post(url, json=payload) as response:
if verbose:
sys.stderr.write(f'# RESPONSE: {response}\n')
response.raise_for_status()
response = await response.json()
if verbose:
sys.stderr.write(f'# RESPONSE: {response}\n')
assert len(response["choices"]) == 1
choice = response["choices"][0]
assert len(response.choices) == 1
choice = response.choices[0]
content = choice.message.content
if choice.finish_reason == "tool_calls":
messages.append(choice.message) # type: ignore
assert choice.message.tool_calls
for tool_call in choice.message.tool_calls:
content = choice['message']['content']
if choice['finish_reason'] == "tool_calls":
messages.append(choice['message'])
assert choice['message']['tool_calls']
for tool_call in choice['message']['tool_calls']:
if content:
print(f'💭 {content}')
args = json.loads(tool_call.function.arguments)
pretty_call = f'{tool_call.function.name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
name = tool_call['function']['name']
args = json.loads(tool_call['function']['arguments'])
pretty_call = f'{name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
sys.stdout.write(f'⚙️ {pretty_call}')
sys.stdout.flush()
tool_result = await tool_map[tool_call.function.name](session, **args)
tool_result = await tool_map[name](session, **args)
sys.stdout.write(f"{tool_result}\n")
messages.append(ChatCompletionToolMessageParam(
tool_call_id=tool_call.id,
messages.append(dict(
tool_call_id=tool_call.get('id'),
role="tool",
content=json.dumps(tool_result),
))