Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

me model tool addition #9

Merged
merged 12 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
- name: Set up environment
run: |
pip install --upgrade pip wheel setuptools
pip install bandit[toml]==1.7.4 ruff==0.5.5
pip install bandit[toml]==1.7.4 ruff==0.6.7
- name: Linting check
run: |
bandit -qr -c pyproject.toml src/
Expand Down
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,8 @@ cython_debug/
# static files generated from Django application using `collectstatic`
media
static

# database stuff
*db
*.db-shm
*.db-wal
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added
- Add get morphoelectric (me) model tool

## [0.1.1] - 26.09.2024

### Fixed
Expand Down
9 changes: 9 additions & 0 deletions src/neuroagent/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ class SettingsGetMorpho(BaseModel):
model_config = ConfigDict(frozen=True)


class SettingsGetMEModel(BaseModel):
"""Get ME Model settings."""

search_size: int = 10

model_config = ConfigDict(frozen=True)


class SettingsKnowledgeGraph(BaseModel):
"""Knowledge graph API settings."""

Expand Down Expand Up @@ -157,6 +165,7 @@ class SettingsTools(BaseModel):
morpho: SettingsGetMorpho = SettingsGetMorpho()
trace: SettingsTrace = SettingsTrace()
kg_morpho_features: SettingsKGMorpho = SettingsKGMorpho()
me_model: SettingsGetMEModel = SettingsGetMEModel()

model_config = ConfigDict(frozen=True)

