"""POI (Points of Interest) API endpoint.""" import os os.environ["POLARS_UNKNOWN_EXTENSION_TYPE_BEHAVIOR"] = "load_as_storage" from pathlib import Path from fastapi import APIRouter, Query import polars as pl router = APIRouter() DATA_FILE = Path("data_sources/uk_pois.parquet") # Categories useful for property buyers POI_CATEGORIES = { "schools": [ "elementary_school", "school", "high_school", "preschool", "college_university", "private_school", ], "healthcare": [ "doctor", "dentist", "pharmacy", "hospital", "public_health_clinic", ], "transport": [ "train_station", "bus_station", "metro_station", "light_rail_and_subway_stations", ], "parks": ["park", "national_park", "dog_park"], "emergency": ["police_department", "fire_department"], "supermarkets": ["supermarket", "grocery_store", "convenience_store"], } # Flatten for quick lookup ALL_CATEGORIES = {cat for cats in POI_CATEGORIES.values() for cat in cats} # Cache the dataframe _df_cache: pl.DataFrame | None = None def get_df() -> pl.DataFrame | None: """Load and cache the POI dataframe.""" global _df_cache if _df_cache is None: if not DATA_FILE.exists(): return None df = pl.read_parquet(DATA_FILE) # Extract fields we need and filter to relevant categories _df_cache = df.select( pl.col("id"), pl.col("names").struct.field("primary").alias("name"), pl.col("categories").struct.field("primary").alias("category"), pl.col("bbox").struct.field("xmin").alias("lng"), pl.col("bbox").struct.field("ymin").alias("lat"), ).filter(pl.col("category").is_in(ALL_CATEGORIES)) return _df_cache def preload_pois() -> None: """Preload POI data on startup.""" df = get_df() if df is not None: print(f"Loaded {len(df):,} POIs") @router.get("/pois") async def get_pois( categories: str = Query(..., description="Comma-separated category groups"), bounds: str = Query(..., description="Bounding box: south,west,north,east"), ) -> dict: """Get POIs within bounds for specified category groups.""" df = get_df() if df is None: return {"features": []} # Parse bounds try: south, west, north, east = map(float, bounds.split(",")) except ValueError: return {"features": []} # Get categories to include requested_groups = [g.strip() for g in categories.split(",")] cats_to_include = set() for group in requested_groups: if group in POI_CATEGORIES: cats_to_include.update(POI_CATEGORIES[group]) if not cats_to_include: return {"features": []} # Filter by bounds and categories filtered = df.filter( (pl.col("lat") >= south) & (pl.col("lat") <= north) & (pl.col("lng") >= west) & (pl.col("lng") <= east) & (pl.col("category").is_in(cats_to_include)) ) # Limit results to avoid overwhelming the frontend MAX_POIS = 5000 if len(filtered) > MAX_POIS: filtered = filtered.sample(n=MAX_POIS, seed=42) return {"features": filtered.to_dicts()} @router.get("/poi-categories") async def get_poi_categories() -> dict: """Get available POI category groups.""" return {"categories": list(POI_CATEGORIES.keys())}