Skip to content

Commit

Permalink
remove unnecessary function helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
andersy005 committed Sep 10, 2024
1 parent dd9d289 commit e3f948e
Showing 1 changed file with 19 additions and 139 deletions.
158 changes: 19 additions & 139 deletions offsets_db_api/routers/credits.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import datetime
import json
import re

from fastapi import APIRouter, Depends, Query, Request
from fastapi_cache.decorator import cache
from pydantic import BaseModel
from sqlmodel import Session, col, func, or_, select

from offsets_db_api.cache import CACHE_NAMESPACE
Expand All @@ -18,74 +15,6 @@
router = APIRouter()
logger = get_logger()

# Helper functions


def normalize_text(text: str) -> str:
return re.sub(r'[^\w\s]', '', text.lower()).strip()


ACRONYM_EXPANSIONS = {
'jp': ['j p', 'j.p.'],
'ms': ['microsoft'],
'ibm': ['i b m', 'i.b.m.'],
# Add more acronym expansions as needed
}


def expand_acronyms(text: str) -> list[str]:
words = text.split()
expansions = [text]

for i, word in enumerate(words):
if word in ACRONYM_EXPANSIONS:
for expansion in ACRONYM_EXPANSIONS[word]:
new_words = words.copy()
new_words[i] = expansion
expansions.append(' '.join(new_words))

return expansions


COMPANY_ALIASES = {
'apple': ['apple inc', 'apple incorporated'],
'jp morgan': ['jpmorgan', 'jp morgan chase', 'chase bank', 'j p morgan', 'j.p. morgan'],
'microsoft': ['microsoft corporation', 'ms'],
# Add more aliases as needed
}


def get_aliases(term: str) -> list[str]:
normalized_term = normalize_text(term)
return next(
(
[key] + aliases
for key, aliases in COMPANY_ALIASES.items()
if normalized_term in [normalize_text(a) for a in [key] + aliases]
),
[normalized_term],
)


class SearchField(BaseModel):
field: str
weight: float


def parse_search_fields(
search_fields_str: str = '[{"field":"retirement_beneficiary","weight":2.0},{"field":"retirement_account","weight":1.5},{"field":"retirement_note","weight":1.0},{"field":"retirement_reason","weight":1.0}]',
) -> list[SearchField]:
try:
search_fields = json.loads(search_fields_str)
return [SearchField(**field) for field in search_fields]
except json.JSONDecodeError:
raise ValueError('Invalid JSON format for search_fields')
except KeyError:
raise ValueError("Each search field must have 'field' and 'weight' keys")


# Main endpoint


@router.get('/', summary='List credits', response_model=PaginatedCredits)
@cache(namespace=CACHE_NAMESPACE)
Expand All @@ -107,12 +36,14 @@ async def get_credits(
None,
description='Search string. Use "r:" prefix for regex search, "t:" for trigram search, "w:" for weighted search, or leave blank for case-insensitive partial match.',
),
search_fields: str = Query(
default='[{"field":"retirement_beneficiary","weight":2.0},{"field":"retirement_account","weight":1.5},{"field":"retirement_note","weight":1.0},{"field":"retirement_reason","weight":1.0}]',
description='JSON string of fields to search in with their weights',
),
similarity_threshold: float = Query(
0.7, ge=0.0, le=1.0, description='similarity threshold (0-1)'
search_fields: list[str] = Query(
default=[
'retirement_beneficiary',
'retirement_account',
'retirement_note',
'retirement_reason',
],
description='Fields to search in',
),
sort: list[str] = Query(
default=['project_id'],
Expand Down Expand Up @@ -154,70 +85,19 @@ async def get_credits(
operation=operation,
)

# Handle advanced search
if search:
# Default to case-insensitive partial match
search_term = f'%{search}%'
search_conditions = []
logger.info(f'Search string: {search}')
logger.info(f'Search fields: {search_fields}')

search_fields = parse_search_fields(search_fields)

if search.startswith('r:'):
# Regex search
pattern = search[2:] # Remove 'r:' prefix
logger.info(f'Regex search pattern: {pattern}')
for field_info in search_fields:
field = field_info.field
if field in Credit.__table__.columns:
search_conditions.append(col(getattr(Credit, field)).op('~*')(pattern))
elif field in Project.__table__.columns:
search_conditions.append(col(getattr(Project, field)).op('~*')(pattern))
elif search.startswith('t:'):
# Trigram similarity search
search_term = search[2:] # Remove 't:' prefix
logger.info(f'Trigram search term: {search_term}')
for field_info in search_fields:
field = field_info.field
if field in Credit.__table__.columns:
search_conditions.append(
func.word_similarity(func.lower(getattr(Credit, field)), search_term)
> similarity_threshold
)
elif field in Project.__table__.columns:
search_conditions.append(
func.word_similarity(func.lower(getattr(Project, field)), search_term)
> similarity_threshold
)
elif search.startswith('w:'):
# Weighted search with alias and acronym expansion
search_term = search[2:] # Remove 'w:' prefix
logger.info(f'Weighted search term: {search_term}')
variations = expand_acronyms(search_term)
variations.extend(get_aliases(search_term))

for variation in variations:
for field_info in search_fields:
field = field_info.field
weight = field_info.weight
if field in Credit.__table__.columns:
search_conditions.append(
func.similarity(func.lower(getattr(Credit, field)), variation) * weight
> similarity_threshold
)
elif field in Project.__table__.columns:
search_conditions.append(
func.similarity(func.lower(getattr(Project, field)), variation) * weight
> similarity_threshold
)
else:
# Case-insensitive partial match (default behavior)
search_pattern = f'%{search}%'
for field_info in search_fields:
field = field_info.field
if field in Credit.__table__.columns:
search_conditions.append(col(getattr(Credit, field)).ilike(search_pattern))
elif field in Project.__table__.columns:
search_conditions.append(col(getattr(Project, field)).ilike(search_pattern))
for field in search_fields:
if field in Credit.__table__.columns:
search_conditions.append(
func.lower(getattr(Credit, field)).like(func.lower(search_term))
)
elif field in Project.__table__.columns:
search_conditions.append(
func.lower(getattr(Project, field)).like(func.lower(search_term))
)

if search_conditions:
statement = statement.where(or_(*search_conditions))
Expand Down

0 comments on commit e3f948e

Please sign in to comment.