mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 12:10:18 +00:00
agent
: ditch openai dependency, use cache_prompt and expose seed
This commit is contained in:
parent
b559d64ecc
commit
2428b73853
@ -1,7 +1,6 @@
|
|||||||
aiohttp
|
aiohttp
|
||||||
fastapi
|
fastapi
|
||||||
ipython
|
ipython
|
||||||
openai
|
|
||||||
pydantic
|
pydantic
|
||||||
typer
|
typer
|
||||||
uvicorn
|
uvicorn
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
# dependencies = [
|
# dependencies = [
|
||||||
# "aiohttp",
|
# "aiohttp",
|
||||||
# "fastapi",
|
# "fastapi",
|
||||||
# "openai",
|
|
||||||
# "pydantic",
|
# "pydantic",
|
||||||
# "typer",
|
# "typer",
|
||||||
# "uvicorn",
|
# "uvicorn",
|
||||||
@ -13,8 +12,6 @@ import json
|
|||||||
import asyncio
|
import asyncio
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from openai import AsyncOpenAI
|
|
||||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam, ChatCompletionUserMessageParam
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import sys
|
import sys
|
||||||
import typer
|
import typer
|
||||||
@ -141,51 +138,60 @@ async def main(
|
|||||||
tools: Optional[list[str]] = None,
|
tools: Optional[list[str]] = None,
|
||||||
max_iterations: Optional[int] = 10,
|
max_iterations: Optional[int] = 10,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
|
cache_prompt: bool = True,
|
||||||
|
seed: Optional[int] = None,
|
||||||
endpoint: str = "http://localhost:8080/v1/",
|
endpoint: str = "http://localhost:8080/v1/",
|
||||||
):
|
):
|
||||||
client = AsyncOpenAI(api_key=api_key, base_url=endpoint)
|
|
||||||
|
|
||||||
tool_map, tools = await discover_tools(tools or [], verbose)
|
tool_map, tools = await discover_tools(tools or [], verbose)
|
||||||
|
|
||||||
sys.stdout.write(f'🛠️ {", ".join(tool_map.keys())}\n')
|
sys.stdout.write(f'🛠️ {", ".join(tool_map.keys())}\n')
|
||||||
|
|
||||||
messages: list[ChatCompletionMessageParam] = [
|
messages = [
|
||||||
ChatCompletionUserMessageParam(
|
dict(
|
||||||
role="user",
|
role="user",
|
||||||
content=goal,
|
content=goal,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
headers = {
|
||||||
|
'Authorization': f'Bearer {api_key}'
|
||||||
|
}
|
||||||
|
async with aiohttp.ClientSession(headers=headers) as session:
|
||||||
for i in range(max_iterations or sys.maxsize):
|
for i in range(max_iterations or sys.maxsize):
|
||||||
response = await client.chat.completions.create(
|
url = f'{endpoint}chat/completions'
|
||||||
model="gpt-4o",
|
payload = dict(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
model="gpt-4o",
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
seed=seed,
|
||||||
|
cache_prompt=cache_prompt,
|
||||||
)
|
)
|
||||||
|
async with session.post(url, json=payload) as response:
|
||||||
|
if verbose:
|
||||||
|
sys.stderr.write(f'# RESPONSE: {response}\n')
|
||||||
|
response.raise_for_status()
|
||||||
|
response = await response.json()
|
||||||
|
|
||||||
if verbose:
|
assert len(response["choices"]) == 1
|
||||||
sys.stderr.write(f'# RESPONSE: {response}\n')
|
choice = response["choices"][0]
|
||||||
|
|
||||||
assert len(response.choices) == 1
|
content = choice['message']['content']
|
||||||
choice = response.choices[0]
|
if choice['finish_reason'] == "tool_calls":
|
||||||
|
messages.append(choice['message'])
|
||||||
content = choice.message.content
|
assert choice['message']['tool_calls']
|
||||||
if choice.finish_reason == "tool_calls":
|
for tool_call in choice['message']['tool_calls']:
|
||||||
messages.append(choice.message) # type: ignore
|
|
||||||
assert choice.message.tool_calls
|
|
||||||
for tool_call in choice.message.tool_calls:
|
|
||||||
if content:
|
if content:
|
||||||
print(f'💭 {content}')
|
print(f'💭 {content}')
|
||||||
|
|
||||||
args = json.loads(tool_call.function.arguments)
|
name = tool_call['function']['name']
|
||||||
pretty_call = f'{tool_call.function.name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
|
args = json.loads(tool_call['function']['arguments'])
|
||||||
|
pretty_call = f'{name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
|
||||||
sys.stdout.write(f'⚙️ {pretty_call}')
|
sys.stdout.write(f'⚙️ {pretty_call}')
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
tool_result = await tool_map[tool_call.function.name](session, **args)
|
tool_result = await tool_map[name](session, **args)
|
||||||
sys.stdout.write(f" → {tool_result}\n")
|
sys.stdout.write(f" → {tool_result}\n")
|
||||||
messages.append(ChatCompletionToolMessageParam(
|
messages.append(dict(
|
||||||
tool_call_id=tool_call.id,
|
tool_call_id=tool_call.get('id'),
|
||||||
role="tool",
|
role="tool",
|
||||||
content=json.dumps(tool_result),
|
content=json.dumps(tool_result),
|
||||||
))
|
))
|
||||||
|
Loading…
Reference in New Issue
Block a user