Skip to content

Commit

Permalink
Refactor search fields in credits endpoint to enable trigram similari…
Browse files Browse the repository at this point in the history
…ty search
  • Loading branch information
andersy005 committed Sep 10, 2024
1 parent 54367de commit 1ca249a
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@


def upgrade() -> None:
# Create pg_trgm extension
op.execute('CREATE EXTENSION IF NOT EXISTS pg_trgm')
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
'clip',
Expand Down Expand Up @@ -117,10 +119,30 @@ def upgrade() -> None:
op.create_index(
op.f('ix_credit_transaction_date'), 'credit', ['transaction_date'], unique=False
)

# Create GIN indexes for trigram similarity search
op.execute(
'CREATE INDEX idx_credit_retirement_beneficiary_trgm ON credit USING gin (retirement_beneficiary gin_trgm_ops)'
)
op.execute(
'CREATE INDEX idx_credit_retirement_note_trgm ON credit USING gin (retirement_note gin_trgm_ops)'
)
op.execute(
'CREATE INDEX idx_credit_retirement_account_trgm ON credit USING gin (retirement_account gin_trgm_ops)'
)
op.execute(
'CREATE INDEX idx_credit_retirement_reason_trgm ON credit USING gin (retirement_reason gin_trgm_ops)'
)
# ### end Alembic commands ###


def downgrade() -> None:
# Remove GIN indexes
op.execute('DROP INDEX IF EXISTS idx_credit_retirement_beneficiary_trgm')
op.execute('DROP INDEX IF EXISTS idx_credit_retirement_note_trgm')
op.execute('DROP INDEX IF EXISTS idx_credit_retirement_account_trgm')
op.execute('DROP INDEX IF EXISTS idx_credit_retirement_reason_trgm')

# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_credit_transaction_date'), table_name='credit')
op.drop_index(op.f('ix_credit_retirement_reason'), table_name='credit')
Expand All @@ -134,4 +156,7 @@ def downgrade() -> None:
op.drop_table('project')
op.drop_table('file')
op.drop_table('clip')

# Remove pg_trgm extension
op.execute('DROP EXTENSION IF EXISTS pg_trgm')
# ### end Alembic commands ###
135 changes: 123 additions & 12 deletions offsets_db_api/routers/credits.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import datetime
import json
import re

from fastapi import APIRouter, Depends, Query, Request
from fastapi_cache.decorator import cache
from sqlmodel import Session, col, or_, select
from pydantic import BaseModel
from sqlmodel import Session, col, func, or_, select

from offsets_db_api.cache import CACHE_NAMESPACE
from offsets_db_api.database import get_session
Expand All @@ -15,6 +18,74 @@
router = APIRouter()
logger = get_logger()

# Helper functions


def normalize_text(text: str) -> str:
return re.sub(r'[^\w\s]', '', text.lower()).strip()


ACRONYM_EXPANSIONS = {
'jp': ['j p', 'j.p.'],
'ms': ['microsoft'],
'ibm': ['i b m', 'i.b.m.'],
# Add more acronym expansions as needed
}


def expand_acronyms(text: str) -> list[str]:
words = text.split()
expansions = [text]

for i, word in enumerate(words):
if word in ACRONYM_EXPANSIONS:
for expansion in ACRONYM_EXPANSIONS[word]:
new_words = words.copy()
new_words[i] = expansion
expansions.append(' '.join(new_words))

return expansions


COMPANY_ALIASES = {
'apple': ['apple inc', 'apple incorporated'],
'jp morgan': ['jpmorgan', 'jp morgan chase', 'chase bank', 'j p morgan', 'j.p. morgan'],
'microsoft': ['microsoft corporation', 'ms'],
# Add more aliases as needed
}


def get_aliases(term: str) -> list[str]:
normalized_term = normalize_text(term)
return next(
(
[key] + aliases
for key, aliases in COMPANY_ALIASES.items()
if normalized_term in [normalize_text(a) for a in [key] + aliases]
),
[normalized_term],
)


class SearchField(BaseModel):
field: str
weight: float


