From d4fe881ef49793ff9093237db3fc5121da070375 Mon Sep 17 00:00:00 2001 From: Andras Schmelczer Date: Fri, 30 Jan 2026 18:34:12 +0000 Subject: [PATCH] Update map to do filtering --- frontend/src/App.tsx | 84 ++++++++++---- frontend/src/components/Filters.tsx | 157 ++++++++++---------------- frontend/src/components/Map.tsx | 166 ++++++++++++---------------- frontend/src/lib/constants.ts | 20 +--- frontend/src/types.ts | 35 +++--- server/config.py | 12 -- server/routes/hexagons.py | 149 +++++++++++-------------- server/routes/pois.py | 98 ++++++++++++---- 8 files changed, 349 insertions(+), 372 deletions(-) diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index f4838fb..1ba8e38 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,9 +1,9 @@ -import { useState, useEffect, useCallback, useRef } from 'react'; +import { useState, useEffect, useCallback, useRef, useMemo } from 'react'; import Map from './components/Map'; import Filters from './components/Filters'; -import { DEFAULT_FILTERS } from './lib/constants'; import type { - Filters as FiltersType, + FeatureMeta, + FeatureFilters, Bounds, HexagonData, ViewChangeParams, @@ -11,7 +11,6 @@ import type { POI, POIResponse, POICategoriesMap, - ColorMode, } from './types'; const DEBOUNCE_MS = 150; @@ -42,8 +41,10 @@ function getApiBaseUrl(): string { } export default function App() { - const [filters, setFilters] = useState(DEFAULT_FILTERS); - const [data, setData] = useState([]); + const [features, setFeatures] = useState([]); + const [filters, setFilters] = useState({}); + const [activeFeature, setActiveFeature] = useState(null); + const [rawData, setRawData] = useState([]); const [resolution, setResolution] = useState(8); const [bounds, setBounds] = useState(null); const [loading, setLoading] = useState(false); @@ -51,8 +52,6 @@ export default function App() { const debounceRef = useRef | null>(null); const abortControllerRef = useRef(null); - const [colorMode, setColorMode] = useState('price'); - // POI state const [pois, setPois] = useState([]); const [poiCategories, setPOICategories] = useState({}); @@ -60,8 +59,21 @@ export default function App() { const poiDebounceRef = useRef | null>(null); const poiAbortControllerRef = useRef(null); - // Fetch POI category definitions from server on mount + // Fetch feature metadata + POI categories on mount useEffect(() => { + fetch(`${getApiBaseUrl()}/api/features`) + .then((res) => res.json()) + .then((json: { features: FeatureMeta[] }) => { + setFeatures(json.features); + // Initialize filters with full range for each feature + const initial: FeatureFilters = {}; + for (const f of json.features) { + initial[f.name] = [f.min, f.max]; + } + setFilters(initial); + }) + .catch((err) => console.error('Failed to fetch features:', err)); + fetch(`${getApiBaseUrl()}/api/poi-categories`) .then((res) => res.json()) .then((json: { categories: POICategoriesMap }) => { @@ -70,7 +82,7 @@ export default function App() { .catch((err) => console.error('Failed to fetch POI categories:', err)); }, []); - // Debounced fetch when dependencies change + // Debounced fetch when resolution/bounds change (no filter params sent) useEffect(() => { if (!bounds) return; @@ -89,17 +101,13 @@ export default function App() { const boundsStr = `${bounds.south},${bounds.west},${bounds.north},${bounds.east}`; const params = new URLSearchParams({ resolution: resolution.toString(), - min_year: filters.minYear.toString(), - max_year: filters.maxYear.toString(), - min_price: filters.minPrice.toString(), - max_price: filters.maxPrice.toString(), bounds: boundsStr, }); const res = await fetch(`${getApiBaseUrl()}/api/hexagons?${params}`, { signal: abortControllerRef.current.signal, }); const json: ApiResponse = await res.json(); - setData(json.features || []); + setRawData(json.features || []); } catch (err) { if (err instanceof Error && err.name !== 'AbortError') { console.error('Failed to fetch data:', err); @@ -114,7 +122,36 @@ export default function App() { clearTimeout(debounceRef.current); } }; - }, [filters, resolution, bounds]); + }, [resolution, bounds]); + + // Client-side filtering + const data = useMemo(() => { + if (features.length === 0) return rawData; + + return rawData.filter((hex) => { + if (activeFeature) { + // Only apply the active feature's filter + const range = filters[activeFeature]; + if (!range) return true; + const minVal = hex[`min_${activeFeature}`]; + const maxVal = hex[`max_${activeFeature}`]; + if (minVal == null || maxVal == null) return true; + return (minVal as number) <= range[1] && (maxVal as number) >= range[0]; + } + // Apply ALL filters as intersection + for (const f of features) { + const range = filters[f.name]; + if (!range) continue; + // Skip features where filter is at full range + if (range[0] === f.min && range[1] === f.max) continue; + const minVal = hex[`min_${f.name}`]; + const maxVal = hex[`max_${f.name}`]; + if (minVal == null || maxVal == null) continue; + if ((minVal as number) > range[1] || (maxVal as number) < range[0]) return false; + } + return true; + }); + }, [rawData, filters, activeFeature, features]); // Fetch POIs when bounds or selected categories change useEffect(() => { @@ -171,17 +208,24 @@ export default function App() { return (
- + {loading && (
Loading...
)} diff --git a/frontend/src/components/Filters.tsx b/frontend/src/components/Filters.tsx index b956df8..181d685 100644 --- a/frontend/src/components/Filters.tsx +++ b/frontend/src/components/Filters.tsx @@ -1,32 +1,38 @@ import { useState, useRef, useEffect } from 'react'; import { Slider } from './ui/slider'; import { Label } from './ui/label'; -import { YEAR_MIN, YEAR_MAX, YEAR_STEP, PRICE_MIN, PRICE_MAX, PRICE_STEP } from '../lib/constants'; -import type { Filters as FiltersType, POICategoriesMap, ColorMode } from '../types'; +import type { FeatureMeta, FeatureFilters, POICategoriesMap } from '../types'; interface FiltersProps { - filters: FiltersType; - onChange: (filters: FiltersType) => void; + features: FeatureMeta[]; + filters: FeatureFilters; + activeFeature: string | null; + onFiltersChange: (filters: FeatureFilters) => void; + onActiveFeatureChange: (feature: string | null) => void; zoom: number; poiCategories: POICategoriesMap; selectedPOICategories: Set; onPOICategoriesChange: (categories: Set) => void; - colorMode: ColorMode; - onColorModeChange: (mode: ColorMode) => void; +} + +function formatValue(value: number): string { + if (Math.abs(value) >= 1_000_000) return `${(value / 1_000_000).toFixed(1)}M`; + if (Math.abs(value) >= 1_000) return `${(value / 1_000).toFixed(1)}k`; + if (Number.isInteger(value)) return value.toString(); + return value.toFixed(2); } export default function Filters({ + features, filters, - onChange, + activeFeature, + onFiltersChange, + onActiveFeatureChange, zoom, poiCategories, selectedPOICategories, onPOICategoriesChange, - colorMode, - onColorModeChange, }: FiltersProps) { - const update = (key: keyof FiltersType, value: number) => onChange({ ...filters, [key]: value }); - const [dropdownOpen, setDropdownOpen] = useState(false); const dropdownRef = useRef(null); @@ -63,99 +69,53 @@ export default function Filters({ const selectedCount = selectedPOICategories.size; return ( -
+

UK Property Prices

Zoom: {zoom.toFixed(1)}
-
- - onChange({ ...filters, minYear: min, maxYear: max })} - /> -
+ {features.map((feature) => { + const range = filters[feature.name] || [feature.min, feature.max]; + const isActive = activeFeature === feature.name; + const step = (feature.max - feature.min) / 100; -
- - update('minPrice', v)} - /> -
- -
- - update('maxPrice', v)} - /> -
- -
- -
- - -
-
- - {colorMode === 'price' ? ( -
-
Average Price
+ return (
-
- £0 - £200k - £400k - £800k+ + key={feature.name} + className={`space-y-1 p-2 rounded ${isActive ? 'ring-2 ring-blue-400 bg-blue-50' : ''}`} + > + + { + onFiltersChange({ ...filters, [feature.name]: [min, max] }); + }} + onPointerDown={() => onActiveFeatureChange(feature.name)} + onPointerUp={() => onActiveFeatureChange(null)} + />
+ ); + })} + +
+
Color Scale
+
+
+ Low + High
- ) : ( -
-
Journey Time to Bank
-
-
- 0 min - 30 min - 60 min - 120+ min -
-
- )} +
@@ -199,7 +159,7 @@ export default function Filters({
{categoryKeys.map((key) => { - const { emoji, label } = poiCategories[key]; + const { emoji, label, count } = poiCategories[key]; return ( ); })} diff --git a/frontend/src/components/Map.tsx b/frontend/src/components/Map.tsx index a3e8ec2..6353cac 100644 --- a/frontend/src/components/Map.tsx +++ b/frontend/src/components/Map.tsx @@ -6,13 +6,14 @@ import { H3HexagonLayer } from '@deck.gl/geo-layers'; import { IconLayer } from '@deck.gl/layers'; import type { PickingInfo } from '@deck.gl/core'; import 'maplibre-gl/dist/maplibre-gl.css'; -import type { HexagonData, ViewState, ViewChangeParams, Bounds, POI, ColorMode } from '../types'; +import type { HexagonData, ViewState, ViewChangeParams, Bounds, POI, FeatureMeta } from '../types'; interface MapProps { data: HexagonData[]; pois: POI[]; onViewChange: (params: ViewChangeParams) => void; - colorMode: ColorMode; + activeFeature: string | null; + features: FeatureMeta[]; } // Twemoji CDN base URL @@ -185,66 +186,31 @@ const INITIAL_VIEW: ViewState = { const MAP_STYLE = 'https://basemaps.cartocdn.com/gl/positron-gl-style/style.json'; -interface ColorStop { - price: number; - color: [number, number, number]; -} - -// Continuous color scale from green (low) -> yellow -> red -> purple (high) -const COLOR_SCALE: ColorStop[] = [ - { price: 0, color: [46, 204, 113] }, // Green - { price: 200000, color: [241, 196, 15] }, // Yellow - { price: 400000, color: [231, 76, 60] }, // Red - { price: 800000, color: [142, 68, 173] }, // Purple +// Gradient stops for normalized [0,1] values +const GRADIENT: { t: number; color: [number, number, number] }[] = [ + { t: 0, color: [46, 204, 113] }, // Green + { t: 0.33, color: [241, 196, 15] }, // Yellow + { t: 0.66, color: [231, 76, 60] }, // Red + { t: 1, color: [142, 68, 173] }, // Purple ]; -function interpolateColor( - c1: [number, number, number], - c2: [number, number, number], - t: number -): [number, number, number] { - return [ - Math.round(c1[0] + (c2[0] - c1[0]) * t), - Math.round(c1[1] + (c2[1] - c1[1]) * t), - Math.round(c1[2] + (c2[2] - c1[2]) * t), - ]; -} +function normalizedToColor(t: number): [number, number, number] { + if (t <= 0) return GRADIENT[0].color; + if (t >= 1) return GRADIENT[GRADIENT.length - 1].color; -function scaleToColor( - value: number | null | undefined, - scale: ColorStop[] -): [number, number, number] { - if (value == null || isNaN(value)) return [128, 128, 128]; - - if (value <= scale[0].price) return scale[0].color; - if (value >= scale[scale.length - 1].price) return scale[scale.length - 1].color; - - for (let i = 0; i < scale.length - 1; i++) { - const lower = scale[i]; - const upper = scale[i + 1]; - if (value >= lower.price && value <= upper.price) { - const t = (value - lower.price) / (upper.price - lower.price); - return interpolateColor(lower.color, upper.color, t); + for (let i = 0; i < GRADIENT.length - 1; i++) { + const lo = GRADIENT[i]; + const hi = GRADIENT[i + 1]; + if (t >= lo.t && t <= hi.t) { + const frac = (t - lo.t) / (hi.t - lo.t); + return [ + Math.round(lo.color[0] + (hi.color[0] - lo.color[0]) * frac), + Math.round(lo.color[1] + (hi.color[1] - lo.color[1]) * frac), + Math.round(lo.color[2] + (hi.color[2] - lo.color[2]) * frac), + ]; } } - - return scale[scale.length - 1].color; -} - -function priceToColor(price: number | null | undefined): [number, number, number] { - return scaleToColor(price, COLOR_SCALE); -} - -// Journey time color scale: green (short) -> yellow -> orange -> red (long) -const JOURNEY_COLOR_SCALE: ColorStop[] = [ - { price: 0, color: [46, 204, 113] }, // Green - { price: 30, color: [241, 196, 15] }, // Yellow - { price: 60, color: [231, 76, 60] }, // Red - { price: 120, color: [142, 68, 173] }, // Purple -]; - -function journeyTimeToColor(minutes: number | null | undefined): [number, number, number] { - return scaleToColor(minutes, JOURNEY_COLOR_SCALE); + return GRADIENT[GRADIENT.length - 1].color; } function zoomToResolution(zoom: number): number { @@ -271,7 +237,6 @@ function getBoundsFromViewState(viewState: ViewState, width: number, height: num const halfWidthDeg = (width / 2) * degreesPerPixelLng; // Latitude uses Mercator projection (non-linear) - // Convert center lat to pixel y, offset by half height, convert back to lat const latRad = (clampedLat * Math.PI) / 180; const mercatorY = (1 - Math.log(Math.tan(latRad) + 1 / Math.cos(latRad)) / Math.PI) / 2; const centerPixelY = mercatorY * worldSize; @@ -281,7 +246,7 @@ function getBoundsFromViewState(viewState: ViewState, width: number, height: num // Convert pixel Y back to latitude const pixelYToLat = (pixelY: number): number => { - const mercY = Math.max(0.001, Math.min(0.999, pixelY / worldSize)); // Clamp to avoid edge cases + const mercY = Math.max(0.001, Math.min(0.999, pixelY / worldSize)); const latRadians = Math.atan(Math.sinh(Math.PI * (1 - 2 * mercY))); return (latRadians * 180) / Math.PI; }; @@ -315,7 +280,7 @@ function DeckOverlay({ return null; } -export default function Map({ data, pois, onViewChange, colorMode }: MapProps) { +export default function Map({ data, pois, onViewChange, activeFeature, features }: MapProps) { const containerRef = useRef(null); const [viewState, setViewState] = useState(INITIAL_VIEW); const [dimensions, setDimensions] = useState({ width: 0, height: 0 }); @@ -355,7 +320,6 @@ export default function Map({ data, pois, onViewChange, colorMode }: MapProps) { const map = evt.target; for (const layer of map.getStyle().layers || []) { if (layer.type !== 'symbol') continue; - // Stronger white halo so text pops over hex fills map.setPaintProperty(layer.id, 'text-halo-color', 'rgba(255,255,255,1)'); map.setPaintProperty(layer.id, 'text-halo-width', 2); map.setPaintProperty(layer.id, 'text-color', '#222'); @@ -383,24 +347,32 @@ export default function Map({ data, pois, onViewChange, colorMode }: MapProps) { } }, []); + // Determine which feature to use for coloring + const colorFeatureName = activeFeature || (features.length > 0 ? features[0].name : null); + const colorFeatureMeta = features.find((f) => f.name === colorFeatureName) || null; + const layers = useMemo( () => [ new H3HexagonLayer({ id: 'h3-hexagons', data, getHexagon: (d) => d.h3, - getFillColor: (d) => - colorMode === 'journey_time' - ? journeyTimeToColor(d.median_journey_minutes) - : priceToColor(d.avg_price), + getFillColor: (d) => { + if (!colorFeatureName || !colorFeatureMeta) return [128, 128, 128] as [number, number, number]; + const val = d[`min_${colorFeatureName}`]; + if (val == null) return [128, 128, 128] as [number, number, number]; + const range = colorFeatureMeta.max - colorFeatureMeta.min; + if (range === 0) return GRADIENT[0].color; + const t = ((val as number) - colorFeatureMeta.min) / range; + return normalizedToColor(t); + }, updateTriggers: { - getFillColor: colorMode, + getFillColor: [colorFeatureName, colorFeatureMeta], }, extruded: false, pickable: true, opacity: 0.5, highPrecision: true, - // Render below labels so road names, place names etc. stay visible // @ts-expect-error beforeId is a MapboxOverlay interleave prop, not typed in LayerProps beforeId: LABEL_LAYER_ID, }), @@ -420,41 +392,39 @@ export default function Map({ data, pois, onViewChange, colorMode }: MapProps) { onHover: handlePoiHover, }), ], - [data, pois, handlePoiHover, colorMode] + [data, pois, handlePoiHover, colorFeatureName, colorFeatureMeta] ); - const getTooltip = useCallback(({ object }: { object?: HexagonData }) => { - if (!object || !('h3' in object)) return null; + const getTooltip = useCallback( + ({ object }: { object?: HexagonData }) => { + if (!object || !('h3' in object)) return null; - const hex = object as HexagonData; - const journeyLines: string[] = []; - if (hex.median_pt_quick_minutes != null) - journeyLines.push(`🚇 Quick PT: ${hex.median_pt_quick_minutes} min`); - if (hex.median_pt_easy_minutes != null) - journeyLines.push(`🚌 Easy PT: ${hex.median_pt_easy_minutes} min`); - if (hex.median_cycling_minutes != null) - journeyLines.push(`🚲 Cycling: ${hex.median_cycling_minutes} min`); - const journeyTimeHtml = - journeyLines.length > 0 - ? `
${journeyLines.join('
')}
` - : ''; + const hex = object; + const lines: string[] = []; + lines.push(`${(hex.count as number).toLocaleString()} properties`); - return { - html: `
- Avg: £${hex.avg_price?.toLocaleString() || 'N/A'} -
- ${hex.count} sales
- Range: £${hex.min_price?.toLocaleString()} - £${hex.max_price?.toLocaleString()} -
- ${journeyTimeHtml} -
`, - style: { - backgroundColor: 'white', - borderRadius: '4px', - boxShadow: '0 2px 4px rgba(0,0,0,0.2)', - }, - }; - }, []); + for (const f of features) { + const minVal = hex[`min_${f.name}`]; + const maxVal = hex[`max_${f.name}`]; + if (minVal != null && maxVal != null) { + const minStr = typeof minVal === 'number' ? minVal.toLocaleString(undefined, { maximumFractionDigits: 1 }) : String(minVal); + const maxStr = typeof maxVal === 'number' ? maxVal.toLocaleString(undefined, { maximumFractionDigits: 1 }) : String(maxVal); + const highlight = f.name === colorFeatureName ? 'font-weight: bold;' : ''; + lines.push(`
${f.label}: ${minStr} - ${maxStr}
`); + } + } + + return { + html: `
${lines.join('')}
`, + style: { + backgroundColor: 'white', + borderRadius: '4px', + boxShadow: '0 2px 4px rgba(0,0,0,0.2)', + }, + }; + }, + [features, colorFeatureName] + ); return (
diff --git a/frontend/src/lib/constants.ts b/frontend/src/lib/constants.ts index 0b7504d..890093b 100644 --- a/frontend/src/lib/constants.ts +++ b/frontend/src/lib/constants.ts @@ -1,19 +1 @@ -import type { Filters } from '../types'; - -// Filter configuration constants -// Should match backend pipeline/config.py - -export const YEAR_MIN = 1995; -export const YEAR_MAX = 2024; -export const YEAR_STEP = 1; - -export const PRICE_MIN = 0; -export const PRICE_MAX = 5000000; // £5M max for slider, but no server-side cap -export const PRICE_STEP = 50000; - -export const DEFAULT_FILTERS: Filters = { - minYear: 2020, - maxYear: YEAR_MAX, - minPrice: PRICE_MIN, - maxPrice: PRICE_MAX, -}; +// No hardcoded filter constants - features are discovered dynamically from the API. diff --git a/frontend/src/types.ts b/frontend/src/types.ts index d372bc2..7416021 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -1,8 +1,17 @@ -export interface Filters { - minYear: number; - maxYear: number; - minPrice: number; - maxPrice: number; +export interface FeatureMeta { + name: string; + min: number; + max: number; + label: string; +} + +// Filters: feature name -> [selectedMin, selectedMax] +export type FeatureFilters = Record; + +export interface HexagonData { + h3: string; + count: number; + [key: string]: string | number | null; } export interface Bounds { @@ -12,21 +21,6 @@ export interface Bounds { east: number; } -export interface HexagonData { - h3: string; - count: number; - avg_price: number; - median_price: number; - min_price: number; - max_price: number; - median_journey_minutes: number | null; - median_pt_easy_minutes: number | null; - median_pt_quick_minutes: number | null; - median_cycling_minutes: number | null; -} - -export type ColorMode = 'price' | 'journey_time'; - export interface ViewState { longitude: number; latitude: number; @@ -60,6 +54,7 @@ export interface POIResponse { export interface POICategoryInfo { emoji: string; label: string; + count: number; } export type POICategoriesMap = Record; diff --git a/server/config.py b/server/config.py index 798a6bd..fcb27d9 100644 --- a/server/config.py +++ b/server/config.py @@ -4,12 +4,6 @@ from pipeline.config import ( AGGREGATES_DIR, H3_RESOLUTIONS as VALID_RESOLUTIONS, DEFAULT_H3_RESOLUTION as DEFAULT_RESOLUTION, - MIN_YEAR, - MAX_YEAR, - DEFAULT_MIN_YEAR, - DEFAULT_MAX_YEAR, - DEFAULT_MIN_PRICE, - DEFAULT_MAX_PRICE, ) # Extra area to return beyond requested bounds (0.2 = 20%) @@ -20,11 +14,5 @@ __all__ = [ "AGGREGATES_DIR", "VALID_RESOLUTIONS", "DEFAULT_RESOLUTION", - "MIN_YEAR", - "MAX_YEAR", - "DEFAULT_MIN_YEAR", - "DEFAULT_MAX_YEAR", - "DEFAULT_MIN_PRICE", - "DEFAULT_MAX_PRICE", "BOUNDS_BUFFER_PERCENT", ] diff --git a/server/routes/hexagons.py b/server/routes/hexagons.py index afe3777..c12177e 100644 --- a/server/routes/hexagons.py +++ b/server/routes/hexagons.py @@ -10,10 +10,6 @@ from server.config import ( AGGREGATES_DIR, VALID_RESOLUTIONS, DEFAULT_RESOLUTION, - DEFAULT_MIN_YEAR, - DEFAULT_MAX_YEAR, - DEFAULT_MIN_PRICE, - DEFAULT_MAX_PRICE, BOUNDS_BUFFER_PERCENT, ) @@ -22,6 +18,38 @@ router = APIRouter() # Cache loaded dataframes in memory (one per resolution) _df_cache: dict[int, pl.DataFrame] = {} +# Discovered features (computed once on first load) +_features_cache: list[dict] | None = None + + +def _snake_to_label(name: str) -> str: + """Convert snake_case feature name to a human-readable label.""" + return name.replace("_", " ").title() + + +def _discover_features(df: pl.DataFrame) -> list[dict]: + """Discover features from column pairs min_X / max_X.""" + features = [] + seen = set() + for col in df.columns: + if col.startswith("min_"): + name = col[4:] + max_col = f"max_{name}" + if max_col in df.columns and name not in seen: + seen.add(name) + global_min = df[col].min() + global_max = df[max_col].max() + if global_min is not None and global_max is not None: + features.append( + { + "name": name, + "min": float(global_min), + "max": float(global_max), + "label": _snake_to_label(name), + } + ) + return features + def preload_dataframes() -> None: """Load all resolution dataframes into cache on startup.""" @@ -38,25 +66,41 @@ def get_cached_df(resolution: int) -> pl.DataFrame | None: # Load and add H3 cell centroids for fast bbox filtering df = pl.read_parquet(parquet_path) - # Pre-compute cell centroids for bbox filtering (much faster than is_in) + # Pre-compute cell centroids for bbox filtering centroids = [h3.cell_to_latlng(cell) for cell in df["h3"].to_list()] df = df.with_columns( [ - pl.Series("lat", [c[0] for c in centroids]), - pl.Series("lng", [c[1] for c in centroids]), + pl.Series("_lat", [c[0] for c in centroids]), + pl.Series("_lng", [c[1] for c in centroids]), ] ) _df_cache[resolution] = df return _df_cache[resolution] +def get_features() -> list[dict]: + """Get discovered features, computing from the first available resolution.""" + global _features_cache + if _features_cache is None: + for resolution in VALID_RESOLUTIONS: + df = get_cached_df(resolution) + if df is not None: + _features_cache = _discover_features(df) + break + if _features_cache is None: + _features_cache = [] + return _features_cache + + +@router.get("/features") +async def get_features_endpoint() -> dict: + """Return discovered feature metadata with global min/max ranges.""" + return {"features": get_features()} + + @lru_cache(maxsize=128) def query_hexagons_cached( resolution: int, - min_year: int, - max_year: int, - min_price: int, - max_price: int, bounds_tuple: tuple[float, float, float, float], ) -> list[dict]: """Cached query - returns features list.""" @@ -64,65 +108,18 @@ def query_hexagons_cached( df = get_cached_df(resolution) if df is None: - return [], False + return [] - # Fast bbox filter using pre-computed centroids (O(1) per row) + # Fast bbox filter using pre-computed centroids df = df.filter( - (pl.col("lat") >= south) - & (pl.col("lat") <= north) - & (pl.col("lng") >= west) - & (pl.col("lng") <= east) + (pl.col("_lat") >= south) + & (pl.col("_lat") <= north) + & (pl.col("_lng") >= west) + & (pl.col("_lng") <= east) ) - # Filter by year range - df = df.filter((pl.col("year") >= min_year) & (pl.col("year") <= max_year)) - - # Check which journey time columns exist - journey_cols = [ - "median_journey_minutes", - "median_pt_easy_minutes", - "median_pt_quick_minutes", - "median_cycling_minutes", - ] - available_journey_cols = [c for c in journey_cols if c in df.columns] - - # Aggregate across years (weighted by count) - agg_exprs = [ - pl.col("count").sum().alias("count"), - (pl.col("avg_price") * pl.col("count")).sum().alias("weighted_price_sum"), - pl.col("median_price").median().alias("median_price"), - pl.col("min_price").min().alias("min_price"), - pl.col("max_price").max().alias("max_price"), - ] - for jc in available_journey_cols: - # Journey time is same across years, just take first non-null - agg_exprs.append(pl.col(jc).first()) - - df = df.group_by("h3").agg(agg_exprs) - - # Calculate weighted average price - df = df.with_columns( - (pl.col("weighted_price_sum") / pl.col("count")).alias("avg_price") - ).drop("weighted_price_sum") - - # Filter by price range - df = df.filter( - (pl.col("avg_price") >= min_price) & (pl.col("avg_price") <= max_price) - ) - - # Build response efficiently using Polars - select_cols = [ - pl.col("h3"), - pl.col("count"), - pl.col("avg_price").round(2), - pl.col("median_price").round(2), - pl.col("min_price"), - pl.col("max_price"), - ] - for jc in available_journey_cols: - select_cols.append(pl.col(jc).round(0)) - - df = df.select(select_cols) + # Drop internal centroid columns before returning + df = df.drop("_lat", "_lng") return df.to_dicts() @@ -135,13 +132,9 @@ async def get_hexagons( le=max(VALID_RESOLUTIONS), description=f"H3 resolution ({min(VALID_RESOLUTIONS)}-{max(VALID_RESOLUTIONS)})", ), - min_year: int = Query(DEFAULT_MIN_YEAR, description="Minimum year filter"), - max_year: int = Query(DEFAULT_MAX_YEAR, description="Maximum year filter"), - min_price: float = Query(DEFAULT_MIN_PRICE, description="Minimum average price"), - max_price: float = Query(DEFAULT_MAX_PRICE, description="Maximum average price"), bounds: str | None = Query(None, description="Bounding box: south,west,north,east"), ) -> dict: - """Get aggregated property data as GeoJSON hexagons within bounds.""" + """Get aggregated property data as hexagons within bounds.""" if resolution not in VALID_RESOLUTIONS: resolution = DEFAULT_RESOLUTION @@ -165,9 +158,7 @@ async def get_hexagons( west -= lng_buffer east += lng_buffer - # Round bounds to reduce cache misses (0.01 degree ≈ 1km precision) - # Always expand bounds (floor for min, ceil for max) to prevent hexagons - # popping in when crossing rounding boundaries + # Round bounds to reduce cache misses (0.01 degree ~ 1km precision) precision = 0.01 bounds_tuple = ( math.floor(south / precision) * precision, @@ -176,14 +167,6 @@ async def get_hexagons( math.ceil(east / precision) * precision, ) - # Convert prices to int for cache key hashability - features = query_hexagons_cached( - resolution, - min_year, - max_year, - int(min_price), - int(max_price), - bounds_tuple, - ) + features = query_hexagons_cached(resolution, bounds_tuple) return {"features": features} diff --git a/server/routes/pois.py b/server/routes/pois.py index fcc225b..edf48db 100644 --- a/server/routes/pois.py +++ b/server/routes/pois.py @@ -9,8 +9,11 @@ router = APIRouter() DATA_FILE = Path("data_sources/uk_pois.parquet") -# Category groups with emoji and member categories -POI_CATEGORY_GROUPS: dict[str, dict] = { +# Group definitions: maps a group key to its display metadata and the +# individual POI categories it contains. Categories are matched against +# the values that actually exist in the loaded parquet so that the +# selector only shows groups with real data. +_GROUP_DEFS: dict[str, dict] = { "schools": { "emoji": "🏫", "label": "Schools", @@ -189,33 +192,80 @@ POI_CATEGORY_GROUPS: dict[str, dict] = { }, } -# Flatten for quick lookup -ALL_CATEGORIES = { - cat for group in POI_CATEGORY_GROUPS.values() for cat in group["categories"] -} +# Built at startup from the data — only groups whose member categories +# actually appear in the parquet file are included. +_active_groups: dict[str, dict] = {} + +# Reverse lookup: category value -> group key (built at startup) +_cat_to_group: dict[str, str] = {} # Cache the dataframe _df_cache: pl.DataFrame | None = None +def _load_and_build() -> pl.DataFrame | None: + """Load the parquet, build category groups from actual data.""" + global _df_cache, _active_groups, _cat_to_group + + if not DATA_FILE.exists(): + return None + + df = pl.read_parquet(DATA_FILE).select("id", "name", "category", "lat", "lng") + + # Distinct categories present in the data + data_categories: set[str] = set( + df.select("category").unique().to_series().to_list() + ) + + # Per-category counts for the response + counts: dict[str, int] = dict( + df.group_by("category") + .agg(pl.len().alias("n")) + .iter_rows() + ) + + # Build reverse map from every known category to its group + cat_to_group: dict[str, str] = {} + for key, gdef in _GROUP_DEFS.items(): + for cat in gdef["categories"]: + cat_to_group[cat] = key + + # Only keep categories that belong to a known group + known_categories = data_categories & cat_to_group.keys() + + # Build active groups — only those with at least one matching category + active: dict[str, dict] = {} + for key, gdef in _GROUP_DEFS.items(): + present = [c for c in gdef["categories"] if c in known_categories] + if present: + active[key] = { + "emoji": gdef["emoji"], + "label": gdef["label"], + "categories": present, + "count": sum(counts.get(c, 0) for c in present), + } + + _active_groups = active + _cat_to_group = cat_to_group + + # Filter dataframe to only known categories + _df_cache = df.filter(pl.col("category").is_in(known_categories)) + return _df_cache + + def get_df() -> pl.DataFrame | None: - """Load and cache the POI dataframe.""" - global _df_cache + """Return cached POI dataframe, loading if necessary.""" if _df_cache is None: - if not DATA_FILE.exists(): - return None - df = pl.read_parquet(DATA_FILE) - _df_cache = df.select("id", "name", "category", "lat", "lng").filter( - pl.col("category").is_in(ALL_CATEGORIES) - ) + return _load_and_build() return _df_cache def preload_pois() -> None: """Preload POI data on startup.""" - df = get_df() + df = _load_and_build() if df is not None: - print(f"Loaded {len(df):,} POIs") + n_groups = len(_active_groups) + print(f"Loaded {len(df):,} POIs across {n_groups} category groups") @router.get("/pois") @@ -234,10 +284,10 @@ async def get_pois( return {"features": []} requested_groups = [g.strip() for g in categories.split(",")] - cats_to_include = set() + cats_to_include: set[str] = set() for group in requested_groups: - if group in POI_CATEGORY_GROUPS: - cats_to_include.update(POI_CATEGORY_GROUPS[group]["categories"]) + if group in _active_groups: + cats_to_include.update(_active_groups[group]["categories"]) if not cats_to_include: return {"features": []} @@ -259,10 +309,14 @@ async def get_pois( @router.get("/poi-categories") async def get_poi_categories() -> dict: - """Get available POI category groups with emoji and labels.""" + """Get available POI category groups derived from loaded data.""" return { "categories": { - key: {"emoji": group["emoji"], "label": group["label"]} - for key, group in POI_CATEGORY_GROUPS.items() + key: { + "emoji": group["emoji"], + "label": group["label"], + "count": group["count"], + } + for key, group in _active_groups.items() } }