Expand Down
22 changes: 22 additions & 0 deletions src/neuroagent/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from neuroagent.multi_agents import BaseMultiAgent, SupervisorMultiAgent
from neuroagent.tools import (
ElectrophysFeatureTool,
GetMEModelTool,
GetMorphoTool,
GetTracesTool,
KGMorphoFeatureTool,
Expand Down Expand Up @@ -304,6 +305,25 @@ def get_morphology_feature_tool(
return tool


def get_me_model_tool(
settings: Annotated[Settings, Depends(get_settings)],
token: Annotated[str, Depends(get_kg_token)],
httpx_client: Annotated[AsyncClient, Depends(get_httpx_client)],
) -> GetMEModelTool:
"""Load get ME model tool."""
tool = GetMEModelTool(
metadata={
"url": settings.knowledge_graph.url,
"token": token,
"httpx_client": httpx_client,
"search_size": settings.tools.me_model.search_size,
"brainregion_path": settings.knowledge_graph.br_saving_path,
"celltypes_path": settings.knowledge_graph.ct_saving_path,
}
)
return tool


def get_language_model(
settings: Annotated[Settings, Depends(get_settings)],
) -> ChatOpenAI:
Expand Down Expand Up @@ -369,6 +389,7 @@ def get_agent(
ElectrophysFeatureTool, Depends(get_electrophys_feature_tool)
],
traces_tool: Annotated[GetTracesTool, Depends(get_traces_tool)],
me_model_tool: Annotated[GetMEModelTool, Depends(get_me_model_tool)],
settings: Annotated[Settings, Depends(get_settings)],
) -> BaseAgent | BaseMultiAgent:
"""Get the generative question answering service."""
Expand Down Expand Up @@ -397,6 +418,7 @@ def get_agent(
kg_morpho_feature_tool,
electrophys_feature_tool,
traces_tool,
me_model_tool,
]
logger.info("Load simple agent")
return SimpleAgent(llm=llm, tools=tools) # type: ignore
Expand Down
2 changes: 2 additions & 0 deletions src/neuroagent/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tools folder."""

from neuroagent.tools.electrophys_tool import ElectrophysFeatureTool, FeaturesOutput
from neuroagent.tools.get_me_model_tool import GetMEModelTool
from neuroagent.tools.get_morpho_tool import GetMorphoTool, KnowledgeGraphOutput
from neuroagent.tools.kg_morpho_features_tool import (
KGMorphoFeatureOutput,
Expand Down Expand Up @@ -35,4 +36,5 @@
"ParagraphMetadata",
"ResolveBrainRegionTool",
"TracesOutput",
"GetMEModelTool",
]
266 changes: 266 additions & 0 deletions src/neuroagent/tools/get_me_model_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
"""Module defining the Get ME Model tool."""

import logging
from typing import Any, Literal, Optional, Type

from langchain_core.tools import ToolException
from pydantic import BaseModel, Field

from neuroagent.cell_types import get_celltypes_descendants
from neuroagent.tools.base_tool import BaseToolOutput, BasicTool
from neuroagent.utils import get_descendants_id

logger = logging.getLogger(__name__)


class InputGetMEModel(BaseModel):
"""Inputs of the knowledge graph API."""

brain_region_id: str = Field(description="ID of the brain region of interest.")
mtype_id: Optional[str] = Field(
default=None, description="ID of the M-type of interest."
)
etype_id: Optional[
Literal[
"bAC",
"bIR",
"bNAC",
"bSTUT",
"cAC",
"cIR",
"cNAC",
"cSTUT",
"dNAC",
"dSTUT",
]
] = Field(default=None, description="ID of the E-type of interest.")


class MEModelOutput(BaseToolOutput):
"""Output schema for the knowledge graph API."""

me_model_id: str
me_model_name: str | None
me_model_description: str | None
mtype: str | None
etype: str | None

brain_region_id: str
brain_region_label: str | None

subject_species_label: str | None
subject_age: str | None


class GetMEModelTool(BasicTool):
"""Class defining the Get ME Model logic."""

name: str = "get-me-model-tool"
description: str = """Searches a neuroscience based knowledge graph to retrieve neuron morpho-electric model names, IDs and descriptions.
Requires a 'brain_region_id' which is the ID of the brain region of interest as registered in the knowledge graph. To get this ID, please use the `resolve-brain-region-tool` first.
Ideally, the user should also provide an 'mtype_id' and/or an 'etype_id' to filter the search results. But in case they are not provided, the search will return all models that match the brain region.
The output is a list of ME models, containing:
- The brain region ID.
- The brain region name.
- The subject species name.
- The subject age.
- The model ID.
- The model name.
- The model description.
The model ID is in the form of an HTTP(S) link such as 'https://bbp.epfl.ch/data/bbp/mmb-point-neuron-framework-model/...'."""
metadata: dict[str, Any]
args_schema: Type[BaseModel] = InputGetMEModel

def _run(self) -> None:
pass

async def _arun(
self,
brain_region_id: str,
mtype_id: str | None = None,
etype_id: str | None = None,
) -> list[MEModelOutput]:
"""From a brain region ID, extract ME models.

Parameters
----------
brain_region_id
ID of the brain region of interest (of the form http://api.brain-map.org/api/v2/data/Structure/...)
mtype_id
ID of the mtype of the model
etype_id
ID of the etype of the model

Returns
-------
list of MEModelOutput to describe the model and its metadata, or an error dict.
"""
logger.info(
f"Entering Get ME Model tool. Inputs: {brain_region_id=}, {mtype_id=}, {etype_id=}"
)
try:
# From the brain region ID, get the descendants.
hierarchy_ids = get_descendants_id(
brain_region_id, json_path=self.metadata["brainregion_path"]
)
logger.info(
f"Found {len(list(hierarchy_ids))} children of the brain ontology."
)

if mtype_id:
mtype_ids = set(
get_celltypes_descendants(mtype_id, self.metadata["celltypes_path"])
)
logger.info(
f"Found {len(list(mtype_ids))} children of the cell types ontology for mtype."
)
else:
mtype_ids = None

if etype_id:
etype_ids = set(
get_celltypes_descendants(etype_id, self.metadata["celltypes_path"])
)
logger.info(
f"Found {len(list(etype_ids))} children of the cell types ontology for etype."
)
else:
etype_ids = None

# Create the ES query to query the KG.
entire_query = self.create_query(
brain_regions_ids=hierarchy_ids,
mtype_ids=mtype_ids,
etype_ids=etype_ids,
)

# Send the query to get ME models.
response = await self.metadata["httpx_client"].post(
url=self.metadata["url"],
headers={"Authorization": f"Bearer {self.metadata['token']}"},
json=entire_query,
)

# Process the output and return.
return self._process_output(response.json())

except Exception as e:
raise ToolException(str(e), self.name)

def create_query(
self,
brain_regions_ids: set[str],
mtype_ids: set[str] | None = None,
etype_ids: set[str] | None = None,
) -> dict[str, Any]:
"""Create ES query out of the BR, mtype, and etype IDs.

Parameters
----------
brain_regions_ids
IDs of the brain region of interest (of the form http://api.brain-map.org/api/v2/data/Structure/...)
mtype_id
ID of the mtype of the model
etype_id
ID of the etype of the model

Returns
-------
dict containing the elasticsearch query to send to the KG.
"""
# At least one of the children brain region should match.
conditions = [
{
"bool": {
"should": [
{"term": {"[email protected]": hierarchy_id}}
for hierarchy_id in brain_regions_ids
]
}
},
{"term": {"@type.keyword": "https://neuroshapes.org/MEModel"}},
{"term": {"deprecated": False}},
]

if mtype_ids:
# The correct mtype should match. For now
# It is a one term should condition, but eventually
# we will resolve the subclasses of the mtypes.
# They will all be appended here.
conditions.append(
{
"bool": {
"should": [
{"match": {"mType.label": mtype_id}}
for mtype_id in mtype_ids
]
}
}
)

if etype_ids:
# The correct etype should match.
conditions.append(
{
"bool": {
"should": [
{"match": {"eType.label": etype_id}}
for etype_id in etype_ids
]
}
}
)

# Assemble the query to return ME models.
entire_query = {
"size": self.metadata["search_size"],
"track_total_hits": True,
"query": {"bool": {"must": conditions}},
"sort": {"createdAt": {"order": "desc", "unmapped_type": "keyword"}},
}
return entire_query

@staticmethod
def _process_output(output: Any) -> list[MEModelOutput]:
"""Process output to fit the MEModelOutput pydantic class defined above.

Parameters
----------
output
Raw output of the _arun method, which comes from the KG

Returns
-------
list of MEModelOutput to describe the model and its metadata.
"""
formatted_output = [
MEModelOutput(
me_model_id=res["_source"]["@id"],
me_model_name=res["_source"].get("name"),
me_model_description=res["_source"].get("description"),
mtype=(
res["_source"]["mType"].get("label")
if "mType" in res["_source"]
else None
),
etype=(
res["_source"]["eType"].get("label")
if "eType" in res["_source"]
else None
),
brain_region_id=res["_source"]["brainRegion"]["@id"],
brain_region_label=res["_source"]["brainRegion"].get("label"),
subject_species_label=(
res["_source"]["subjectSpecies"].get("label")
if "subjectSpecies" in res["_source"]
else None
),
subject_age=(
res["_source"]["subjectAge"].get("label")
if "subjectAge" in res["_source"]
else None
),
)
for res in output["hits"]["hits"]
]
return formatted_output
1 change: 1 addition & 0 deletions tests/agents/test_simple_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path

import pytest

from neuroagent.agents import AgentOutput, AgentStep, SimpleAgent


Expand Down
1 change: 1 addition & 0 deletions tests/agents/test_simple_chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from langchain_core.messages import HumanMessage, ToolMessage
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver

from neuroagent.agents import AgentOutput, AgentStep, SimpleChatAgent


Expand Down
Loading
Loading