From 26e76f9704185d1ad44f5d245071bf8b93bce774 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Wed, 2 Oct 2024 19:12:57 +0100 Subject: [PATCH] `agent`: allow interactive chat by default, and don't reuse sessions --- examples/agent/run.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/examples/agent/run.py b/examples/agent/run.py index a897952b6..9b0fc0267 100644 --- a/examples/agent/run.py +++ b/examples/agent/run.py @@ -77,7 +77,7 @@ class OpenAPIMethod: if components: self.parameters_schema['components'] = components - async def __call__(self, session: aiohttp.ClientSession, **kwargs): + async def __call__(self, **kwargs): if self.body: body = kwargs.pop(self.body['name'], None) if self.body['required']: @@ -96,9 +96,10 @@ class OpenAPIMethod: params = '&'.join(f'{name}={urllib.parse.quote(str(value))}' for name, value in query_params.items() if value is not None) url = f'{self.url}?{params}' - async with session.post(url, json=body) as response: - response.raise_for_status() - response_json = await response.json() + async with aiohttp.ClientSession() as session: + async with session.post(url, json=body) as response: + response.raise_for_status() + response_json = await response.json() return response_json @@ -131,6 +132,7 @@ async def discover_tools(tool_endpoints: list[str], logger) -> tuple[dict, list] return tool_map, tools + def typer_async_workaround(): 'Adapted from https://github.com/fastapi/typer/issues/950#issuecomment-2351076467' def decorator(f): @@ -149,6 +151,7 @@ async def main( verbose: bool = False, cache_prompt: bool = True, seed: Optional[int] = None, + interactive: bool = True, openai: bool = False, endpoint: Optional[str] = None, api_key: Optional[str] = None, @@ -180,7 +183,7 @@ async def main( 'Content-Type': 'application/json', 'Authorization': f'Bearer {api_key}' } - async with aiohttp.ClientSession(headers=headers) as session: + async def run_turn(): for i in range(max_iterations or sys.maxsize): url = f'{endpoint}chat/completions' payload = dict( @@ -195,10 +198,11 @@ async def main( )) # type: ignore logger.debug('Calling %s with %s', url, json.dumps(payload, indent=2)) - async with session.post(url, json=payload) as response: - logger.debug('Response: %s', response) - response.raise_for_status() - response = await response.json() + async with aiohttp.ClientSession(headers=headers) as session: + async with session.post(url, json=payload) as response: + logger.debug('Response: %s', response) + response.raise_for_status() + response = await response.json() assert len(response['choices']) == 1 choice = response['choices'][0] @@ -216,7 +220,7 @@ async def main( pretty_call = f'{name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})' logger.info(f'⚙️ {pretty_call}') sys.stdout.flush() - tool_result = await tool_map[name](session, **args) + tool_result = await tool_map[name](**args) tool_result_str = json.dumps(tool_result) logger.info(' → %d chars', len(tool_result_str)) logger.debug('%s', tool_result_str) @@ -233,5 +237,13 @@ async def main( if max_iterations is not None: raise Exception(f'Failed to get a valid response after {max_iterations} tool calls') + while interactive: + await run_turn() + messages.append(dict( + role='user', + content=input('💬 ') + )) + + if __name__ == '__main__': typer.run(main)