diff --git a/offsets_db_api/cache.py b/offsets_db_api/cache.py index 661ce14..dee421e 100644 --- a/offsets_db_api/cache.py +++ b/offsets_db_api/cache.py @@ -4,8 +4,8 @@ from fastapi import Request, Response from fastapi_cache import FastAPICache -from .logging import get_logger -from .query_helpers import _convert_query_params_to_dict +from offsets_db_api.log import get_logger +from offsets_db_api.query_helpers import _convert_query_params_to_dict logger = get_logger() diff --git a/offsets_db_api/database.py b/offsets_db_api/database.py index f5c22f5..9e7b902 100644 --- a/offsets_db_api/database.py +++ b/offsets_db_api/database.py @@ -2,7 +2,7 @@ from sqlmodel import Session, create_engine -from .settings import get_settings +from offsets_db_api.settings import get_settings def get_engine(*, database_url: str): diff --git a/offsets_db_api/logging.py b/offsets_db_api/log.py similarity index 100% rename from offsets_db_api/logging.py rename to offsets_db_api/log.py diff --git a/offsets_db_api/main.py b/offsets_db_api/main.py index 617e1d1..0b9bcb3 100644 --- a/offsets_db_api/main.py +++ b/offsets_db_api/main.py @@ -10,10 +10,10 @@ from watchdog.events import FileSystemEventHandler from watchdog.observers import Observer -from .app_metadata import metadata -from .cache import clear_cache, request_key_builder, watch_dog_dir, watch_dog_file -from .logging import get_logger -from .routers import charts, clips, credits, files, health, projects +from offsets_db_api.app_metadata import metadata +from offsets_db_api.cache import clear_cache, request_key_builder, watch_dog_dir, watch_dog_file +from offsets_db_api.log import get_logger +from offsets_db_api.routers import charts, clips, credits, files, health, projects logger = get_logger() diff --git a/offsets_db_api/models.py b/offsets_db_api/models.py index a4689f0..b1c731f 100644 --- a/offsets_db_api/models.py +++ b/offsets_db_api/models.py @@ -5,7 +5,7 @@ from sqlalchemy.dialects import postgresql from sqlmodel import BigInteger, Column, Field, Relationship, SQLModel, String -from .schemas import FileCategory, FileStatus, Pagination +from offsets_db_api.schemas import FileCategory, FileStatus, Pagination class File(SQLModel, table=True): @@ -140,12 +140,12 @@ class CreditWithCategory(CreditBase): class PaginatedProjects(pydantic.BaseModel): pagination: Pagination - data: list[ProjectWithClips] + data: list[ProjectWithClips] | list[dict[str, typing.Any]] class PaginatedCredits(pydantic.BaseModel): pagination: Pagination - data: list[CreditWithCategory] + data: list[CreditWithCategory] | list[dict[str, typing.Any]] class BinnedValues(pydantic.BaseModel): @@ -207,3 +207,8 @@ class PaginatedBinnedCreditTotals(pydantic.BaseModel): class PaginatedClips(pydantic.BaseModel): pagination: Pagination data: list[ClipwithProjects] | list[dict[str, typing.Any]] + + +class PaginatedFiles(pydantic.BaseModel): + pagination: Pagination + data: list[File] | list[dict[str, typing.Any]] diff --git a/offsets_db_api/query_helpers.py b/offsets_db_api/query_helpers.py index 2441f4c..b2dd86b 100644 --- a/offsets_db_api/query_helpers.py +++ b/offsets_db_api/query_helpers.py @@ -6,8 +6,8 @@ from sqlalchemy.orm import Query from sqlmodel import and_, asc, desc, distinct, func, nullslast, or_, select -from .logging import get_logger -from .models import Clip, ClipProject, Credit, Project +from offsets_db_api.log import get_logger +from offsets_db_api.models import Clip, ClipProject, Credit, Project logger = get_logger() diff --git a/offsets_db_api/routers/charts.py b/offsets_db_api/routers/charts.py index ff95780..fce3086 100644 --- a/offsets_db_api/routers/charts.py +++ b/offsets_db_api/routers/charts.py @@ -5,12 +5,12 @@ import pandas as pd from fastapi import APIRouter, Depends, Query, Request from fastapi_cache.decorator import cache -from sqlmodel import Session, col, or_ +from sqlmodel import Date, Session, and_, case, cast, col, func, or_, select -from ..cache import CACHE_NAMESPACE -from ..database import get_engine, get_session -from ..logging import get_logger -from ..models import ( +from offsets_db_api.cache import CACHE_NAMESPACE +from offsets_db_api.database import get_session +from offsets_db_api.log import get_logger +from offsets_db_api.models import ( Credit, PaginatedBinnedCreditTotals, PaginatedBinnedValues, @@ -19,44 +19,14 @@ PaginatedProjectCreditTotals, Project, ) -from ..query_helpers import apply_filters -from ..schemas import Pagination, Registries -from ..security import check_api_key -from ..settings import get_settings +from offsets_db_api.schemas import Pagination, Registries +from offsets_db_api.security import check_api_key +from offsets_db_api.sql_helpers import apply_filters router = APIRouter() logger = get_logger() -def filter_valid_projects(df: pd.DataFrame, categories: list | None = None) -> pd.DataFrame: - if categories is None: - return df - # Filter the dataframe to include only rows with the specified categories - valid_projects = df[df['category'].isin(categories)] - return valid_projects - - -def projects_by_category( - *, df: pd.DataFrame, categories: list | None = None -) -> list[dict[str, int]]: - valid_projects = filter_valid_projects(df, categories) - counts = valid_projects.groupby('category').count()['project_id'] - return [{'category': category, 'value': count} for category, count in counts.items()] - - -def credits_by_category( - *, df: pd.DataFrame, categories: list | None = None -) -> list[dict[str, int]]: - valid_projects = filter_valid_projects(df, categories) - credits = ( - valid_projects.groupby('category').agg({'issued': 'sum', 'retired': 'sum'}).reset_index() - ) - return [ - {'category': row['category'], 'issued': row['issued'], 'retired': row['retired']} - for _, row in credits.iterrows() - ] - - def calculate_end_date(start_date, freq): """Calculate the end date based on the start date and frequency.""" @@ -164,179 +134,6 @@ def generate_dynamic_numeric_bins(*, min_value, max_value, bin_width=None): return numeric_bins -def projects_counts_by_listing_date( - *, - df: pd.DataFrame, - freq: typing.Literal['D', 'W', 'M', 'Y'] = 'Y', - categories: list[str] | None = None, -) -> list[dict[str, typing.Any]]: - """ - Generate project counts by listing date. - """ - logger.info('📊 Generating project counts by listing date...') - valid_df = filter_valid_projects(df, categories=categories) - min_value, max_value = valid_df['listed_at'].agg(['min', 'max']) - - if pd.isna(min_value) or pd.isna(max_value): - logger.info('✅ No data to bin!') - return [] - - date_bins = generate_date_bins(min_value=min_value, max_value=max_value, freq=freq) - valid_df['bin'] = pd.cut( - valid_df['listed_at'], bins=date_bins, labels=date_bins[:-1], right=False - ) - valid_df['bin'] = valid_df['bin'].astype(str) - - # Aggregate the data - grouped = valid_df.groupby(['bin', 'category'])['project_id'].count().reset_index() - - formatted_results = [] - for _, row in grouped.iterrows(): - bin_label = row['bin'] - category = row['category'] - value = row['project_id'] - - start_date = pd.Timestamp(bin_label).date() if bin_label else None - end_date = calculate_end_date(start_date, freq).date() if start_date else None - - formatted_results.append( - dict(start=start_date, end=end_date, category=category, value=value) - ) - - logger.info('✅ Binned data generated successfully!') - return formatted_results - - -def projects_by_credit_totals( - *, df: pd.DataFrame, credit_type: str, bin_width=None, categories: list[str] | None = None -) -> list[dict[str, typing.Any]]: - """Generate binned data based on the given attribute and frequency.""" - logger.info(f'📊 Generating binned data based on {credit_type}...') - valid_df = filter_valid_projects(df, categories=categories) - min_value, max_value = valid_df[credit_type].agg(['min', 'max']) - - if pd.isna(min_value) or pd.isna(max_value): - logger.info('✅ No data to bin!') - return [] - - bins = generate_dynamic_numeric_bins( - min_value=min_value, max_value=max_value, bin_width=bin_width - ).tolist() - valid_df['bin'] = pd.cut(valid_df[credit_type], bins=bins, labels=bins[:-1], right=False) - valid_df['bin'] = valid_df['bin'].astype(str) - grouped = valid_df.groupby(['bin', 'category'])['project_id'].count().reset_index() - formatted_results = [] - for _, row in grouped.iterrows(): - bin_label = row['bin'] - category = row['category'] - value = row['project_id'] - start_value = int(bin_label) - index = bins.index(start_value) - end_value = bins[index + 1] - formatted_results.append( - dict(start=start_value, end=end_value, category=category, value=value) - ) - logger.info(f'✅ {len(formatted_results)} bins generated') - - return formatted_results - - -def single_project_credits_by_transaction_date( - *, df: pd.DataFrame, freq: typing.Literal['D', 'W', 'M', 'Y'] | None = None -) -> list[dict[str, typing.Any]]: - min_date, max_date = df.transaction_date.agg(['min', 'max']) - if pd.isna(min_date) or pd.isna(max_date): - logger.info('✅ No data to bin!') - return [] - - if freq is None: - date_diff = max_date - min_date - if date_diff < datetime.timedelta(days=7): - freq = 'D' - elif date_diff < datetime.timedelta( - days=30 - ): # Approximating a month to 30 days for simplicity - freq = 'W' - elif date_diff < datetime.timedelta( - days=365 - ): # Approximating a year to 365 days for simplicity - freq = 'M' - else: - freq = 'Y' - - # Check if all events fall within the same year or month - if min_date.year == max_date.year: - freq = 'Y' - if min_date.month == max_date.month: - freq = 'M' - - date_bins = generate_date_bins(min_value=min_date, max_value=max_date, freq=freq) - - # Binning logic - df['bin'] = pd.cut(df['transaction_date'], bins=date_bins, labels=date_bins[:-1], right=False) - df['bin'] = df['bin'].astype(str) - grouped = df.groupby(['bin'])['quantity'].sum().reset_index() - formatted_results = [] - for _, row in grouped.iterrows(): - bin_label = row['bin'] - value = row['quantity'] - - start_date = pd.Timestamp(bin_label).date() if bin_label else None - end_date = calculate_end_date(start_date, freq).date() if start_date else None - - formatted_results.append(dict(start=start_date, end=end_date, value=value)) - return formatted_results - - -def credits_by_transaction_date( - *, - df: pd.DataFrame, - freq: typing.Literal['D', 'W', 'M', 'Y'] = 'Y', - num_bins: int | None = None, - categories: list[str] | None = None, -) -> list[dict[str, typing.Any]]: - """ - Get credits by transaction date. - """ - valid_df = filter_valid_projects(df, categories=categories) - - min_date, max_date = valid_df.transaction_date.agg(['min', 'max']) - - if pd.isna(min_date) or pd.isna(max_date): - logger.info('✅ No data to bin!') - return [] - if num_bins: - date_bins = generate_date_bins( - min_value=min_date, max_value=max_date, num_bins=num_bins - ) # Assuming this function returns a list of date ranges - else: - date_bins = generate_date_bins(min_value=min_date, max_value=max_date, freq=freq) - - # Binning logic - valid_df['bin'] = pd.cut( - valid_df['transaction_date'], bins=date_bins, labels=date_bins[:-1], right=False - ) - valid_df['bin'] = valid_df['bin'].astype(str) - - # Aggregate the data - grouped = valid_df.groupby(['bin', 'category'])['quantity'].sum().reset_index() - - # Formatting the results - formatted_results = [] - for _, row in grouped.iterrows(): - bin_label = row['bin'] - category = row['category'] - value = row['quantity'] - - start_date = pd.Timestamp(bin_label).date() if bin_label else None - end_date = calculate_end_date(start_date, freq).date() if start_date else None - - formatted_results.append( - dict(start=start_date, end=end_date, category=category, value=value) - ) - return formatted_results - - @router.get('/projects_by_listing_date', response_model=PaginatedBinnedValues) @cache(namespace=CACHE_NAMESPACE) async def get_projects_by_listing_date( @@ -364,18 +161,26 @@ async def get_projects_by_listing_date( current_page: int = Query(1, description='Page number', ge=1), per_page: int = Query(100, description='Items per page', le=200, ge=1), session: Session = Depends(get_session), - # authorized_user: bool = Depends(check_api_key), + authorized_user: bool = Depends(check_api_key), ): """Get aggregated project registration data""" + logger.info(f'Getting project registration data: {request.url}') - query = session.query(Project) + # Base query + subquery = select( + col(Project.listed_at), + col(Project.project_id), + func.unnest(Project.category).label('category'), + ).alias('subquery') + + query = select(subquery) + # Apply filters filters = [ ('registry', registry, 'ilike', Project), ('country', country, 'ilike', Project), ('protocol', protocol, 'ANY', Project), - ('category', category, 'ANY', Project), ('is_compliance', is_compliance, '==', Project), ('listed_at', listed_at_from, '>=', Project), ('listed_at', listed_at_to, '<=', Project), @@ -387,39 +192,96 @@ async def get_projects_by_listing_date( for attribute, values, operation, model in filters: query = apply_filters( - query=query, model=model, attribute=attribute, values=values, operation=operation + statement=query, model=model, attribute=attribute, values=values, operation=operation ) # Handle 'search' filter separately due to its unique logic if search: search_pattern = f'%{search}%' - query = query.filter( + query = query.where( or_( col(Project.project_id).ilike(search_pattern), col(Project.name).ilike(search_pattern), ) ) - settings = get_settings() - engine = get_engine(database_url=settings.database_url) - logger.info(f'Query statement: {query.statement}') + # Apply category filter + if category: + query = query.where(subquery.c.category.in_(category)) + + # Get min and max listing dates + min_max_query = select(func.min(subquery.c.listed_at), func.max(subquery.c.listed_at)) + min_date, max_date = session.exec(min_max_query.select_from(subquery)).fetchone() + + if min_date is None or max_date is None: + logger.info('✅ No data to bin!') + return PaginatedBinnedValues( + pagination=Pagination( + total_entries=0, + total_pages=1, + next_page=None, + current_page=current_page, + ), + data=[], + ) + + # Generate date bins using the original function + date_bins = generate_date_bins(min_value=min_date, max_value=max_date, freq=freq).tolist() + + # Create a CASE statement for binning + bin_case = case( + *[ + ( + and_(subquery.c.listed_at >= bin_start, subquery.c.listed_at < bin_end), + cast(bin_start, Date), + ) + for bin_start, bin_end in zip(date_bins[:-1], date_bins[1:]) + ], + else_=cast(date_bins[-1], Date), + ).label('bin') + + # Add binning to the query and aggregate + binned_query = ( + select( + bin_case, + subquery.c.category, + func.count(subquery.c.project_id.distinct()).label('value'), + ) + .select_from(subquery) + .group_by(bin_case, subquery.c.category) + ) + + # Execute the query + results = session.exec(binned_query).fetchall() - df = pd.read_sql_query(query.statement, engine).explode('category') - df = df.astype({'listed_at': 'datetime64[ns]'}) - logger.info(f'Sample of the dataframe with size: {df.shape}\n{df.head()}') - results = projects_counts_by_listing_date(df=df, freq=freq, categories=category) - total_entries = len(results) - total_pages = 1 - next_page = None + # Format the results + formatted_results = [] + current_year = datetime.datetime.now().year + for row in results: + start_date = row.bin + if start_date.year > current_year: + continue # Skip future dates + end_date = calculate_end_date(start_date, freq) + formatted_results.append( + { + 'start': start_date.strftime('%Y-%m-%d'), + 'end': end_date.strftime('%Y-%m-%d'), + 'category': row.category, + 'value': int(row.value), + } + ) + + # Sort the results + formatted_results.sort(key=lambda x: (x['start'], x['category'])) return PaginatedBinnedValues( pagination=Pagination( - total_entries=total_entries, - total_pages=total_pages, - next_page=next_page, + total_entries=len(formatted_results), + total_pages=1, + next_page=None, current_page=current_page, ), - data=results, + data=formatted_results, ) @@ -451,19 +313,28 @@ async def get_credits_by_transaction_date( authorized_user: bool = Depends(check_api_key), ): """Get aggregated credit transaction data""" + logger.info(f'Getting credit transaction data: {request.url}') - # join Credit with Project on project_id - query = session.query(Credit, Project.category).join( - Project, Credit.project_id == Project.project_id, isouter=True + # Base query + subquery = ( + select( + col(Credit.transaction_date), + col(Credit.quantity), + func.unnest(Project.category).label('category'), + ) + .join(Project, col(Credit.project_id) == col(Project.project_id)) + .alias('subquery') ) + query = select(subquery) + + # Apply filters filters = [ ('registry', registry, 'ilike', Project), ('country', country, 'ilike', Project), ('transaction_type', transaction_type, 'ilike', Credit), ('protocol', protocol, 'ANY', Project), - ('category', category, 'ANY', Project), ('is_compliance', is_compliance, '==', Project), ('vintage', vintage, '==', Credit), ('transaction_date', transaction_date_from, '>=', Credit), @@ -472,41 +343,96 @@ async def get_credits_by_transaction_date( for attribute, values, operation, model in filters: query = apply_filters( - query=query, model=model, attribute=attribute, values=values, operation=operation + statement=query, model=model, attribute=attribute, values=values, operation=operation ) # Handle 'search' filter separately due to its unique logic if search: search_pattern = f'%{search}%' - query = query.filter( + query = query.where( or_( col(Project.project_id).ilike(search_pattern), col(Project.name).ilike(search_pattern), ) ) - settings = get_settings() - engine = get_engine(database_url=settings.database_url) + # Apply category filter + if category: + query = query.where(subquery.c.category.in_(category)) + + # Get min and max transaction dates + min_max_query = select( + func.min(subquery.c.transaction_date), func.max(subquery.c.transaction_date) + ) + min_date, max_date = session.exec(min_max_query.select_from(subquery)).fetchone() + + if min_date is None or max_date is None: + logger.info('✅ No data to bin!') + return PaginatedBinnedValues( + pagination=Pagination( + total_entries=0, + total_pages=1, + next_page=None, + current_page=current_page, + ), + data=[], + ) + + # Generate date bins using the original function + date_bins = generate_date_bins(min_value=min_date, max_value=max_date, freq=freq).tolist() + + # Create a CASE statement for binning + bin_case = case( + *[ + ( + and_( + subquery.c.transaction_date >= bin_start, subquery.c.transaction_date < bin_end + ), + cast(bin_start, Date), + ) + for bin_start, bin_end in zip(date_bins[:-1], date_bins[1:]) + ], + else_=cast(date_bins[-1], Date), + ).label('bin') + + # Add binning to the query and aggregate + binned_query = ( + select(bin_case, subquery.c.category, func.sum(subquery.c.quantity).label('value')) + .select_from(subquery) + .group_by(bin_case, subquery.c.category) + ) + + # Execute the query + results = session.exec(binned_query).fetchall() - logger.info(f'Query statement: {query.statement}') + # Format the results + formatted_results = [] + current_year = datetime.datetime.now().year + for row in results: + start_date = row.bin + if start_date.year > current_year: + continue # Skip future dates + end_date = calculate_end_date(start_date, freq) + formatted_results.append( + { + 'start': start_date.strftime('%Y-%m-%d'), + 'end': end_date.strftime('%Y-%m-%d'), + 'category': row.category, + 'value': int(row.value), + } + ) - df = pd.read_sql_query(query.statement, engine).explode('category') - logger.info(f'Sample of the dataframe with size: {df.shape}\n{df.head()}') - # fix the data types - df = df.astype({'transaction_date': 'datetime64[ns]'}) - results = credits_by_transaction_date(df=df, freq=freq, categories=category) + # Sort the results + formatted_results.sort(key=lambda x: (x['start'], x['category'])) - total_entries = len(results) - total_pages = 1 - next_page = None return PaginatedBinnedValues( pagination=Pagination( - total_entries=total_entries, - total_pages=total_pages, - next_page=next_page, + total_entries=len(formatted_results), + total_pages=1, + next_page=None, current_page=current_page, ), - data=results, + data=formatted_results, ) @@ -532,14 +458,17 @@ async def get_credits_by_project_id( authorized_user: bool = Depends(check_api_key), ): """Get aggregated credit transaction data""" + logger.info(f'Getting credit transaction data: {request.url}') - # Join Credit with Project and filter by project_id - query = ( - session.query(Credit, Project.category, Project.listed_at) - .join(Project, Credit.project_id == Project.project_id) - .filter(Project.project_id == project_id) + + # Base query + statement = ( + select(col(Credit.transaction_date), col(Credit.quantity)) + .join(Project) + .where(Project.project_id == project_id) ) + # Apply filters filters = [ ('transaction_type', transaction_type, 'ilike', Credit), ('transaction_date', transaction_date_from, '>=', Credit), @@ -548,31 +477,84 @@ async def get_credits_by_project_id( ] for attribute, values, operation, model in filters: - query = apply_filters( - query=query, model=model, attribute=attribute, values=values, operation=operation + statement = apply_filters( + statement=statement, + model=model, + attribute=attribute, + values=values, + operation=operation, ) - settings = get_settings() - engine = get_engine(database_url=settings.database_url) + # Get min and max transaction dates + min_max_query = select(func.min(Credit.transaction_date), func.max(Credit.transaction_date)) + min_date, max_date = session.exec(min_max_query).fetchone() - logger.info(f'Query statement: {query.statement}') + if min_date is None or max_date is None: + logger.info('✅ No data to bin!') + return PaginatedProjectCreditTotals( + pagination=Pagination( + total_entries=0, + total_pages=1, + next_page=None, + current_page=current_page, + ), + data=[], + ) - df = pd.read_sql_query(query.statement, engine) - # fix the data types - df = df.astype({'transaction_date': 'datetime64[ns]'}) - results = single_project_credits_by_transaction_date(df=df, freq=freq) + # Determine frequency if not provided + if freq is None: + date_diff = max_date - min_date + if date_diff < datetime.timedelta(days=7): + freq = 'D' + elif date_diff < datetime.timedelta(days=30): + freq = 'W' + elif date_diff < datetime.timedelta(days=365): + freq = 'M' + else: + freq = 'Y' + + if min_date.year == max_date.year: + freq = 'M' if min_date.month == max_date.month else 'Y' + # Generate date bins + date_bins = generate_date_bins(min_value=min_date, max_value=max_date, freq=freq).tolist() + + # Create a CASE statement for binning + bin_case = case( + *[ + ( + and_(Credit.transaction_date >= bin_start, Credit.transaction_date < bin_end), + cast(bin_start, Date), + ) + for bin_start, bin_end in zip(date_bins[:-1], date_bins[1:]) + ], + else_=cast(date_bins[-1], Date), + ).label('bin') + + # Add binning to the query and aggregate + binned_query = ( + statement.add_columns(bin_case) + .group_by('bin') + .with_only_columns(bin_case, func.sum(Credit.quantity).label('value')) + ) + + # Execute the query + results = session.exec(binned_query).fetchall() + + # Format the results + formatted_results = [] + for row in results: + start_date = row.bin + end_date = calculate_end_date(start_date, freq) + formatted_results.append({'start': start_date, 'end': end_date, 'value': row.value}) - total_entries = len(results) - total_pages = 1 - next_page = None return PaginatedProjectCreditTotals( pagination=Pagination( - total_entries=total_entries, - total_pages=total_pages, - next_page=next_page, + total_entries=len(formatted_results), + total_pages=1, + next_page=None, current_page=current_page, ), - data=results, + data=formatted_results, ) @@ -613,15 +595,15 @@ async def get_projects_by_credit_totals( authorized_user: bool = Depends(check_api_key), ): """Get aggregated project credit totals""" + logger.info(f'📊 Generating projects by {credit_type} totals...: {request.url}') - query = session.query(Project) + statement = select(Project) filters = [ ('registry', registry, 'ilike', Project), ('country', country, 'ilike', Project), ('protocol', protocol, 'ANY', Project), - ('category', category, 'ANY', Project), ('is_compliance', is_compliance, '==', Project), ('listed_at', listed_at_from, '>=', Project), ('listed_at', listed_at_to, '<=', Project), @@ -634,39 +616,104 @@ async def get_projects_by_credit_totals( ] for attribute, values, operation, model in filters: - query = apply_filters( - query=query, model=model, attribute=attribute, values=values, operation=operation + statement = apply_filters( + statement=statement, + model=model, + attribute=attribute, + values=values, + operation=operation, ) # Handle 'search' filter separately due to its unique logic if search: search_pattern = f'%{search}%' - query = query.filter( + statement = statement.where( or_( col(Project.project_id).ilike(search_pattern), col(Project.name).ilike(search_pattern), ) ) - settings = get_settings() - engine = get_engine(database_url=settings.database_url) - logger.info(f'Query statement: {query.statement}') + # Explode the category column + subquery = statement.subquery() + exploded = ( + select( + func.unnest(subquery.c.category).label('category'), + getattr(subquery.c, credit_type).label('credit_value'), + subquery.c.project_id, + ) + .select_from(subquery) + .subquery() + ) + + # Apply category filter after unnesting + if category: + exploded = select(exploded).where(exploded.c.category.in_(category)).subquery() + + # Get min and max values for binning + min_max_query = select( + func.min(exploded.c.credit_value).label('min_value'), + func.max(exploded.c.credit_value).label('max_value'), + ) + min_max_result = session.exec(min_max_query).fetchone() + min_value, max_value = min_max_result.min_value, min_max_result.max_value + + if min_value is None or max_value is None: + logger.info('✅ No data to bin!') + return PaginatedBinnedCreditTotals( + pagination=Pagination( + total_entries=0, + total_pages=1, + next_page=None, + current_page=current_page, + ), + data=[], + ) + + # Generate bins + bins = generate_dynamic_numeric_bins( + min_value=min_value, max_value=max_value, bin_width=bin_width + ).tolist() + # Create a CASE statement for binning + bin_case = case( + *[ + ( + and_(exploded.c.credit_value >= bin_start, exploded.c.credit_value < bin_end), + bin_start, + ) + for bin_start, bin_end in zip(bins[:-1], bins[1:]) + ], + else_=bins[-2], # Use the last bin start for values >= the last bin start + ).label('bin') + + # Count projects by bin and category + binned_query = select( + bin_case, exploded.c.category, func.count(exploded.c.project_id.distinct()).label('value') + ).group_by(bin_case, exploded.c.category) - df = pd.read_sql_query(query.statement, engine).explode('category') - logger.info(f'Sample of the dataframe with size: {df.shape}\n{df.head()}') - results = projects_by_credit_totals(df=df, credit_type=credit_type, bin_width=bin_width) + # Execute the query + results = session.exec(binned_query).fetchall() + + # Format the results + formatted_results = [] + for row in results: + bin_start = row.bin + bin_index = bins.index(bin_start) + bin_end = bins[bin_index + 1] if bin_index < len(bins) - 1 else None + formatted_results.append( + {'start': bin_start, 'end': bin_end, 'category': row.category, 'value': row.value} + ) + + logger.info(f'✅ {len(formatted_results)} bins generated') - total_entries = len(results) - total_pages = 1 - next_page = None return PaginatedBinnedCreditTotals( pagination=Pagination( - total_entries=total_entries, - total_pages=total_pages, - next_page=next_page, + total_entries=len(formatted_results), + total_pages=1, + next_page=None, current_page=current_page, ), - data=results, + data=formatted_results, ) @@ -699,15 +746,15 @@ async def get_projects_by_category( authorized_user: bool = Depends(check_api_key), ): """Get project counts by category""" + logger.info(f'Getting project count by category: {request.url}') - query = session.query(Project) + statement = select(Project) filters = [ ('registry', registry, 'ilike', Project), ('country', country, 'ilike', Project), ('protocol', protocol, 'ANY', Project), - ('category', category, 'ANY', Project), ('is_compliance', is_compliance, '==', Project), ('listed_at', listed_at_from, '>=', Project), ('listed_at', listed_at_to, '<=', Project), @@ -718,29 +765,46 @@ async def get_projects_by_category( ] for attribute, values, operation, model in filters: - query = apply_filters( - query=query, model=model, attribute=attribute, values=values, operation=operation + statement = apply_filters( + statement=statement, + model=model, + attribute=attribute, + values=values, + operation=operation, ) # Handle 'search' filter separately due to its unique logic if search: search_pattern = f'%{search}%' - query = query.filter( + statement = statement.where( or_( col(Project.project_id).ilike(search_pattern), col(Project.name).ilike(search_pattern), ) ) - settings = get_settings() - engine = get_engine(database_url=settings.database_url) + subquery = statement.subquery() + exploded = ( + select(func.unnest(subquery.c.category).label('category'), subquery.c.project_id) + .select_from(subquery) + .subquery() + ) + + # apply category filter after unnesting + if category: + exploded = select(exploded).where(exploded.c.category.in_(category)).subquery() - df = pd.read_sql_query(query.statement, engine).explode('category') - logger.info(f'Sample of the dataframe with size: {df.shape}\n{df.head()}') - results = projects_by_category(df=df, categories=category) + # count projects by category + projects_count_query = select( + exploded.c.category, func.count(exploded.c.project_id.distinct()).label('value') + ).group_by(exploded.c.category) + + results = session.exec(projects_count_query).fetchall() + + formatted_results = [{'category': row.category, 'value': row.value} for row in results] return PaginatedProjectCounts( - data=results, + data=formatted_results, pagination=Pagination( current_page=current_page, next_page=None, total_entries=len(results), total_pages=1 ), @@ -776,15 +840,17 @@ async def get_credits_by_category( authorized_user: bool = Depends(check_api_key), ): """Get project counts by category""" + logger.info(f'Getting project count by category: {request.url}') - query = session.query(Project) + # Start with a base query + query = select(Project) + # Apply filters filters = [ ('registry', registry, 'ilike', Project), ('country', country, 'ilike', Project), ('protocol', protocol, 'ANY', Project), - ('category', category, 'ANY', Project), ('is_compliance', is_compliance, '==', Project), ('listed_at', listed_at_from, '>=', Project), ('listed_at', listed_at_to, '<=', Project), @@ -796,30 +862,57 @@ async def get_credits_by_category( for attribute, values, operation, model in filters: query = apply_filters( - query=query, model=model, attribute=attribute, values=values, operation=operation + statement=query, model=model, attribute=attribute, values=values, operation=operation ) # Handle 'search' filter separately due to its unique logic if search: search_pattern = f'%{search}%' - query = query.filter( + query = query.where( or_( col(Project.project_id).ilike(search_pattern), col(Project.name).ilike(search_pattern), ) ) - settings = get_settings() - engine = get_engine(database_url=settings.database_url) + # Explode the category column + subquery = query.subquery() + exploded = ( + select( + func.unnest(subquery.c.category).label('category'), + subquery.c.issued, + subquery.c.retired, + ) + .select_from(subquery) + .subquery() + ) + + # Apply category filter after unnesting + if category: + exploded = select(exploded).where(exploded.c.category.in_(category)).subquery() + + # Group by category and sum issued and retired credits + credits_query = select( + exploded.c.category, + func.sum(exploded.c.issued).label('issued'), + func.sum(exploded.c.retired).label('retired'), + ).group_by(exploded.c.category) - df = pd.read_sql_query(query.statement, engine).explode('category') - logger.info(f'Sample of the dataframe with size: {df.shape}\n{df.head()}') + # Execute the query + results = session.exec(credits_query).fetchall() - results = credits_by_category(df=df, categories=category) + # Format the results + formatted_results = [ + {'category': row.category, 'issued': int(row.issued), 'retired': int(row.retired)} + for row in results + ] return PaginatedCreditCounts( - data=results, + data=formatted_results, pagination=Pagination( - current_page=current_page, next_page=None, total_entries=len(results), total_pages=1 + current_page=current_page, + next_page=None, + total_entries=len(formatted_results), + total_pages=1, ), ) diff --git a/offsets_db_api/routers/clips.py b/offsets_db_api/routers/clips.py index 450f5dd..3635b46 100644 --- a/offsets_db_api/routers/clips.py +++ b/offsets_db_api/routers/clips.py @@ -5,12 +5,12 @@ from sqlalchemy.orm import aliased from sqlmodel import Session, col, func, or_, select -from ..cache import CACHE_NAMESPACE -from ..database import get_session -from ..logging import get_logger -from ..models import Clip, ClipProject, PaginatedClips, Pagination, Project -from ..query_helpers import apply_filters, apply_sorting, handle_pagination -from ..security import check_api_key +from offsets_db_api.cache import CACHE_NAMESPACE +from offsets_db_api.database import get_session +from offsets_db_api.log import get_logger +from offsets_db_api.models import Clip, ClipProject, PaginatedClips, Pagination, Project +from offsets_db_api.security import check_api_key +from offsets_db_api.sql_helpers import apply_filters, apply_sorting, handle_pagination router = APIRouter() logger = get_logger() @@ -74,7 +74,7 @@ async def get_clips( project_data_subquery_alias = aliased(project_data_subquery, name='project_data') # construct the main query - query = select( + statement = select( Clip.date, Clip.title, Clip.url, @@ -88,15 +88,19 @@ async def get_clips( ).join(project_data_subquery_alias, col(Clip.id) == col(project_data_subquery_alias.c.clip_id)) for attribute, values, operation, model in filters: - query = apply_filters( - query=query, model=model, attribute=attribute, values=values, operation=operation + statement = apply_filters( + statement=statement, + model=model, + attribute=attribute, + values=values, + operation=operation, ) # Handle 'search' filter separately due to its unique logic if search: search_pattern = f'%{search}%' clip_project_alias = aliased(ClipProject) - query = query.join( + statement = statement.join( clip_project_alias, Clip.id == clip_project_alias.clip_id, ).filter( @@ -107,10 +111,10 @@ async def get_clips( ) if sort: - query = apply_sorting(query=query, sort=sort, model=Clip, primary_key='id') + statement = apply_sorting(statement=statement, sort=sort, model=Clip, primary_key='id') total_entries, current_page, total_pages, next_page, query_results = handle_pagination( - query=query, + statement=statement, primary_key=Clip.id, current_page=current_page, per_page=per_page, diff --git a/offsets_db_api/routers/credits.py b/offsets_db_api/routers/credits.py index 51a8c28..b7cf5d9 100644 --- a/offsets_db_api/routers/credits.py +++ b/offsets_db_api/routers/credits.py @@ -2,15 +2,15 @@ from fastapi import APIRouter, Depends, Query, Request from fastapi_cache.decorator import cache -from sqlmodel import Session, or_ +from sqlmodel import Session, col, or_, select -from ..cache import CACHE_NAMESPACE -from ..database import get_session -from ..logging import get_logger -from ..models import Credit, PaginatedCredits, Project -from ..query_helpers import apply_filters, apply_sorting, handle_pagination -from ..schemas import Pagination, Registries -from ..security import check_api_key +from offsets_db_api.cache import CACHE_NAMESPACE +from offsets_db_api.database import get_session +from offsets_db_api.log import get_logger +from offsets_db_api.models import Credit, PaginatedCredits, Project +from offsets_db_api.schemas import Pagination, Registries +from offsets_db_api.security import check_api_key +from offsets_db_api.sql_helpers import apply_filters, apply_sorting, handle_pagination router = APIRouter() logger = get_logger() @@ -49,8 +49,8 @@ async def get_credits( logger.info(f'Getting credits: {request.url}') # Outer join to get all credits, even if they don't have a project - query = session.query(Credit, Project.category).join( - Project, Credit.project_id == Project.project_id, isouter=True + statement = select(Credit, Project.category).join( + Project, col(Credit.project_id) == col(Project.project_id), isouter=True ) filters = [ @@ -65,30 +65,37 @@ async def get_credits( # Filter for project_id if project_id: - # insert at the beginning of the list to ensure that it is applied first filters.insert(0, ('project_id', project_id, '==', Project)) for attribute, values, operation, model in filters: - query = apply_filters( - query=query, model=model, attribute=attribute, values=values, operation=operation + statement = apply_filters( + statement=statement, + model=model, + attribute=attribute, + values=values, + operation=operation, ) # Handle 'search' filter separately due to its unique logic if search: search_pattern = f'%{search}%' - query = query.filter( - or_(Project.project_id.ilike(search_pattern), Project.name.ilike(search_pattern)) + statement = statement.where( + or_( + col(Project.project_id).ilike(search_pattern), + col(Project.name).ilike(search_pattern), + ) ) if sort: - query = apply_sorting(query=query, sort=sort, model=Credit, primary_key='id') + statement = apply_sorting(statement=statement, sort=sort, model=Credit, primary_key='id') total_entries, current_page, total_pages, next_page, results = handle_pagination( - query=query, + statement=statement, primary_key=Credit.id, current_page=current_page, per_page=per_page, request=request, + session=session, ) credits_with_category = [ diff --git a/offsets_db_api/routers/files.py b/offsets_db_api/routers/files.py index d223b54..b6ea613 100644 --- a/offsets_db_api/routers/files.py +++ b/offsets_db_api/routers/files.py @@ -1,17 +1,18 @@ import datetime -from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, Request, status from fastapi_cache.decorator import cache -from sqlmodel import Session - -from ..cache import CACHE_NAMESPACE -from ..database import get_engine, get_session -from ..logging import get_logger -from ..models import File, FileCategory, FileStatus -from ..schemas import FileURLPayload -from ..security import check_api_key -from ..settings import get_settings -from ..tasks import process_files +from sqlmodel import Session, select + +from offsets_db_api.cache import CACHE_NAMESPACE +from offsets_db_api.database import get_engine, get_session +from offsets_db_api.log import get_logger +from offsets_db_api.models import File, FileCategory, FileStatus, PaginatedFiles, Pagination +from offsets_db_api.schemas import FileURLPayload +from offsets_db_api.security import check_api_key +from offsets_db_api.settings import get_settings +from offsets_db_api.sql_helpers import apply_filters, apply_sorting, handle_pagination +from offsets_db_api.tasks import process_files router = APIRouter() logger = get_logger() @@ -61,54 +62,71 @@ async def get_file( """Get a file by id""" logger.info('Getting file %s', file_id) - if file_obj := session.query(File).get(file_id): + statement = select(File).where(File.id == file_id) + if file_obj := session.exec(statement).one_or_none(): return file_obj else: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f'file {file_id} not found', + detail=f'File {file_id} not found', ) -@router.get('/', response_model=list[File], summary='List files') +@router.get('/', response_model=PaginatedFiles, summary='List files') @cache(namespace=CACHE_NAMESPACE) async def get_files( + request: Request, category: FileCategory | None = None, status: FileStatus | None = None, recorded_at_from: datetime.datetime | None = None, recorded_at_to: datetime.datetime | None = None, - limit: int = 100, - offset: int = 0, + sort: list[str] = Query( + default=['recorded_at'], + description='List of sorting parameters in the format `field_name` or `+field_name` for ascending order or `-field_name` for descending order.', + ), + current_page: int = Query(1, description='Page number', ge=1), + per_page: int = Query(100, description='Items per page', le=200, ge=1), session: Session = Depends(get_session), authorized_user: bool = Depends(check_api_key), ): """Get files""" - logger.info( - 'Getting files with filter: category=%s, status=%s, recorded_at_from=%s, recorded_at_to=%s, limit=%d, offset=%d', - category, - status, - recorded_at_from, - recorded_at_to, - limit, - offset, - ) - - query = session.query(File) - - if category: - query = query.filter_by(category=category) - - if status: - query = query.filter_by(status=status) - - if recorded_at_from: - query = query.filter(File.recorded_at >= recorded_at_from) - - if recorded_at_to: - query = query.filter(File.recorded_at <= recorded_at_to) + logger.info(f'Getting files with filter: {request.url}') + + filters = [ + ('category', category, '==', File), + ('status', status, '==', File), + ('recorded_at', recorded_at_from, '>=', File), + ('recorded_at', recorded_at_to, '<=', File), + ] + + statement = select(File) + for attribute, values, operation, model in filters: + statement = apply_filters( + statement=statement, + model=model, + attribute=attribute, + values=values, + operation=operation, + ) - files = query.limit(limit).offset(offset).all() + if sort: + statement = apply_sorting(statement=statement, sort=sort, model=File, primary_key='id') - logger.info('Found %d files', len(files)) + total_entries, current_page, total_pages, next_page, results = handle_pagination( + statement=statement, + primary_key=File.id, + current_page=current_page, + per_page=per_page, + request=request, + session=session, + ) - return files + return PaginatedFiles( + pagination=Pagination( + total_entries=total_entries, + current_page=current_page, + total_pages=total_pages, + next_page=next_page, + ), + data=results, + ) diff --git a/offsets_db_api/routers/health.py b/offsets_db_api/routers/health.py index f3b7db0..070c1e8 100644 --- a/offsets_db_api/routers/health.py +++ b/offsets_db_api/routers/health.py @@ -5,12 +5,12 @@ from fastapi_cache.decorator import cache from sqlmodel import Session, col, select -from ..cache import CACHE_NAMESPACE -from ..database import get_session -from ..logging import get_logger -from ..models import File, FileCategory, FileStatus -from ..security import check_api_key -from ..settings import Settings, get_settings +from offsets_db_api.cache import CACHE_NAMESPACE +from offsets_db_api.database import get_session +from offsets_db_api.log import get_logger +from offsets_db_api.models import File, FileCategory, FileStatus +from offsets_db_api.security import check_api_key +from offsets_db_api.settings import Settings, get_settings router = APIRouter() logger = get_logger() @@ -51,13 +51,13 @@ async def db_status( {'date': recorded_at.strftime('%a, %b %d %Y %H:%M:%S UTC'), 'url': url} ) - db_latest_update = {} - for category, entries in grouped_files.items(): - db_latest_update[category] = { + db_latest_update = { + category: { 'date': entries[0]['date'], 'url': entries[0]['url'], } - + for category, entries in grouped_files.items() + } return { 'status': 'ok', 'staging': settings.staging, diff --git a/offsets_db_api/routers/projects.py b/offsets_db_api/routers/projects.py index d1b6292..f259b13 100644 --- a/offsets_db_api/routers/projects.py +++ b/offsets_db_api/routers/projects.py @@ -4,22 +4,23 @@ from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from fastapi_cache.decorator import cache from sqlalchemy import or_ -from sqlalchemy.orm import contains_eager -from sqlmodel import Session +from sqlmodel import Session, col, select -from ..cache import CACHE_NAMESPACE -from ..database import get_session -from ..logging import get_logger -from ..models import Clip, ClipProject, PaginatedProjects, Project, ProjectWithClips -from ..query_helpers import apply_filters, apply_sorting, handle_pagination -from ..schemas import Pagination, Registries -from ..security import check_api_key +from offsets_db_api.cache import CACHE_NAMESPACE +from offsets_db_api.database import get_session +from offsets_db_api.log import get_logger +from offsets_db_api.models import Clip, ClipProject, PaginatedProjects, Project, ProjectWithClips +from offsets_db_api.schemas import Pagination, Registries +from offsets_db_api.security import check_api_key +from offsets_db_api.sql_helpers import apply_filters, apply_sorting, handle_pagination router = APIRouter() logger = get_logger() -@router.get('/', response_model=PaginatedProjects) +@router.get( + '/', response_model=PaginatedProjects, summary='Get projects with pagination and filtering' +) @cache(namespace=CACHE_NAMESPACE) async def get_projects( request: Request, @@ -55,13 +56,6 @@ async def get_projects( logger.info(f'Getting projects: {request.url}') - query = ( - session.query(Project, Clip) - .join(Project.clip_relationships, isouter=True) - .join(ClipProject.clip, isouter=True) - .options(contains_eager(Project.clip_relationships).contains_eager(ClipProject.clip)) - ) - filters = [ ('registry', registry, 'ilike', Project), ('country', country, 'ilike', Project), @@ -76,46 +70,63 @@ async def get_projects( ('retired', retired_max, '<=', Project), ] + statement = select(Project) + for attribute, values, operation, model in filters: - query = apply_filters( - query=query, model=model, attribute=attribute, values=values, operation=operation + statement = apply_filters( + statement=statement, + model=model, + attribute=attribute, + values=values, + operation=operation, ) - # Handle 'search' filter separately due to its unique logic if search: search_pattern = f'%{search}%' - query = query.filter( - or_(Project.project_id.ilike(search_pattern), Project.name.ilike(search_pattern)) + statement = statement.where( + or_( + col(Project.project_id).ilike(search_pattern), + col(Project.name).ilike(search_pattern), + ) ) if sort: - query = apply_sorting(query=query, sort=sort, model=Project, primary_key='project_id') + statement = apply_sorting( + statement=statement, sort=sort, model=Project, primary_key='project_id' + ) total_entries, current_page, total_pages, next_page, results = handle_pagination( - query=query, + statement=statement, primary_key=Project.project_id, current_page=current_page, per_page=per_page, request=request, + session=session, ) - # Execute the query - project_clip_pairs = results + # Get the list of project IDs from the results + project_ids = [project.project_id for project in results] + + # Subquery to get clips related to the project IDs + clip_statement = ( + select(ClipProject.project_id, Clip) + .where(col(ClipProject.project_id).in_(project_ids)) + .join(Clip, col(Clip.id) == col(ClipProject.clip_id)) + ) + clip_results = session.exec(clip_statement).all() - # Group clips by project using a dictionary and project_id as the key + # Group clips by project_id project_to_clips = defaultdict(list) - projects = {} - for project, clip in project_clip_pairs: - if project.project_id not in projects: - projects[project.project_id] = project - project_to_clips[project.project_id].append(clip) + for project_id, clip in clip_results: + project_to_clips[project_id].append(clip) # Transform the dictionary into a list of projects with clips projects_with_clips = [] - for project_id, clips in project_to_clips.items(): - project = projects[project_id] + for project in results: project_data = project.model_dump() - project_data['clips'] = [clip.model_dump() for clip in clips if clip is not None] + project_data['clips'] = [ + clip.model_dump() for clip in project_to_clips.get(project.project_id, []) + ] projects_with_clips.append(project_data) return PaginatedProjects( @@ -144,26 +155,24 @@ async def get_project( """Get a project by registry and project_id""" logger.info(f'Getting project: {request.url}') - # Start the query to get the project and related clips - project_with_clips = ( - session.query(Project) - .join(Project.clip_relationships, isouter=True) - .join(ClipProject.clip, isouter=True) - .options(contains_eager(Project.clip_relationships).contains_eager(ClipProject.clip)) - .filter(Project.project_id == project_id) - .one_or_none() - ) + # main query to get the project details + statement = select(Project).where(Project.project_id == project_id) + project = session.exec(statement).one_or_none() - if project_with_clips: - # Extract the Project and related Clips from the query result - project_data = project_with_clips.model_dump() - project_data['clips'] = [ - clip_project.clip.model_dump() - for clip_project in project_with_clips.clip_relationships - if clip_project.clip - ] - return project_data - else: + if not project: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f'project {project_id} not found' ) + + # Subquery to get the related clips + clip_statement = ( + select(Clip) + .join(ClipProject, col(Clip.id) == col(ClipProject.clip_id)) + .where(ClipProject.project_id == project_id) + ) + clip_projects_subquery = session.exec(clip_statement).all() + + # Construct the response data + project_data = project.model_dump() + project_data['clips'] = [clip.model_dump() for clip in clip_projects_subquery] + return project_data diff --git a/offsets_db_api/security.py b/offsets_db_api/security.py index f375d79..81577fe 100644 --- a/offsets_db_api/security.py +++ b/offsets_db_api/security.py @@ -4,7 +4,7 @@ from fastapi import Depends, HTTPException, Security, status from fastapi.security import APIKeyHeader -from .settings import Settings, get_settings +from offsets_db_api.settings import Settings, get_settings api_key_header = APIKeyHeader(name='X-API-KEY', auto_error=False) diff --git a/offsets_db_api/sql_helpers.py b/offsets_db_api/sql_helpers.py new file mode 100644 index 0000000..ef04d60 --- /dev/null +++ b/offsets_db_api/sql_helpers.py @@ -0,0 +1,193 @@ +import datetime +import typing + +from fastapi import HTTPException, Request +from sqlalchemy.dialects.postgresql import ARRAY +from sqlmodel import Session, SQLModel, and_, asc, desc, distinct, func, nullslast, or_, select +from sqlmodel.sql.expression import Select as _Select, SelectOfScalar + +from offsets_db_api.models import Clip, ClipProject, Credit, File, Project +from offsets_db_api.query_helpers import _generate_next_page_url +from offsets_db_api.schemas import Registries + + +def apply_sorting( + *, + statement: _Select[typing.Any] | SelectOfScalar[typing.Any], + sort: list[str], + model: type[Credit | Project | Clip | ClipProject | File | SQLModel], + primary_key: str, +) -> _Select[typing.Any] | SelectOfScalar[typing.Any]: + # Define valid column names + columns = [c.name for c in model.__table__.columns] + + # Ensure that the primary key field is always included in the sort parameters list to ensure consistent pagination + if primary_key not in sort and f'-{primary_key}' not in sort and f'+{primary_key}' not in sort: + sort.append(primary_key) + + for sort_param in sort: + sort_param = sort_param.strip() + # Check if sort_param starts with '-' for descending order + if sort_param.startswith('-'): + order = desc + field = sort_param[1:] # Remove the '-' from sort_param + + elif sort_param.startswith('+'): + order = asc + field = sort_param[1:] # Remove the '+' from sort_param + else: + order = asc + field = sort_param + + # Check if field is a valid column name + if field not in columns: + raise HTTPException( + status_code=400, + detail=f'Invalid sort field: {field}. Must be one of {columns}', + ) + + # Apply sorting to the statement + statement = statement.order_by(nullslast(order(getattr(model, field)))) + + return statement + + +def apply_filters( + *, + statement: _Select[typing.Any] | SelectOfScalar[typing.Any], + model: type[Credit | Project | Clip | ClipProject | File | SQLModel], + attribute: str, + values: list[str] | None | int | datetime.date | list[Registries] | typing.Any, + operation: str, +) -> _Select[typing.Any] | SelectOfScalar[typing.Any]: + """ + Apply filters to the statement based on operation type. + Supports 'ilike', '==', '>=', and '<=' operations. + + Parameters + ---------- + statement: Select + SQLAlchemy Select statement + model: Credit | Project | Clip | ClipProject + SQLAlchemy model class + attribute: str + model attribute to apply filter on + values: list + list of values to filter with + operation: str + operation type to apply to the filter ('ilike', '==', '>=', '<=') + + + Returns + ------- + statement: Select + updated SQLAlchemy Select statement + """ + + if values is not None: + attr_type = getattr(model, attribute).type + is_array = isinstance(attr_type, ARRAY) + is_list = isinstance(values, list | tuple | set) + + if is_array and is_list: + if operation == 'ALL': + statement = statement.where( + and_(*[getattr(model, attribute).op('@>')(f'{{{v}}}') for v in values]) + ) + else: + statement = statement.where( + or_(*[getattr(model, attribute).op('@>')(f'{{{v}}}') for v in values]) + ) + + if operation == 'ilike': + statement = ( + statement.where(or_(*[getattr(model, attribute).ilike(v) for v in values])) + if is_list + else statement.where(getattr(model, attribute).ilike(values)) + ) + elif operation == '==': + statement = ( + statement.where(or_(*[getattr(model, attribute) == v for v in values])) + if is_list + else statement.where(getattr(model, attribute) == values) + ) + elif operation == '>=': + statement = ( + statement.where(or_(*[getattr(model, attribute) >= v for v in values])) + if is_list + else statement.where(getattr(model, attribute) >= values) + ) + elif operation == '<=': + statement = ( + statement.where(or_(*[getattr(model, attribute) <= v for v in values])) + if is_list + else statement.where(getattr(model, attribute) <= values) + ) + + return statement + + +def handle_pagination( + *, + statement: _Select[typing.Any] | SelectOfScalar[typing.Any], + primary_key: typing.Any, + current_page: int, + per_page: int, + request: Request, + session: Session, +) -> tuple[ + int, + int, + int, + str | None, + typing.Iterable[Project | Clip | ClipProject | Credit], +]: + """ + Calculate total records, pages and next page URL for a given query. + + Parameters + ---------- + statement: Select + SQLAlchemy Select statement + primary_key + Primary key field for distinct count + current_page: int + Current page number + per_page: int + Number of records per page + request: Request + FastAPI request instance + session: Session + SQLAlchemy session instance + + Returns + ------- + total_entries: int + Total records in query + total_pages: int + Total pages in query + next_page: Optional[str] + URL of next page + results: List[SQLModel] + Results for the current page + """ + + pk_column = primary_key if isinstance(primary_key, str) else primary_key.key + count_query = select(func.count(distinct(getattr(statement.columns, pk_column)))) + total_entries = session.exec(count_query).one() + + total_pages = (total_entries + per_page - 1) // per_page # ceil(total / per_page) + + # Calculate the next page URL + next_page = None + + if current_page < total_pages: + next_page = _generate_next_page_url( + request=request, current_page=current_page, per_page=per_page + ) + + # Get the results for the current page + paginated_statement = statement.offset((current_page - 1) * per_page).limit(per_page) + results = session.exec(paginated_statement).all() + + return total_entries, current_page, total_pages, next_page, results diff --git a/offsets_db_api/tasks.py b/offsets_db_api/tasks.py index 37e41e4..061c621 100644 --- a/offsets_db_api/tasks.py +++ b/offsets_db_api/tasks.py @@ -5,9 +5,9 @@ from offsets_db_data.models import clip_schema, credit_schema, project_schema from sqlmodel import ARRAY, BigInteger, Boolean, Date, DateTime, String, text -from .cache import watch_dog_file -from .logging import get_logger -from .models import File +from offsets_db_api.cache import watch_dog_file +from offsets_db_api.log import get_logger +from offsets_db_api.models import File logger = get_logger() diff --git a/tests/test_charts.py b/tests/test_charts.py index 92b236c..d4a1065 100644 --- a/tests/test_charts.py +++ b/tests/test_charts.py @@ -1,8 +1,6 @@ import pandas as pd import pytest -from offsets_db_api.routers.charts import filter_valid_projects, projects_by_category - @pytest.fixture def sample_projects(): @@ -24,49 +22,6 @@ def sample_projects(): return data -@pytest.mark.parametrize( - 'categories, expected', - [ - (None, ['ghg-management', 'renewable-energy', 'biodiversity', 'water-management']), - (['renewable-energy'], ['renewable-energy']), - ], -) -def test_filter_valid_projects(sample_projects, categories, expected): - result = filter_valid_projects(sample_projects, categories=categories) - assert list(result['category'].unique()) == expected - - -@pytest.mark.parametrize( - 'categories,expected', - [ - ( - None, - [ - {'category': 'ghg-management', 'value': 2}, - {'category': 'renewable-energy', 'value': 1}, - {'category': 'biodiversity', 'value': 1}, - {'category': 'water-management', 'value': 1}, - ], - ), - (['ghg-management'], [{'category': 'ghg-management', 'value': 2}]), - ( - ['renewable-energy', 'biodiversity'], - [ - {'category': 'renewable-energy', 'value': 1}, - {'category': 'biodiversity', 'value': 1}, - ], - ), - (['non-existent-category'], []), - ([], []), - ], -) -def test_projects_by_category(categories, expected, sample_projects): - result = projects_by_category(df=sample_projects, categories=categories) - sorted_result = sorted(result, key=lambda x: x['category']) - sorted_expected = sorted(expected, key=lambda x: x['category']) - assert sorted_result == sorted_expected - - @pytest.mark.parametrize('freq', ['D', 'M', 'Y', 'W']) @pytest.mark.parametrize('registry', ['american-carbon-registry', 'climate-action-reserve']) @pytest.mark.parametrize('country', ['US', 'CA'])