mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 12:10:18 +00:00
tool-call
: fix agent type lints
This commit is contained in:
parent
1e5c0e747e
commit
9295ca95db
@ -30,4 +30,4 @@
|
||||
uv run examples/tool-call/agent.py \
|
||||
--tool-endpoint http://localhost:8088 \
|
||||
--goal "What is the sum of 2535 squared and 32222000403 then multiplied by one and a half. What's a third of the result?"
|
||||
```
|
||||
```
|
||||
|
@ -11,12 +11,13 @@
|
||||
# ///
|
||||
import json
|
||||
import openai
|
||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam, ChatCompletionUserMessageParam
|
||||
from pydantic import BaseModel
|
||||
import requests
|
||||
import sys
|
||||
import typer
|
||||
from typing import Annotated, List, Optional
|
||||
import urllib
|
||||
from typing import Annotated, Optional
|
||||
import urllib.parse
|
||||
|
||||
|
||||
class OpenAPIMethod:
|
||||
@ -94,7 +95,7 @@ class OpenAPIMethod:
|
||||
def main(
|
||||
goal: Annotated[str, typer.Option()],
|
||||
api_key: Optional[str] = None,
|
||||
tool_endpoint: Optional[List[str]] = None,
|
||||
tool_endpoint: Optional[list[str]] = None,
|
||||
format: Annotated[Optional[str], typer.Option(help="The output format: either a Python type (e.g. 'float' or a Pydantic model defined in one of the tool files), or a JSON schema, e.g. '{\"format\": \"date\"}'")] = None,
|
||||
max_iterations: Optional[int] = 10,
|
||||
parallel_calls: Optional[bool] = False,
|
||||
@ -102,16 +103,16 @@ def main(
|
||||
# endpoint: Optional[str] = None,
|
||||
endpoint: str = "http://localhost:8080/v1/",
|
||||
):
|
||||
|
||||
|
||||
openai.api_key = api_key
|
||||
openai.base_url = endpoint
|
||||
|
||||
|
||||
tool_map = {}
|
||||
tools = []
|
||||
|
||||
|
||||
for url in (tool_endpoint or []):
|
||||
assert url.startswith('http://') or url.startswith('https://'), f'Tools must be URLs, not local files: {url}'
|
||||
|
||||
|
||||
catalog_url = f'{url}/openapi.json'
|
||||
catalog_response = requests.get(catalog_url)
|
||||
catalog_response.raise_for_status()
|
||||
@ -131,11 +132,11 @@ def main(
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
sys.stdout.write(f'🛠️ {", ".join(tool_map.keys())}\n')
|
||||
|
||||
messages = [
|
||||
dict(
|
||||
messages: list[ChatCompletionMessageParam] = [
|
||||
ChatCompletionUserMessageParam(
|
||||
role="user",
|
||||
content=goal,
|
||||
)
|
||||
@ -143,7 +144,7 @@ def main(
|
||||
|
||||
i = 0
|
||||
while (max_iterations is None or i < max_iterations):
|
||||
|
||||
|
||||
response = openai.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=messages,
|
||||
@ -152,13 +153,14 @@ def main(
|
||||
|
||||
if verbose:
|
||||
sys.stderr.write(f'# RESPONSE: {response}\n')
|
||||
|
||||
|
||||
assert len(response.choices) == 1
|
||||
choice = response.choices[0]
|
||||
|
||||
content = choice.message.content
|
||||
if choice.finish_reason == "tool_calls":
|
||||
messages.append(choice.message)
|
||||
messages.append(choice.message) # type: ignore
|
||||
assert choice.message.tool_calls
|
||||
for tool_call in choice.message.tool_calls:
|
||||
if content:
|
||||
print(f'💭 {content}')
|
||||
@ -169,11 +171,11 @@ def main(
|
||||
sys.stdout.flush()
|
||||
tool_result = tool_map[tool_call.function.name](**args)
|
||||
sys.stdout.write(f" → {tool_result}\n")
|
||||
messages.append(dict(
|
||||
messages.append(ChatCompletionToolMessageParam(
|
||||
tool_call_id=tool_call.id,
|
||||
role="tool",
|
||||
name=tool_call.function.name,
|
||||
content=f'{tool_result}',
|
||||
# name=tool_call.function.name,
|
||||
content=json.dumps(tool_result),
|
||||
# content=f'{pretty_call} = {tool_result}',
|
||||
))
|
||||
else:
|
||||
|
@ -59,8 +59,8 @@ def main(files: List[str], host: str = '0.0.0.0', port: int = 8000):
|
||||
continue
|
||||
|
||||
vt = type(v)
|
||||
if vt.__module__ == 'langchain_core.tools' and vt.__name__.endswith('Tool') and hasattr(v, 'func') and callable(v.func):
|
||||
v = v.func
|
||||
if vt.__module__ == 'langchain_core.tools' and vt.__name__.endswith('Tool') and hasattr(v, 'func') and callable(func := getattr(v, 'func')):
|
||||
v = func
|
||||
|
||||
print(f'INFO: Binding /{k}')
|
||||
try:
|
||||
@ -73,4 +73,4 @@ def main(files: List[str], host: str = '0.0.0.0', port: int = 8000):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
typer.run(main)
|
||||
typer.run(main)
|
||||
|
@ -1,13 +1,10 @@
|
||||
from datetime import date
|
||||
import datetime
|
||||
import json
|
||||
from pydantic import BaseModel
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import typer
|
||||
from typing import Union, Optional, Dict
|
||||
import types
|
||||
from typing import Union, Optional, Dict
|
||||
|
||||
|
||||
class Duration(BaseModel):
|
||||
@ -58,7 +55,7 @@ def wait_for_duration(duration: Duration) -> None:
|
||||
time.sleep(duration.get_total_seconds)
|
||||
|
||||
@staticmethod
|
||||
def wait_for_date(target_date: date) -> None:
|
||||
def wait_for_date(target_date: datetime.date) -> None:
|
||||
f'''
|
||||
Wait until a specific date is reached before continuing.
|
||||
Today's date is {datetime.date.today()}
|
||||
|
Loading…
Reference in New Issue
Block a user