From 8f589be0acfbd9db7e48fda853e77b0d2bd6ff1a Mon Sep 17 00:00:00 2001 From: Boris Bergsma Date: Thu, 19 Sep 2024 11:45:42 +0200 Subject: [PATCH] fixed streaming with chat agent (#5) * fixed streaming * mypy --- CHANGELOG.md | 3 + src/neuroagent/agents/base_agent.py | 2 +- src/neuroagent/agents/simple_agent.py | 2 +- src/neuroagent/agents/simple_chat_agent.py | 77 ++++++++++++---------- src/neuroagent/app/routers/qa.py | 10 ++- 5 files changed, 58 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f8c2a9..64bdd46 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,4 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Migration to pydantic V2. + +### Fixed +- Streaming with chat agent. - Deleted some legacy code. diff --git a/src/neuroagent/agents/base_agent.py b/src/neuroagent/agents/base_agent.py index b853dd8..9ecf545 100644 --- a/src/neuroagent/agents/base_agent.py +++ b/src/neuroagent/agents/base_agent.py @@ -40,7 +40,7 @@ async def arun(self, *args: Any, **kwargs: Any) -> AgentOutput: """Arun method of the service.""" @abstractmethod - async def astream(self, *args: Any, **kwargs: Any) -> AsyncIterator[str]: + def astream(self, *args: Any, **kwargs: Any) -> AsyncIterator[str]: """Astream method of the service.""" @staticmethod diff --git a/src/neuroagent/agents/simple_agent.py b/src/neuroagent/agents/simple_agent.py index 819bd38..3ea1bd2 100644 --- a/src/neuroagent/agents/simple_agent.py +++ b/src/neuroagent/agents/simple_agent.py @@ -58,7 +58,7 @@ async def arun(self, query: str) -> Any: result = await self.agent.ainvoke({"messages": [("human", query)]}) return self._process_output(result) - async def astream(self, query: str) -> AsyncIterator[str]: # type: ignore + async def astream(self, query: str) -> AsyncIterator[str]: """Run the agent against a query in streaming way. Parameters diff --git a/src/neuroagent/agents/simple_chat_agent.py b/src/neuroagent/agents/simple_chat_agent.py index 1262a2c..b331d75 100644 --- a/src/neuroagent/agents/simple_chat_agent.py +++ b/src/neuroagent/agents/simple_chat_agent.py @@ -1,6 +1,7 @@ """Simple agent.""" import logging +from contextlib import AsyncExitStack from typing import Any, AsyncIterator from langchain_core.messages import AIMessage, HumanMessage @@ -43,7 +44,9 @@ async def arun(self, thread_id: str, query: str) -> Any: result = await self.agent.ainvoke({"messages": [input_message]}, config=config) return self._process_output(result) - async def astream(self, thread_id: str, query: str) -> AsyncIterator[str]: # type: ignore + async def astream( + self, thread_id: str, query: str, connection_string: str | None = None + ) -> AsyncIterator[str]: """Run the agent against a query in streaming way. Parameters @@ -51,41 +54,49 @@ async def astream(self, thread_id: str, query: str) -> AsyncIterator[str]: # ty thread_id ID of the thread of the chat. query - Query of the user + Query of the user. + connection_string + connection string for the checkpoint database. - Returns - ------- + Yields + ------ Iterator streaming the processed output of the LLM """ - config = {"configurable": {"thread_id": thread_id}} - streamed_response = self.agent.astream_events( - {"messages": query}, version="v2", config=config - ) - - async for event in streamed_response: - kind = event["event"] - - # newline everytime model starts streaming. - if kind == "on_chat_model_start": - yield "\n\n" - # check for the model stream. - if kind == "on_chat_model_stream": - # check if we are calling the tools. - data_chunk = event["data"]["chunk"] - if "tool_calls" in data_chunk.additional_kwargs: - tool = data_chunk.additional_kwargs["tool_calls"] - if tool[0]["function"]["name"]: - yield ( - f'\nCalling tool : {tool[0]["function"]["name"]} with' - " arguments : " - ) - if tool[0]["function"]["arguments"]: - yield tool[0]["function"]["arguments"] - - content = data_chunk.content - if content: - yield content - yield "\n" + async with ( + self.agent.checkpointer.__class__.from_conn_string(connection_string) + if connection_string + else AsyncExitStack() as memory + ): + if isinstance(memory, BaseCheckpointSaver): + self.agent.checkpointer = memory + config = {"configurable": {"thread_id": thread_id}} + streamed_response = self.agent.astream_events( + {"messages": query}, version="v2", config=config + ) + async for event in streamed_response: + kind = event["event"] + + # newline everytime model starts streaming. + if kind == "on_chat_model_start": + yield "\n\n" + # check for the model stream. + if kind == "on_chat_model_stream": + # check if we are calling the tools. + data_chunk = event["data"]["chunk"] + if "tool_calls" in data_chunk.additional_kwargs: + tool = data_chunk.additional_kwargs["tool_calls"] + if tool[0]["function"]["name"]: + yield ( + f'\nCalling tool : {tool[0]["function"]["name"]} with' + " arguments : " + ) + if tool[0]["function"]["arguments"]: + yield tool[0]["function"]["arguments"] + + content = data_chunk.content + if content: + yield content + yield "\n" @staticmethod def _process_output(output: Any) -> AgentOutput: diff --git a/src/neuroagent/app/routers/qa.py b/src/neuroagent/app/routers/qa.py index 1104d00..b655c5a 100644 --- a/src/neuroagent/app/routers/qa.py +++ b/src/neuroagent/app/routers/qa.py @@ -14,6 +14,7 @@ from neuroagent.app.dependencies import ( get_agent, get_chat_agent, + get_connection_string, get_user_id, ) from neuroagent.app.routers.database.schemas import Threads @@ -56,9 +57,16 @@ async def run_streamed_chat_agent( request: AgentRequest, _: Annotated[Threads, Depends(get_object)], agent: Annotated[BaseAgent, Depends(get_chat_agent)], + connection_string: Annotated[str | None, Depends(get_connection_string)], thread_id: str, ) -> StreamingResponse: """Run agent in streaming mode.""" logger.info("Running agent query.") logger.info(f"User's query: {request.query}") - return StreamingResponse(agent.astream(query=request.query, thread_id=thread_id)) # type: ignore + return StreamingResponse( + agent.astream( + query=request.query, + thread_id=thread_id, + connection_string=connection_string, + ) + )