from typing import Any from fastapi import APIRouter, Query, HTTPException import polars as pl import h3 from server.config import ( AGGREGATES_DIR, VALID_RESOLUTIONS, DEFAULT_RESOLUTION, DEFAULT_MIN_YEAR, DEFAULT_MAX_YEAR, DEFAULT_MIN_PRICE, DEFAULT_MAX_PRICE, ) router = APIRouter() def get_h3_cells_for_bounds( south: float, west: float, north: float, east: float, resolution: int ) -> set[str] | None: """Get all H3 cells that cover a bounding box. Returns None if area too large.""" # Clamp to valid ranges south = max(-85, min(85, south)) north = max(-85, min(85, north)) west = max(-180, min(180, west)) east = max(-180, min(180, east)) # Ensure valid bounds if south >= north or west >= east: return set() # If viewport is too large, return None to skip filtering # This prevents H3 from trying to enumerate millions of cells lat_span = north - south lng_span = east - west if lat_span > 20 or lng_span > 30: return None # Create polygon from bounds (counter-clockwise winding for H3/GeoJSON) # Order: SW -> NW -> NE -> SE -> SW polygon = [ (south, west), (north, west), (north, east), (south, east), (south, west), ] try: return h3.polygon_to_cells(h3.LatLngPoly(polygon), resolution) except Exception: return None @router.get("/hexagons") async def get_hexagons( resolution: int = Query( DEFAULT_RESOLUTION, ge=min(VALID_RESOLUTIONS), le=max(VALID_RESOLUTIONS), description=f"H3 resolution ({min(VALID_RESOLUTIONS)}-{max(VALID_RESOLUTIONS)})", ), min_year: int = Query(DEFAULT_MIN_YEAR, description="Minimum year filter"), max_year: int = Query(DEFAULT_MAX_YEAR, description="Maximum year filter"), min_price: float = Query(DEFAULT_MIN_PRICE, description="Minimum average price"), max_price: float = Query(DEFAULT_MAX_PRICE, description="Maximum average price"), bounds: str | None = Query( None, description="Bounding box: south,west,north,east" ), ) -> dict: """Get aggregated property data as GeoJSON hexagons within bounds.""" if resolution not in VALID_RESOLUTIONS: resolution = DEFAULT_RESOLUTION # Bounds are required for efficient queries if not bounds: raise HTTPException(status_code=400, detail="bounds parameter is required") # Parse bounds try: south, west, north, east = map(float, bounds.split(",")) except ValueError: raise HTTPException( status_code=400, detail="Invalid bounds format. Use: south,west,north,east" ) # Load the appropriate resolution file parquet_path = AGGREGATES_DIR / f"res{resolution}.parquet" if not parquet_path.exists(): return {"features": []} df = pl.scan_parquet(parquet_path) # Get H3 cells that cover the viewport (None if too large to enumerate) viewport_cells = get_h3_cells_for_bounds(south, west, north, east, resolution) # Filter to only cells in viewport (skip if viewport too large) if viewport_cells is not None: if len(viewport_cells) == 0: return {"features": []} df = df.filter(pl.col("h3").is_in(viewport_cells)) # Filter by year range df = df.filter((pl.col("year") >= min_year) & (pl.col("year") <= max_year)) # Aggregate across years (weighted by count) df = df.group_by("h3").agg( pl.col("count").sum().alias("count"), (pl.col("avg_price") * pl.col("count")).sum().alias("weighted_price_sum"), pl.col("median_price").median().alias("median_price"), pl.col("min_price").min().alias("min_price"), pl.col("max_price").max().alias("max_price"), ) # Calculate weighted average price df = df.with_columns( (pl.col("weighted_price_sum") / pl.col("count")).alias("avg_price") ).drop("weighted_price_sum") # Filter by price range df = df.filter( (pl.col("avg_price") >= min_price) & (pl.col("avg_price") <= max_price) ) # Limit results to prevent browser crashes MAX_HEXAGONS = 50000 df = df.limit(MAX_HEXAGONS) # Collect results result = df.collect() # Return lightweight response - just h3 index and properties # Frontend H3HexagonLayer will render the geometry # Use to_dicts() which is faster than iter_rows for large results rows = result.to_dicts() features = [ { "h3": row["h3"], "count": row["count"], "avg_price": round(row["avg_price"], 2), "median_price": round(row["median_price"], 2) if row["median_price"] else None, "min_price": row["min_price"], "max_price": row["max_price"], } for row in rows ] return {"features": features, "truncated": len(rows) >= MAX_HEXAGONS}