149 lines
4.8 KiB
Python
149 lines
4.8 KiB
Python
from typing import Any
|
|
from fastapi import APIRouter, Query, HTTPException
|
|
import polars as pl
|
|
import h3
|
|
|
|
from server.config import (
|
|
AGGREGATES_DIR,
|
|
VALID_RESOLUTIONS,
|
|
DEFAULT_RESOLUTION,
|
|
DEFAULT_MIN_YEAR,
|
|
DEFAULT_MAX_YEAR,
|
|
DEFAULT_MIN_PRICE,
|
|
DEFAULT_MAX_PRICE,
|
|
)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
def get_h3_cells_for_bounds(
|
|
south: float, west: float, north: float, east: float, resolution: int
|
|
) -> set[str] | None:
|
|
"""Get all H3 cells that cover a bounding box. Returns None if area too large."""
|
|
# Clamp to valid ranges
|
|
south = max(-85, min(85, south))
|
|
north = max(-85, min(85, north))
|
|
west = max(-180, min(180, west))
|
|
east = max(-180, min(180, east))
|
|
|
|
# Ensure valid bounds
|
|
if south >= north or west >= east:
|
|
return set()
|
|
|
|
# If viewport is too large, return None to skip filtering
|
|
# This prevents H3 from trying to enumerate millions of cells
|
|
lat_span = north - south
|
|
lng_span = east - west
|
|
if lat_span > 20 or lng_span > 30:
|
|
return None
|
|
|
|
# Create polygon from bounds (counter-clockwise winding for H3/GeoJSON)
|
|
# Order: SW -> NW -> NE -> SE -> SW
|
|
polygon = [
|
|
(south, west),
|
|
(north, west),
|
|
(north, east),
|
|
(south, east),
|
|
(south, west),
|
|
]
|
|
|
|
try:
|
|
return h3.polygon_to_cells(h3.LatLngPoly(polygon), resolution)
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
@router.get("/hexagons")
|
|
async def get_hexagons(
|
|
resolution: int = Query(
|
|
DEFAULT_RESOLUTION,
|
|
ge=min(VALID_RESOLUTIONS),
|
|
le=max(VALID_RESOLUTIONS),
|
|
description=f"H3 resolution ({min(VALID_RESOLUTIONS)}-{max(VALID_RESOLUTIONS)})",
|
|
),
|
|
min_year: int = Query(DEFAULT_MIN_YEAR, description="Minimum year filter"),
|
|
max_year: int = Query(DEFAULT_MAX_YEAR, description="Maximum year filter"),
|
|
min_price: float = Query(DEFAULT_MIN_PRICE, description="Minimum average price"),
|
|
max_price: float = Query(DEFAULT_MAX_PRICE, description="Maximum average price"),
|
|
bounds: str | None = Query(
|
|
None, description="Bounding box: south,west,north,east"
|
|
),
|
|
) -> dict:
|
|
"""Get aggregated property data as GeoJSON hexagons within bounds."""
|
|
if resolution not in VALID_RESOLUTIONS:
|
|
resolution = DEFAULT_RESOLUTION
|
|
|
|
# Bounds are required for efficient queries
|
|
if not bounds:
|
|
raise HTTPException(status_code=400, detail="bounds parameter is required")
|
|
|
|
# Parse bounds
|
|
try:
|
|
south, west, north, east = map(float, bounds.split(","))
|
|
except ValueError:
|
|
raise HTTPException(
|
|
status_code=400, detail="Invalid bounds format. Use: south,west,north,east"
|
|
)
|
|
|
|
# Load the appropriate resolution file
|
|
parquet_path = AGGREGATES_DIR / f"res{resolution}.parquet"
|
|
if not parquet_path.exists():
|
|
return {"features": []}
|
|
|
|
df = pl.scan_parquet(parquet_path)
|
|
|
|
# Get H3 cells that cover the viewport (None if too large to enumerate)
|
|
viewport_cells = get_h3_cells_for_bounds(south, west, north, east, resolution)
|
|
|
|
# Filter to only cells in viewport (skip if viewport too large)
|
|
if viewport_cells is not None:
|
|
if len(viewport_cells) == 0:
|
|
return {"features": []}
|
|
df = df.filter(pl.col("h3").is_in(viewport_cells))
|
|
|
|
# Filter by year range
|
|
df = df.filter((pl.col("year") >= min_year) & (pl.col("year") <= max_year))
|
|
|
|
# Aggregate across years (weighted by count)
|
|
df = df.group_by("h3").agg(
|
|
pl.col("count").sum().alias("count"),
|
|
(pl.col("avg_price") * pl.col("count")).sum().alias("weighted_price_sum"),
|
|
pl.col("median_price").median().alias("median_price"),
|
|
pl.col("min_price").min().alias("min_price"),
|
|
pl.col("max_price").max().alias("max_price"),
|
|
)
|
|
|
|
# Calculate weighted average price
|
|
df = df.with_columns(
|
|
(pl.col("weighted_price_sum") / pl.col("count")).alias("avg_price")
|
|
).drop("weighted_price_sum")
|
|
|
|
# Filter by price range
|
|
df = df.filter(
|
|
(pl.col("avg_price") >= min_price) & (pl.col("avg_price") <= max_price)
|
|
)
|
|
|
|
# Limit results to prevent browser crashes
|
|
MAX_HEXAGONS = 50000
|
|
df = df.limit(MAX_HEXAGONS)
|
|
|
|
# Collect results
|
|
result = df.collect()
|
|
|
|
# Return lightweight response - just h3 index and properties
|
|
# Frontend H3HexagonLayer will render the geometry
|
|
# Use to_dicts() which is faster than iter_rows for large results
|
|
rows = result.to_dicts()
|
|
features = [
|
|
{
|
|
"h3": row["h3"],
|
|
"count": row["count"],
|
|
"avg_price": round(row["avg_price"], 2),
|
|
"median_price": round(row["median_price"], 2) if row["median_price"] else None,
|
|
"min_price": row["min_price"],
|
|
"max_price": row["max_price"],
|
|
}
|
|
for row in rows
|
|
]
|
|
|
|
return {"features": features, "truncated": len(rows) >= MAX_HEXAGONS}
|