Improve perf
This commit is contained in:
parent
86690f41f1
commit
8c1f6a82e2
2 changed files with 106 additions and 98 deletions
|
|
@ -30,5 +30,6 @@ def run_pipeline():
|
|||
size_mb = path.stat().st_size / (1024 * 1024)
|
||||
print(f" Saved: {path.name} ({size_mb:.1f} MB)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_pipeline()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any
|
||||
from functools import lru_cache
|
||||
from fastapi import APIRouter, Query, HTTPException
|
||||
import polars as pl
|
||||
import h3
|
||||
|
|
@ -15,91 +15,54 @@ from server.config import (
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def get_h3_cells_for_bounds(
|
||||
south: float, west: float, north: float, east: float, resolution: int
|
||||
) -> set[str] | None:
|
||||
"""Get all H3 cells that cover a bounding box. Returns None if area too large."""
|
||||
# Clamp to valid ranges
|
||||
south = max(-85, min(85, south))
|
||||
north = max(-85, min(85, north))
|
||||
west = max(-180, min(180, west))
|
||||
east = max(-180, min(180, east))
|
||||
|
||||
# Ensure valid bounds
|
||||
if south >= north or west >= east:
|
||||
return set()
|
||||
|
||||
# If viewport is too large, return None to skip filtering
|
||||
# This prevents H3 from trying to enumerate millions of cells
|
||||
lat_span = north - south
|
||||
lng_span = east - west
|
||||
if lat_span > 20 or lng_span > 30:
|
||||
return None
|
||||
|
||||
# Create polygon from bounds (counter-clockwise winding for H3/GeoJSON)
|
||||
# Order: SW -> NW -> NE -> SE -> SW
|
||||
polygon = [
|
||||
(south, west),
|
||||
(north, west),
|
||||
(north, east),
|
||||
(south, east),
|
||||
(south, west),
|
||||
]
|
||||
|
||||
try:
|
||||
return h3.polygon_to_cells(h3.LatLngPoly(polygon), resolution)
|
||||
except Exception:
|
||||
return None
|
||||
# Cache loaded dataframes in memory (one per resolution)
|
||||
_df_cache: dict[int, pl.DataFrame] = {}
|
||||
|
||||
|
||||
@router.get("/hexagons")
|
||||
async def get_hexagons(
|
||||
resolution: int = Query(
|
||||
DEFAULT_RESOLUTION,
|
||||
ge=min(VALID_RESOLUTIONS),
|
||||
le=max(VALID_RESOLUTIONS),
|
||||
description=f"H3 resolution ({min(VALID_RESOLUTIONS)}-{max(VALID_RESOLUTIONS)})",
|
||||
),
|
||||
min_year: int = Query(DEFAULT_MIN_YEAR, description="Minimum year filter"),
|
||||
max_year: int = Query(DEFAULT_MAX_YEAR, description="Maximum year filter"),
|
||||
min_price: float = Query(DEFAULT_MIN_PRICE, description="Minimum average price"),
|
||||
max_price: float = Query(DEFAULT_MAX_PRICE, description="Maximum average price"),
|
||||
bounds: str | None = Query(
|
||||
None, description="Bounding box: south,west,north,east"
|
||||
),
|
||||
) -> dict:
|
||||
"""Get aggregated property data as GeoJSON hexagons within bounds."""
|
||||
if resolution not in VALID_RESOLUTIONS:
|
||||
resolution = DEFAULT_RESOLUTION
|
||||
def get_cached_df(resolution: int) -> pl.DataFrame | None:
|
||||
"""Get cached dataframe for resolution, loading from disk if needed."""
|
||||
if resolution not in _df_cache:
|
||||
parquet_path = AGGREGATES_DIR / f"res{resolution}.parquet"
|
||||
if not parquet_path.exists():
|
||||
return None
|
||||
# Load and add H3 cell centroids for fast bbox filtering
|
||||
df = pl.read_parquet(parquet_path)
|
||||
|
||||
# Bounds are required for efficient queries
|
||||
if not bounds:
|
||||
raise HTTPException(status_code=400, detail="bounds parameter is required")
|
||||
|
||||
# Parse bounds
|
||||
try:
|
||||
south, west, north, east = map(float, bounds.split(","))
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid bounds format. Use: south,west,north,east"
|
||||
# Pre-compute cell centroids for bbox filtering (much faster than is_in)
|
||||
centroids = [h3.cell_to_latlng(cell) for cell in df["h3"].to_list()]
|
||||
df = df.with_columns(
|
||||
[
|
||||
pl.Series("lat", [c[0] for c in centroids]),
|
||||
pl.Series("lng", [c[1] for c in centroids]),
|
||||
]
|
||||
)
|
||||
_df_cache[resolution] = df
|
||||
return _df_cache[resolution]
|
||||
|
||||
# Load the appropriate resolution file
|
||||
parquet_path = AGGREGATES_DIR / f"res{resolution}.parquet"
|
||||
if not parquet_path.exists():
|
||||
return {"features": []}
|
||||
|
||||
df = pl.scan_parquet(parquet_path)
|
||||
@lru_cache(maxsize=128)
|
||||
def query_hexagons_cached(
|
||||
resolution: int,
|
||||
min_year: int,
|
||||
max_year: int,
|
||||
min_price: int,
|
||||
max_price: int,
|
||||
bounds_tuple: tuple[float, float, float, float],
|
||||
) -> tuple[list[dict], bool]:
|
||||
"""Cached query - returns (features, truncated)."""
|
||||
south, west, north, east = bounds_tuple
|
||||
|
||||
# Get H3 cells that cover the viewport (None if too large to enumerate)
|
||||
viewport_cells = get_h3_cells_for_bounds(south, west, north, east, resolution)
|
||||
df = get_cached_df(resolution)
|
||||
if df is None:
|
||||
return [], False
|
||||
|
||||
# Filter to only cells in viewport (skip if viewport too large)
|
||||
if viewport_cells is not None:
|
||||
if len(viewport_cells) == 0:
|
||||
return {"features": []}
|
||||
df = df.filter(pl.col("h3").is_in(viewport_cells))
|
||||
# Fast bbox filter using pre-computed centroids (O(1) per row)
|
||||
df = df.filter(
|
||||
(pl.col("lat") >= south)
|
||||
& (pl.col("lat") <= north)
|
||||
& (pl.col("lng") >= west)
|
||||
& (pl.col("lng") <= east)
|
||||
)
|
||||
|
||||
# Filter by year range
|
||||
df = df.filter((pl.col("year") >= min_year) & (pl.col("year") <= max_year))
|
||||
|
|
@ -123,27 +86,71 @@ async def get_hexagons(
|
|||
(pl.col("avg_price") >= min_price) & (pl.col("avg_price") <= max_price)
|
||||
)
|
||||
|
||||
# Limit results to prevent browser crashes
|
||||
# Limit results
|
||||
MAX_HEXAGONS = 50000
|
||||
df = df.limit(MAX_HEXAGONS)
|
||||
truncated = len(df) >= MAX_HEXAGONS
|
||||
if truncated:
|
||||
df = df.limit(MAX_HEXAGONS)
|
||||
|
||||
# Collect results
|
||||
result = df.collect()
|
||||
# Build response efficiently using Polars
|
||||
df = df.select(
|
||||
[
|
||||
pl.col("h3"),
|
||||
pl.col("count"),
|
||||
pl.col("avg_price").round(2),
|
||||
pl.col("median_price").round(2),
|
||||
pl.col("min_price"),
|
||||
pl.col("max_price"),
|
||||
]
|
||||
)
|
||||
|
||||
# Return lightweight response - just h3 index and properties
|
||||
# Frontend H3HexagonLayer will render the geometry
|
||||
# Use to_dicts() which is faster than iter_rows for large results
|
||||
rows = result.to_dicts()
|
||||
features = [
|
||||
{
|
||||
"h3": row["h3"],
|
||||
"count": row["count"],
|
||||
"avg_price": round(row["avg_price"], 2),
|
||||
"median_price": round(row["median_price"], 2) if row["median_price"] else None,
|
||||
"min_price": row["min_price"],
|
||||
"max_price": row["max_price"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
return df.to_dicts(), truncated
|
||||
|
||||
return {"features": features, "truncated": len(rows) >= MAX_HEXAGONS}
|
||||
|
||||
@router.get("/hexagons")
|
||||
async def get_hexagons(
|
||||
resolution: int = Query(
|
||||
DEFAULT_RESOLUTION,
|
||||
ge=min(VALID_RESOLUTIONS),
|
||||
le=max(VALID_RESOLUTIONS),
|
||||
description=f"H3 resolution ({min(VALID_RESOLUTIONS)}-{max(VALID_RESOLUTIONS)})",
|
||||
),
|
||||
min_year: int = Query(DEFAULT_MIN_YEAR, description="Minimum year filter"),
|
||||
max_year: int = Query(DEFAULT_MAX_YEAR, description="Maximum year filter"),
|
||||
min_price: float = Query(DEFAULT_MIN_PRICE, description="Minimum average price"),
|
||||
max_price: float = Query(DEFAULT_MAX_PRICE, description="Maximum average price"),
|
||||
bounds: str | None = Query(None, description="Bounding box: south,west,north,east"),
|
||||
) -> dict:
|
||||
"""Get aggregated property data as GeoJSON hexagons within bounds."""
|
||||
if resolution not in VALID_RESOLUTIONS:
|
||||
resolution = DEFAULT_RESOLUTION
|
||||
|
||||
if not bounds:
|
||||
raise HTTPException(status_code=400, detail="bounds parameter is required")
|
||||
|
||||
try:
|
||||
south, west, north, east = map(float, bounds.split(","))
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid bounds format. Use: south,west,north,east"
|
||||
)
|
||||
|
||||
# Round bounds to reduce cache misses (0.01 degree ≈ 1km precision)
|
||||
bounds_tuple = (
|
||||
round(south, 2),
|
||||
round(west, 2),
|
||||
round(north, 2),
|
||||
round(east, 2),
|
||||
)
|
||||
|
||||
# Convert prices to int for cache key hashability
|
||||
features, truncated = query_hexagons_cached(
|
||||
resolution,
|
||||
min_year,
|
||||
max_year,
|
||||
int(min_price),
|
||||
int(max_price),
|
||||
bounds_tuple,
|
||||
)
|
||||
|
||||
return {"features": features, "truncated": truncated}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue