122 lines
3.4 KiB
Python
122 lines
3.4 KiB
Python
"""POI (Points of Interest) API endpoint."""
|
|
|
|
import os
|
|
|
|
os.environ["POLARS_UNKNOWN_EXTENSION_TYPE_BEHAVIOR"] = "load_as_storage"
|
|
|
|
from pathlib import Path
|
|
|
|
from fastapi import APIRouter, Query
|
|
import polars as pl
|
|
|
|
router = APIRouter()
|
|
|
|
DATA_FILE = Path("data_sources/uk_pois.parquet")
|
|
|
|
# Categories useful for property buyers
|
|
POI_CATEGORIES = {
|
|
"schools": [
|
|
"elementary_school",
|
|
"school",
|
|
"high_school",
|
|
"preschool",
|
|
"college_university",
|
|
"private_school",
|
|
],
|
|
"healthcare": [
|
|
"doctor",
|
|
"dentist",
|
|
"pharmacy",
|
|
"hospital",
|
|
"public_health_clinic",
|
|
],
|
|
"transport": [
|
|
"train_station",
|
|
"bus_station",
|
|
"metro_station",
|
|
"light_rail_and_subway_stations",
|
|
],
|
|
"parks": ["park", "national_park", "dog_park"],
|
|
"emergency": ["police_department", "fire_department"],
|
|
"supermarkets": ["supermarket", "grocery_store", "convenience_store"],
|
|
}
|
|
|
|
# Flatten for quick lookup
|
|
ALL_CATEGORIES = {cat for cats in POI_CATEGORIES.values() for cat in cats}
|
|
|
|
# Cache the dataframe
|
|
_df_cache: pl.DataFrame | None = None
|
|
|
|
|
|
def get_df() -> pl.DataFrame | None:
|
|
"""Load and cache the POI dataframe."""
|
|
global _df_cache
|
|
if _df_cache is None:
|
|
if not DATA_FILE.exists():
|
|
return None
|
|
df = pl.read_parquet(DATA_FILE)
|
|
# Extract fields we need and filter to relevant categories
|
|
_df_cache = df.select(
|
|
pl.col("id"),
|
|
pl.col("names").struct.field("primary").alias("name"),
|
|
pl.col("categories").struct.field("primary").alias("category"),
|
|
pl.col("bbox").struct.field("xmin").alias("lng"),
|
|
pl.col("bbox").struct.field("ymin").alias("lat"),
|
|
).filter(pl.col("category").is_in(ALL_CATEGORIES))
|
|
return _df_cache
|
|
|
|
|
|
def preload_pois() -> None:
|
|
"""Preload POI data on startup."""
|
|
df = get_df()
|
|
if df is not None:
|
|
print(f"Loaded {len(df):,} POIs")
|
|
|
|
|
|
@router.get("/pois")
|
|
async def get_pois(
|
|
categories: str = Query(..., description="Comma-separated category groups"),
|
|
bounds: str = Query(..., description="Bounding box: south,west,north,east"),
|
|
) -> dict:
|
|
"""Get POIs within bounds for specified category groups."""
|
|
df = get_df()
|
|
if df is None:
|
|
return {"features": []}
|
|
|
|
# Parse bounds
|
|
try:
|
|
south, west, north, east = map(float, bounds.split(","))
|
|
except ValueError:
|
|
return {"features": []}
|
|
|
|
# Get categories to include
|
|
requested_groups = [g.strip() for g in categories.split(",")]
|
|
cats_to_include = set()
|
|
for group in requested_groups:
|
|
if group in POI_CATEGORIES:
|
|
cats_to_include.update(POI_CATEGORIES[group])
|
|
|
|
if not cats_to_include:
|
|
return {"features": []}
|
|
|
|
# Filter by bounds and categories
|
|
filtered = df.filter(
|
|
(pl.col("lat") >= south)
|
|
& (pl.col("lat") <= north)
|
|
& (pl.col("lng") >= west)
|
|
& (pl.col("lng") <= east)
|
|
& (pl.col("category").is_in(cats_to_include))
|
|
)
|
|
|
|
# Limit results to avoid overwhelming the frontend
|
|
MAX_POIS = 5000
|
|
if len(filtered) > MAX_POIS:
|
|
filtered = filtered.sample(n=MAX_POIS, seed=42)
|
|
|
|
return {"features": filtered.to_dicts()}
|
|
|
|
|
|
@router.get("/poi-categories")
|
|
async def get_poi_categories() -> dict:
|
|
"""Get available POI category groups."""
|
|
return {"categories": list(POI_CATEGORIES.keys())}
|