perfect-postcode/server/routes/pois.py
2026-01-26 22:02:23 +00:00

122 lines
3.4 KiB
Python

"""POI (Points of Interest) API endpoint."""
import os
os.environ["POLARS_UNKNOWN_EXTENSION_TYPE_BEHAVIOR"] = "load_as_storage"
from pathlib import Path
from fastapi import APIRouter, Query
import polars as pl
router = APIRouter()
DATA_FILE = Path("data_sources/uk_pois.parquet")
# Categories useful for property buyers
POI_CATEGORIES = {
"schools": [
"elementary_school",
"school",
"high_school",
"preschool",
"college_university",
"private_school",
],
"healthcare": [
"doctor",
"dentist",
"pharmacy",
"hospital",
"public_health_clinic",
],
"transport": [
"train_station",
"bus_station",
"metro_station",
"light_rail_and_subway_stations",
],
"parks": ["park", "national_park", "dog_park"],
"emergency": ["police_department", "fire_department"],
"supermarkets": ["supermarket", "grocery_store", "convenience_store"],
}
# Flatten for quick lookup
ALL_CATEGORIES = {cat for cats in POI_CATEGORIES.values() for cat in cats}
# Cache the dataframe
_df_cache: pl.DataFrame | None = None
def get_df() -> pl.DataFrame | None:
"""Load and cache the POI dataframe."""
global _df_cache
if _df_cache is None:
if not DATA_FILE.exists():
return None
df = pl.read_parquet(DATA_FILE)
# Extract fields we need and filter to relevant categories
_df_cache = df.select(
pl.col("id"),
pl.col("names").struct.field("primary").alias("name"),
pl.col("categories").struct.field("primary").alias("category"),
pl.col("bbox").struct.field("xmin").alias("lng"),
pl.col("bbox").struct.field("ymin").alias("lat"),
).filter(pl.col("category").is_in(ALL_CATEGORIES))
return _df_cache
def preload_pois() -> None:
"""Preload POI data on startup."""
df = get_df()
if df is not None:
print(f"Loaded {len(df):,} POIs")
@router.get("/pois")
async def get_pois(
categories: str = Query(..., description="Comma-separated category groups"),
bounds: str = Query(..., description="Bounding box: south,west,north,east"),
) -> dict:
"""Get POIs within bounds for specified category groups."""
df = get_df()
if df is None:
return {"features": []}
# Parse bounds
try:
south, west, north, east = map(float, bounds.split(","))
except ValueError:
return {"features": []}
# Get categories to include
requested_groups = [g.strip() for g in categories.split(",")]
cats_to_include = set()
for group in requested_groups:
if group in POI_CATEGORIES:
cats_to_include.update(POI_CATEGORIES[group])
if not cats_to_include:
return {"features": []}
# Filter by bounds and categories
filtered = df.filter(
(pl.col("lat") >= south)
& (pl.col("lat") <= north)
& (pl.col("lng") >= west)
& (pl.col("lng") <= east)
& (pl.col("category").is_in(cats_to_include))
)
# Limit results to avoid overwhelming the frontend
MAX_POIS = 5000
if len(filtered) > MAX_POIS:
filtered = filtered.sample(n=MAX_POIS, seed=42)
return {"features": filtered.to_dicts()}
@router.get("/poi-categories")
async def get_poi_categories() -> dict:
"""Get available POI category groups."""
return {"categories": list(POI_CATEGORIES.keys())}