172 lines
5.3 KiB
Python
172 lines
5.3 KiB
Python
import math
|
|
from functools import lru_cache
|
|
from fastapi import APIRouter, Query, HTTPException
|
|
import polars as pl
|
|
import h3
|
|
|
|
from tqdm import tqdm
|
|
|
|
from server.config import (
|
|
AGGREGATES_DIR,
|
|
VALID_RESOLUTIONS,
|
|
DEFAULT_RESOLUTION,
|
|
BOUNDS_BUFFER_PERCENT,
|
|
)
|
|
|
|
router = APIRouter()
|
|
|
|
# Cache loaded dataframes in memory (one per resolution)
|
|
_df_cache: dict[int, pl.DataFrame] = {}
|
|
|
|
# Discovered features (computed once on first load)
|
|
_features_cache: list[dict] | None = None
|
|
|
|
|
|
def _snake_to_label(name: str) -> str:
|
|
"""Convert snake_case feature name to a human-readable label."""
|
|
return name.replace("_", " ").title()
|
|
|
|
|
|
def _discover_features(df: pl.DataFrame) -> list[dict]:
|
|
"""Discover features from column pairs min_X / max_X."""
|
|
features = []
|
|
seen = set()
|
|
for col in df.columns:
|
|
if col.startswith("min_"):
|
|
name = col[4:]
|
|
max_col = f"max_{name}"
|
|
if max_col in df.columns and name not in seen:
|
|
seen.add(name)
|
|
global_min = df[col].min()
|
|
global_max = df[max_col].max()
|
|
if global_min is not None and global_max is not None:
|
|
features.append(
|
|
{
|
|
"name": name,
|
|
"min": float(global_min),
|
|
"max": float(global_max),
|
|
"label": _snake_to_label(name),
|
|
}
|
|
)
|
|
return features
|
|
|
|
|
|
def preload_dataframes() -> None:
|
|
"""Load all resolution dataframes into cache on startup."""
|
|
for resolution in tqdm(VALID_RESOLUTIONS, desc="Loading parquet files"):
|
|
get_cached_df(resolution)
|
|
|
|
|
|
def get_cached_df(resolution: int) -> pl.DataFrame | None:
|
|
"""Get cached dataframe for resolution, loading from disk if needed."""
|
|
if resolution not in _df_cache:
|
|
parquet_path = AGGREGATES_DIR / f"res{resolution}.parquet"
|
|
if not parquet_path.exists():
|
|
return None
|
|
# Load and add H3 cell centroids for fast bbox filtering
|
|
df = pl.read_parquet(parquet_path)
|
|
|
|
# Pre-compute cell centroids for bbox filtering
|
|
centroids = [h3.cell_to_latlng(cell) for cell in df["h3"].to_list()]
|
|
df = df.with_columns(
|
|
[
|
|
pl.Series("_lat", [c[0] for c in centroids]),
|
|
pl.Series("_lng", [c[1] for c in centroids]),
|
|
]
|
|
)
|
|
_df_cache[resolution] = df
|
|
return _df_cache[resolution]
|
|
|
|
|
|
def get_features() -> list[dict]:
|
|
"""Get discovered features, computing from the first available resolution."""
|
|
global _features_cache
|
|
if _features_cache is None:
|
|
for resolution in VALID_RESOLUTIONS:
|
|
df = get_cached_df(resolution)
|
|
if df is not None:
|
|
_features_cache = _discover_features(df)
|
|
break
|
|
if _features_cache is None:
|
|
_features_cache = []
|
|
return _features_cache
|
|
|
|
|
|
@router.get("/features")
|
|
async def get_features_endpoint() -> dict:
|
|
"""Return discovered feature metadata with global min/max ranges."""
|
|
return {"features": get_features()}
|
|
|
|
|
|
@lru_cache(maxsize=128)
|
|
def query_hexagons_cached(
|
|
resolution: int,
|
|
bounds_tuple: tuple[float, float, float, float],
|
|
) -> list[dict]:
|
|
"""Cached query - returns features list."""
|
|
south, west, north, east = bounds_tuple
|
|
|
|
df = get_cached_df(resolution)
|
|
if df is None:
|
|
return []
|
|
|
|
# Fast bbox filter using pre-computed centroids
|
|
df = df.filter(
|
|
(pl.col("_lat") >= south)
|
|
& (pl.col("_lat") <= north)
|
|
& (pl.col("_lng") >= west)
|
|
& (pl.col("_lng") <= east)
|
|
)
|
|
|
|
# Drop internal centroid columns before returning
|
|
df = df.drop("_lat", "_lng")
|
|
|
|
return df.to_dicts()
|
|
|
|
|
|
@router.get("/hexagons")
|
|
async def get_hexagons(
|
|
resolution: int = Query(
|
|
DEFAULT_RESOLUTION,
|
|
ge=min(VALID_RESOLUTIONS),
|
|
le=max(VALID_RESOLUTIONS),
|
|
description=f"H3 resolution ({min(VALID_RESOLUTIONS)}-{max(VALID_RESOLUTIONS)})",
|
|
),
|
|
bounds: str | None = Query(None, description="Bounding box: south,west,north,east"),
|
|
) -> dict:
|
|
"""Get aggregated property data as hexagons within bounds."""
|
|
if resolution not in VALID_RESOLUTIONS:
|
|
resolution = DEFAULT_RESOLUTION
|
|
|
|
if not bounds:
|
|
raise HTTPException(status_code=400, detail="bounds parameter is required")
|
|
|
|
try:
|
|
south, west, north, east = map(float, bounds.split(","))
|
|
except ValueError:
|
|
raise HTTPException(
|
|
status_code=400, detail="Invalid bounds format. Use: south,west,north,east"
|
|
)
|
|
|
|
# Expand bounds by buffer percentage for smoother panning
|
|
lat_range = north - south
|
|
lng_range = east - west
|
|
lat_buffer = lat_range * BOUNDS_BUFFER_PERCENT
|
|
lng_buffer = lng_range * BOUNDS_BUFFER_PERCENT
|
|
south -= lat_buffer
|
|
north += lat_buffer
|
|
west -= lng_buffer
|
|
east += lng_buffer
|
|
|
|
# Round bounds to reduce cache misses (0.01 degree ~ 1km precision)
|
|
precision = 0.01
|
|
bounds_tuple = (
|
|
math.floor(south / precision) * precision,
|
|
math.floor(west / precision) * precision,
|
|
math.ceil(north / precision) * precision,
|
|
math.ceil(east / precision) * precision,
|
|
)
|
|
|
|
features = query_hexagons_cached(resolution, bounds_tuple)
|
|
|
|
return {"features": features}
|