import math from functools import lru_cache from fastapi import APIRouter, Query, HTTPException import polars as pl import h3 from tqdm import tqdm from server.config import ( AGGREGATES_DIR, VALID_RESOLUTIONS, DEFAULT_RESOLUTION, BOUNDS_BUFFER_PERCENT, ) router = APIRouter() # Cache loaded dataframes in memory (one per resolution) _df_cache: dict[int, pl.DataFrame] = {} # Discovered features (computed once on first load) _features_cache: list[dict] | None = None def _snake_to_label(name: str) -> str: """Convert snake_case feature name to a human-readable label.""" return name.replace("_", " ").title() def _discover_features(df: pl.DataFrame) -> list[dict]: """Discover features from column pairs min_X / max_X.""" features = [] seen = set() for col in df.columns: if col.startswith("min_"): name = col[4:] max_col = f"max_{name}" if max_col in df.columns and name not in seen: seen.add(name) global_min = df[col].min() global_max = df[max_col].max() if global_min is not None and global_max is not None: features.append( { "name": name, "min": float(global_min), "max": float(global_max), "label": _snake_to_label(name), } ) return features def preload_dataframes() -> None: """Load all resolution dataframes into cache on startup.""" for resolution in tqdm(VALID_RESOLUTIONS, desc="Loading parquet files"): get_cached_df(resolution) def get_cached_df(resolution: int) -> pl.DataFrame | None: """Get cached dataframe for resolution, loading from disk if needed.""" if resolution not in _df_cache: parquet_path = AGGREGATES_DIR / f"res{resolution}.parquet" if not parquet_path.exists(): return None # Load and add H3 cell centroids for fast bbox filtering df = pl.read_parquet(parquet_path) # Pre-compute cell centroids for bbox filtering centroids = [h3.cell_to_latlng(cell) for cell in df["h3"].to_list()] df = df.with_columns( [ pl.Series("_lat", [c[0] for c in centroids]), pl.Series("_lng", [c[1] for c in centroids]), ] ) _df_cache[resolution] = df return _df_cache[resolution] def get_features() -> list[dict]: """Get discovered features, computing from the first available resolution.""" global _features_cache if _features_cache is None: for resolution in VALID_RESOLUTIONS: df = get_cached_df(resolution) if df is not None: _features_cache = _discover_features(df) break if _features_cache is None: _features_cache = [] return _features_cache @router.get("/features") async def get_features_endpoint() -> dict: """Return discovered feature metadata with global min/max ranges.""" return {"features": get_features()} @lru_cache(maxsize=128) def query_hexagons_cached( resolution: int, bounds_tuple: tuple[float, float, float, float], ) -> list[dict]: """Cached query - returns features list.""" south, west, north, east = bounds_tuple df = get_cached_df(resolution) if df is None: return [] # Fast bbox filter using pre-computed centroids df = df.filter( (pl.col("_lat") >= south) & (pl.col("_lat") <= north) & (pl.col("_lng") >= west) & (pl.col("_lng") <= east) ) # Drop internal centroid columns before returning df = df.drop("_lat", "_lng") return df.to_dicts() @router.get("/hexagons") async def get_hexagons( resolution: int = Query( DEFAULT_RESOLUTION, ge=min(VALID_RESOLUTIONS), le=max(VALID_RESOLUTIONS), description=f"H3 resolution ({min(VALID_RESOLUTIONS)}-{max(VALID_RESOLUTIONS)})", ), bounds: str | None = Query(None, description="Bounding box: south,west,north,east"), ) -> dict: """Get aggregated property data as hexagons within bounds.""" if resolution not in VALID_RESOLUTIONS: resolution = DEFAULT_RESOLUTION if not bounds: raise HTTPException(status_code=400, detail="bounds parameter is required") try: south, west, north, east = map(float, bounds.split(",")) except ValueError: raise HTTPException( status_code=400, detail="Invalid bounds format. Use: south,west,north,east" ) # Expand bounds by buffer percentage for smoother panning lat_range = north - south lng_range = east - west lat_buffer = lat_range * BOUNDS_BUFFER_PERCENT lng_buffer = lng_range * BOUNDS_BUFFER_PERCENT south -= lat_buffer north += lat_buffer west -= lng_buffer east += lng_buffer # Round bounds to reduce cache misses (0.01 degree ~ 1km precision) precision = 0.01 bounds_tuple = ( math.floor(south / precision) * precision, math.floor(west / precision) * precision, math.ceil(north / precision) * precision, math.ceil(east / precision) * precision, ) features = query_hexagons_cached(resolution, bounds_tuple) return {"features": features}