Skip to content

Commit

Permalink
Merge branch 'failing_ci' into unit_tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cszsol committed Sep 30, 2024
2 parents 2efc883 + 74036cd commit fcebe80
Show file tree
Hide file tree
Showing 11 changed files with 50 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ jobs:
run: |
pip install --upgrade pip
pip install mypy==1.8.0
pip install -e ".[dev]"
pip install ".[dev]"
- name: Running mypy and tests
run: |
mypy src/
Expand Down
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.1.1] - 26.09.2024

### Fixed
- Fixed a bug that prevented AsyncSqlite checkpoint to access the DB in streamed endpoints.
- Fixed a bug that caused some unit tests to fail due to a change in how httpx_mock works in version 0.32

## [0.1.0] - 19.09.2024

### Added
Expand Down
2 changes: 1 addition & 1 deletion src/neuroagent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Neuroagent package."""

__version__ = "0.1.0"
__version__ = "0.1.1"
24 changes: 24 additions & 0 deletions src/neuroagent/agents/base_agent.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Base agent."""

from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator

from langchain.chat_models.base import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from pydantic import BaseModel, ConfigDict


Expand Down Expand Up @@ -47,3 +49,25 @@ def astream(self, *args: Any, **kwargs: Any) -> AsyncIterator[str]:
@abstractmethod
def _process_output(*args: Any, **kwargs: Any) -> AgentOutput:
"""Format the output."""


class AsyncSqliteSaverWithPrefix(AsyncSqliteSaver):
"""Wrapper around the AsyncSqliteSaver that accepts a connection string with prefix."""

@classmethod
@asynccontextmanager
async def from_conn_string(
cls, conn_string: str
) -> AsyncIterator["AsyncSqliteSaver"]:
"""Create a new AsyncSqliteSaver instance from a connection string.
Args:
conn_string (str): The SQLite connection string. It can have the 'sqlite:///' prefix.
Yields
------
AsyncSqliteSaverWithPrefix: A new AsyncSqliteSaverWithPrefix instance.
"""
conn_string = conn_string.split("///")[-1]
async with super().from_conn_string(conn_string) as memory:
yield AsyncSqliteSaverWithPrefix(memory.conn)
10 changes: 5 additions & 5 deletions src/neuroagent/agents/simple_chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
class SimpleChatAgent(BaseAgent):
"""Simple Agent class."""

memory: BaseCheckpointSaver
memory: BaseCheckpointSaver[Any]

@model_validator(mode="before")
@classmethod
Expand Down Expand Up @@ -73,12 +73,9 @@ async def astream(
streamed_response = self.agent.astream_events(
{"messages": query}, version="v2", config=config
)
is_streaming = False
async for event in streamed_response:
kind = event["event"]

# newline everytime model starts streaming.
if kind == "on_chat_model_start":
yield "\n\n"
# check for the model stream.
if kind == "on_chat_model_stream":
# check if we are calling the tools.
Expand All @@ -95,6 +92,9 @@ async def astream(

content = data_chunk.content
if content:
if not is_streaming:
yield "\n<begin_llm_response>\n"
is_streaming = True
yield content
yield "\n"

Expand Down
10 changes: 5 additions & 5 deletions src/neuroagent/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.exc import SQLAlchemyError
Expand All @@ -23,6 +22,7 @@
SimpleAgent,
SimpleChatAgent,
)
from neuroagent.agents.base_agent import AsyncSqliteSaverWithPrefix
from neuroagent.app.config import Settings
from neuroagent.cell_types import CellTypesMeta
from neuroagent.multi_agents import BaseMultiAgent, SupervisorMultiAgent
Expand Down Expand Up @@ -320,12 +320,12 @@ def get_language_model(

async def get_agent_memory(
connection_string: Annotated[str | None, Depends(get_connection_string)],
) -> AsyncIterator[BaseCheckpointSaver | None]:
) -> AsyncIterator[BaseCheckpointSaver[Any] | None]:
"""Get the agent checkpointer."""
if connection_string:
if connection_string.startswith("sqlite"):
async with AsyncSqliteSaver.from_conn_string(
connection_string.split("///")[-1]
async with AsyncSqliteSaverWithPrefix.from_conn_string(
connection_string
) as memory:
await memory.setup()
yield memory
Expand Down Expand Up @@ -403,7 +403,7 @@ def get_agent(

def get_chat_agent(
llm: Annotated[ChatOpenAI, Depends(get_language_model)],
memory: Annotated[BaseCheckpointSaver, Depends(get_agent_memory)],
memory: Annotated[BaseCheckpointSaver[Any], Depends(get_agent_memory)],
literature_tool: Annotated[LiteratureSearchTool, Depends(get_literature_tool)],
br_resolver_tool: Annotated[
ResolveBrainRegionTool, Depends(get_brain_region_resolver_tool)
Expand Down
1 change: 1 addition & 0 deletions src/neuroagent/scripts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Neuroagent scripts."""
6 changes: 4 additions & 2 deletions tests/agents/test_simple_chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from neuroagent.agents import AgentOutput, AgentStep, SimpleChatAgent


@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
@pytest.mark.asyncio
async def test_arun(fake_llm_with_tools, httpx_mock):
llm, tools, fake_responses = await anext(fake_llm_with_tools)
Expand Down Expand Up @@ -64,6 +65,7 @@ async def test_arun(fake_llm_with_tools, httpx_mock):
assert len(messages_list) == 10


@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
@pytest.mark.asyncio
async def test_astream(fake_llm_with_tools, httpx_mock):
llm, tools, fake_responses = await anext(fake_llm_with_tools)
Expand All @@ -84,8 +86,8 @@ async def test_astream(fake_llm_with_tools, httpx_mock):

msg_list = "".join([el async for el in response])
assert (
msg_list == "\n\n\nCalling tool : get-morpho-tool with arguments :"
' {"brain_region_id":"http://api.brain-map.org/api/v2/data/Structure/549"}\n\nGreat'
msg_list == "\nCalling tool : get-morpho-tool with arguments :"
' {"brain_region_id":"http://api.brain-map.org/api/v2/data/Structure/549"}\n<begin_llm_response>\nGreat'
" answer\n"
)

Expand Down
1 change: 1 addition & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ async def test_get_kg_data_errors(httpx_mock):
)


@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
@pytest.mark.asyncio
async def test_get_kg_data(httpx_mock):
url = "http://fake_url"
Expand Down
1 change: 1 addition & 0 deletions tests/tools/test_electrophys_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


class TestElectrophysTool:
@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
@pytest.mark.asyncio
async def test_arun(self, httpx_mock):
url = "http://fake_url"
Expand Down
1 change: 1 addition & 0 deletions tests/tools/test_traces_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@


class TestTracesTool:
@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
@pytest.mark.asyncio
async def test_arun(self, httpx_mock, brain_region_json_path):
url = "http://fake_url"
Expand Down

0 comments on commit fcebe80

Please sign in to comment.