Update map to do filtering
This commit is contained in:
parent
6122ee44da
commit
d4fe881ef4
8 changed files with 349 additions and 372 deletions
|
|
@ -10,10 +10,6 @@ from server.config import (
|
|||
AGGREGATES_DIR,
|
||||
VALID_RESOLUTIONS,
|
||||
DEFAULT_RESOLUTION,
|
||||
DEFAULT_MIN_YEAR,
|
||||
DEFAULT_MAX_YEAR,
|
||||
DEFAULT_MIN_PRICE,
|
||||
DEFAULT_MAX_PRICE,
|
||||
BOUNDS_BUFFER_PERCENT,
|
||||
)
|
||||
|
||||
|
|
@ -22,6 +18,38 @@ router = APIRouter()
|
|||
# Cache loaded dataframes in memory (one per resolution)
|
||||
_df_cache: dict[int, pl.DataFrame] = {}
|
||||
|
||||
# Discovered features (computed once on first load)
|
||||
_features_cache: list[dict] | None = None
|
||||
|
||||
|
||||
def _snake_to_label(name: str) -> str:
|
||||
"""Convert snake_case feature name to a human-readable label."""
|
||||
return name.replace("_", " ").title()
|
||||
|
||||
|
||||
def _discover_features(df: pl.DataFrame) -> list[dict]:
|
||||
"""Discover features from column pairs min_X / max_X."""
|
||||
features = []
|
||||
seen = set()
|
||||
for col in df.columns:
|
||||
if col.startswith("min_"):
|
||||
name = col[4:]
|
||||
max_col = f"max_{name}"
|
||||
if max_col in df.columns and name not in seen:
|
||||
seen.add(name)
|
||||
global_min = df[col].min()
|
||||
global_max = df[max_col].max()
|
||||
if global_min is not None and global_max is not None:
|
||||
features.append(
|
||||
{
|
||||
"name": name,
|
||||
"min": float(global_min),
|
||||
"max": float(global_max),
|
||||
"label": _snake_to_label(name),
|
||||
}
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
def preload_dataframes() -> None:
|
||||
"""Load all resolution dataframes into cache on startup."""
|
||||
|
|
@ -38,25 +66,41 @@ def get_cached_df(resolution: int) -> pl.DataFrame | None:
|
|||
# Load and add H3 cell centroids for fast bbox filtering
|
||||
df = pl.read_parquet(parquet_path)
|
||||
|
||||
# Pre-compute cell centroids for bbox filtering (much faster than is_in)
|
||||
# Pre-compute cell centroids for bbox filtering
|
||||
centroids = [h3.cell_to_latlng(cell) for cell in df["h3"].to_list()]
|
||||
df = df.with_columns(
|
||||
[
|
||||
pl.Series("lat", [c[0] for c in centroids]),
|
||||
pl.Series("lng", [c[1] for c in centroids]),
|
||||
pl.Series("_lat", [c[0] for c in centroids]),
|
||||
pl.Series("_lng", [c[1] for c in centroids]),
|
||||
]
|
||||
)
|
||||
_df_cache[resolution] = df
|
||||
return _df_cache[resolution]
|
||||
|
||||
|
||||
def get_features() -> list[dict]:
|
||||
"""Get discovered features, computing from the first available resolution."""
|
||||
global _features_cache
|
||||
if _features_cache is None:
|
||||
for resolution in VALID_RESOLUTIONS:
|
||||
df = get_cached_df(resolution)
|
||||
if df is not None:
|
||||
_features_cache = _discover_features(df)
|
||||
break
|
||||
if _features_cache is None:
|
||||
_features_cache = []
|
||||
return _features_cache
|
||||
|
||||
|
||||
@router.get("/features")
|
||||
async def get_features_endpoint() -> dict:
|
||||
"""Return discovered feature metadata with global min/max ranges."""
|
||||
return {"features": get_features()}
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def query_hexagons_cached(
|
||||
resolution: int,
|
||||
min_year: int,
|
||||
max_year: int,
|
||||
min_price: int,
|
||||
max_price: int,
|
||||
bounds_tuple: tuple[float, float, float, float],
|
||||
) -> list[dict]:
|
||||
"""Cached query - returns features list."""
|
||||
|
|
@ -64,65 +108,18 @@ def query_hexagons_cached(
|
|||
|
||||
df = get_cached_df(resolution)
|
||||
if df is None:
|
||||
return [], False
|
||||
return []
|
||||
|
||||
# Fast bbox filter using pre-computed centroids (O(1) per row)
|
||||
# Fast bbox filter using pre-computed centroids
|
||||
df = df.filter(
|
||||
(pl.col("lat") >= south)
|
||||
& (pl.col("lat") <= north)
|
||||
& (pl.col("lng") >= west)
|
||||
& (pl.col("lng") <= east)
|
||||
(pl.col("_lat") >= south)
|
||||
& (pl.col("_lat") <= north)
|
||||
& (pl.col("_lng") >= west)
|
||||
& (pl.col("_lng") <= east)
|
||||
)
|
||||
|
||||
# Filter by year range
|
||||
df = df.filter((pl.col("year") >= min_year) & (pl.col("year") <= max_year))
|
||||
|
||||
# Check which journey time columns exist
|
||||
journey_cols = [
|
||||
"median_journey_minutes",
|
||||
"median_pt_easy_minutes",
|
||||
"median_pt_quick_minutes",
|
||||
"median_cycling_minutes",
|
||||
]
|
||||
available_journey_cols = [c for c in journey_cols if c in df.columns]
|
||||
|
||||
# Aggregate across years (weighted by count)
|
||||
agg_exprs = [
|
||||
pl.col("count").sum().alias("count"),
|
||||
(pl.col("avg_price") * pl.col("count")).sum().alias("weighted_price_sum"),
|
||||
pl.col("median_price").median().alias("median_price"),
|
||||
pl.col("min_price").min().alias("min_price"),
|
||||
pl.col("max_price").max().alias("max_price"),
|
||||
]
|
||||
for jc in available_journey_cols:
|
||||
# Journey time is same across years, just take first non-null
|
||||
agg_exprs.append(pl.col(jc).first())
|
||||
|
||||
df = df.group_by("h3").agg(agg_exprs)
|
||||
|
||||
# Calculate weighted average price
|
||||
df = df.with_columns(
|
||||
(pl.col("weighted_price_sum") / pl.col("count")).alias("avg_price")
|
||||
).drop("weighted_price_sum")
|
||||
|
||||
# Filter by price range
|
||||
df = df.filter(
|
||||
(pl.col("avg_price") >= min_price) & (pl.col("avg_price") <= max_price)
|
||||
)
|
||||
|
||||
# Build response efficiently using Polars
|
||||
select_cols = [
|
||||
pl.col("h3"),
|
||||
pl.col("count"),
|
||||
pl.col("avg_price").round(2),
|
||||
pl.col("median_price").round(2),
|
||||
pl.col("min_price"),
|
||||
pl.col("max_price"),
|
||||
]
|
||||
for jc in available_journey_cols:
|
||||
select_cols.append(pl.col(jc).round(0))
|
||||
|
||||
df = df.select(select_cols)
|
||||
# Drop internal centroid columns before returning
|
||||
df = df.drop("_lat", "_lng")
|
||||
|
||||
return df.to_dicts()
|
||||
|
||||
|
|
@ -135,13 +132,9 @@ async def get_hexagons(
|
|||
le=max(VALID_RESOLUTIONS),
|
||||
description=f"H3 resolution ({min(VALID_RESOLUTIONS)}-{max(VALID_RESOLUTIONS)})",
|
||||
),
|
||||
min_year: int = Query(DEFAULT_MIN_YEAR, description="Minimum year filter"),
|
||||
max_year: int = Query(DEFAULT_MAX_YEAR, description="Maximum year filter"),
|
||||
min_price: float = Query(DEFAULT_MIN_PRICE, description="Minimum average price"),
|
||||
max_price: float = Query(DEFAULT_MAX_PRICE, description="Maximum average price"),
|
||||
bounds: str | None = Query(None, description="Bounding box: south,west,north,east"),
|
||||
) -> dict:
|
||||
"""Get aggregated property data as GeoJSON hexagons within bounds."""
|
||||
"""Get aggregated property data as hexagons within bounds."""
|
||||
if resolution not in VALID_RESOLUTIONS:
|
||||
resolution = DEFAULT_RESOLUTION
|
||||
|
||||
|
|
@ -165,9 +158,7 @@ async def get_hexagons(
|
|||
west -= lng_buffer
|
||||
east += lng_buffer
|
||||
|
||||
# Round bounds to reduce cache misses (0.01 degree ≈ 1km precision)
|
||||
# Always expand bounds (floor for min, ceil for max) to prevent hexagons
|
||||
# popping in when crossing rounding boundaries
|
||||
# Round bounds to reduce cache misses (0.01 degree ~ 1km precision)
|
||||
precision = 0.01
|
||||
bounds_tuple = (
|
||||
math.floor(south / precision) * precision,
|
||||
|
|
@ -176,14 +167,6 @@ async def get_hexagons(
|
|||
math.ceil(east / precision) * precision,
|
||||
)
|
||||
|
||||
# Convert prices to int for cache key hashability
|
||||
features = query_hexagons_cached(
|
||||
resolution,
|
||||
min_year,
|
||||
max_year,
|
||||
int(min_price),
|
||||
int(max_price),
|
||||
bounds_tuple,
|
||||
)
|
||||
features = query_hexagons_cached(resolution, bounds_tuple)
|
||||
|
||||
return {"features": features}
|
||||
|
|
|
|||
|
|
@ -9,8 +9,11 @@ router = APIRouter()
|
|||
|
||||
DATA_FILE = Path("data_sources/uk_pois.parquet")
|
||||
|
||||
# Category groups with emoji and member categories
|
||||
POI_CATEGORY_GROUPS: dict[str, dict] = {
|
||||
# Group definitions: maps a group key to its display metadata and the
|
||||
# individual POI categories it contains. Categories are matched against
|
||||
# the values that actually exist in the loaded parquet so that the
|
||||
# selector only shows groups with real data.
|
||||
_GROUP_DEFS: dict[str, dict] = {
|
||||
"schools": {
|
||||
"emoji": "🏫",
|
||||
"label": "Schools",
|
||||
|
|
@ -189,33 +192,80 @@ POI_CATEGORY_GROUPS: dict[str, dict] = {
|
|||
},
|
||||
}
|
||||
|
||||
# Flatten for quick lookup
|
||||
ALL_CATEGORIES = {
|
||||
cat for group in POI_CATEGORY_GROUPS.values() for cat in group["categories"]
|
||||
}
|
||||
# Built at startup from the data — only groups whose member categories
|
||||
# actually appear in the parquet file are included.
|
||||
_active_groups: dict[str, dict] = {}
|
||||
|
||||
# Reverse lookup: category value -> group key (built at startup)
|
||||
_cat_to_group: dict[str, str] = {}
|
||||
|
||||
# Cache the dataframe
|
||||
_df_cache: pl.DataFrame | None = None
|
||||
|
||||
|
||||
def _load_and_build() -> pl.DataFrame | None:
|
||||
"""Load the parquet, build category groups from actual data."""
|
||||
global _df_cache, _active_groups, _cat_to_group
|
||||
|
||||
if not DATA_FILE.exists():
|
||||
return None
|
||||
|
||||
df = pl.read_parquet(DATA_FILE).select("id", "name", "category", "lat", "lng")
|
||||
|
||||
# Distinct categories present in the data
|
||||
data_categories: set[str] = set(
|
||||
df.select("category").unique().to_series().to_list()
|
||||
)
|
||||
|
||||
# Per-category counts for the response
|
||||
counts: dict[str, int] = dict(
|
||||
df.group_by("category")
|
||||
.agg(pl.len().alias("n"))
|
||||
.iter_rows()
|
||||
)
|
||||
|
||||
# Build reverse map from every known category to its group
|
||||
cat_to_group: dict[str, str] = {}
|
||||
for key, gdef in _GROUP_DEFS.items():
|
||||
for cat in gdef["categories"]:
|
||||
cat_to_group[cat] = key
|
||||
|
||||
# Only keep categories that belong to a known group
|
||||
known_categories = data_categories & cat_to_group.keys()
|
||||
|
||||
# Build active groups — only those with at least one matching category
|
||||
active: dict[str, dict] = {}
|
||||
for key, gdef in _GROUP_DEFS.items():
|
||||
present = [c for c in gdef["categories"] if c in known_categories]
|
||||
if present:
|
||||
active[key] = {
|
||||
"emoji": gdef["emoji"],
|
||||
"label": gdef["label"],
|
||||
"categories": present,
|
||||
"count": sum(counts.get(c, 0) for c in present),
|
||||
}
|
||||
|
||||
_active_groups = active
|
||||
_cat_to_group = cat_to_group
|
||||
|
||||
# Filter dataframe to only known categories
|
||||
_df_cache = df.filter(pl.col("category").is_in(known_categories))
|
||||
return _df_cache
|
||||
|
||||
|
||||
def get_df() -> pl.DataFrame | None:
|
||||
"""Load and cache the POI dataframe."""
|
||||
global _df_cache
|
||||
"""Return cached POI dataframe, loading if necessary."""
|
||||
if _df_cache is None:
|
||||
if not DATA_FILE.exists():
|
||||
return None
|
||||
df = pl.read_parquet(DATA_FILE)
|
||||
_df_cache = df.select("id", "name", "category", "lat", "lng").filter(
|
||||
pl.col("category").is_in(ALL_CATEGORIES)
|
||||
)
|
||||
return _load_and_build()
|
||||
return _df_cache
|
||||
|
||||
|
||||
def preload_pois() -> None:
|
||||
"""Preload POI data on startup."""
|
||||
df = get_df()
|
||||
df = _load_and_build()
|
||||
if df is not None:
|
||||
print(f"Loaded {len(df):,} POIs")
|
||||
n_groups = len(_active_groups)
|
||||
print(f"Loaded {len(df):,} POIs across {n_groups} category groups")
|
||||
|
||||
|
||||
@router.get("/pois")
|
||||
|
|
@ -234,10 +284,10 @@ async def get_pois(
|
|||
return {"features": []}
|
||||
|
||||
requested_groups = [g.strip() for g in categories.split(",")]
|
||||
cats_to_include = set()
|
||||
cats_to_include: set[str] = set()
|
||||
for group in requested_groups:
|
||||
if group in POI_CATEGORY_GROUPS:
|
||||
cats_to_include.update(POI_CATEGORY_GROUPS[group]["categories"])
|
||||
if group in _active_groups:
|
||||
cats_to_include.update(_active_groups[group]["categories"])
|
||||
|
||||
if not cats_to_include:
|
||||
return {"features": []}
|
||||
|
|
@ -259,10 +309,14 @@ async def get_pois(
|
|||
|
||||
@router.get("/poi-categories")
|
||||
async def get_poi_categories() -> dict:
|
||||
"""Get available POI category groups with emoji and labels."""
|
||||
"""Get available POI category groups derived from loaded data."""
|
||||
return {
|
||||
"categories": {
|
||||
key: {"emoji": group["emoji"], "label": group["label"]}
|
||||
for key, group in POI_CATEGORY_GROUPS.items()
|
||||
key: {
|
||||
"emoji": group["emoji"],
|
||||
"label": group["label"],
|
||||
"count": group["count"],
|
||||
}
|
||||
for key, group in _active_groups.items()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue