Skip to content

Commit

Permalink
fixed streaming with chat agent (#5)
Browse files Browse the repository at this point in the history
* fixed streaming

* mypy
  • Loading branch information
BoBer78 authored Sep 19, 2024
1 parent bbeb09e commit 8f589be
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 36 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed
- Migration to pydantic V2.

### Fixed
- Streaming with chat agent.
- Deleted some legacy code.
2 changes: 1 addition & 1 deletion src/neuroagent/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def arun(self, *args: Any, **kwargs: Any) -> AgentOutput:
"""Arun method of the service."""

@abstractmethod
async def astream(self, *args: Any, **kwargs: Any) -> AsyncIterator[str]:
def astream(self, *args: Any, **kwargs: Any) -> AsyncIterator[str]:
"""Astream method of the service."""

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion src/neuroagent/agents/simple_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def arun(self, query: str) -> Any:
result = await self.agent.ainvoke({"messages": [("human", query)]})
return self._process_output(result)

async def astream(self, query: str) -> AsyncIterator[str]: # type: ignore
async def astream(self, query: str) -> AsyncIterator[str]:
"""Run the agent against a query in streaming way.
Parameters
Expand Down
77 changes: 44 additions & 33 deletions src/neuroagent/agents/simple_chat_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Simple agent."""

import logging
from contextlib import AsyncExitStack
from typing import Any, AsyncIterator

from langchain_core.messages import AIMessage, HumanMessage
Expand Down Expand Up @@ -43,49 +44,59 @@ async def arun(self, thread_id: str, query: str) -> Any:
result = await self.agent.ainvoke({"messages": [input_message]}, config=config)
return self._process_output(result)

async def astream(self, thread_id: str, query: str) -> AsyncIterator[str]: # type: ignore
async def astream(
self, thread_id: str, query: str, connection_string: str | None = None
) -> AsyncIterator[str]:
"""Run the agent against a query in streaming way.
Parameters
----------
thread_id
ID of the thread of the chat.
query
Query of the user
Query of the user.
connection_string
connection string for the checkpoint database.
Returns
-------
Yields
------
Iterator streaming the processed output of the LLM
"""
config = {"configurable": {"thread_id": thread_id}}
streamed_response = self.agent.astream_events(
{"messages": query}, version="v2", config=config
)

async for event in streamed_response:
kind = event["event"]

# newline everytime model starts streaming.
if kind == "on_chat_model_start":
yield "\n\n"
# check for the model stream.
if kind == "on_chat_model_stream":
# check if we are calling the tools.
data_chunk = event["data"]["chunk"]
if "tool_calls" in data_chunk.additional_kwargs:
tool = data_chunk.additional_kwargs["tool_calls"]
if tool[0]["function"]["name"]:
yield (
f'\nCalling tool : {tool[0]["function"]["name"]} with'
" arguments : "
)
if tool[0]["function"]["arguments"]:
yield tool[0]["function"]["arguments"]

content = data_chunk.content
if content:
yield content
yield "\n"
async with (
self.agent.checkpointer.__class__.from_conn_string(connection_string)
if connection_string
else AsyncExitStack() as memory
):
if isinstance(memory, BaseCheckpointSaver):
self.agent.checkpointer = memory
config = {"configurable": {"thread_id": thread_id}}
streamed_response = self.agent.astream_events(
{"messages": query}, version="v2", config=config
)
async for event in streamed_response:
kind = event["event"]

# newline everytime model starts streaming.
if kind == "on_chat_model_start":
yield "\n\n"
# check for the model stream.
if kind == "on_chat_model_stream":
# check if we are calling the tools.
data_chunk = event["data"]["chunk"]
if "tool_calls" in data_chunk.additional_kwargs:
tool = data_chunk.additional_kwargs["tool_calls"]
if tool[0]["function"]["name"]:
yield (
f'\nCalling tool : {tool[0]["function"]["name"]} with'
" arguments : "
)
if tool[0]["function"]["arguments"]:
yield tool[0]["function"]["arguments"]

content = data_chunk.content
if content:
yield content
yield "\n"

@staticmethod
def _process_output(output: Any) -> AgentOutput:
Expand Down
10 changes: 9 additions & 1 deletion src/neuroagent/app/routers/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from neuroagent.app.dependencies import (
get_agent,
get_chat_agent,
get_connection_string,
get_user_id,
)
from neuroagent.app.routers.database.schemas import Threads
Expand Down Expand Up @@ -56,9 +57,16 @@ async def run_streamed_chat_agent(
request: AgentRequest,
_: Annotated[Threads, Depends(get_object)],
agent: Annotated[BaseAgent, Depends(get_chat_agent)],
connection_string: Annotated[str | None, Depends(get_connection_string)],
thread_id: str,
) -> StreamingResponse:
"""Run agent in streaming mode."""
logger.info("Running agent query.")
logger.info(f"User's query: {request.query}")
return StreamingResponse(agent.astream(query=request.query, thread_id=thread_id)) # type: ignore
return StreamingResponse(
agent.astream(
query=request.query,
thread_id=thread_id,
connection_string=connection_string,
)
)

0 comments on commit 8f589be

Please sign in to comment.