def parse_search_fields(
search_fields_str: str = '[{"field":"retirement_beneficiary","weight":2.0},{"field":"retirement_account","weight":1.5},{"field":"retirement_note","weight":1.0},{"field":"retirement_reason","weight":1.0}]',
) -> list[SearchField]:
try:
search_fields = json.loads(search_fields_str)
return [SearchField(**field) for field in search_fields]
except json.JSONDecodeError:
raise ValueError('Invalid JSON format for search_fields')
except KeyError:
raise ValueError("Each search field must have 'field' and 'weight' keys")


# Main endpoint


@router.get('/', summary='List credits', response_model=PaginatedCredits)
@cache(namespace=CACHE_NAMESPACE)
Expand All @@ -34,16 +105,14 @@ async def get_credits(
),
search: str | None = Query(
None,
description='Search string. Use "r:" prefix for regex search, or leave blank for case-insensitive partial match.',
description='Search string. Use "r:" prefix for regex search, "t:" for trigram search, "w:" for weighted search, or leave blank for case-insensitive partial match.',
),
search_fields: str = Query(
default='[{"field":"retirement_beneficiary","weight":2.0},{"field":"retirement_account","weight":1.5},{"field":"retirement_note","weight":1.0},{"field":"retirement_reason","weight":1.0}]',
description='JSON string of fields to search in with their weights',
),
search_fields: list[str] = Query(
default=[
'retirement_beneficiary',
'retirement_account',
'retirement_note',
'retirement_reason',
],
description='Fields to search in',
similarity_threshold: float = Query(
0.7, ge=0.0, le=1.0, description='similarity threshold (0-1)'
),
sort: list[str] = Query(
default=['project_id'],
Expand Down Expand Up @@ -90,19 +159,61 @@ async def get_credits(
search_conditions = []
logger.info(f'Search string: {search}')
logger.info(f'Search fields: {search_fields}')

search_fields = parse_search_fields(search_fields)

if search.startswith('r:'):
# Regex search
pattern = search[2:] # Remove 'r:' prefix
logger.info(f'Regex search pattern: {pattern}')
for field in search_fields:
for field_info in search_fields:
field = field_info.field
if field in Credit.__table__.columns:
search_conditions.append(col(getattr(Credit, field)).op('~*')(pattern))
elif field in Project.__table__.columns:
search_conditions.append(col(getattr(Project, field)).op('~*')(pattern))
elif search.startswith('t:'):
# Trigram similarity search
search_term = search[2:] # Remove 't:' prefix
logger.info(f'Trigram search term: {search_term}')
for field_info in search_fields:
field = field_info.field
if field in Credit.__table__.columns:
search_conditions.append(
func.word_similarity(func.lower(getattr(Credit, field)), search_term)
> similarity_threshold
)
elif field in Project.__table__.columns:
search_conditions.append(
func.word_similarity(func.lower(getattr(Project, field)), search_term)
> similarity_threshold
)
elif search.startswith('w:'):
# Weighted search with alias and acronym expansion
search_term = search[2:] # Remove 'w:' prefix
logger.info(f'Weighted search term: {search_term}')
variations = expand_acronyms(search_term)
variations.extend(get_aliases(search_term))

for variation in variations:
for field_info in search_fields:
field = field_info.field
weight = field_info.weight
if field in Credit.__table__.columns:
search_conditions.append(
func.similarity(func.lower(getattr(Credit, field)), variation) * weight
> similarity_threshold
)
elif field in Project.__table__.columns:
search_conditions.append(
func.similarity(func.lower(getattr(Project, field)), variation) * weight
> similarity_threshold
)
else:
# Case-insensitive partial match (default behavior)
search_pattern = f'%{search}%'
for field in search_fields:
for field_info in search_fields:
field = field_info.field
if field in Credit.__table__.columns:
search_conditions.append(col(getattr(Credit, field)).ilike(search_pattern))
elif field in Project.__table__.columns:
Expand Down

0 comments on commit 1ca249a

Please sign in to comment.