diff --git a/pipeline/run.py b/pipeline/run.py index 8bb3884..3502ff2 100644 --- a/pipeline/run.py +++ b/pipeline/run.py @@ -30,5 +30,6 @@ def run_pipeline(): size_mb = path.stat().st_size / (1024 * 1024) print(f" Saved: {path.name} ({size_mb:.1f} MB)") + if __name__ == "__main__": run_pipeline() diff --git a/server/routes/hexagons.py b/server/routes/hexagons.py index 592b097..40835e9 100644 --- a/server/routes/hexagons.py +++ b/server/routes/hexagons.py @@ -1,4 +1,4 @@ -from typing import Any +from functools import lru_cache from fastapi import APIRouter, Query, HTTPException import polars as pl import h3 @@ -15,91 +15,54 @@ from server.config import ( router = APIRouter() - -def get_h3_cells_for_bounds( - south: float, west: float, north: float, east: float, resolution: int -) -> set[str] | None: - """Get all H3 cells that cover a bounding box. Returns None if area too large.""" - # Clamp to valid ranges - south = max(-85, min(85, south)) - north = max(-85, min(85, north)) - west = max(-180, min(180, west)) - east = max(-180, min(180, east)) - - # Ensure valid bounds - if south >= north or west >= east: - return set() - - # If viewport is too large, return None to skip filtering - # This prevents H3 from trying to enumerate millions of cells - lat_span = north - south - lng_span = east - west - if lat_span > 20 or lng_span > 30: - return None - - # Create polygon from bounds (counter-clockwise winding for H3/GeoJSON) - # Order: SW -> NW -> NE -> SE -> SW - polygon = [ - (south, west), - (north, west), - (north, east), - (south, east), - (south, west), - ] - - try: - return h3.polygon_to_cells(h3.LatLngPoly(polygon), resolution) - except Exception: - return None +# Cache loaded dataframes in memory (one per resolution) +_df_cache: dict[int, pl.DataFrame] = {} -@router.get("/hexagons") -async def get_hexagons( - resolution: int = Query( - DEFAULT_RESOLUTION, - ge=min(VALID_RESOLUTIONS), - le=max(VALID_RESOLUTIONS), - description=f"H3 resolution ({min(VALID_RESOLUTIONS)}-{max(VALID_RESOLUTIONS)})", - ), - min_year: int = Query(DEFAULT_MIN_YEAR, description="Minimum year filter"), - max_year: int = Query(DEFAULT_MAX_YEAR, description="Maximum year filter"), - min_price: float = Query(DEFAULT_MIN_PRICE, description="Minimum average price"), - max_price: float = Query(DEFAULT_MAX_PRICE, description="Maximum average price"), - bounds: str | None = Query( - None, description="Bounding box: south,west,north,east" - ), -) -> dict: - """Get aggregated property data as GeoJSON hexagons within bounds.""" - if resolution not in VALID_RESOLUTIONS: - resolution = DEFAULT_RESOLUTION +def get_cached_df(resolution: int) -> pl.DataFrame | None: + """Get cached dataframe for resolution, loading from disk if needed.""" + if resolution not in _df_cache: + parquet_path = AGGREGATES_DIR / f"res{resolution}.parquet" + if not parquet_path.exists(): + return None + # Load and add H3 cell centroids for fast bbox filtering + df = pl.read_parquet(parquet_path) - # Bounds are required for efficient queries - if not bounds: - raise HTTPException(status_code=400, detail="bounds parameter is required") - - # Parse bounds - try: - south, west, north, east = map(float, bounds.split(",")) - except ValueError: - raise HTTPException( - status_code=400, detail="Invalid bounds format. Use: south,west,north,east" + # Pre-compute cell centroids for bbox filtering (much faster than is_in) + centroids = [h3.cell_to_latlng(cell) for cell in df["h3"].to_list()] + df = df.with_columns( + [ + pl.Series("lat", [c[0] for c in centroids]), + pl.Series("lng", [c[1] for c in centroids]), + ] ) + _df_cache[resolution] = df + return _df_cache[resolution] - # Load the appropriate resolution file - parquet_path = AGGREGATES_DIR / f"res{resolution}.parquet" - if not parquet_path.exists(): - return {"features": []} - df = pl.scan_parquet(parquet_path) +@lru_cache(maxsize=128) +def query_hexagons_cached( + resolution: int, + min_year: int, + max_year: int, + min_price: int, + max_price: int, + bounds_tuple: tuple[float, float, float, float], +) -> tuple[list[dict], bool]: + """Cached query - returns (features, truncated).""" + south, west, north, east = bounds_tuple - # Get H3 cells that cover the viewport (None if too large to enumerate) - viewport_cells = get_h3_cells_for_bounds(south, west, north, east, resolution) + df = get_cached_df(resolution) + if df is None: + return [], False - # Filter to only cells in viewport (skip if viewport too large) - if viewport_cells is not None: - if len(viewport_cells) == 0: - return {"features": []} - df = df.filter(pl.col("h3").is_in(viewport_cells)) + # Fast bbox filter using pre-computed centroids (O(1) per row) + df = df.filter( + (pl.col("lat") >= south) + & (pl.col("lat") <= north) + & (pl.col("lng") >= west) + & (pl.col("lng") <= east) + ) # Filter by year range df = df.filter((pl.col("year") >= min_year) & (pl.col("year") <= max_year)) @@ -123,27 +86,71 @@ async def get_hexagons( (pl.col("avg_price") >= min_price) & (pl.col("avg_price") <= max_price) ) - # Limit results to prevent browser crashes + # Limit results MAX_HEXAGONS = 50000 - df = df.limit(MAX_HEXAGONS) + truncated = len(df) >= MAX_HEXAGONS + if truncated: + df = df.limit(MAX_HEXAGONS) - # Collect results - result = df.collect() + # Build response efficiently using Polars + df = df.select( + [ + pl.col("h3"), + pl.col("count"), + pl.col("avg_price").round(2), + pl.col("median_price").round(2), + pl.col("min_price"), + pl.col("max_price"), + ] + ) - # Return lightweight response - just h3 index and properties - # Frontend H3HexagonLayer will render the geometry - # Use to_dicts() which is faster than iter_rows for large results - rows = result.to_dicts() - features = [ - { - "h3": row["h3"], - "count": row["count"], - "avg_price": round(row["avg_price"], 2), - "median_price": round(row["median_price"], 2) if row["median_price"] else None, - "min_price": row["min_price"], - "max_price": row["max_price"], - } - for row in rows - ] + return df.to_dicts(), truncated - return {"features": features, "truncated": len(rows) >= MAX_HEXAGONS} + +@router.get("/hexagons") +async def get_hexagons( + resolution: int = Query( + DEFAULT_RESOLUTION, + ge=min(VALID_RESOLUTIONS), + le=max(VALID_RESOLUTIONS), + description=f"H3 resolution ({min(VALID_RESOLUTIONS)}-{max(VALID_RESOLUTIONS)})", + ), + min_year: int = Query(DEFAULT_MIN_YEAR, description="Minimum year filter"), + max_year: int = Query(DEFAULT_MAX_YEAR, description="Maximum year filter"), + min_price: float = Query(DEFAULT_MIN_PRICE, description="Minimum average price"), + max_price: float = Query(DEFAULT_MAX_PRICE, description="Maximum average price"), + bounds: str | None = Query(None, description="Bounding box: south,west,north,east"), +) -> dict: + """Get aggregated property data as GeoJSON hexagons within bounds.""" + if resolution not in VALID_RESOLUTIONS: + resolution = DEFAULT_RESOLUTION + + if not bounds: + raise HTTPException(status_code=400, detail="bounds parameter is required") + + try: + south, west, north, east = map(float, bounds.split(",")) + except ValueError: + raise HTTPException( + status_code=400, detail="Invalid bounds format. Use: south,west,north,east" + ) + + # Round bounds to reduce cache misses (0.01 degree ≈ 1km precision) + bounds_tuple = ( + round(south, 2), + round(west, 2), + round(north, 2), + round(east, 2), + ) + + # Convert prices to int for cache key hashability + features, truncated = query_hexagons_cached( + resolution, + min_year, + max_year, + int(min_price), + int(max_price), + bounds_tuple, + ) + + return {"features": features, "truncated": truncated}