diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index dfbb0ab..369e6de 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -84,7 +84,7 @@ jobs: run: | pip install --upgrade pip pip install mypy==1.8.0 - pip install -e ".[dev]" + pip install ".[dev]" - name: Running mypy and tests run: | mypy src/ diff --git a/CHANGELOG.md b/CHANGELOG.md index e809514..6b5cb48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.1.1] - 26.09.2024 + +### Fixed +- Fixed a bug that prevented AsyncSqlite checkpoint to access the DB in streamed endpoints. +- Fixed a bug that caused some unit tests to fail due to a change in how httpx_mock works in version 0.32 + ## [0.1.0] - 19.09.2024 ### Added diff --git a/src/neuroagent/__init__.py b/src/neuroagent/__init__.py index b4d4ea2..508d96b 100644 --- a/src/neuroagent/__init__.py +++ b/src/neuroagent/__init__.py @@ -1,3 +1,3 @@ """Neuroagent package.""" -__version__ = "0.1.0" +__version__ = "0.1.1" diff --git a/src/neuroagent/agents/base_agent.py b/src/neuroagent/agents/base_agent.py index 9ecf545..347e3bd 100644 --- a/src/neuroagent/agents/base_agent.py +++ b/src/neuroagent/agents/base_agent.py @@ -1,10 +1,12 @@ """Base agent.""" from abc import ABC, abstractmethod +from contextlib import asynccontextmanager from typing import Any, AsyncIterator from langchain.chat_models.base import BaseChatModel from langchain_core.tools import BaseTool +from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver from pydantic import BaseModel, ConfigDict @@ -47,3 +49,25 @@ def astream(self, *args: Any, **kwargs: Any) -> AsyncIterator[str]: @abstractmethod def _process_output(*args: Any, **kwargs: Any) -> AgentOutput: """Format the output.""" + + +class AsyncSqliteSaverWithPrefix(AsyncSqliteSaver): + """Wrapper around the AsyncSqliteSaver that accepts a connection string with prefix.""" + + @classmethod + @asynccontextmanager + async def from_conn_string( + cls, conn_string: str + ) -> AsyncIterator["AsyncSqliteSaver"]: + """Create a new AsyncSqliteSaver instance from a connection string. + + Args: + conn_string (str): The SQLite connection string. It can have the 'sqlite:///' prefix. + + Yields + ------ + AsyncSqliteSaverWithPrefix: A new AsyncSqliteSaverWithPrefix instance. + """ + conn_string = conn_string.split("///")[-1] + async with super().from_conn_string(conn_string) as memory: + yield AsyncSqliteSaverWithPrefix(memory.conn) diff --git a/src/neuroagent/agents/simple_chat_agent.py b/src/neuroagent/agents/simple_chat_agent.py index b331d75..882b8d7 100644 --- a/src/neuroagent/agents/simple_chat_agent.py +++ b/src/neuroagent/agents/simple_chat_agent.py @@ -17,7 +17,7 @@ class SimpleChatAgent(BaseAgent): """Simple Agent class.""" - memory: BaseCheckpointSaver + memory: BaseCheckpointSaver[Any] @model_validator(mode="before") @classmethod @@ -73,12 +73,9 @@ async def astream( streamed_response = self.agent.astream_events( {"messages": query}, version="v2", config=config ) + is_streaming = False async for event in streamed_response: kind = event["event"] - - # newline everytime model starts streaming. - if kind == "on_chat_model_start": - yield "\n\n" # check for the model stream. if kind == "on_chat_model_stream": # check if we are calling the tools. @@ -95,6 +92,9 @@ async def astream( content = data_chunk.content if content: + if not is_streaming: + yield "\n\n" + is_streaming = True yield content yield "\n" diff --git a/src/neuroagent/app/dependencies.py b/src/neuroagent/app/dependencies.py index 75a619e..78d8f13 100644 --- a/src/neuroagent/app/dependencies.py +++ b/src/neuroagent/app/dependencies.py @@ -11,7 +11,6 @@ from langchain_openai import ChatOpenAI from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver -from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver from sqlalchemy import create_engine from sqlalchemy.engine import Engine from sqlalchemy.exc import SQLAlchemyError @@ -23,6 +22,7 @@ SimpleAgent, SimpleChatAgent, ) +from neuroagent.agents.base_agent import AsyncSqliteSaverWithPrefix from neuroagent.app.config import Settings from neuroagent.cell_types import CellTypesMeta from neuroagent.multi_agents import BaseMultiAgent, SupervisorMultiAgent @@ -320,12 +320,12 @@ def get_language_model( async def get_agent_memory( connection_string: Annotated[str | None, Depends(get_connection_string)], -) -> AsyncIterator[BaseCheckpointSaver | None]: +) -> AsyncIterator[BaseCheckpointSaver[Any] | None]: """Get the agent checkpointer.""" if connection_string: if connection_string.startswith("sqlite"): - async with AsyncSqliteSaver.from_conn_string( - connection_string.split("///")[-1] + async with AsyncSqliteSaverWithPrefix.from_conn_string( + connection_string ) as memory: await memory.setup() yield memory @@ -403,7 +403,7 @@ def get_agent( def get_chat_agent( llm: Annotated[ChatOpenAI, Depends(get_language_model)], - memory: Annotated[BaseCheckpointSaver, Depends(get_agent_memory)], + memory: Annotated[BaseCheckpointSaver[Any], Depends(get_agent_memory)], literature_tool: Annotated[LiteratureSearchTool, Depends(get_literature_tool)], br_resolver_tool: Annotated[ ResolveBrainRegionTool, Depends(get_brain_region_resolver_tool) diff --git a/src/neuroagent/scripts/__init__.py b/src/neuroagent/scripts/__init__.py new file mode 100644 index 0000000..cc662d0 --- /dev/null +++ b/src/neuroagent/scripts/__init__.py @@ -0,0 +1 @@ +"""Neuroagent scripts.""" diff --git a/tests/agents/test_simple_chat_agent.py b/tests/agents/test_simple_chat_agent.py index 6ec5474..c7e0a92 100644 --- a/tests/agents/test_simple_chat_agent.py +++ b/tests/agents/test_simple_chat_agent.py @@ -9,6 +9,7 @@ from neuroagent.agents import AgentOutput, AgentStep, SimpleChatAgent +@pytest.mark.httpx_mock(can_send_already_matched_responses=True) @pytest.mark.asyncio async def test_arun(fake_llm_with_tools, httpx_mock): llm, tools, fake_responses = await anext(fake_llm_with_tools) @@ -64,6 +65,7 @@ async def test_arun(fake_llm_with_tools, httpx_mock): assert len(messages_list) == 10 +@pytest.mark.httpx_mock(can_send_already_matched_responses=True) @pytest.mark.asyncio async def test_astream(fake_llm_with_tools, httpx_mock): llm, tools, fake_responses = await anext(fake_llm_with_tools) @@ -84,8 +86,8 @@ async def test_astream(fake_llm_with_tools, httpx_mock): msg_list = "".join([el async for el in response]) assert ( - msg_list == "\n\n\nCalling tool : get-morpho-tool with arguments :" - ' {"brain_region_id":"http://api.brain-map.org/api/v2/data/Structure/549"}\n\nGreat' + msg_list == "\nCalling tool : get-morpho-tool with arguments :" + ' {"brain_region_id":"http://api.brain-map.org/api/v2/data/Structure/549"}\n\nGreat' " answer\n" ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 152232c..9e9af8b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -313,6 +313,7 @@ async def test_get_kg_data_errors(httpx_mock): ) +@pytest.mark.httpx_mock(can_send_already_matched_responses=True) @pytest.mark.asyncio async def test_get_kg_data(httpx_mock): url = "http://fake_url" diff --git a/tests/tools/test_electrophys_tool.py b/tests/tools/test_electrophys_tool.py index e7729fc..6705c2a 100644 --- a/tests/tools/test_electrophys_tool.py +++ b/tests/tools/test_electrophys_tool.py @@ -15,6 +15,7 @@ class TestElectrophysTool: + @pytest.mark.httpx_mock(can_send_already_matched_responses=True) @pytest.mark.asyncio async def test_arun(self, httpx_mock): url = "http://fake_url" diff --git a/tests/tools/test_traces_tool.py b/tests/tools/test_traces_tool.py index 0bae056..d0689aa 100644 --- a/tests/tools/test_traces_tool.py +++ b/tests/tools/test_traces_tool.py @@ -11,6 +11,7 @@ class TestTracesTool: + @pytest.mark.httpx_mock(can_send_already_matched_responses=True) @pytest.mark.asyncio async def test_arun(self, httpx_mock, brain_region_json_path): url = "http://fake_url"