perfect-postcode/server/routes/hexagons.py

162 lines
5 KiB
Python

import math
from functools import lru_cache
from fastapi import APIRouter, Query, HTTPException
import polars as pl
import h3
from tqdm import tqdm
from server.config import (
AGGREGATES_DIR,
VALID_RESOLUTIONS,
DEFAULT_RESOLUTION,
DEFAULT_MIN_YEAR,
DEFAULT_MAX_YEAR,
DEFAULT_MIN_PRICE,
DEFAULT_MAX_PRICE,
)
router = APIRouter()
# Cache loaded dataframes in memory (one per resolution)
_df_cache: dict[int, pl.DataFrame] = {}
def preload_dataframes() -> None:
"""Load all resolution dataframes into cache on startup."""
for resolution in tqdm(VALID_RESOLUTIONS, desc="Loading parquet files"):
get_cached_df(resolution)
def get_cached_df(resolution: int) -> pl.DataFrame | None:
"""Get cached dataframe for resolution, loading from disk if needed."""
if resolution not in _df_cache:
parquet_path = AGGREGATES_DIR / f"res{resolution}.parquet"
if not parquet_path.exists():
return None
# Load and add H3 cell centroids for fast bbox filtering
df = pl.read_parquet(parquet_path)
# Pre-compute cell centroids for bbox filtering (much faster than is_in)
centroids = [h3.cell_to_latlng(cell) for cell in df["h3"].to_list()]
df = df.with_columns(
[
pl.Series("lat", [c[0] for c in centroids]),
pl.Series("lng", [c[1] for c in centroids]),
]
)
_df_cache[resolution] = df
return _df_cache[resolution]
@lru_cache(maxsize=128)
def query_hexagons_cached(
resolution: int,
min_year: int,
max_year: int,
min_price: int,
max_price: int,
bounds_tuple: tuple[float, float, float, float],
) -> list[dict]:
"""Cached query - returns features list."""
south, west, north, east = bounds_tuple
df = get_cached_df(resolution)
if df is None:
return [], False
# Fast bbox filter using pre-computed centroids (O(1) per row)
df = df.filter(
(pl.col("lat") >= south)
& (pl.col("lat") <= north)
& (pl.col("lng") >= west)
& (pl.col("lng") <= east)
)
# Filter by year range
df = df.filter((pl.col("year") >= min_year) & (pl.col("year") <= max_year))
# Aggregate across years (weighted by count)
df = df.group_by("h3").agg(
pl.col("count").sum().alias("count"),
(pl.col("avg_price") * pl.col("count")).sum().alias("weighted_price_sum"),
pl.col("median_price").median().alias("median_price"),
pl.col("min_price").min().alias("min_price"),
pl.col("max_price").max().alias("max_price"),
)
# Calculate weighted average price
df = df.with_columns(
(pl.col("weighted_price_sum") / pl.col("count")).alias("avg_price")
).drop("weighted_price_sum")
# Filter by price range
df = df.filter(
(pl.col("avg_price") >= min_price) & (pl.col("avg_price") <= max_price)
)
# Build response efficiently using Polars
df = df.select(
[
pl.col("h3"),
pl.col("count"),
pl.col("avg_price").round(2),
pl.col("median_price").round(2),
pl.col("min_price"),
pl.col("max_price"),
]
)
return df.to_dicts()
@router.get("/hexagons")
async def get_hexagons(
resolution: int = Query(
DEFAULT_RESOLUTION,
ge=min(VALID_RESOLUTIONS),
le=max(VALID_RESOLUTIONS),
description=f"H3 resolution ({min(VALID_RESOLUTIONS)}-{max(VALID_RESOLUTIONS)})",
),
min_year: int = Query(DEFAULT_MIN_YEAR, description="Minimum year filter"),
max_year: int = Query(DEFAULT_MAX_YEAR, description="Maximum year filter"),
min_price: float = Query(DEFAULT_MIN_PRICE, description="Minimum average price"),
max_price: float = Query(DEFAULT_MAX_PRICE, description="Maximum average price"),
bounds: str | None = Query(None, description="Bounding box: south,west,north,east"),
) -> dict:
"""Get aggregated property data as GeoJSON hexagons within bounds."""
if resolution not in VALID_RESOLUTIONS:
resolution = DEFAULT_RESOLUTION
if not bounds:
raise HTTPException(status_code=400, detail="bounds parameter is required")
try:
south, west, north, east = map(float, bounds.split(","))
except ValueError:
raise HTTPException(
status_code=400, detail="Invalid bounds format. Use: south,west,north,east"
)
# Round bounds to reduce cache misses (0.01 degree ≈ 1km precision)
# Always expand bounds (floor for min, ceil for max) to prevent hexagons
# popping in when crossing rounding boundaries
precision = 0.01
bounds_tuple = (
math.floor(south / precision) * precision,
math.floor(west / precision) * precision,
math.ceil(north / precision) * precision,
math.ceil(east / precision) * precision,
)
# Convert prices to int for cache key hashability
features = query_hexagons_cached(
resolution,
min_year,
max_year,
int(min_price),
int(max_price),
bounds_tuple,
)
return {"features": features}