Update map to do filtering

This commit is contained in:
Andras Schmelczer 2026-01-30 18:34:12 +00:00
parent 6122ee44da
commit d4fe881ef4
8 changed files with 349 additions and 372 deletions

View file

@ -1,9 +1,9 @@
import { useState, useEffect, useCallback, useRef } from 'react';
import { useState, useEffect, useCallback, useRef, useMemo } from 'react';
import Map from './components/Map';
import Filters from './components/Filters';
import { DEFAULT_FILTERS } from './lib/constants';
import type {
Filters as FiltersType,
FeatureMeta,
FeatureFilters,
Bounds,
HexagonData,
ViewChangeParams,
@ -11,7 +11,6 @@ import type {
POI,
POIResponse,
POICategoriesMap,
ColorMode,
} from './types';
const DEBOUNCE_MS = 150;
@ -42,8 +41,10 @@ function getApiBaseUrl(): string {
}
export default function App() {
const [filters, setFilters] = useState<FiltersType>(DEFAULT_FILTERS);
const [data, setData] = useState<HexagonData[]>([]);
const [features, setFeatures] = useState<FeatureMeta[]>([]);
const [filters, setFilters] = useState<FeatureFilters>({});
const [activeFeature, setActiveFeature] = useState<string | null>(null);
const [rawData, setRawData] = useState<HexagonData[]>([]);
const [resolution, setResolution] = useState<number>(8);
const [bounds, setBounds] = useState<Bounds | null>(null);
const [loading, setLoading] = useState<boolean>(false);
@ -51,8 +52,6 @@ export default function App() {
const debounceRef = useRef<ReturnType<typeof setTimeout> | null>(null);
const abortControllerRef = useRef<AbortController | null>(null);
const [colorMode, setColorMode] = useState<ColorMode>('price');
// POI state
const [pois, setPois] = useState<POI[]>([]);
const [poiCategories, setPOICategories] = useState<POICategoriesMap>({});
@ -60,8 +59,21 @@ export default function App() {
const poiDebounceRef = useRef<ReturnType<typeof setTimeout> | null>(null);
const poiAbortControllerRef = useRef<AbortController | null>(null);
// Fetch POI category definitions from server on mount
// Fetch feature metadata + POI categories on mount
useEffect(() => {
fetch(`${getApiBaseUrl()}/api/features`)
.then((res) => res.json())
.then((json: { features: FeatureMeta[] }) => {
setFeatures(json.features);
// Initialize filters with full range for each feature
const initial: FeatureFilters = {};
for (const f of json.features) {
initial[f.name] = [f.min, f.max];
}
setFilters(initial);
})
.catch((err) => console.error('Failed to fetch features:', err));
fetch(`${getApiBaseUrl()}/api/poi-categories`)
.then((res) => res.json())
.then((json: { categories: POICategoriesMap }) => {
@ -70,7 +82,7 @@ export default function App() {
.catch((err) => console.error('Failed to fetch POI categories:', err));
}, []);
// Debounced fetch when dependencies change
// Debounced fetch when resolution/bounds change (no filter params sent)
useEffect(() => {
if (!bounds) return;
@ -89,17 +101,13 @@ export default function App() {
const boundsStr = `${bounds.south},${bounds.west},${bounds.north},${bounds.east}`;
const params = new URLSearchParams({
resolution: resolution.toString(),
min_year: filters.minYear.toString(),
max_year: filters.maxYear.toString(),
min_price: filters.minPrice.toString(),
max_price: filters.maxPrice.toString(),
bounds: boundsStr,
});
const res = await fetch(`${getApiBaseUrl()}/api/hexagons?${params}`, {
signal: abortControllerRef.current.signal,
});
const json: ApiResponse = await res.json();
setData(json.features || []);
setRawData(json.features || []);
} catch (err) {
if (err instanceof Error && err.name !== 'AbortError') {
console.error('Failed to fetch data:', err);
@ -114,7 +122,36 @@ export default function App() {
clearTimeout(debounceRef.current);
}
};
}, [filters, resolution, bounds]);
}, [resolution, bounds]);
// Client-side filtering
const data = useMemo(() => {
if (features.length === 0) return rawData;
return rawData.filter((hex) => {
if (activeFeature) {
// Only apply the active feature's filter
const range = filters[activeFeature];
if (!range) return true;
const minVal = hex[`min_${activeFeature}`];
const maxVal = hex[`max_${activeFeature}`];
if (minVal == null || maxVal == null) return true;
return (minVal as number) <= range[1] && (maxVal as number) >= range[0];
}
// Apply ALL filters as intersection
for (const f of features) {
const range = filters[f.name];
if (!range) continue;
// Skip features where filter is at full range
if (range[0] === f.min && range[1] === f.max) continue;
const minVal = hex[`min_${f.name}`];
const maxVal = hex[`max_${f.name}`];
if (minVal == null || maxVal == null) continue;
if ((minVal as number) > range[1] || (maxVal as number) < range[0]) return false;
}
return true;
});
}, [rawData, filters, activeFeature, features]);
// Fetch POIs when bounds or selected categories change
useEffect(() => {
@ -171,17 +208,24 @@ export default function App() {
return (
<div className="h-screen flex">
<Filters
features={features}
filters={filters}
onChange={setFilters}
activeFeature={activeFeature}
onFiltersChange={setFilters}
onActiveFeatureChange={setActiveFeature}
zoom={zoom}
poiCategories={poiCategories}
selectedPOICategories={selectedPOICategories}
onPOICategoriesChange={setSelectedPOICategories}
colorMode={colorMode}
onColorModeChange={setColorMode}
/>
<div className="flex-1 relative">
<Map data={data} pois={pois} onViewChange={handleViewChange} colorMode={colorMode} />
<Map
data={data}
pois={pois}
onViewChange={handleViewChange}
activeFeature={activeFeature}
features={features}
/>
{loading && (
<div className="absolute top-4 right-4 bg-white px-3 py-1 rounded shadow">Loading...</div>
)}

View file

@ -1,32 +1,38 @@
import { useState, useRef, useEffect } from 'react';
import { Slider } from './ui/slider';
import { Label } from './ui/label';
import { YEAR_MIN, YEAR_MAX, YEAR_STEP, PRICE_MIN, PRICE_MAX, PRICE_STEP } from '../lib/constants';
import type { Filters as FiltersType, POICategoriesMap, ColorMode } from '../types';
import type { FeatureMeta, FeatureFilters, POICategoriesMap } from '../types';
interface FiltersProps {
filters: FiltersType;
onChange: (filters: FiltersType) => void;
features: FeatureMeta[];
filters: FeatureFilters;
activeFeature: string | null;
onFiltersChange: (filters: FeatureFilters) => void;
onActiveFeatureChange: (feature: string | null) => void;
zoom: number;
poiCategories: POICategoriesMap;
selectedPOICategories: Set<string>;
onPOICategoriesChange: (categories: Set<string>) => void;
colorMode: ColorMode;
onColorModeChange: (mode: ColorMode) => void;
}
function formatValue(value: number): string {
if (Math.abs(value) >= 1_000_000) return `${(value / 1_000_000).toFixed(1)}M`;
if (Math.abs(value) >= 1_000) return `${(value / 1_000).toFixed(1)}k`;
if (Number.isInteger(value)) return value.toString();
return value.toFixed(2);
}
export default function Filters({
features,
filters,
onChange,
activeFeature,
onFiltersChange,
onActiveFeatureChange,
zoom,
poiCategories,
selectedPOICategories,
onPOICategoriesChange,
colorMode,
onColorModeChange,
}: FiltersProps) {
const update = (key: keyof FiltersType, value: number) => onChange({ ...filters, [key]: value });
const [dropdownOpen, setDropdownOpen] = useState(false);
const dropdownRef = useRef<HTMLDivElement>(null);
@ -63,99 +69,53 @@ export default function Filters({
const selectedCount = selectedPOICategories.size;
return (
<div className="w-72 p-4 bg-white shadow-lg space-y-6 overflow-y-auto max-h-screen">
<div className="w-72 p-4 bg-white shadow-lg space-y-4 overflow-y-auto max-h-screen">
<h1 className="text-xl font-bold">UK Property Prices</h1>
<div className="text-sm text-slate-500">Zoom: {zoom.toFixed(1)}</div>
<div className="space-y-2">
<Label>
Year Range: {filters.minYear} - {filters.maxYear}
</Label>
<Slider
min={YEAR_MIN}
max={YEAR_MAX}
step={YEAR_STEP}
value={[filters.minYear, filters.maxYear]}
onValueChange={([min, max]) => onChange({ ...filters, minYear: min, maxYear: max })}
/>
</div>
{features.map((feature) => {
const range = filters[feature.name] || [feature.min, feature.max];
const isActive = activeFeature === feature.name;
const step = (feature.max - feature.min) / 100;
<div className="space-y-2">
<Label>Min Price: £{filters.minPrice.toLocaleString()}</Label>
<Slider
min={PRICE_MIN}
max={PRICE_MAX}
step={PRICE_STEP}
value={[filters.minPrice]}
onValueChange={([v]) => update('minPrice', v)}
/>
</div>
<div className="space-y-2">
<Label>Max Price: £{filters.maxPrice.toLocaleString()}</Label>
<Slider
min={PRICE_MIN}
max={PRICE_MAX}
step={PRICE_STEP}
value={[filters.maxPrice]}
onValueChange={([v]) => update('maxPrice', v)}
/>
</div>
<div className="space-y-2">
<Label>Color By</Label>
<div className="flex gap-2">
<button
className={`flex-1 px-3 py-1.5 text-sm rounded ${colorMode === 'price' ? 'bg-slate-800 text-white' : 'bg-slate-100 text-slate-700'}`}
onClick={() => onColorModeChange('price')}
>
Price
</button>
<button
className={`flex-1 px-3 py-1.5 text-sm rounded ${colorMode === 'journey_time' ? 'bg-slate-800 text-white' : 'bg-slate-100 text-slate-700'}`}
onClick={() => onColorModeChange('journey_time')}
>
Journey Time
</button>
</div>
</div>
{colorMode === 'price' ? (
<div className="p-3 bg-slate-100 rounded text-xs">
<div className="mb-2 font-medium">Average Price</div>
return (
<div
className="h-4 rounded"
style={{
background:
'linear-gradient(to right, rgb(46, 204, 113), rgb(241, 196, 15), rgb(231, 76, 60), rgb(142, 68, 173))',
}}
></div>
<div className="flex justify-between mt-1">
<span>£0</span>
<span>£200k</span>
<span>£400k</span>
<span>£800k+</span>
key={feature.name}
className={`space-y-1 p-2 rounded ${isActive ? 'ring-2 ring-blue-400 bg-blue-50' : ''}`}
>
<Label className="text-xs">
{feature.label}: {formatValue(range[0])} - {formatValue(range[1])}
</Label>
<Slider
min={feature.min}
max={feature.max}
step={step}
value={[range[0], range[1]]}
onValueChange={([min, max]) => {
onFiltersChange({ ...filters, [feature.name]: [min, max] });
}}
onPointerDown={() => onActiveFeatureChange(feature.name)}
onPointerUp={() => onActiveFeatureChange(null)}
/>
</div>
);
})}
<div className="p-3 bg-slate-100 rounded text-xs">
<div className="mb-2 font-medium">Color Scale</div>
<div
className="h-4 rounded"
style={{
background:
'linear-gradient(to right, rgb(46, 204, 113), rgb(241, 196, 15), rgb(231, 76, 60), rgb(142, 68, 173))',
}}
></div>
<div className="flex justify-between mt-1">
<span>Low</span>
<span>High</span>
</div>
) : (
<div className="p-3 bg-slate-100 rounded text-xs">
<div className="mb-2 font-medium">Journey Time to Bank</div>
<div
className="h-4 rounded"
style={{
background:
'linear-gradient(to right, rgb(46, 204, 113), rgb(241, 196, 15), rgb(231, 76, 60), rgb(142, 68, 173))',
}}
></div>
<div className="flex justify-between mt-1">
<span>0 min</span>
<span>30 min</span>
<span>60 min</span>
<span>120+ min</span>
</div>
</div>
)}
</div>
<div className="space-y-2" ref={dropdownRef}>
<Label>Points of Interest</Label>
@ -199,7 +159,7 @@ export default function Filters({
</div>
<div className="max-h-64 overflow-y-auto py-1">
{categoryKeys.map((key) => {
const { emoji, label } = poiCategories[key];
const { emoji, label, count } = poiCategories[key];
return (
<label
key={key}
@ -211,9 +171,10 @@ export default function Filters({
onChange={() => toggleCategory(key)}
className="rounded"
/>
<span className="text-sm">
<span className="text-sm flex-1">
{emoji} {label}
</span>
<span className="text-xs text-slate-400">{count.toLocaleString()}</span>
</label>
);
})}

View file

@ -6,13 +6,14 @@ import { H3HexagonLayer } from '@deck.gl/geo-layers';
import { IconLayer } from '@deck.gl/layers';
import type { PickingInfo } from '@deck.gl/core';
import 'maplibre-gl/dist/maplibre-gl.css';
import type { HexagonData, ViewState, ViewChangeParams, Bounds, POI, ColorMode } from '../types';
import type { HexagonData, ViewState, ViewChangeParams, Bounds, POI, FeatureMeta } from '../types';
interface MapProps {
data: HexagonData[];
pois: POI[];
onViewChange: (params: ViewChangeParams) => void;
colorMode: ColorMode;
activeFeature: string | null;
features: FeatureMeta[];
}
// Twemoji CDN base URL
@ -185,66 +186,31 @@ const INITIAL_VIEW: ViewState = {
const MAP_STYLE = 'https://basemaps.cartocdn.com/gl/positron-gl-style/style.json';
interface ColorStop {
price: number;
color: [number, number, number];
}
// Continuous color scale from green (low) -> yellow -> red -> purple (high)
const COLOR_SCALE: ColorStop[] = [
{ price: 0, color: [46, 204, 113] }, // Green
{ price: 200000, color: [241, 196, 15] }, // Yellow
{ price: 400000, color: [231, 76, 60] }, // Red
{ price: 800000, color: [142, 68, 173] }, // Purple
// Gradient stops for normalized [0,1] values
const GRADIENT: { t: number; color: [number, number, number] }[] = [
{ t: 0, color: [46, 204, 113] }, // Green
{ t: 0.33, color: [241, 196, 15] }, // Yellow
{ t: 0.66, color: [231, 76, 60] }, // Red
{ t: 1, color: [142, 68, 173] }, // Purple
];
function interpolateColor(
c1: [number, number, number],
c2: [number, number, number],
t: number
): [number, number, number] {
return [
Math.round(c1[0] + (c2[0] - c1[0]) * t),
Math.round(c1[1] + (c2[1] - c1[1]) * t),
Math.round(c1[2] + (c2[2] - c1[2]) * t),
];
}
function normalizedToColor(t: number): [number, number, number] {
if (t <= 0) return GRADIENT[0].color;
if (t >= 1) return GRADIENT[GRADIENT.length - 1].color;
function scaleToColor(
value: number | null | undefined,
scale: ColorStop[]
): [number, number, number] {
if (value == null || isNaN(value)) return [128, 128, 128];
if (value <= scale[0].price) return scale[0].color;
if (value >= scale[scale.length - 1].price) return scale[scale.length - 1].color;
for (let i = 0; i < scale.length - 1; i++) {
const lower = scale[i];
const upper = scale[i + 1];
if (value >= lower.price && value <= upper.price) {
const t = (value - lower.price) / (upper.price - lower.price);
return interpolateColor(lower.color, upper.color, t);
for (let i = 0; i < GRADIENT.length - 1; i++) {
const lo = GRADIENT[i];
const hi = GRADIENT[i + 1];
if (t >= lo.t && t <= hi.t) {
const frac = (t - lo.t) / (hi.t - lo.t);
return [
Math.round(lo.color[0] + (hi.color[0] - lo.color[0]) * frac),
Math.round(lo.color[1] + (hi.color[1] - lo.color[1]) * frac),
Math.round(lo.color[2] + (hi.color[2] - lo.color[2]) * frac),
];
}
}
return scale[scale.length - 1].color;
}
function priceToColor(price: number | null | undefined): [number, number, number] {
return scaleToColor(price, COLOR_SCALE);
}
// Journey time color scale: green (short) -> yellow -> orange -> red (long)
const JOURNEY_COLOR_SCALE: ColorStop[] = [
{ price: 0, color: [46, 204, 113] }, // Green
{ price: 30, color: [241, 196, 15] }, // Yellow
{ price: 60, color: [231, 76, 60] }, // Red
{ price: 120, color: [142, 68, 173] }, // Purple
];
function journeyTimeToColor(minutes: number | null | undefined): [number, number, number] {
return scaleToColor(minutes, JOURNEY_COLOR_SCALE);
return GRADIENT[GRADIENT.length - 1].color;
}
function zoomToResolution(zoom: number): number {
@ -271,7 +237,6 @@ function getBoundsFromViewState(viewState: ViewState, width: number, height: num
const halfWidthDeg = (width / 2) * degreesPerPixelLng;
// Latitude uses Mercator projection (non-linear)
// Convert center lat to pixel y, offset by half height, convert back to lat
const latRad = (clampedLat * Math.PI) / 180;
const mercatorY = (1 - Math.log(Math.tan(latRad) + 1 / Math.cos(latRad)) / Math.PI) / 2;
const centerPixelY = mercatorY * worldSize;
@ -281,7 +246,7 @@ function getBoundsFromViewState(viewState: ViewState, width: number, height: num
// Convert pixel Y back to latitude
const pixelYToLat = (pixelY: number): number => {
const mercY = Math.max(0.001, Math.min(0.999, pixelY / worldSize)); // Clamp to avoid edge cases
const mercY = Math.max(0.001, Math.min(0.999, pixelY / worldSize));
const latRadians = Math.atan(Math.sinh(Math.PI * (1 - 2 * mercY)));
return (latRadians * 180) / Math.PI;
};
@ -315,7 +280,7 @@ function DeckOverlay({
return null;
}
export default function Map({ data, pois, onViewChange, colorMode }: MapProps) {
export default function Map({ data, pois, onViewChange, activeFeature, features }: MapProps) {
const containerRef = useRef<HTMLDivElement>(null);
const [viewState, setViewState] = useState<ViewState>(INITIAL_VIEW);
const [dimensions, setDimensions] = useState<Dimensions>({ width: 0, height: 0 });
@ -355,7 +320,6 @@ export default function Map({ data, pois, onViewChange, colorMode }: MapProps) {
const map = evt.target;
for (const layer of map.getStyle().layers || []) {
if (layer.type !== 'symbol') continue;
// Stronger white halo so text pops over hex fills
map.setPaintProperty(layer.id, 'text-halo-color', 'rgba(255,255,255,1)');
map.setPaintProperty(layer.id, 'text-halo-width', 2);
map.setPaintProperty(layer.id, 'text-color', '#222');
@ -383,24 +347,32 @@ export default function Map({ data, pois, onViewChange, colorMode }: MapProps) {
}
}, []);
// Determine which feature to use for coloring
const colorFeatureName = activeFeature || (features.length > 0 ? features[0].name : null);
const colorFeatureMeta = features.find((f) => f.name === colorFeatureName) || null;
const layers = useMemo(
() => [
new H3HexagonLayer<HexagonData>({
id: 'h3-hexagons',
data,
getHexagon: (d) => d.h3,
getFillColor: (d) =>
colorMode === 'journey_time'
? journeyTimeToColor(d.median_journey_minutes)
: priceToColor(d.avg_price),
getFillColor: (d) => {
if (!colorFeatureName || !colorFeatureMeta) return [128, 128, 128] as [number, number, number];
const val = d[`min_${colorFeatureName}`];
if (val == null) return [128, 128, 128] as [number, number, number];
const range = colorFeatureMeta.max - colorFeatureMeta.min;
if (range === 0) return GRADIENT[0].color;
const t = ((val as number) - colorFeatureMeta.min) / range;
return normalizedToColor(t);
},
updateTriggers: {
getFillColor: colorMode,
getFillColor: [colorFeatureName, colorFeatureMeta],
},
extruded: false,
pickable: true,
opacity: 0.5,
highPrecision: true,
// Render below labels so road names, place names etc. stay visible
// @ts-expect-error beforeId is a MapboxOverlay interleave prop, not typed in LayerProps
beforeId: LABEL_LAYER_ID,
}),
@ -420,41 +392,39 @@ export default function Map({ data, pois, onViewChange, colorMode }: MapProps) {
onHover: handlePoiHover,
}),
],
[data, pois, handlePoiHover, colorMode]
[data, pois, handlePoiHover, colorFeatureName, colorFeatureMeta]
);
const getTooltip = useCallback(({ object }: { object?: HexagonData }) => {
if (!object || !('h3' in object)) return null;
const getTooltip = useCallback(
({ object }: { object?: HexagonData }) => {
if (!object || !('h3' in object)) return null;
const hex = object as HexagonData;
const journeyLines: string[] = [];
if (hex.median_pt_quick_minutes != null)
journeyLines.push(`🚇 Quick PT: ${hex.median_pt_quick_minutes} min`);
if (hex.median_pt_easy_minutes != null)
journeyLines.push(`🚌 Easy PT: ${hex.median_pt_easy_minutes} min`);
if (hex.median_cycling_minutes != null)
journeyLines.push(`🚲 Cycling: ${hex.median_cycling_minutes} min`);
const journeyTimeHtml =
journeyLines.length > 0
? `<div style="color: #0066cc; margin-top: 4px; font-size: 12px;">${journeyLines.join('<br/>')}</div>`
: '';
const hex = object;
const lines: string[] = [];
lines.push(`<strong>${(hex.count as number).toLocaleString()} properties</strong>`);
return {
html: `<div style="padding: 8px; font-size: 14px;">
<strong>Avg: £${hex.avg_price?.toLocaleString() || 'N/A'}</strong>
<div style="color: #666; font-size: 12px;">
${hex.count} sales<br/>
Range: £${hex.min_price?.toLocaleString()} - £${hex.max_price?.toLocaleString()}
</div>
${journeyTimeHtml}
</div>`,
style: {
backgroundColor: 'white',
borderRadius: '4px',
boxShadow: '0 2px 4px rgba(0,0,0,0.2)',
},
};
}, []);
for (const f of features) {
const minVal = hex[`min_${f.name}`];
const maxVal = hex[`max_${f.name}`];
if (minVal != null && maxVal != null) {
const minStr = typeof minVal === 'number' ? minVal.toLocaleString(undefined, { maximumFractionDigits: 1 }) : String(minVal);
const maxStr = typeof maxVal === 'number' ? maxVal.toLocaleString(undefined, { maximumFractionDigits: 1 }) : String(maxVal);
const highlight = f.name === colorFeatureName ? 'font-weight: bold;' : '';
lines.push(`<div style="${highlight}">${f.label}: ${minStr} - ${maxStr}</div>`);
}
}
return {
html: `<div style="padding: 8px; font-size: 12px;">${lines.join('')}</div>`,
style: {
backgroundColor: 'white',
borderRadius: '4px',
boxShadow: '0 2px 4px rgba(0,0,0,0.2)',
},
};
},
[features, colorFeatureName]
);
return (
<div className="flex-1 h-full relative" ref={containerRef}>

View file

@ -1,19 +1 @@
import type { Filters } from '../types';
// Filter configuration constants
// Should match backend pipeline/config.py
export const YEAR_MIN = 1995;
export const YEAR_MAX = 2024;
export const YEAR_STEP = 1;
export const PRICE_MIN = 0;
export const PRICE_MAX = 5000000; // £5M max for slider, but no server-side cap
export const PRICE_STEP = 50000;
export const DEFAULT_FILTERS: Filters = {
minYear: 2020,
maxYear: YEAR_MAX,
minPrice: PRICE_MIN,
maxPrice: PRICE_MAX,
};
// No hardcoded filter constants - features are discovered dynamically from the API.

View file

@ -1,8 +1,17 @@
export interface Filters {
minYear: number;
maxYear: number;
minPrice: number;
maxPrice: number;
export interface FeatureMeta {
name: string;
min: number;
max: number;
label: string;
}
// Filters: feature name -> [selectedMin, selectedMax]
export type FeatureFilters = Record<string, [number, number]>;
export interface HexagonData {
h3: string;
count: number;
[key: string]: string | number | null;
}
export interface Bounds {
@ -12,21 +21,6 @@ export interface Bounds {
east: number;
}
export interface HexagonData {
h3: string;
count: number;
avg_price: number;
median_price: number;
min_price: number;
max_price: number;
median_journey_minutes: number | null;
median_pt_easy_minutes: number | null;
median_pt_quick_minutes: number | null;
median_cycling_minutes: number | null;
}
export type ColorMode = 'price' | 'journey_time';
export interface ViewState {
longitude: number;
latitude: number;
@ -60,6 +54,7 @@ export interface POIResponse {
export interface POICategoryInfo {
emoji: string;
label: string;
count: number;
}
export type POICategoriesMap = Record<string, POICategoryInfo>;

View file

@ -4,12 +4,6 @@ from pipeline.config import (
AGGREGATES_DIR,
H3_RESOLUTIONS as VALID_RESOLUTIONS,
DEFAULT_H3_RESOLUTION as DEFAULT_RESOLUTION,
MIN_YEAR,
MAX_YEAR,
DEFAULT_MIN_YEAR,
DEFAULT_MAX_YEAR,
DEFAULT_MIN_PRICE,
DEFAULT_MAX_PRICE,
)
# Extra area to return beyond requested bounds (0.2 = 20%)
@ -20,11 +14,5 @@ __all__ = [
"AGGREGATES_DIR",
"VALID_RESOLUTIONS",
"DEFAULT_RESOLUTION",
"MIN_YEAR",
"MAX_YEAR",
"DEFAULT_MIN_YEAR",
"DEFAULT_MAX_YEAR",
"DEFAULT_MIN_PRICE",
"DEFAULT_MAX_PRICE",
"BOUNDS_BUFFER_PERCENT",
]

View file

@ -10,10 +10,6 @@ from server.config import (
AGGREGATES_DIR,
VALID_RESOLUTIONS,
DEFAULT_RESOLUTION,
DEFAULT_MIN_YEAR,
DEFAULT_MAX_YEAR,
DEFAULT_MIN_PRICE,
DEFAULT_MAX_PRICE,
BOUNDS_BUFFER_PERCENT,
)
@ -22,6 +18,38 @@ router = APIRouter()
# Cache loaded dataframes in memory (one per resolution)
_df_cache: dict[int, pl.DataFrame] = {}
# Discovered features (computed once on first load)
_features_cache: list[dict] | None = None
def _snake_to_label(name: str) -> str:
"""Convert snake_case feature name to a human-readable label."""
return name.replace("_", " ").title()
def _discover_features(df: pl.DataFrame) -> list[dict]:
"""Discover features from column pairs min_X / max_X."""
features = []
seen = set()
for col in df.columns:
if col.startswith("min_"):
name = col[4:]
max_col = f"max_{name}"
if max_col in df.columns and name not in seen:
seen.add(name)
global_min = df[col].min()
global_max = df[max_col].max()
if global_min is not None and global_max is not None:
features.append(
{
"name": name,
"min": float(global_min),
"max": float(global_max),
"label": _snake_to_label(name),
}
)
return features
def preload_dataframes() -> None:
"""Load all resolution dataframes into cache on startup."""
@ -38,25 +66,41 @@ def get_cached_df(resolution: int) -> pl.DataFrame | None:
# Load and add H3 cell centroids for fast bbox filtering
df = pl.read_parquet(parquet_path)
# Pre-compute cell centroids for bbox filtering (much faster than is_in)
# Pre-compute cell centroids for bbox filtering
centroids = [h3.cell_to_latlng(cell) for cell in df["h3"].to_list()]
df = df.with_columns(
[
pl.Series("lat", [c[0] for c in centroids]),
pl.Series("lng", [c[1] for c in centroids]),
pl.Series("_lat", [c[0] for c in centroids]),
pl.Series("_lng", [c[1] for c in centroids]),
]
)
_df_cache[resolution] = df
return _df_cache[resolution]
def get_features() -> list[dict]:
"""Get discovered features, computing from the first available resolution."""
global _features_cache
if _features_cache is None:
for resolution in VALID_RESOLUTIONS:
df = get_cached_df(resolution)
if df is not None:
_features_cache = _discover_features(df)
break
if _features_cache is None:
_features_cache = []
return _features_cache
@router.get("/features")
async def get_features_endpoint() -> dict:
"""Return discovered feature metadata with global min/max ranges."""
return {"features": get_features()}
@lru_cache(maxsize=128)
def query_hexagons_cached(
resolution: int,
min_year: int,
max_year: int,
min_price: int,
max_price: int,
bounds_tuple: tuple[float, float, float, float],
) -> list[dict]:
"""Cached query - returns features list."""
@ -64,65 +108,18 @@ def query_hexagons_cached(
df = get_cached_df(resolution)
if df is None:
return [], False
return []
# Fast bbox filter using pre-computed centroids (O(1) per row)
# Fast bbox filter using pre-computed centroids
df = df.filter(
(pl.col("lat") >= south)
& (pl.col("lat") <= north)
& (pl.col("lng") >= west)
& (pl.col("lng") <= east)
(pl.col("_lat") >= south)
& (pl.col("_lat") <= north)
& (pl.col("_lng") >= west)
& (pl.col("_lng") <= east)
)
# Filter by year range
df = df.filter((pl.col("year") >= min_year) & (pl.col("year") <= max_year))
# Check which journey time columns exist
journey_cols = [
"median_journey_minutes",
"median_pt_easy_minutes",
"median_pt_quick_minutes",
"median_cycling_minutes",
]
available_journey_cols = [c for c in journey_cols if c in df.columns]
# Aggregate across years (weighted by count)
agg_exprs = [
pl.col("count").sum().alias("count"),
(pl.col("avg_price") * pl.col("count")).sum().alias("weighted_price_sum"),
pl.col("median_price").median().alias("median_price"),
pl.col("min_price").min().alias("min_price"),
pl.col("max_price").max().alias("max_price"),
]
for jc in available_journey_cols:
# Journey time is same across years, just take first non-null
agg_exprs.append(pl.col(jc).first())
df = df.group_by("h3").agg(agg_exprs)
# Calculate weighted average price
df = df.with_columns(
(pl.col("weighted_price_sum") / pl.col("count")).alias("avg_price")
).drop("weighted_price_sum")
# Filter by price range
df = df.filter(
(pl.col("avg_price") >= min_price) & (pl.col("avg_price") <= max_price)
)
# Build response efficiently using Polars
select_cols = [
pl.col("h3"),
pl.col("count"),
pl.col("avg_price").round(2),
pl.col("median_price").round(2),
pl.col("min_price"),
pl.col("max_price"),
]
for jc in available_journey_cols:
select_cols.append(pl.col(jc).round(0))
df = df.select(select_cols)
# Drop internal centroid columns before returning
df = df.drop("_lat", "_lng")
return df.to_dicts()
@ -135,13 +132,9 @@ async def get_hexagons(
le=max(VALID_RESOLUTIONS),
description=f"H3 resolution ({min(VALID_RESOLUTIONS)}-{max(VALID_RESOLUTIONS)})",
),
min_year: int = Query(DEFAULT_MIN_YEAR, description="Minimum year filter"),
max_year: int = Query(DEFAULT_MAX_YEAR, description="Maximum year filter"),
min_price: float = Query(DEFAULT_MIN_PRICE, description="Minimum average price"),
max_price: float = Query(DEFAULT_MAX_PRICE, description="Maximum average price"),
bounds: str | None = Query(None, description="Bounding box: south,west,north,east"),
) -> dict:
"""Get aggregated property data as GeoJSON hexagons within bounds."""
"""Get aggregated property data as hexagons within bounds."""
if resolution not in VALID_RESOLUTIONS:
resolution = DEFAULT_RESOLUTION
@ -165,9 +158,7 @@ async def get_hexagons(
west -= lng_buffer
east += lng_buffer
# Round bounds to reduce cache misses (0.01 degree ≈ 1km precision)
# Always expand bounds (floor for min, ceil for max) to prevent hexagons
# popping in when crossing rounding boundaries
# Round bounds to reduce cache misses (0.01 degree ~ 1km precision)
precision = 0.01
bounds_tuple = (
math.floor(south / precision) * precision,
@ -176,14 +167,6 @@ async def get_hexagons(
math.ceil(east / precision) * precision,
)
# Convert prices to int for cache key hashability
features = query_hexagons_cached(
resolution,
min_year,
max_year,
int(min_price),
int(max_price),
bounds_tuple,
)
features = query_hexagons_cached(resolution, bounds_tuple)
return {"features": features}

View file

@ -9,8 +9,11 @@ router = APIRouter()
DATA_FILE = Path("data_sources/uk_pois.parquet")
# Category groups with emoji and member categories
POI_CATEGORY_GROUPS: dict[str, dict] = {
# Group definitions: maps a group key to its display metadata and the
# individual POI categories it contains. Categories are matched against
# the values that actually exist in the loaded parquet so that the
# selector only shows groups with real data.
_GROUP_DEFS: dict[str, dict] = {
"schools": {
"emoji": "🏫",
"label": "Schools",
@ -189,33 +192,80 @@ POI_CATEGORY_GROUPS: dict[str, dict] = {
},
}
# Flatten for quick lookup
ALL_CATEGORIES = {
cat for group in POI_CATEGORY_GROUPS.values() for cat in group["categories"]
}
# Built at startup from the data — only groups whose member categories
# actually appear in the parquet file are included.
_active_groups: dict[str, dict] = {}
# Reverse lookup: category value -> group key (built at startup)
_cat_to_group: dict[str, str] = {}
# Cache the dataframe
_df_cache: pl.DataFrame | None = None
def _load_and_build() -> pl.DataFrame | None:
"""Load the parquet, build category groups from actual data."""
global _df_cache, _active_groups, _cat_to_group
if not DATA_FILE.exists():
return None
df = pl.read_parquet(DATA_FILE).select("id", "name", "category", "lat", "lng")
# Distinct categories present in the data
data_categories: set[str] = set(
df.select("category").unique().to_series().to_list()
)
# Per-category counts for the response
counts: dict[str, int] = dict(
df.group_by("category")
.agg(pl.len().alias("n"))
.iter_rows()
)
# Build reverse map from every known category to its group
cat_to_group: dict[str, str] = {}
for key, gdef in _GROUP_DEFS.items():
for cat in gdef["categories"]:
cat_to_group[cat] = key
# Only keep categories that belong to a known group
known_categories = data_categories & cat_to_group.keys()
# Build active groups — only those with at least one matching category
active: dict[str, dict] = {}
for key, gdef in _GROUP_DEFS.items():
present = [c for c in gdef["categories"] if c in known_categories]
if present:
active[key] = {
"emoji": gdef["emoji"],
"label": gdef["label"],
"categories": present,
"count": sum(counts.get(c, 0) for c in present),
}
_active_groups = active
_cat_to_group = cat_to_group
# Filter dataframe to only known categories
_df_cache = df.filter(pl.col("category").is_in(known_categories))
return _df_cache
def get_df() -> pl.DataFrame | None:
"""Load and cache the POI dataframe."""
global _df_cache
"""Return cached POI dataframe, loading if necessary."""
if _df_cache is None:
if not DATA_FILE.exists():
return None
df = pl.read_parquet(DATA_FILE)
_df_cache = df.select("id", "name", "category", "lat", "lng").filter(
pl.col("category").is_in(ALL_CATEGORIES)
)
return _load_and_build()
return _df_cache
def preload_pois() -> None:
"""Preload POI data on startup."""
df = get_df()
df = _load_and_build()
if df is not None:
print(f"Loaded {len(df):,} POIs")
n_groups = len(_active_groups)
print(f"Loaded {len(df):,} POIs across {n_groups} category groups")
@router.get("/pois")
@ -234,10 +284,10 @@ async def get_pois(
return {"features": []}
requested_groups = [g.strip() for g in categories.split(",")]
cats_to_include = set()
cats_to_include: set[str] = set()
for group in requested_groups:
if group in POI_CATEGORY_GROUPS:
cats_to_include.update(POI_CATEGORY_GROUPS[group]["categories"])
if group in _active_groups:
cats_to_include.update(_active_groups[group]["categories"])
if not cats_to_include:
return {"features": []}
@ -259,10 +309,14 @@ async def get_pois(
@router.get("/poi-categories")
async def get_poi_categories() -> dict:
"""Get available POI category groups with emoji and labels."""
"""Get available POI category groups derived from loaded data."""
return {
"categories": {
key: {"emoji": group["emoji"], "label": group["label"]}
for key, group in POI_CATEGORY_GROUPS.items()
key: {
"emoji": group["emoji"],
"label": group["label"],
"count": group["count"],
}
for key, group in _active_groups.items()
}
}