import math from functools import lru_cache from fastapi import APIRouter, Query, HTTPException import polars as pl import h3 from tqdm import tqdm from server.config import ( AGGREGATES_DIR, VALID_RESOLUTIONS, DEFAULT_RESOLUTION, DEFAULT_MIN_YEAR, DEFAULT_MAX_YEAR, DEFAULT_MIN_PRICE, DEFAULT_MAX_PRICE, ) router = APIRouter() # Cache loaded dataframes in memory (one per resolution) _df_cache: dict[int, pl.DataFrame] = {} def preload_dataframes() -> None: """Load all resolution dataframes into cache on startup.""" for resolution in tqdm(VALID_RESOLUTIONS, desc="Loading parquet files"): get_cached_df(resolution) def get_cached_df(resolution: int) -> pl.DataFrame | None: """Get cached dataframe for resolution, loading from disk if needed.""" if resolution not in _df_cache: parquet_path = AGGREGATES_DIR / f"res{resolution}.parquet" if not parquet_path.exists(): return None # Load and add H3 cell centroids for fast bbox filtering df = pl.read_parquet(parquet_path) # Pre-compute cell centroids for bbox filtering (much faster than is_in) centroids = [h3.cell_to_latlng(cell) for cell in df["h3"].to_list()] df = df.with_columns( [ pl.Series("lat", [c[0] for c in centroids]), pl.Series("lng", [c[1] for c in centroids]), ] ) _df_cache[resolution] = df return _df_cache[resolution] @lru_cache(maxsize=128) def query_hexagons_cached( resolution: int, min_year: int, max_year: int, min_price: int, max_price: int, bounds_tuple: tuple[float, float, float, float], ) -> list[dict]: """Cached query - returns features list.""" south, west, north, east = bounds_tuple df = get_cached_df(resolution) if df is None: return [], False # Fast bbox filter using pre-computed centroids (O(1) per row) df = df.filter( (pl.col("lat") >= south) & (pl.col("lat") <= north) & (pl.col("lng") >= west) & (pl.col("lng") <= east) ) # Filter by year range df = df.filter((pl.col("year") >= min_year) & (pl.col("year") <= max_year)) # Aggregate across years (weighted by count) df = df.group_by("h3").agg( pl.col("count").sum().alias("count"), (pl.col("avg_price") * pl.col("count")).sum().alias("weighted_price_sum"), pl.col("median_price").median().alias("median_price"), pl.col("min_price").min().alias("min_price"), pl.col("max_price").max().alias("max_price"), ) # Calculate weighted average price df = df.with_columns( (pl.col("weighted_price_sum") / pl.col("count")).alias("avg_price") ).drop("weighted_price_sum") # Filter by price range df = df.filter( (pl.col("avg_price") >= min_price) & (pl.col("avg_price") <= max_price) ) # Build response efficiently using Polars df = df.select( [ pl.col("h3"), pl.col("count"), pl.col("avg_price").round(2), pl.col("median_price").round(2), pl.col("min_price"), pl.col("max_price"), ] ) return df.to_dicts() @router.get("/hexagons") async def get_hexagons( resolution: int = Query( DEFAULT_RESOLUTION, ge=min(VALID_RESOLUTIONS), le=max(VALID_RESOLUTIONS), description=f"H3 resolution ({min(VALID_RESOLUTIONS)}-{max(VALID_RESOLUTIONS)})", ), min_year: int = Query(DEFAULT_MIN_YEAR, description="Minimum year filter"), max_year: int = Query(DEFAULT_MAX_YEAR, description="Maximum year filter"), min_price: float = Query(DEFAULT_MIN_PRICE, description="Minimum average price"), max_price: float = Query(DEFAULT_MAX_PRICE, description="Maximum average price"), bounds: str | None = Query(None, description="Bounding box: south,west,north,east"), ) -> dict: """Get aggregated property data as GeoJSON hexagons within bounds.""" if resolution not in VALID_RESOLUTIONS: resolution = DEFAULT_RESOLUTION if not bounds: raise HTTPException(status_code=400, detail="bounds parameter is required") try: south, west, north, east = map(float, bounds.split(",")) except ValueError: raise HTTPException( status_code=400, detail="Invalid bounds format. Use: south,west,north,east" ) # Round bounds to reduce cache misses (0.01 degree ≈ 1km precision) # Always expand bounds (floor for min, ceil for max) to prevent hexagons # popping in when crossing rounding boundaries precision = 0.01 bounds_tuple = ( math.floor(south / precision) * precision, math.floor(west / precision) * precision, math.ceil(north / precision) * precision, math.ceil(east / precision) * precision, ) # Convert prices to int for cache key hashability features = query_hexagons_cached( resolution, min_year, max_year, int(min_price), int(max_price), bounds_tuple, ) return {"features": features}