Skip to content

Commit

Permalink
improve multi chat agent
Browse files Browse the repository at this point in the history
  • Loading branch information
Mustafa Kerem Kurban committed Oct 3, 2024
1 parent 4ff3e1e commit 21e0ed4
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 38 deletions.
22 changes: 11 additions & 11 deletions src/neuroagent/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,17 +498,17 @@ def get_chat_agent(

if settings.agent.model == "multi-hierarchical":
logger.info("Load multi-agent (hierarchical teams) chat")
tools = [
br_resolver_tool,
morpho_tool,
morphology_feature_tool,
kg_morpho_feature_tool,
literature_tool,
electrophys_feature_tool,
traces_tool,
me_model_tool,
bluenaas_tool,
]
tools = {
"br_resolver_tool": br_resolver_tool,
"morpho_tool": morpho_tool,
"morphology_feature_tool": morphology_feature_tool,
"kg_morpho_feature_tool": kg_morpho_feature_tool,
"literature_tool": literature_tool,
"electrophys_feature_tool": electrophys_feature_tool,
"traces_tool": traces_tool,
"me_model_tool": me_model_tool,
"bluenaas_tool": bluenaas_tool,
}
return HierarchicalTeamAgent(llm=llm, tools=tools, memory=memory) # type: ignore
elif settings.agent.model == "simple":
logger.info("Load simple chat")
Expand Down
4 changes: 2 additions & 2 deletions src/neuroagent/multi_agents/base_multi_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ class BaseMultiAgent(BaseModel, ABC):
"""Base class for multi agents."""

llm: BaseChatModel
main_agent: Any
agents: list[tuple[str, list[BasicTool]]]
main_agent: Any | None = None
agents: list[tuple[str, list[BasicTool]]] | None = None

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand Down
81 changes: 56 additions & 25 deletions src/neuroagent/multi_agents/hierarchical_multi_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@
import operator
import logging
from typing import (Annotated, Any, AsyncIterator, Hashable, List, Sequence,
TypedDict)
TypedDict, Dict)
from contextlib import AsyncExitStack

from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, trim_messages
from langchain_core.output_parsers.openai_functions import \
JsonOutputFunctionsParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables import RunnableConfig, RunnablePassthrough
from langchain_openai.chat_models import ChatOpenAI
from langgraph.graph import END, START, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.prebuilt import create_react_agent
from pydantic import ConfigDict, model_validator
from pydantic import ConfigDict, model_validator, Field
from langgraph.checkpoint.base import BaseCheckpointSaver

from neuroagent.agents import AgentOutput
from neuroagent.multi_agents.base_multi_agent import BaseMultiAgent
Expand All @@ -25,21 +27,27 @@ class HierarchicalTeamAgent(BaseMultiAgent):
"""Hierarchical Team Agent managing multiple teams."""

model_config = ConfigDict(arbitrary_types_allowed=True)
tools: Dict[str, BasicTool] = Field(default_factory=list)
memory: BaseCheckpointSaver[Any] | None = None
top_level_chain: Any = None
trimmer: Any = None

def __init__(self,
llm: Any,
tools: dict[str, Any],
agents: list[tuple[str, list[BasicTool]]] = None):
llm: Any,
tools: Dict[str, BasicTool],
agents: list[tuple[str, list[BasicTool]]] = None,
memory: BaseCheckpointSaver[Any] = None):
super().__init__(llm=llm, agents=agents)
self.llm = llm
self.tools = tools
self.top_level_chain = self.create_graph()
self.memory = memory
self.trimmer = trim_messages(
max_tokens=100000,
strategy="last",
token_counter=self.llm,
include_system=True,
)
self.top_level_chain = self.create_graph()

@staticmethod
def agent_node(self, state, agent, name):
Expand Down Expand Up @@ -109,7 +117,6 @@ class SimulationTeamState(TypedDict):
# Used to route work. The supervisor calls a function
# that will update this every time it makes a decision
next: str

# Define tools
simulation_tools = [
self.tools["br_resolver_tool"],
Expand All @@ -125,7 +132,8 @@ class SimulationTeamState(TypedDict):
] # Add other bluenaas endpoints later on..

# Create agents
simulation_agent = create_react_agent(self.llm, tools=simulation_tools)
simulation_agent = create_react_agent(self.llm, tools=simulation_tools, checkpointer=self.memory)

simulation_node = functools.partial(
self.agent_node, agent=simulation_agent, name="SimulationAgent"
) # might need to rename to SingleCellSimAgent when circuit level tools come
Expand Down Expand Up @@ -196,13 +204,13 @@ class AnalysisTeamState(TypedDict):
] # Replace with your actual tools

# Create agents
morphology_agent = create_react_agent(self.llm, tools=morphology_tools)
morphology_agent = create_react_agent(self.llm, tools=morphology_tools, checkpointer=self.memory)
morphology_node = functools.partial(
self.agent_node, agent=morphology_agent, name="MorphologyAgent"
)

electrophysiology_agent = create_react_agent(
self.llm, tools=electrophysiology_tools
self.llm, tools=electrophysiology_tools, checkpointer=self.memory
)
electrophysiology_node = functools.partial(
self.agent_node, agent=electrophysiology_agent, name="ElectrophysiologyAgent"
Expand Down Expand Up @@ -271,12 +279,17 @@ def create_team_supervisor(self, system_prompt: str, members: List[str]):
),
),
]).partial(options=str(options), team_members=", ".join(members))
return (
prompt

chain = (
RunnablePassthrough()
| prompt
| self.trimmer
| self.llm.bind_functions(functions=[function_def], function_call="route")
| JsonOutputFunctionsParser()
)

return chain


def run(self, query: str, thread_id: str) -> AgentOutput:
res = self.top_level_chain.invoke(
Expand All @@ -292,19 +305,37 @@ async def arun(self, query: str, thread_id: str) -> AgentOutput:
)
return self._process_output(res)

async def astream(self, query: str, thread_id: str) -> AsyncIterator[str]: # type: ignore
"""Astream method of the service."""
graph = self.create_graph()
config = RunnableConfig(configurable={"thread_id": thread_id})
async for chunk in graph.astream(
input={"messages": [HumanMessage(content=query)]}, config=config
async def astream(
self, thread_id: str, query: str, connection_string: str | None = None
) -> AsyncIterator[str]:
"""Run the agent against a query in streaming way.
Parameters
----------
thread_id
ID of the thread of the chat.
query
Query of the user.
connection_string
connection string for the checkpoint database.
Yields
------
Iterator streaming the processed output of the LLM
"""
async with (
self.agent.checkpointer.__class__.from_conn_string(connection_string)
if connection_string
else AsyncExitStack() as memory
):
if "Supervisor" in chunk.keys() and chunk["Supervisor"]["next"] != "FINISH":
yield f'\nCalling agent : {chunk["Supervisor"]["next"]}\n'
else:
values = [i for i in chunk.values()] # noqa: C416
if "messages" in values[0]:
yield f'\n {values[0]["messages"][0].content}'
if isinstance(memory, BaseCheckpointSaver):
self.agent.checkpointer = memory
config = {"configurable": {"thread_id": thread_id}}
streamed_response = self.agent.astream_events(
{"messages": query}, version="v2", config=config
)
async for event in streamed_response:
yield event

@staticmethod
def _process_output(output: Any) -> AgentOutput:
Expand Down

0 comments on commit 21e0ed4

Please sign in to comment.