SPlit up
Some checks failed
Build and publish Docker image / build-and-push (push) Failing after 15s
CI / Check (push) Failing after 1m58s

This commit is contained in:
Andras Schmelczer 2026-06-12 21:51:37 +01:00
parent cf39ad754e
commit f59d01227b
91 changed files with 10370 additions and 7562 deletions

View file

@ -619,7 +619,10 @@ export default function AreaPane({
/> />
{crimeSeries && crimeSeries.points.length > 1 && ( {crimeSeries && crimeSeries.points.length > 1 && (
<div className="mt-2"> <div className="mt-2">
<CrimeYearChart points={crimeSeries.points} /> <CrimeYearChart
points={crimeSeries.points}
latestAvailableYear={stats?.crime_latest_year}
/>
</div> </div>
)} )}
</div> </div>
@ -663,7 +666,10 @@ export default function AreaPane({
} }
chart={ chart={
crimeSeries && crimeSeries.points.length > 1 ? ( crimeSeries && crimeSeries.points.length > 1 ? (
<CrimeYearChart points={crimeSeries.points} /> <CrimeYearChart
points={crimeSeries.points}
latestAvailableYear={stats?.crime_latest_year}
/>
) : ( ) : (
numericStats.histogram && numericStats.histogram &&
(globalHistogram ? ( (globalHistogram ? (

View file

@ -1,14 +1,22 @@
import { useEffect, useMemo, useRef, useState } from 'react'; import { useEffect, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import type { CrimeYearPoint } from '../../types'; import type { CrimeYearPoint } from '../../types';
interface CrimeYearChartProps { interface CrimeYearChartProps {
points: CrimeYearPoint[]; points: CrimeYearPoint[];
/**
* Latest year available in the crime dataset as a whole. When the series
* ends earlier, the area's police force stopped publishing (e.g. Greater
* Manchester since mid-2019) and the chart is captioned as stale.
*/
latestAvailableYear?: number;
} }
const PADDING = { top: 6, right: 4, bottom: 14, left: 4 }; const PADDING = { top: 6, right: 4, bottom: 14, left: 4 };
const HEIGHT = 48; const HEIGHT = 48;
export default function CrimeYearChart({ points }: CrimeYearChartProps) { export default function CrimeYearChart({ points, latestAvailableYear }: CrimeYearChartProps) {
const { t } = useTranslation();
const containerRef = useRef<HTMLDivElement>(null); const containerRef = useRef<HTMLDivElement>(null);
const [width, setWidth] = useState(0); const [width, setWidth] = useState(0);
@ -97,6 +105,11 @@ export default function CrimeYearChart({ points }: CrimeYearChartProps) {
</text> </text>
</svg> </svg>
)} )}
{latestAvailableYear != null && yearMax < latestAvailableYear && (
<p className="mt-0.5 text-[10px] leading-snug text-amber-700 dark:text-amber-400">
{t('areaPane.crimeDataEnds', { year: yearMax })}
</p>
)}
</div> </div>
); );
} }

View file

@ -0,0 +1,67 @@
import { useEffect } from 'react';
import { useControl } from 'react-map-gl/maplibre';
import { MapboxOverlay } from '@deck.gl/mapbox';
interface DeckWithPrivateDraw {
_drawLayers?: (
redrawReason: string,
renderOptions?: { viewports?: unknown[]; [key: string]: unknown }
) => unknown;
__propertyMapNullViewportPatch?: boolean;
}
function patchNullViewportDraw(overlay: MapboxOverlay) {
const deck = (overlay as unknown as { _deck?: DeckWithPrivateDraw })._deck;
if (!deck || deck.__propertyMapNullViewportPatch || typeof deck._drawLayers !== 'function') {
return;
}
const drawLayers = deck._drawLayers.bind(deck);
deck._drawLayers = (redrawReason, renderOptions) => {
const viewports = renderOptions?.viewports;
if (viewports) {
// Split-route startup can hand deck.gl a transient null viewport before MapLibre has sized the map.
const nonNullViewports = viewports.filter(Boolean);
if (nonNullViewports.length === 0) return;
if (nonNullViewports.length !== viewports.length) {
return drawLayers(redrawReason, { ...renderOptions, viewports: nonNullViewports });
}
}
return drawLayers(redrawReason, renderOptions);
};
deck.__propertyMapNullViewportPatch = true;
}
class SafeMapboxOverlay extends MapboxOverlay {
onAdd(map: unknown) {
const element = super.onAdd(map);
patchNullViewportDraw(this);
return element;
}
setProps(props: Parameters<MapboxOverlay['setProps']>[0]) {
super.setProps(props);
patchNullViewportDraw(this);
}
}
export function DeckOverlay({
layers,
getTooltip,
}: {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
layers: any[];
// eslint-disable-next-line @typescript-eslint/no-explicit-any
getTooltip: any;
}) {
const overlay = useControl(() => new SafeMapboxOverlay({ interleaved: true }));
useEffect(() => {
overlay.setProps({
layers: layers.filter(Boolean),
getTooltip,
});
}, [overlay, layers, getTooltip]);
return null;
}

View file

@ -0,0 +1,45 @@
import { memo } from 'react';
import type { FeatureFilters, FeatureMeta, HexagonData, PostcodeFeature } from '../../types';
import HoverCard from './HoverCard';
interface HoverCardOverlayProps {
x: number;
y: number;
id: string;
usePostcodeView: boolean;
data: HexagonData[];
postcodeData: PostcodeFeature[];
filters: FeatureFilters;
features: FeatureMeta[];
}
/** Resolves the hovered hexagon/postcode row from the loaded map data and
* renders the hover card for it. Memoized so the row lookup only reruns when
* the hover target or the underlying data actually changes. */
export const HoverCardOverlay = memo(function HoverCardOverlay({
x,
y,
id,
usePostcodeView,
data,
postcodeData,
filters,
features,
}: HoverCardOverlayProps) {
return (
<HoverCard
x={x}
y={y}
id={id}
isPostcode={usePostcodeView}
data={
usePostcodeView
? postcodeData.find((f) => f.properties.postcode === id)?.properties || null
: data.find((d) => d.h3 === id) || null
}
filters={filters}
features={features}
/>
);
});

View file

@ -0,0 +1,146 @@
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import type { TFunction } from 'i18next';
import type { ActualListing } from '../../types';
function formatListingPrice(price: number): string {
return `£${price.toLocaleString()}`;
}
function formatListingHeadline(listing: ActualListing, t: TFunction): string | null {
const parts: string[] = [];
if (listing.bedrooms != null) parts.push(t('common.bedsCount', { count: listing.bedrooms }));
if (listing.bathrooms != null) parts.push(t('common.bathsCount', { count: listing.bathrooms }));
if (listing.property_sub_type) parts.push(listing.property_sub_type);
else if (listing.property_type) parts.push(listing.property_type);
return parts.length > 0 ? parts.join(' · ') : null;
}
export const ListingPopupSingleContent = memo(function ListingPopupSingleContent({
listing,
}: {
listing: ActualListing;
}) {
const { t } = useTranslation();
return (
<a
href={listing.listing_url}
target="_blank"
rel="noopener noreferrer"
className="block px-3 py-2"
>
{listing.asking_price != null && (
<div className="text-base font-bold text-teal-600 dark:text-teal-400">
{formatListingPrice(listing.asking_price)}
{listing.price_qualifier ? (
<span className="ml-1 text-xs font-medium text-warm-500 dark:text-warm-400">
{listing.price_qualifier}
</span>
) : null}
</div>
)}
{formatListingHeadline(listing, t) && (
<div className="text-xs text-warm-700 dark:text-warm-200 mt-0.5">
{formatListingHeadline(listing, t)}
</div>
)}
{listing.address && (
<div className="text-xs text-warm-500 dark:text-warm-400 mt-0.5 line-clamp-2">
{listing.address}
</div>
)}
{listing.postcode && (
<div className="text-[11px] text-warm-400 dark:text-warm-500 mt-0.5">
{listing.postcode}
</div>
)}
{listing.floor_area_sqm != null && (
<div className="text-[11px] text-warm-500 dark:text-warm-400 mt-0.5">
{Math.round(listing.floor_area_sqm)} sqm
{listing.asking_price_per_sqm != null
? ` · £${Math.round(listing.asking_price_per_sqm).toLocaleString()}/sqm`
: ''}
</div>
)}
{listing.features.length > 0 && (
<ul className="mt-1.5 text-[11px] text-warm-600 dark:text-warm-300 list-disc pl-4 space-y-0.5">
{listing.features.slice(0, 3).map((feature, idx) => (
<li key={idx} className="line-clamp-1">
{feature}
</li>
))}
</ul>
)}
<div className="mt-1.5 text-[11px] text-teal-600 dark:text-teal-400 font-medium">
Open listing
</div>
</a>
);
});
export const ListingClusterPopupContent = memo(function ListingClusterPopupContent({
count,
listings,
}: {
count: number;
listings: ActualListing[];
}) {
const { t } = useTranslation();
const visibleCount = listings.length;
return (
<div>
<div className="border-b border-warm-200 px-3 py-2 dark:border-warm-700">
<div className="text-base font-bold text-red-600 dark:text-red-400">
{count.toLocaleString()} listings
</div>
<div className="text-[11px] text-warm-500 dark:text-warm-400">
{visibleCount > 0
? `Showing ${visibleCount.toLocaleString()} of ${count.toLocaleString()}`
: 'Grouped near this map position'}
</div>
</div>
{visibleCount > 0 && (
<div className="max-h-80 overflow-y-auto py-1">
{listings.map((listing, idx) => {
const headline = formatListingHeadline(listing, t);
return (
<a
key={`${listing.listing_url}-${idx}`}
href={listing.listing_url}
target="_blank"
rel="noopener noreferrer"
className="block border-b border-warm-100 px-3 py-2 last:border-b-0 hover:bg-warm-50 dark:border-warm-700 dark:hover:bg-warm-700/60"
>
<div className="flex items-start justify-between gap-3">
<div className="min-w-0">
<div className="text-sm font-semibold text-teal-700 dark:text-teal-300">
{listing.asking_price != null
? formatListingPrice(listing.asking_price)
: 'Listing'}
</div>
{headline && (
<div className="mt-0.5 truncate text-xs text-warm-700 dark:text-warm-200">
{headline}
</div>
)}
{listing.address && (
<div className="mt-0.5 line-clamp-1 text-[11px] text-warm-500 dark:text-warm-400">
{listing.address}
</div>
)}
</div>
{listing.postcode && (
<div className="shrink-0 text-[11px] font-medium text-warm-400 dark:text-warm-500">
{listing.postcode}
</div>
)}
</div>
</a>
);
})}
</div>
)}
</div>
);
});

View file

@ -1,10 +1,8 @@
import { useCallback, useRef, useEffect, useState, useMemo, memo } from 'react'; import { useCallback, useRef, useEffect, useState, useMemo, memo } from 'react';
import type { CSSProperties } from 'react'; import type { CSSProperties } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import type { TFunction } from 'i18next'; import { Map as MapGL, ScaleControl } from 'react-map-gl/maplibre';
import { Layer, Map as MapGL, Source, useControl, ScaleControl } from 'react-map-gl/maplibre';
import type { MapRef } from 'react-map-gl/maplibre'; import type { MapRef } from 'react-map-gl/maplibre';
import { MapboxOverlay } from '@deck.gl/mapbox';
import 'maplibre-gl/dist/maplibre-gl.css'; import 'maplibre-gl/dist/maplibre-gl.css';
import type { import type {
HexagonData, HexagonData,
@ -17,7 +15,6 @@ import type {
Bounds, Bounds,
MapFlyToOptions, MapFlyToOptions,
ActualListing, ActualListing,
SchoolMetadata,
} from '../../types'; } from '../../types';
import { import {
@ -26,28 +23,25 @@ import {
getBoundsWithBottomScreenInset, getBoundsWithBottomScreenInset,
getMapStyle, getMapStyle,
getMapDataBeforeId, getMapDataBeforeId,
getPoiIconUrl,
getMapCenterForTargetScreenPoint, getMapCenterForTargetScreenPoint,
} from '../../lib/map-utils'; } from '../../lib/map-utils';
import { import { MAP_MIN_ZOOM, MAP_BOUNDS, POI_AUTO_CARD_ZOOM_THRESHOLD } from '../../lib/consts';
MAP_MIN_ZOOM, import type { SearchedLocation } from './LocationSearch';
MAP_BOUNDS,
POI_GROUP_COLORS,
POSTCODE_ZOOM_THRESHOLD,
POI_AUTO_CARD_ZOOM_THRESHOLD,
} from '../../lib/consts';
import LocationSearch, { type SearchedLocation } from './LocationSearch';
import MapLegend from './MapLegend';
import HoverCard from './HoverCard';
import { LogoIcon } from '../ui/icons/LogoIcon'; import { LogoIcon } from '../ui/icons/LogoIcon';
import { CloseIcon } from '../ui/icons/CloseIcon'; import { CloseIcon } from '../ui/icons/CloseIcon';
import type { FeatureFilters } from '../../types'; import type { FeatureFilters } from '../../types';
import { useDeckLayers } from '../../hooks/useDeckLayers'; import { useDeckLayers } from '../../hooks/useDeckLayers';
import { useTranslatedModes, type TravelTimeEntry } from '../../hooks/useTravelTime'; import { useMapCardLayout } from '../../hooks/useMapCardLayout';
import { ts } from '../../i18n/server'; import type { TravelTimeEntry } from '../../hooks/useTravelTime';
import { type OverlayId, OVERLAY_MIN_ZOOM } from '../../lib/overlays'; import { type OverlayId } from '../../lib/overlays';
import { CRIME_TYPE_VALUES } from '../../lib/crime-types'; import { CRIME_TYPE_VALUES } from '../../lib/crime-types';
import type { BasemapId } from '../../lib/basemaps'; import type { BasemapId } from '../../lib/basemaps';
import { DeckOverlay } from './DeckOverlay';
import { OverlayTileLayers } from './OverlayTileLayers';
import { MapTopCards } from './MapTopCards';
import { PoiPopupCardContent } from './PoiPopupCard';
import { ListingClusterPopupContent, ListingPopupSingleContent } from './ListingPopups';
import { HoverCardOverlay } from './HoverCardOverlay';
interface MapProps { interface MapProps {
data: HexagonData[]; data: HexagonData[];
@ -99,168 +93,11 @@ const EMPTY_ACTUAL_LISTINGS: ActualListing[] = [];
const EMPTY_OVERLAYS = new Set<OverlayId>(); const EMPTY_OVERLAYS = new Set<OverlayId>();
const ALL_CRIME_TYPES = new Set<string>(CRIME_TYPE_VALUES); const ALL_CRIME_TYPES = new Set<string>(CRIME_TYPE_VALUES);
function formatListingPrice(price: number): string {
return `£${price.toLocaleString()}`;
}
function formatListingHeadline(listing: ActualListing, t: TFunction): string | null {
const parts: string[] = [];
if (listing.bedrooms != null) parts.push(t('common.bedsCount', { count: listing.bedrooms }));
if (listing.bathrooms != null) parts.push(t('common.bathsCount', { count: listing.bathrooms }));
if (listing.property_sub_type) parts.push(listing.property_sub_type);
else if (listing.property_type) parts.push(listing.property_type);
return parts.length > 0 ? parts.join(' · ') : null;
}
function ListingPopupSingleContent({ listing, t }: { listing: ActualListing; t: TFunction }) {
return (
<a
href={listing.listing_url}
target="_blank"
rel="noopener noreferrer"
className="block px-3 py-2"
>
{listing.asking_price != null && (
<div className="text-base font-bold text-teal-600 dark:text-teal-400">
{formatListingPrice(listing.asking_price)}
{listing.price_qualifier ? (
<span className="ml-1 text-xs font-medium text-warm-500 dark:text-warm-400">
{listing.price_qualifier}
</span>
) : null}
</div>
)}
{formatListingHeadline(listing, t) && (
<div className="text-xs text-warm-700 dark:text-warm-200 mt-0.5">
{formatListingHeadline(listing, t)}
</div>
)}
{listing.address && (
<div className="text-xs text-warm-500 dark:text-warm-400 mt-0.5 line-clamp-2">
{listing.address}
</div>
)}
{listing.postcode && (
<div className="text-[11px] text-warm-400 dark:text-warm-500 mt-0.5">
{listing.postcode}
</div>
)}
{listing.floor_area_sqm != null && (
<div className="text-[11px] text-warm-500 dark:text-warm-400 mt-0.5">
{Math.round(listing.floor_area_sqm)} sqm
{listing.asking_price_per_sqm != null
? ` · £${Math.round(listing.asking_price_per_sqm).toLocaleString()}/sqm`
: ''}
</div>
)}
{listing.features.length > 0 && (
<ul className="mt-1.5 text-[11px] text-warm-600 dark:text-warm-300 list-disc pl-4 space-y-0.5">
{listing.features.slice(0, 3).map((feature, idx) => (
<li key={idx} className="line-clamp-1">
{feature}
</li>
))}
</ul>
)}
<div className="mt-1.5 text-[11px] text-teal-600 dark:text-teal-400 font-medium">
Open listing
</div>
</a>
);
}
function ListingClusterPopupContent({
count,
listings,
t,
}: {
count: number;
listings: ActualListing[];
t: TFunction;
}) {
const visibleCount = listings.length;
return (
<div>
<div className="border-b border-warm-200 px-3 py-2 dark:border-warm-700">
<div className="text-base font-bold text-red-600 dark:text-red-400">
{count.toLocaleString()} listings
</div>
<div className="text-[11px] text-warm-500 dark:text-warm-400">
{visibleCount > 0
? `Showing ${visibleCount.toLocaleString()} of ${count.toLocaleString()}`
: 'Grouped near this map position'}
</div>
</div>
{visibleCount > 0 && (
<div className="max-h-80 overflow-y-auto py-1">
{listings.map((listing, idx) => {
const headline = formatListingHeadline(listing, t);
return (
<a
key={`${listing.listing_url}-${idx}`}
href={listing.listing_url}
target="_blank"
rel="noopener noreferrer"
className="block border-b border-warm-100 px-3 py-2 last:border-b-0 hover:bg-warm-50 dark:border-warm-700 dark:hover:bg-warm-700/60"
>
<div className="flex items-start justify-between gap-3">
<div className="min-w-0">
<div className="text-sm font-semibold text-teal-700 dark:text-teal-300">
{listing.asking_price != null
? formatListingPrice(listing.asking_price)
: 'Listing'}
</div>
{headline && (
<div className="mt-0.5 truncate text-xs text-warm-700 dark:text-warm-200">
{headline}
</div>
)}
{listing.address && (
<div className="mt-0.5 line-clamp-1 text-[11px] text-warm-500 dark:text-warm-400">
{listing.address}
</div>
)}
</div>
{listing.postcode && (
<div className="shrink-0 text-[11px] font-medium text-warm-400 dark:text-warm-500">
{listing.postcode}
</div>
)}
</div>
</a>
);
})}
</div>
)}
</div>
);
}
interface PoiPopupCardData {
name: string;
category: string;
icon_category?: string;
group: string;
emoji: string;
school?: SchoolMetadata;
}
interface Dimensions { interface Dimensions {
width: number; width: number;
height: number; height: number;
} }
const DESKTOP_TOP_CARD_WIDTH = 300;
const DESKTOP_TOP_CARD_GAP = 8;
const DESKTOP_TOP_CARD_HORIZONTAL_INSET = 24;
const DESKTOP_TOP_CARDS_STACKED_MIN_MAP_WIDTH =
DESKTOP_TOP_CARD_WIDTH + DESKTOP_TOP_CARD_HORIZONTAL_INSET;
const DESKTOP_TOP_CARDS_ROW_MIN_MAP_WIDTH =
DESKTOP_TOP_CARD_WIDTH * 2 + DESKTOP_TOP_CARD_GAP + DESKTOP_TOP_CARD_HORIZONTAL_INSET;
const DESKTOP_TOP_CARD_CLASS = 'w-[300px]';
const DESKTOP_LOCATION_SEARCH_INPUT_CLASS =
'px-2 py-2 text-sm w-full border-none outline-none bg-transparent text-warm-700 dark:text-warm-200 placeholder-warm-400 dark:placeholder-warm-500';
type MapContainerStyle = CSSProperties & { type MapContainerStyle = CSSProperties & {
'--map-mobile-bottom-inset'?: string; '--map-mobile-bottom-inset'?: string;
}; };
@ -323,218 +160,6 @@ function getViewportRelativeVisibleAreaCenter(
}; };
} }
interface DeckWithPrivateDraw {
_drawLayers?: (
redrawReason: string,
renderOptions?: { viewports?: unknown[]; [key: string]: unknown }
) => unknown;
__propertyMapNullViewportPatch?: boolean;
}
function patchNullViewportDraw(overlay: MapboxOverlay) {
const deck = (overlay as unknown as { _deck?: DeckWithPrivateDraw })._deck;
if (!deck || deck.__propertyMapNullViewportPatch || typeof deck._drawLayers !== 'function') {
return;
}
const drawLayers = deck._drawLayers.bind(deck);
deck._drawLayers = (redrawReason, renderOptions) => {
const viewports = renderOptions?.viewports;
if (viewports) {
// Split-route startup can hand deck.gl a transient null viewport before MapLibre has sized the map.
const nonNullViewports = viewports.filter(Boolean);
if (nonNullViewports.length === 0) return;
if (nonNullViewports.length !== viewports.length) {
return drawLayers(redrawReason, { ...renderOptions, viewports: nonNullViewports });
}
}
return drawLayers(redrawReason, renderOptions);
};
deck.__propertyMapNullViewportPatch = true;
}
class SafeMapboxOverlay extends MapboxOverlay {
onAdd(map: unknown) {
const element = super.onAdd(map);
patchNullViewportDraw(this);
return element;
}
setProps(props: Parameters<MapboxOverlay['setProps']>[0]) {
super.setProps(props);
patchNullViewportDraw(this);
}
}
function getPoiGroupColor(group: string): [number, number, number] {
const color = POI_GROUP_COLORS[group];
if (!color) {
throw new Error(`Missing POI group color for '${group}'`);
}
return color;
}
/** Best-effort web URL from a free-text website field GIAS stores some with
* "http://", some without, and some as bare hostnames. */
function normalizeSchoolWebsiteUrl(raw: string): string | null {
const trimmed = raw.trim();
if (!trimmed) return null;
if (/^https?:\/\//i.test(trimmed)) return trimmed;
if (/^[\w.-]+\.[a-z]{2,}/i.test(trimmed)) return `http://${trimmed}`;
return null;
}
function renderSchoolMetadata(school: SchoolMetadata) {
// First line collects the headline classification (phase, type, religious
// character) so the popup is scannable even when most fields are absent.
const headline: string[] = [];
if (school.phase) headline.push(school.phase);
if (school.type) headline.push(school.type);
const pupilsLine =
school.pupils !== undefined && school.capacity !== undefined
? `${school.pupils.toLocaleString()} / ${school.capacity.toLocaleString()} pupils`
: school.pupils !== undefined
? `${school.pupils.toLocaleString()} pupils`
: school.capacity !== undefined
? `Capacity ${school.capacity.toLocaleString()}`
: null;
const websiteUrl = school.website ? normalizeSchoolWebsiteUrl(school.website) : null;
return (
<dl className="mt-2 grid grid-cols-[auto_1fr] gap-x-2 gap-y-0.5 text-xs text-warm-600 dark:text-warm-300">
{headline.length > 0 && (
<>
<dt className="text-warm-500 dark:text-warm-400">Type</dt>
<dd className="dark:text-warm-200">{headline.join(' · ')}</dd>
</>
)}
{school.age_range && (
<>
<dt className="text-warm-500 dark:text-warm-400">Ages</dt>
<dd className="dark:text-warm-200">{school.age_range}</dd>
</>
)}
{school.gender && school.gender !== 'Mixed' && (
<>
<dt className="text-warm-500 dark:text-warm-400">Gender</dt>
<dd className="dark:text-warm-200">{school.gender}</dd>
</>
)}
{pupilsLine && (
<>
<dt className="text-warm-500 dark:text-warm-400">Pupils</dt>
<dd className="dark:text-warm-200">{pupilsLine}</dd>
</>
)}
{school.fsm_percent !== undefined && (
<>
<dt className="text-warm-500 dark:text-warm-400">Free meal</dt>
<dd className="dark:text-warm-200">{school.fsm_percent.toFixed(1)}%</dd>
</>
)}
{school.ofsted_rating && (
<>
<dt className="text-warm-500 dark:text-warm-400">Ofsted</dt>
<dd className="dark:text-warm-200">{school.ofsted_rating}</dd>
</>
)}
{school.sixth_form === 'Has a sixth form' && (
<>
<dt className="text-warm-500 dark:text-warm-400">Sixth form</dt>
<dd className="dark:text-warm-200">Yes</dd>
</>
)}
{school.religious_character &&
school.religious_character !== 'Does not apply' &&
school.religious_character !== 'None' && (
<>
<dt className="text-warm-500 dark:text-warm-400">Religion</dt>
<dd className="dark:text-warm-200">{school.religious_character}</dd>
</>
)}
{school.admissions_policy && (
<>
<dt className="text-warm-500 dark:text-warm-400">Admissions</dt>
<dd className="dark:text-warm-200">{school.admissions_policy}</dd>
</>
)}
{school.trust && (
<>
<dt className="text-warm-500 dark:text-warm-400">Trust</dt>
<dd className="dark:text-warm-200">{school.trust}</dd>
</>
)}
{(school.address || school.postcode) && (
<>
<dt className="text-warm-500 dark:text-warm-400">Address</dt>
<dd className="dark:text-warm-200">
{[school.address, school.postcode].filter(Boolean).join(', ')}
</dd>
</>
)}
{school.local_authority && (
<>
<dt className="text-warm-500 dark:text-warm-400">LA</dt>
<dd className="dark:text-warm-200">{school.local_authority}</dd>
</>
)}
{school.head_name && (
<>
<dt className="text-warm-500 dark:text-warm-400">Head</dt>
<dd className="dark:text-warm-200">{school.head_name}</dd>
</>
)}
{websiteUrl && (
<>
<dt className="text-warm-500 dark:text-warm-400">Website</dt>
<dd className="truncate">
<a
href={websiteUrl}
target="_blank"
rel="noreferrer noopener"
className="pointer-events-auto text-teal-600 hover:underline dark:text-teal-400"
>
{websiteUrl.replace(/^https?:\/\//, '')}
</a>
</dd>
</>
)}
</dl>
);
}
function PoiPopupCardContent({ poi }: { poi: PoiPopupCardData }) {
return (
<div className="px-3 py-2 max-w-[280px]">
<div className="flex items-center gap-2">
<img
src={getPoiIconUrl(poi.category, poi.emoji, poi.icon_category, poi.name)}
alt=""
aria-hidden="true"
loading="lazy"
referrerPolicy="no-referrer"
className="h-5 w-5 shrink-0 rounded-[4px] bg-white object-contain p-0.5"
/>
<div className="min-w-0">
<div className="font-semibold dark:text-warm-100">{poi.name}</div>
<div className="flex items-center gap-1.5 text-xs text-warm-500 dark:text-warm-400">
<span
className="inline-block w-2 h-2 rounded-full flex-shrink-0"
style={{
backgroundColor: `rgb(${getPoiGroupColor(poi.group).join(',')})`,
}}
/>
{ts(poi.category)}
</div>
</div>
</div>
{poi.school && renderSchoolMetadata(poi.school)}
</div>
);
}
function getRenderedViewState(map: MapRef | null): ViewState | null { function getRenderedViewState(map: MapRef | null): ViewState | null {
if (!map) return null; if (!map) return null;
@ -565,186 +190,6 @@ function getRenderedVisibleCenter(
}; };
} }
function DeckOverlay({
layers,
getTooltip,
}: {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
layers: any[];
// eslint-disable-next-line @typescript-eslint/no-explicit-any
getTooltip: any;
}) {
const overlay = useControl(() => new SafeMapboxOverlay({ interleaved: true }));
useEffect(() => {
overlay.setProps({
layers: layers.filter(Boolean),
getTooltip,
});
}, [overlay, layers, getTooltip]);
return null;
}
function overlayTileUrl(path: string): string {
return `${window.location.origin}/api/overlays/${path}/{z}/{x}/{y}`;
}
function OverlayTileLayers({
activeOverlays,
activeCrimeTypes,
zoom,
}: {
activeOverlays: Set<OverlayId>;
activeCrimeTypes: Set<string>;
zoom: number;
}) {
if (zoom < POSTCODE_ZOOM_THRESHOLD || activeOverlays.size === 0) return null;
const showNoise = activeOverlays.has('noise');
const showCrime = activeOverlays.has('crime-hotspots');
const showTrees = activeOverlays.has('trees-outside-woodlands');
const showPropertyBorders = activeOverlays.has('property-borders');
// Restrict the heatmap to the selected crime types. This must always be a
// concrete expression: passing `filter={undefined}` makes react-map-gl call
// map.addLayer({filter: undefined}), which MapLibre rejects at validation
// ("filter: array expected, undefined found"), so the layer is never created
// and the heatmap stays blank until a later setFilter call. An `in` over the
// selected types matches everything when all 14 are selected.
const crimeFilter = ['in', ['get', 'crime_type'], ['literal', Array.from(activeCrimeTypes)]];
return (
<>
{showNoise && (
<Source
id="overlay-noise-source"
type="raster"
tiles={[overlayTileUrl('noise')]}
tileSize={256}
minzoom={OVERLAY_MIN_ZOOM.noise}
maxzoom={14}
>
<Layer
id="overlay-noise"
type="raster"
minzoom={POSTCODE_ZOOM_THRESHOLD}
paint={{
'raster-opacity': 0.68,
'raster-fade-duration': 120,
}}
/>
</Source>
)}
{showCrime && (
<Source
id="overlay-crime-source"
type="vector"
tiles={[overlayTileUrl('crime-hotspots')]}
minzoom={OVERLAY_MIN_ZOOM['crime-hotspots']}
maxzoom={15}
>
<Layer
id="overlay-crime-heatmap"
type="heatmap"
source-layer="crime_hotspots"
minzoom={POSTCODE_ZOOM_THRESHOLD}
filter={crimeFilter as never}
paint={
{
'heatmap-weight': [
'interpolate',
['linear'],
['coalesce', ['get', 'count'], ['get', 'weight'], 1],
0,
0,
10,
1,
],
'heatmap-intensity': ['interpolate', ['linear'], ['zoom'], 15, 0.8, 18, 2.2],
'heatmap-radius': ['interpolate', ['linear'], ['zoom'], 15, 18, 18, 30],
'heatmap-opacity': 0.72,
'heatmap-color': [
'interpolate',
['linear'],
['heatmap-density'],
0,
'rgba(0, 0, 0, 0)',
0.2,
'rgb(253, 224, 71)',
0.45,
'rgb(249, 115, 22)',
0.75,
'rgb(220, 38, 38)',
1,
'rgb(127, 29, 29)',
],
} as never
}
/>
</Source>
)}
{showTrees && (
<Source
id="overlay-trees-source"
type="vector"
tiles={[overlayTileUrl('trees-outside-woodlands')]}
minzoom={OVERLAY_MIN_ZOOM['trees-outside-woodlands']}
maxzoom={16}
>
<Layer
id="overlay-tree-polygons"
type="fill"
source-layer="trees_outside_woodlands"
minzoom={POSTCODE_ZOOM_THRESHOLD}
paint={
{
'fill-color': '#1f9d55',
'fill-opacity': [
'interpolate',
['linear'],
['coalesce', ['get', 'area_sqm'], 0],
0,
0.28,
250,
0.62,
],
'fill-outline-color': 'rgba(15, 81, 50, 0.65)',
} as never
}
/>
</Source>
)}
{showPropertyBorders && (
<Source
id="overlay-property-borders-source"
type="vector"
tiles={[overlayTileUrl('property-borders')]}
minzoom={OVERLAY_MIN_ZOOM['property-borders']}
maxzoom={16}
>
<Layer
id="overlay-property-borders"
type="line"
source-layer="property_borders"
minzoom={POSTCODE_ZOOM_THRESHOLD}
paint={
{
'line-color': '#b45309',
'line-opacity': ['interpolate', ['linear'], ['zoom'], 15, 0.35, 18, 0.85],
'line-width': ['interpolate', ['linear'], ['zoom'], 15, 0.4, 18, 1.4],
} as never
}
/>
</Source>
)}
</>
);
}
export default memo(function Map({ export default memo(function Map({
data, data,
postcodeData, postcodeData,
@ -790,7 +235,6 @@ export default memo(function Map({
const containerRef = useRef<HTMLDivElement>(null); const containerRef = useRef<HTMLDivElement>(null);
const mapRef = useRef<MapRef | null>(null); const mapRef = useRef<MapRef | null>(null);
const { t } = useTranslation(); const { t } = useTranslation();
const modes = useTranslatedModes();
const densityLabel = densityLabelProp ?? t('mapLegend.numberOfProperties'); const densityLabel = densityLabelProp ?? t('mapLegend.numberOfProperties');
const [internalViewState, setInternalViewState] = useState<ViewState>(initialViewState); const [internalViewState, setInternalViewState] = useState<ViewState>(initialViewState);
const [dimensions, setDimensions] = useState<Dimensions>({ width: 0, height: 0 }); const [dimensions, setDimensions] = useState<Dimensions>({ width: 0, height: 0 });
@ -941,23 +385,16 @@ export default memo(function Map({
() => (bottomScreenInset > 0 ? { '--map-mobile-bottom-inset': `${bottomScreenInset}px` } : {}), () => (bottomScreenInset > 0 ? { '--map-mobile-bottom-inset': `${bottomScreenInset}px` } : {}),
[bottomScreenInset] [bottomScreenInset]
); );
const hideDesktopTopCardsForWidth = const { showLocationSearch, showLegend, topCardsLayoutClass } = useMapCardLayout({
hideTopCardsWhenNarrow && mapWidth: dimensions.width,
dimensions.width > 0 && hideTopCardsWhenNarrow,
dimensions.width < DESKTOP_TOP_CARDS_STACKED_MIN_MAP_WIDTH; hideLegend,
const stackDesktopTopCards = hideLocationSearch,
hideTopCardsWhenNarrow && });
dimensions.width >= DESKTOP_TOP_CARDS_STACKED_MIN_MAP_WIDTH &&
dimensions.width < DESKTOP_TOP_CARDS_ROW_MIN_MAP_WIDTH;
const showLocationSearch = !hideLocationSearch && !hideDesktopTopCardsForWidth;
const showLegend = !hideLegend && !hideDesktopTopCardsForWidth;
const getViewportCenter = useCallback(() => { const getViewportCenter = useCallback(() => {
const center = mapRef.current?.getCenter(); const center = mapRef.current?.getCenter();
return center ? { lat: center.lat, lng: center.lng } : null; return center ? { lat: center.lat, lng: center.lng } : null;
}, []); }, []);
const desktopTopCardsLayoutClass = stackDesktopTopCards
? 'flex-col items-start'
: 'items-start justify-between';
const { const {
layers, layers,
@ -1108,79 +545,29 @@ export default memo(function Map({
) : ( ) : (
<> <>
{(showLocationSearch || showLegend) && ( {(showLocationSearch || showLegend) && (
<div <MapTopCards
className={`absolute top-3 left-3 right-3 z-20 flex gap-2 pointer-events-none ${desktopTopCardsLayoutClass}`} layoutClass={topCardsLayoutClass}
> showLocationSearch={showLocationSearch}
{showLocationSearch && ( showLegend={showLegend}
<LocationSearch onFlyTo={handleFlyTo}
onFlyTo={handleFlyTo} onLocationSearched={onLocationSearched}
onLocationSearched={onLocationSearched} onCurrentLocationFound={onCurrentLocationFound}
onCurrentLocationFound={onCurrentLocationFound} onLocationSearchMouseEnter={handleMouseLeave}
onMouseEnter={handleMouseLeave} getViewportCenter={getViewportCenter}
getViewportCenter={getViewportCenter} viewFeature={viewFeature}
className={DESKTOP_TOP_CARD_CLASS} colorRange={colorRange}
inputClassName={DESKTOP_LOCATION_SEARCH_INPUT_CLASS} viewSource={viewSource}
/> onCancelPin={onCancelPin}
)} onResetPreviewScale={onResetPreviewScale}
{showLegend && canResetPreviewScale={canResetPreviewScale}
(viewFeature && colorRange ? ( colorFeatureMeta={colorFeatureMeta}
viewFeature.startsWith('tt_') ? ( usePostcodeView={usePostcodeView}
<MapLegend countRange={countRange}
featureLabel={t('travel.travelTime', { postcodeCountRange={postcodeCountRange}
mode: modes.label( densityLabel={densityLabel}
viewFeature.split('_')[1] as 'car' | 'bicycle' | 'walking' | 'transit' totalCount={totalCountProp}
), theme={theme}
})} />
range={colorRange}
showCancel={viewSource === 'eye'}
onCancel={onCancelPin}
onResetScale={viewSource === 'eye' ? onResetPreviewScale : undefined}
resetScaleDisabled={!canResetPreviewScale}
mode="feature"
theme={theme}
suffix=" min"
className={DESKTOP_TOP_CARD_CLASS}
/>
) : colorFeatureMeta ? (
<MapLegend
featureLabel={
viewSource === 'eye'
? t('mapLegend.previewing', { name: ts(colorFeatureMeta.name) })
: ts(colorFeatureMeta.name)
}
range={colorRange}
showCancel={viewSource === 'eye'}
onCancel={onCancelPin}
onResetScale={viewSource === 'eye' ? onResetPreviewScale : undefined}
resetScaleDisabled={!canResetPreviewScale}
mode="feature"
enumValues={
colorFeatureMeta.type === 'enum' ? colorFeatureMeta.values : undefined
}
featureName={colorFeatureMeta.name}
theme={theme}
suffix={colorFeatureMeta.suffix}
raw={colorFeatureMeta.raw}
className={DESKTOP_TOP_CARD_CLASS}
/>
) : null
) : (
<MapLegend
featureLabel={densityLabel}
range={
usePostcodeView
? [postcodeCountRange.min, postcodeCountRange.max]
: [countRange.min, countRange.max]
}
totalCount={totalCountProp}
showCancel={false}
onCancel={onCancelPin}
mode="density"
theme={theme}
className={DESKTOP_TOP_CARD_CLASS}
/>
))}
</div>
)} )}
{autoPoiCards.map(({ poi, x, y }) => ( {autoPoiCards.map(({ poi, x, y }) => (
<div <div
@ -1247,28 +634,23 @@ export default memo(function Map({
<CloseIcon className="w-3 h-3" /> <CloseIcon className="w-3 h-3" />
</button> </button>
{listingPopup.mode === 'single' ? ( {listingPopup.mode === 'single' ? (
<ListingPopupSingleContent listing={listingPopup.listing} t={t} /> <ListingPopupSingleContent listing={listingPopup.listing} />
) : ( ) : (
<ListingClusterPopupContent <ListingClusterPopupContent
count={listingPopup.count} count={listingPopup.count}
listings={listingPopup.listings} listings={listingPopup.listings}
t={t}
/> />
)} )}
</div> </div>
)} )}
{hoverPosition && hoveredHexagonId && hoveredHexagonId !== selectedHexagonId && ( {hoverPosition && hoveredHexagonId && hoveredHexagonId !== selectedHexagonId && (
<HoverCard <HoverCardOverlay
x={hoverPosition.x} x={hoverPosition.x}
y={hoverPosition.y} y={hoverPosition.y}
id={hoveredHexagonId} id={hoveredHexagonId}
isPostcode={usePostcodeView} usePostcodeView={usePostcodeView}
data={ data={data}
usePostcodeView postcodeData={postcodeData}
? postcodeData.find((f) => f.properties.postcode === hoveredHexagonId)
?.properties || null
: data.find((d) => d.h3 === hoveredHexagonId) || null
}
filters={filters} filters={filters}
features={features} features={features}
/> />

View file

@ -1,7 +1,7 @@
import { Suspense, useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { Suspense, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { Trans, useTranslation } from 'react-i18next'; import { Trans, useTranslation } from 'react-i18next';
import type { ActualListing, MapFlyToOptions, PostcodeGeometry } from '../../types'; import type { ActualListing, PostcodeGeometry } from '../../types';
import type { SearchedLocation } from './LocationSearch'; import type { SearchedLocation } from './LocationSearch';
import { useMapData } from '../../hooks/useMapData'; import { useMapData } from '../../hooks/useMapData';
import { usePOIData } from '../../hooks/usePOIData'; import { usePOIData } from '../../hooks/usePOIData';
@ -67,11 +67,11 @@ import {
useMobileBackNavigationGuard, useMobileBackNavigationGuard,
useScreenshotReadySignal, useScreenshotReadySignal,
} from './map-page/effects'; } from './map-page/effects';
import { useMobileDrawer } from './map-page/useMobileDrawer';
import type { MapFlyTo, MapPageProps } from './map-page/types'; import type { MapFlyTo, MapPageProps } from './map-page/types';
export type { ExportState } from './map-page/types'; export type { ExportState } from './map-page/types';
type PendingFlyTo = { lat: number; lng: number; zoom: number };
const EMPTY_ACTUAL_LISTINGS: ActualListing[] = []; const EMPTY_ACTUAL_LISTINGS: ActualListing[] = [];
export default function MapPage({ export default function MapPage({
@ -127,10 +127,11 @@ export default function MapPage({
); );
const [leftPaneWidth, leftPaneHandlers] = usePaneResize(384, 200, 0.45, 'left'); const [leftPaneWidth, leftPaneHandlers] = usePaneResize(384, 200, 0.45, 'left');
const [rightPaneWidth, rightPaneHandlers] = usePaneResize(384, 200, 0.45, 'right'); const [rightPaneWidth, rightPaneHandlers] = usePaneResize(384, 200, 0.45, 'right');
const [mobileDrawerOpen, setMobileDrawerOpen] = useState(false); // The POI and overlay panes are mutually exclusive, so a single state tracks
const [mobileBottomSheetHeight, setMobileBottomSheetHeight] = useState(0); // which one (if any) is open.
const [poiPaneOpen, setPoiPaneOpen] = useState(false); const [openMapPane, setOpenMapPane] = useState<'poi' | 'overlay' | null>(null);
const [overlayPaneOpen, setOverlayPaneOpen] = useState(false); const poiPaneOpen = openMapPane === 'poi';
const overlayPaneOpen = openMapPane === 'overlay';
const [currentLocation, setCurrentLocation] = useState<{ lat: number; lng: number } | null>(null); const [currentLocation, setCurrentLocation] = useState<{ lat: number; lng: number } | null>(null);
const [listingsToggleEnabled, setListingsToggleEnabled] = useState(true); const [listingsToggleEnabled, setListingsToggleEnabled] = useState(true);
const [pendingInitialPostcode, setPendingInitialPostcode] = useState<string | null>( const [pendingInitialPostcode, setPendingInitialPostcode] = useState<string | null>(
@ -184,27 +185,21 @@ export default function MapPage({
} = useTravelTime(initialTravelTime); } = useTravelTime(initialTravelTime);
const mapFlyToRef = useRef<MapFlyTo | null>(null); const mapFlyToRef = useRef<MapFlyTo | null>(null);
const pendingCurrentLocationFlyToRef = useRef<{ lat: number; lng: number } | null>(null);
const pendingLocationSearchFlyToRef = useRef<PendingFlyTo | null>(null);
const mobileDrawerPanelRectRef = useRef<DOMRectReadOnly | null>(null);
const areaPaneScrollTopRef = useRef(0); const areaPaneScrollTopRef = useRef(0);
const propertiesPaneScrollTopRef = useRef(0); const propertiesPaneScrollTopRef = useRef(0);
const getMobileMapFlyToOptions = useCallback((): MapFlyToOptions | undefined => { const {
if (!isMobile) return undefined; mobileDrawerOpen,
mobileBottomSheetHeight,
const panelRect = mobileDrawerPanelRectRef.current; setMobileBottomSheetHeight,
if (mobileDrawerOpen && panelRect) { openMobileDrawer,
const bottomInset = Math.max(0, window.innerHeight - panelRect.top); openMobileDrawerForLocationSearch,
if (bottomInset > 0) { clearPendingLocationSearchFlyTo,
return { visibleViewportArea: { bottom: bottomInset } }; queueCurrentLocationFlyTo,
} handleMobileDrawerPanelRectChange,
} handleMobileDrawerClose,
getMobileMapFlyToOptions,
return mobileBottomSheetHeight > 0 } = useMobileDrawer(isMobile, mapFlyToRef);
? { visibleArea: { bottom: mobileBottomSheetHeight } }
: undefined;
}, [isMobile, mobileBottomSheetHeight, mobileDrawerOpen]);
const mapData = useMapData({ const mapData = useMapData({
filters, filters,
@ -217,6 +212,12 @@ export default function MapPage({
shareCode, shareCode,
}); });
// Read the zoom through a ref inside handleAiFilterSubmit so panning/zooming
// doesn't recreate the callback (it sits in the Filters pane's dependency
// chain, which would otherwise re-render on every camera move).
const currentViewZoomRef = useRef<number | undefined>(undefined);
currentViewZoomRef.current = mapData.currentView?.zoom;
const handleAiFilterSubmit = useCallback( const handleAiFilterSubmit = useCallback(
async (query: string) => { async (query: string) => {
const context = { const context = {
@ -283,7 +284,7 @@ export default function MapPage({
mapFlyToRef.current?.( mapFlyToRef.current?.(
destination.lat, destination.lat,
destination.lon, destination.lon,
mapData.currentView?.zoom ?? INITIAL_VIEW_STATE.zoom, currentViewZoomRef.current ?? INITIAL_VIEW_STATE.zoom,
getMobileMapFlyToOptions() getMobileMapFlyToOptions()
); );
} }
@ -298,7 +299,6 @@ export default function MapPage({
getMobileMapFlyToOptions, getMobileMapFlyToOptions,
handleSetEntries, handleSetEntries,
handleSetFilters, handleSetFilters,
mapData.currentView?.zoom,
] ]
); );
@ -395,20 +395,6 @@ export default function MapPage({
journeyDest, journeyDest,
}); });
const consumePendingLocationSearchFlyTo = useCallback((rect?: DOMRectReadOnly | null) => {
const pending = pendingLocationSearchFlyToRef.current;
const panelRect = rect ?? mobileDrawerPanelRectRef.current;
if (!pending || !panelRect) return;
const bottomInset = Math.max(0, window.innerHeight - panelRect.top);
const flyTo = mapFlyToRef.current;
if (!flyTo) return;
flyTo(pending.lat, pending.lng, pending.zoom, {
visibleViewportArea: { bottom: bottomInset },
});
pendingLocationSearchFlyToRef.current = null;
}, []);
const handleLocationSearchResult = useCallback( const handleLocationSearchResult = useCallback(
(result: SearchedLocation | null) => { (result: SearchedLocation | null) => {
if (result) { if (result) {
@ -428,68 +414,41 @@ export default function MapPage({
result.focusAddress result.focusAddress
); );
if (isMobile) { if (isMobile) {
pendingLocationSearchFlyToRef.current = { openMobileDrawerForLocationSearch({
lat: markerLat ?? result.latitude, lat: markerLat ?? result.latitude,
lng: markerLng ?? result.longitude, lng: markerLng ?? result.longitude,
zoom: result.zoom, zoom: result.zoom,
}; });
setMobileDrawerOpen(true);
consumePendingLocationSearchFlyTo();
} }
} else { } else {
setCurrentLocation(null); setCurrentLocation(null);
pendingLocationSearchFlyToRef.current = null; clearPendingLocationSearchFlyTo();
handleCloseSelection(); handleCloseSelection();
} }
}, },
[consumePendingLocationSearchFlyTo, handleCloseSelection, handleLocationSearch, isMobile] [
clearPendingLocationSearchFlyTo,
handleCloseSelection,
handleLocationSearch,
isMobile,
openMobileDrawerForLocationSearch,
]
); );
const consumePendingCurrentLocationFlyTo = useCallback((rect?: DOMRectReadOnly | null) => {
const pending = pendingCurrentLocationFlyToRef.current;
const panelRect = rect ?? mobileDrawerPanelRectRef.current;
if (!pending || !panelRect) return;
const bottomInset = Math.max(0, window.innerHeight - panelRect.top);
const flyTo = mapFlyToRef.current;
if (!flyTo) return;
flyTo(pending.lat, pending.lng, 17, {
visibleViewportArea: { bottom: bottomInset },
});
pendingCurrentLocationFlyToRef.current = null;
}, []);
const handleCurrentLocationFound = useCallback( const handleCurrentLocationFound = useCallback(
(lat: number, lng: number) => { (lat: number, lng: number) => {
if (isMobile) { if (isMobile) {
pendingCurrentLocationFlyToRef.current = { lat, lng }; queueCurrentLocationFlyTo(lat, lng);
consumePendingCurrentLocationFlyTo();
} else { } else {
mapFlyToRef.current?.(lat, lng, 17); mapFlyToRef.current?.(lat, lng, 17);
} }
setCurrentLocation({ lat, lng }); setCurrentLocation({ lat, lng });
handleCurrentLocationSearch(lat, lng); handleCurrentLocationSearch(lat, lng);
if (isMobile) setMobileDrawerOpen(true); if (isMobile) openMobileDrawer();
}, },
[consumePendingCurrentLocationFlyTo, handleCurrentLocationSearch, isMobile] [handleCurrentLocationSearch, isMobile, openMobileDrawer, queueCurrentLocationFlyTo]
); );
const handleMobileDrawerPanelRectChange = useCallback(
(rect: DOMRectReadOnly) => {
mobileDrawerPanelRectRef.current = rect;
consumePendingCurrentLocationFlyTo(rect);
consumePendingLocationSearchFlyTo(rect);
},
[consumePendingCurrentLocationFlyTo, consumePendingLocationSearchFlyTo]
);
const handleMobileDrawerClose = useCallback(() => {
pendingCurrentLocationFlyToRef.current = null;
pendingLocationSearchFlyToRef.current = null;
mobileDrawerPanelRectRef.current = null;
setMobileDrawerOpen(false);
}, []);
const shareReturnViewRef = useRef(shareCode ? initialViewState : null); const shareReturnViewRef = useRef(shareCode ? initialViewState : null);
// Hide the upgrade modal as soon as the user dismisses it. We can't rely on // Hide the upgrade modal as soon as the user dismisses it. We can't rely on
// the camera fly alone to close it: flying back to the free/shared zone only // the camera fly alone to close it: flying back to the free/shared zone only
@ -555,11 +514,7 @@ export default function MapPage({
isMobile, isMobile,
flyTo: mapFlyToRef, flyTo: mapFlyToRef,
onLocationSearch: handleLocationSearch, onLocationSearch: handleLocationSearch,
onOpenMobileDrawer: (target) => { onOpenMobileDrawer: openMobileDrawerForLocationSearch,
pendingLocationSearchFlyToRef.current = target;
setMobileDrawerOpen(true);
consumePendingLocationSearchFlyTo();
},
onSettled: () => setPendingInitialPostcode(null), onSettled: () => setPendingInitialPostcode(null),
}); });
useHorizontalSwipeNavigationGuard(); useHorizontalSwipeNavigationGuard();
@ -578,10 +533,10 @@ export default function MapPage({
(id: string, isPostcode?: boolean, geometry?: PostcodeGeometry) => { (id: string, isPostcode?: boolean, geometry?: PostcodeGeometry) => {
handleHexagonClick(id, isPostcode, geometry); handleHexagonClick(id, isPostcode, geometry);
if (id) { if (id) {
setMobileDrawerOpen(true); openMobileDrawer();
} }
}, },
[handleHexagonClick] [handleHexagonClick, openMobileDrawer]
); );
const hexagonLocation = useHexagonLocation( const hexagonLocation = useHexagonLocation(
@ -641,15 +596,20 @@ export default function MapPage({
shareAndSaveView, shareAndSaveView,
] ]
); );
// dashboardParams changes on every camera move; read it through a ref so the
// save/update handlers (and the Filters pane depending on them) stay stable
// while panning. The ref always holds the params of the latest render.
const dashboardParamsRef = useRef(dashboardParams);
dashboardParamsRef.current = dashboardParams;
const handleSaveSearch = useCallback( const handleSaveSearch = useCallback(
async (name: string) => { async (name: string) => {
await onSaveSearch?.(name, dashboardParams); await onSaveSearch?.(name, dashboardParamsRef.current);
}, },
[dashboardParams, onSaveSearch] [onSaveSearch]
); );
const handleUpdateEditInPlaceWithParams = useCallback(async () => { const handleUpdateEditInPlaceWithParams = useCallback(async () => {
await onUpdateEditInPlace?.(dashboardParams); await onUpdateEditInPlace?.(dashboardParamsRef.current);
}, [dashboardParams, onUpdateEditInPlace]); }, [onUpdateEditInPlace]);
const checkoutReturnPath = useMemo( const checkoutReturnPath = useMemo(
() => `/dashboard${dashboardParams ? `?${dashboardParams}` : ''}`, () => `/dashboard${dashboardParams ? `?${dashboardParams}` : ''}`,
[dashboardParams] [dashboardParams]
@ -686,6 +646,273 @@ export default function MapPage({
} }
}, [mapData.licenseRequired]); }, [mapData.licenseRequired]);
const handleUpgradeClick = useCallback(() => {
onNavigateTo('pricing');
}, [onNavigateTo]);
const handleTogglePoiPane = useCallback(() => {
setOpenMapPane((pane) => (pane === 'poi' ? null : 'poi'));
}, []);
const handleToggleOverlayPane = useCallback(() => {
setOpenMapPane((pane) => (pane === 'overlay' ? null : 'overlay'));
}, []);
const handleClosePoiPane = useCallback(() => {
setOpenMapPane((pane) => (pane === 'poi' ? null : pane));
}, []);
const handleCloseOverlayPane = useCallback(() => {
setOpenMapPane((pane) => (pane === 'overlay' ? null : pane));
}, []);
const handleAreaTabClick = useCallback(() => {
setRightPaneTab('area');
}, [setRightPaneTab]);
const handleMobileDrawerTabChange = useCallback(
(tab: 'area' | 'properties') => {
if (tab === 'properties') {
handlePropertiesTabClick();
} else {
setRightPaneTab(tab);
}
},
[handlePropertiesTabClick, setRightPaneTab]
);
const renderAreaPane = useCallback(
() => (
<Suspense fallback={<PaneFallback />}>
<AreaPane
stats={areaStats}
globalFeatures={features}
loading={loadingAreaStats}
hexagonId={selectedHexagon?.id || null}
isPostcode={selectedHexagon?.type === 'postcode'}
hexagonLocation={hexagonLocation}
filters={filters}
unfilteredCount={unfilteredAreaCount}
statsUseFilters={areaStatsUseFilters}
onStatsUseFiltersChange={setAreaStatsUseFilters}
travelTimeEntries={activeEntries}
shareCode={shareCode}
isGroupExpanded={isAreaGroupExpanded}
onToggleGroup={toggleAreaGroup}
scrollTopRef={areaPaneScrollTopRef}
scrollRestoreKey={
selectedHexagon ? `${selectedHexagon.type}:${selectedHexagon.id}` : null
}
scrollSaveDisabled={loadingAreaStats && areaStats == null}
/>
</Suspense>
),
[
activeEntries,
areaStats,
areaStatsUseFilters,
features,
filters,
hexagonLocation,
isAreaGroupExpanded,
loadingAreaStats,
selectedHexagon,
setAreaStatsUseFilters,
shareCode,
toggleAreaGroup,
unfilteredAreaCount,
]
);
const renderPropertiesPane = useCallback(
() => (
<Suspense fallback={<PaneFallback />}>
<PropertiesPane
properties={properties}
total={propertiesTotal}
loading={loadingProperties}
hexagonId={selectedHexagon?.id || null}
onLoadMore={handleLoadMoreProperties}
scrollTopRef={propertiesPaneScrollTopRef}
scrollRestoreKey={
selectedHexagon ? `${selectedHexagon.type}:${selectedHexagon.id}` : null
}
scrollSaveDisabled={loadingProperties && properties.length === 0}
/>
</Suspense>
),
[handleLoadMoreProperties, loadingProperties, properties, propertiesTotal, selectedHexagon]
);
const poiPane = useMemo(
() => (
<Suspense fallback={<PaneFallback />}>
<POIPane
groups={poiCategoryGroups}
selectedCategories={selectedPOICategories}
onCategoriesChange={setSelectedPOICategories}
poiCount={pois.length}
onClose={handleClosePoiPane}
/>
</Suspense>
),
[handleClosePoiPane, poiCategoryGroups, pois.length, selectedPOICategories]
);
const overlayPane = useMemo(
() => (
<Suspense fallback={<PaneFallback />}>
<OverlayPane
selectedOverlays={activeOverlays}
onOverlaysChange={setActiveOverlays}
selectedCrimeTypes={crimeTypes}
onCrimeTypesChange={setCrimeTypes}
basemap={basemap}
onBasemapChange={setBasemap}
colorOpacity={colorOpacity}
onColorOpacityChange={setColorOpacity}
zoomedIn={overlaysZoomedIn}
onClose={handleCloseOverlayPane}
/>
</Suspense>
),
[activeOverlays, basemap, colorOpacity, crimeTypes, handleCloseOverlayPane, overlaysZoomedIn]
);
const filtersPane = useMemo(
() => (
<Suspense fallback={<PaneFallback />}>
<Filters
features={features}
filters={filters}
activeFeature={activeFeature}
dragValue={dragValue}
enabledFeatures={enabledFeatures}
onAddFilter={handleAddFilter}
onRemoveFilter={handleRemoveFilter}
onFilterChange={handleFilterChange}
onDragStart={handleDragStart}
onDragChange={handleDragChange}
onDragEnd={handleDragEnd}
pinnedFeature={pinnedFeature}
onTogglePin={handleTogglePin}
openInfoFeature={pendingInfoFeature}
onClearOpenInfoFeature={onClearPendingInfoFeature}
travelTimeEntries={entries}
onTravelTimeAddEntry={handleAddEntry}
onTravelTimeRemoveEntry={handleTravelTimeRemoveEntry}
onTravelTimeSetDestination={handleTravelTimeSetDestination}
onTravelTimeRangeChange={handleTimeRangeChange}
onTravelTimeDragEnd={handleTravelTimeDragEnd}
onTravelTimeToggleBest={handleToggleBest}
onTravelTimeToggleNoChange={handleToggleNoChange}
onTravelTimeToggleNoBuses={handleToggleNoBuses}
aiFilterLoading={aiFilterLoading}
aiFilterError={aiFilterError}
aiFilterErrorType={aiFilterErrorType}
aiFilterNotes={aiFilterNotes}
aiFilterSummary={aiFilterSummary}
onAiFilterSubmit={handleAiFilterSubmit}
isLoggedIn={!!user}
onLoginRequired={onRegisterClick}
isLicensed={user?.subscription === 'licensed'}
onUpgradeClick={handleUpgradeClick}
onResetTutorial={!isMobile ? tutorial.resetTutorial : undefined}
filterImpacts={filterCounts.impacts}
onClearAll={handleClearAll}
onSaveSearch={onSaveSearch ? handleSaveSearch : undefined}
savingSearch={savingSearch}
editingSearchName={editingSearch?.name ?? null}
onUpdateSearch={
editingSearch && onUpdateEditInPlace ? handleUpdateEditInPlaceWithParams : undefined
}
onExitEditing={onCancelEdit}
destinationDropdownPortal={isMobile ? false : undefined}
/>
</Suspense>
),
[
activeFeature,
aiFilterError,
aiFilterErrorType,
aiFilterLoading,
aiFilterNotes,
aiFilterSummary,
dragValue,
editingSearch,
enabledFeatures,
entries,
features,
filterCounts.impacts,
filters,
handleAddEntry,
handleAddFilter,
handleAiFilterSubmit,
handleClearAll,
handleDragChange,
handleDragEnd,
handleDragStart,
handleFilterChange,
handleRemoveFilter,
handleSaveSearch,
handleTimeRangeChange,
handleToggleBest,
handleToggleNoBuses,
handleToggleNoChange,
handleTogglePin,
handleTravelTimeDragEnd,
handleTravelTimeRemoveEntry,
handleTravelTimeSetDestination,
handleUpdateEditInPlaceWithParams,
handleUpgradeClick,
isMobile,
onCancelEdit,
onClearPendingInfoFeature,
onRegisterClick,
onSaveSearch,
onUpdateEditInPlace,
pendingInfoFeature,
pinnedFeature,
savingSearch,
tutorial.resetTutorial,
user,
]
);
const mobileLegend = useMemo(
() => (
<MobileMapLegend
mapViewFeature={mapViewFeature}
colorRange={mapData.colorRange}
viewSource={viewSource}
mobileLegendMeta={mobileLegendMeta}
densityLabel={densityLabel}
densityRange={mobileDensityRange}
theme={theme}
canResetPreviewScale={mapData.canResetPreviewScale}
onCancelPin={handleCancelPin}
onResetPreviewScale={mapData.handleResetPreviewScale}
/>
),
[
densityLabel,
handleCancelPin,
mapData.canResetPreviewScale,
mapData.colorRange,
mapData.handleResetPreviewScale,
mapViewFeature,
mobileDensityRange,
mobileLegendMeta,
theme,
viewSource,
]
);
const toasts = useMemo(
() => (
<ExportToast
notice={exportNotice}
closeLabel={t('common.close')}
onClose={clearExportNotice}
/>
),
[clearExportNotice, exportNotice, t]
);
if (screenshotMode) { if (screenshotMode) {
return ( return (
<ScreenshotMapPage <ScreenshotMapPage
@ -706,147 +933,6 @@ export default function MapPage({
); );
} }
const renderAreaPane = () => (
<Suspense fallback={<PaneFallback />}>
<AreaPane
stats={areaStats}
globalFeatures={features}
loading={loadingAreaStats}
hexagonId={selectedHexagon?.id || null}
isPostcode={selectedHexagon?.type === 'postcode'}
hexagonLocation={hexagonLocation}
filters={filters}
unfilteredCount={unfilteredAreaCount}
statsUseFilters={areaStatsUseFilters}
onStatsUseFiltersChange={setAreaStatsUseFilters}
travelTimeEntries={activeEntries}
shareCode={shareCode}
isGroupExpanded={isAreaGroupExpanded}
onToggleGroup={toggleAreaGroup}
scrollTopRef={areaPaneScrollTopRef}
scrollRestoreKey={selectedHexagon ? `${selectedHexagon.type}:${selectedHexagon.id}` : null}
scrollSaveDisabled={loadingAreaStats && areaStats == null}
/>
</Suspense>
);
const renderPropertiesPane = () => (
<Suspense fallback={<PaneFallback />}>
<PropertiesPane
properties={properties}
total={propertiesTotal}
loading={loadingProperties}
hexagonId={selectedHexagon?.id || null}
onLoadMore={handleLoadMoreProperties}
scrollTopRef={propertiesPaneScrollTopRef}
scrollRestoreKey={selectedHexagon ? `${selectedHexagon.type}:${selectedHexagon.id}` : null}
scrollSaveDisabled={loadingProperties && properties.length === 0}
/>
</Suspense>
);
const renderPOIPane = () => (
<Suspense fallback={<PaneFallback />}>
<POIPane
groups={poiCategoryGroups}
selectedCategories={selectedPOICategories}
onCategoriesChange={setSelectedPOICategories}
poiCount={pois.length}
onClose={() => setPoiPaneOpen(false)}
/>
</Suspense>
);
const renderOverlayPane = () => (
<Suspense fallback={<PaneFallback />}>
<OverlayPane
selectedOverlays={activeOverlays}
onOverlaysChange={setActiveOverlays}
selectedCrimeTypes={crimeTypes}
onCrimeTypesChange={setCrimeTypes}
basemap={basemap}
onBasemapChange={setBasemap}
colorOpacity={colorOpacity}
onColorOpacityChange={setColorOpacity}
zoomedIn={overlaysZoomedIn}
onClose={() => setOverlayPaneOpen(false)}
/>
</Suspense>
);
const renderFilters = (options?: { destinationDropdownPortal?: boolean }) => (
<Suspense fallback={<PaneFallback />}>
<Filters
features={features}
filters={filters}
activeFeature={activeFeature}
dragValue={dragValue}
enabledFeatures={enabledFeatures}
onAddFilter={handleAddFilter}
onRemoveFilter={handleRemoveFilter}
onFilterChange={handleFilterChange}
onDragStart={handleDragStart}
onDragChange={handleDragChange}
onDragEnd={handleDragEnd}
pinnedFeature={pinnedFeature}
onTogglePin={handleTogglePin}
openInfoFeature={pendingInfoFeature}
onClearOpenInfoFeature={onClearPendingInfoFeature}
travelTimeEntries={entries}
onTravelTimeAddEntry={handleAddEntry}
onTravelTimeRemoveEntry={handleTravelTimeRemoveEntry}
onTravelTimeSetDestination={handleTravelTimeSetDestination}
onTravelTimeRangeChange={handleTimeRangeChange}
onTravelTimeDragEnd={handleTravelTimeDragEnd}
onTravelTimeToggleBest={handleToggleBest}
onTravelTimeToggleNoChange={handleToggleNoChange}
onTravelTimeToggleNoBuses={handleToggleNoBuses}
aiFilterLoading={aiFilterLoading}
aiFilterError={aiFilterError}
aiFilterErrorType={aiFilterErrorType}
aiFilterNotes={aiFilterNotes}
aiFilterSummary={aiFilterSummary}
onAiFilterSubmit={handleAiFilterSubmit}
isLoggedIn={!!user}
onLoginRequired={onRegisterClick}
isLicensed={user?.subscription === 'licensed'}
onUpgradeClick={() => onNavigateTo('pricing')}
onResetTutorial={!isMobile ? tutorial.resetTutorial : undefined}
filterImpacts={filterCounts.impacts}
onClearAll={handleClearAll}
onSaveSearch={onSaveSearch ? handleSaveSearch : undefined}
savingSearch={savingSearch}
editingSearchName={editingSearch?.name ?? null}
onUpdateSearch={
editingSearch && onUpdateEditInPlace ? handleUpdateEditInPlaceWithParams : undefined
}
onExitEditing={onCancelEdit}
destinationDropdownPortal={options?.destinationDropdownPortal}
/>
</Suspense>
);
const handleTogglePoiPane = () => {
setOverlayPaneOpen(false);
setPoiPaneOpen((open) => !open);
};
const handleToggleOverlayPane = () => {
setPoiPaneOpen(false);
setOverlayPaneOpen((open) => !open);
};
const handleMobileDrawerTabChange = (tab: 'area' | 'properties') => {
if (tab === 'properties') {
handlePropertiesTabClick();
} else {
setRightPaneTab(tab);
}
};
const exportToast = (
<ExportToast notice={exportNotice} closeLabel={t('common.close')} onClose={clearExportNotice} />
);
const toasts = exportToast;
const editingBar = const editingBar =
editingSearch && isMobile ? ( editingSearch && isMobile ? (
<div className="flex items-center gap-2 px-3 py-2 border-b border-warm-200 dark:border-navy-700 bg-warm-50 dark:bg-navy-900"> <div className="flex items-center gap-2 px-3 py-2 border-b border-warm-200 dark:border-navy-700 bg-warm-50 dark:bg-navy-900">
@ -940,25 +1026,12 @@ export default function MapPage({
poiPaneOpen={poiPaneOpen} poiPaneOpen={poiPaneOpen}
onTogglePoiPane={handleTogglePoiPane} onTogglePoiPane={handleTogglePoiPane}
poiButtonLabel={t('poiPane.pointsOfInterest')} poiButtonLabel={t('poiPane.pointsOfInterest')}
poiPane={renderPOIPane()} poiPane={poiPane}
overlayPaneOpen={overlayPaneOpen} overlayPaneOpen={overlayPaneOpen}
onToggleOverlayPane={handleToggleOverlayPane} onToggleOverlayPane={handleToggleOverlayPane}
overlayPane={renderOverlayPane()} overlayPane={overlayPane}
filtersPane={renderFilters({ destinationDropdownPortal: false })} filtersPane={filtersPane}
mobileLegend={ mobileLegend={mobileLegend}
<MobileMapLegend
mapViewFeature={mapViewFeature}
colorRange={mapData.colorRange}
viewSource={viewSource}
mobileLegendMeta={mobileLegendMeta}
densityLabel={densityLabel}
densityRange={mobileDensityRange}
theme={theme}
canResetPreviewScale={mapData.canResetPreviewScale}
onCancelPin={handleCancelPin}
onResetPreviewScale={mapData.handleResetPreviewScale}
/>
}
renderAreaPane={renderAreaPane} renderAreaPane={renderAreaPane}
renderPropertiesPane={renderPropertiesPane} renderPropertiesPane={renderPropertiesPane}
toasts={toasts} toasts={toasts}
@ -975,7 +1048,7 @@ export default function MapPage({
tutorialTheme={tutorialTheme} tutorialTheme={tutorialTheme}
leftPaneWidth={leftPaneWidth} leftPaneWidth={leftPaneWidth}
leftPaneHandlers={leftPaneHandlers} leftPaneHandlers={leftPaneHandlers}
filtersPane={renderFilters()} filtersPane={filtersPane}
mapData={mapData} mapData={mapData}
pois={pois} pois={pois}
activeOverlays={activeOverlays} activeOverlays={activeOverlays}
@ -1008,15 +1081,15 @@ export default function MapPage({
totalCount={filterCounts.total ?? undefined} totalCount={filterCounts.total ?? undefined}
poiPaneOpen={poiPaneOpen} poiPaneOpen={poiPaneOpen}
onTogglePoiPane={handleTogglePoiPane} onTogglePoiPane={handleTogglePoiPane}
poiPane={renderPOIPane()} poiPane={poiPane}
overlayPaneOpen={overlayPaneOpen} overlayPaneOpen={overlayPaneOpen}
onToggleOverlayPane={handleToggleOverlayPane} onToggleOverlayPane={handleToggleOverlayPane}
overlayPane={renderOverlayPane()} overlayPane={overlayPane}
showSelectionPane={!!selectedHexagon} showSelectionPane={!!selectedHexagon}
rightPaneWidth={rightPaneWidth} rightPaneWidth={rightPaneWidth}
rightPaneHandlers={rightPaneHandlers} rightPaneHandlers={rightPaneHandlers}
rightPaneTab={rightPaneTab} rightPaneTab={rightPaneTab}
onAreaTabClick={() => setRightPaneTab('area')} onAreaTabClick={handleAreaTabClick}
onPropertiesTabClick={handlePropertiesTabClick} onPropertiesTabClick={handlePropertiesTabClick}
onCloseSelection={handleCloseSelection} onCloseSelection={handleCloseSelection}
renderAreaPane={renderAreaPane} renderAreaPane={renderAreaPane}

View file

@ -0,0 +1,138 @@
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import type { FeatureMeta, MapFlyToOptions } from '../../types';
import { useTranslatedModes } from '../../hooks/useTravelTime';
import { ts } from '../../i18n/server';
import LocationSearch, { type SearchedLocation } from './LocationSearch';
import MapLegend from './MapLegend';
const DESKTOP_TOP_CARD_CLASS = 'w-[300px]';
const DESKTOP_LOCATION_SEARCH_INPUT_CLASS =
'px-2 py-2 text-sm w-full border-none outline-none bg-transparent text-warm-700 dark:text-warm-200 placeholder-warm-400 dark:placeholder-warm-500';
interface MapTopCardsProps {
layoutClass: string;
showLocationSearch: boolean;
showLegend: boolean;
onFlyTo: (lat: number, lng: number, zoom: number, options?: MapFlyToOptions) => void;
onLocationSearched?: (location: SearchedLocation | null) => void;
onCurrentLocationFound?: (lat: number, lng: number) => void;
onLocationSearchMouseEnter: () => void;
getViewportCenter: () => { lat: number; lng: number } | null;
viewFeature: string | null;
colorRange: [number, number] | null;
viewSource: 'drag' | 'eye' | null;
onCancelPin: () => void;
onResetPreviewScale?: () => void;
canResetPreviewScale: boolean;
colorFeatureMeta: FeatureMeta | null;
usePostcodeView: boolean;
countRange: { min: number; max: number };
postcodeCountRange: { min: number; max: number };
densityLabel: string;
totalCount?: number;
theme: 'light' | 'dark';
}
/** Desktop top-card overlay area: the location search box and the map legend. */
export const MapTopCards = memo(function MapTopCards({
layoutClass,
showLocationSearch,
showLegend,
onFlyTo,
onLocationSearched,
onCurrentLocationFound,
onLocationSearchMouseEnter,
getViewportCenter,
viewFeature,
colorRange,
viewSource,
onCancelPin,
onResetPreviewScale,
canResetPreviewScale,
colorFeatureMeta,
usePostcodeView,
countRange,
postcodeCountRange,
densityLabel,
totalCount,
theme,
}: MapTopCardsProps) {
const { t } = useTranslation();
const modes = useTranslatedModes();
return (
<div
className={`absolute top-3 left-3 right-3 z-20 flex gap-2 pointer-events-none ${layoutClass}`}
>
{showLocationSearch && (
<LocationSearch
onFlyTo={onFlyTo}
onLocationSearched={onLocationSearched}
onCurrentLocationFound={onCurrentLocationFound}
onMouseEnter={onLocationSearchMouseEnter}
getViewportCenter={getViewportCenter}
className={DESKTOP_TOP_CARD_CLASS}
inputClassName={DESKTOP_LOCATION_SEARCH_INPUT_CLASS}
/>
)}
{showLegend &&
(viewFeature && colorRange ? (
viewFeature.startsWith('tt_') ? (
<MapLegend
featureLabel={t('travel.travelTime', {
mode: modes.label(
viewFeature.split('_')[1] as 'car' | 'bicycle' | 'walking' | 'transit'
),
})}
range={colorRange}
showCancel={viewSource === 'eye'}
onCancel={onCancelPin}
onResetScale={viewSource === 'eye' ? onResetPreviewScale : undefined}
resetScaleDisabled={!canResetPreviewScale}
mode="feature"
theme={theme}
suffix=" min"
className={DESKTOP_TOP_CARD_CLASS}
/>
) : colorFeatureMeta ? (
<MapLegend
featureLabel={
viewSource === 'eye'
? t('mapLegend.previewing', { name: ts(colorFeatureMeta.name) })
: ts(colorFeatureMeta.name)
}
range={colorRange}
showCancel={viewSource === 'eye'}
onCancel={onCancelPin}
onResetScale={viewSource === 'eye' ? onResetPreviewScale : undefined}
resetScaleDisabled={!canResetPreviewScale}
mode="feature"
enumValues={colorFeatureMeta.type === 'enum' ? colorFeatureMeta.values : undefined}
featureName={colorFeatureMeta.name}
theme={theme}
suffix={colorFeatureMeta.suffix}
raw={colorFeatureMeta.raw}
className={DESKTOP_TOP_CARD_CLASS}
/>
) : null
) : (
<MapLegend
featureLabel={densityLabel}
range={
usePostcodeView
? [postcodeCountRange.min, postcodeCountRange.max]
: [countRange.min, countRange.max]
}
totalCount={totalCount}
showCancel={false}
onCancel={onCancelPin}
mode="density"
theme={theme}
className={DESKTOP_TOP_CARD_CLASS}
/>
))}
</div>
);
});

View file

@ -0,0 +1,163 @@
import { Layer, Source } from 'react-map-gl/maplibre';
import { POSTCODE_ZOOM_THRESHOLD } from '../../lib/consts';
import { type OverlayId, OVERLAY_MIN_ZOOM } from '../../lib/overlays';
function overlayTileUrl(path: string): string {
return `${window.location.origin}/api/overlays/${path}/{z}/{x}/{y}`;
}
export function OverlayTileLayers({
activeOverlays,
activeCrimeTypes,
zoom,
}: {
activeOverlays: Set<OverlayId>;
activeCrimeTypes: Set<string>;
zoom: number;
}) {
if (zoom < POSTCODE_ZOOM_THRESHOLD || activeOverlays.size === 0) return null;
const showNoise = activeOverlays.has('noise');
const showCrime = activeOverlays.has('crime-hotspots');
const showTrees = activeOverlays.has('trees-outside-woodlands');
const showPropertyBorders = activeOverlays.has('property-borders');
// Restrict the heatmap to the selected crime types. This must always be a
// concrete expression: passing `filter={undefined}` makes react-map-gl call
// map.addLayer({filter: undefined}), which MapLibre rejects at validation
// ("filter: array expected, undefined found"), so the layer is never created
// and the heatmap stays blank until a later setFilter call. An `in` over the
// selected types matches everything when all 14 are selected.
const crimeFilter = ['in', ['get', 'crime_type'], ['literal', Array.from(activeCrimeTypes)]];
return (
<>
{showNoise && (
<Source
id="overlay-noise-source"
type="raster"
tiles={[overlayTileUrl('noise')]}
tileSize={256}
minzoom={OVERLAY_MIN_ZOOM.noise}
maxzoom={14}
>
<Layer
id="overlay-noise"
type="raster"
minzoom={POSTCODE_ZOOM_THRESHOLD}
paint={{
'raster-opacity': 0.68,
'raster-fade-duration': 120,
}}
/>
</Source>
)}
{showCrime && (
<Source
id="overlay-crime-source"
type="vector"
tiles={[overlayTileUrl('crime-hotspots')]}
minzoom={OVERLAY_MIN_ZOOM['crime-hotspots']}
maxzoom={15}
>
<Layer
id="overlay-crime-heatmap"
type="heatmap"
source-layer="crime_hotspots"
minzoom={POSTCODE_ZOOM_THRESHOLD}
filter={crimeFilter as never}
paint={
{
'heatmap-weight': [
'interpolate',
['linear'],
['coalesce', ['get', 'count'], ['get', 'weight'], 1],
0,
0,
10,
1,
],
'heatmap-intensity': ['interpolate', ['linear'], ['zoom'], 15, 0.8, 18, 2.2],
'heatmap-radius': ['interpolate', ['linear'], ['zoom'], 15, 18, 18, 30],
'heatmap-opacity': 0.72,
'heatmap-color': [
'interpolate',
['linear'],
['heatmap-density'],
0,
'rgba(0, 0, 0, 0)',
0.2,
'rgb(253, 224, 71)',
0.45,
'rgb(249, 115, 22)',
0.75,
'rgb(220, 38, 38)',
1,
'rgb(127, 29, 29)',
],
} as never
}
/>
</Source>
)}
{showTrees && (
<Source
id="overlay-trees-source"
type="vector"
tiles={[overlayTileUrl('trees-outside-woodlands')]}
minzoom={OVERLAY_MIN_ZOOM['trees-outside-woodlands']}
maxzoom={16}
>
<Layer
id="overlay-tree-polygons"
type="fill"
source-layer="trees_outside_woodlands"
minzoom={POSTCODE_ZOOM_THRESHOLD}
paint={
{
'fill-color': '#1f9d55',
'fill-opacity': [
'interpolate',
['linear'],
['coalesce', ['get', 'area_sqm'], 0],
0,
0.28,
250,
0.62,
],
'fill-outline-color': 'rgba(15, 81, 50, 0.65)',
} as never
}
/>
</Source>
)}
{showPropertyBorders && (
<Source
id="overlay-property-borders-source"
type="vector"
tiles={[overlayTileUrl('property-borders')]}
minzoom={OVERLAY_MIN_ZOOM['property-borders']}
maxzoom={16}
>
<Layer
id="overlay-property-borders"
type="line"
source-layer="property_borders"
minzoom={POSTCODE_ZOOM_THRESHOLD}
paint={
{
'line-color': '#b45309',
'line-opacity': ['interpolate', ['linear'], ['zoom'], 15, 0.35, 18, 0.85],
'line-width': ['interpolate', ['linear'], ['zoom'], 15, 0.4, 18, 1.4],
} as never
}
/>
</Source>
)}
</>
);
}

View file

@ -0,0 +1,188 @@
import { memo } from 'react';
import type { SchoolMetadata } from '../../types';
import { POI_GROUP_COLORS } from '../../lib/consts';
import { getPoiIconUrl } from '../../lib/map-utils';
import { ts } from '../../i18n/server';
export interface PoiPopupCardData {
name: string;
category: string;
icon_category?: string;
group: string;
emoji: string;
school?: SchoolMetadata;
}
function getPoiGroupColor(group: string): [number, number, number] {
const color = POI_GROUP_COLORS[group];
if (!color) {
throw new Error(`Missing POI group color for '${group}'`);
}
return color;
}
/** Best-effort web URL from a free-text website field GIAS stores some with
* "http://", some without, and some as bare hostnames. */
function normalizeSchoolWebsiteUrl(raw: string): string | null {
const trimmed = raw.trim();
if (!trimmed) return null;
if (/^https?:\/\//i.test(trimmed)) return trimmed;
if (/^[\w.-]+\.[a-z]{2,}/i.test(trimmed)) return `http://${trimmed}`;
return null;
}
function renderSchoolMetadata(school: SchoolMetadata) {
// First line collects the headline classification (phase, type, religious
// character) so the popup is scannable even when most fields are absent.
const headline: string[] = [];
if (school.phase) headline.push(school.phase);
if (school.type) headline.push(school.type);
const pupilsLine =
school.pupils !== undefined && school.capacity !== undefined
? `${school.pupils.toLocaleString()} / ${school.capacity.toLocaleString()} pupils`
: school.pupils !== undefined
? `${school.pupils.toLocaleString()} pupils`
: school.capacity !== undefined
? `Capacity ${school.capacity.toLocaleString()}`
: null;
const websiteUrl = school.website ? normalizeSchoolWebsiteUrl(school.website) : null;
return (
<dl className="mt-2 grid grid-cols-[auto_1fr] gap-x-2 gap-y-0.5 text-xs text-warm-600 dark:text-warm-300">
{headline.length > 0 && (
<>
<dt className="text-warm-500 dark:text-warm-400">Type</dt>
<dd className="dark:text-warm-200">{headline.join(' · ')}</dd>
</>
)}
{school.age_range && (
<>
<dt className="text-warm-500 dark:text-warm-400">Ages</dt>
<dd className="dark:text-warm-200">{school.age_range}</dd>
</>
)}
{school.gender && school.gender !== 'Mixed' && (
<>
<dt className="text-warm-500 dark:text-warm-400">Gender</dt>
<dd className="dark:text-warm-200">{school.gender}</dd>
</>
)}
{pupilsLine && (
<>
<dt className="text-warm-500 dark:text-warm-400">Pupils</dt>
<dd className="dark:text-warm-200">{pupilsLine}</dd>
</>
)}
{school.fsm_percent !== undefined && (
<>
<dt className="text-warm-500 dark:text-warm-400">Free meal</dt>
<dd className="dark:text-warm-200">{school.fsm_percent.toFixed(1)}%</dd>
</>
)}
{school.ofsted_rating && (
<>
<dt className="text-warm-500 dark:text-warm-400">Ofsted</dt>
<dd className="dark:text-warm-200">{school.ofsted_rating}</dd>
</>
)}
{school.sixth_form === 'Has a sixth form' && (
<>
<dt className="text-warm-500 dark:text-warm-400">Sixth form</dt>
<dd className="dark:text-warm-200">Yes</dd>
</>
)}
{school.religious_character &&
school.religious_character !== 'Does not apply' &&
school.religious_character !== 'None' && (
<>
<dt className="text-warm-500 dark:text-warm-400">Religion</dt>
<dd className="dark:text-warm-200">{school.religious_character}</dd>
</>
)}
{school.admissions_policy && (
<>
<dt className="text-warm-500 dark:text-warm-400">Admissions</dt>
<dd className="dark:text-warm-200">{school.admissions_policy}</dd>
</>
)}
{school.trust && (
<>
<dt className="text-warm-500 dark:text-warm-400">Trust</dt>
<dd className="dark:text-warm-200">{school.trust}</dd>
</>
)}
{(school.address || school.postcode) && (
<>
<dt className="text-warm-500 dark:text-warm-400">Address</dt>
<dd className="dark:text-warm-200">
{[school.address, school.postcode].filter(Boolean).join(', ')}
</dd>
</>
)}
{school.local_authority && (
<>
<dt className="text-warm-500 dark:text-warm-400">LA</dt>
<dd className="dark:text-warm-200">{school.local_authority}</dd>
</>
)}
{school.head_name && (
<>
<dt className="text-warm-500 dark:text-warm-400">Head</dt>
<dd className="dark:text-warm-200">{school.head_name}</dd>
</>
)}
{websiteUrl && (
<>
<dt className="text-warm-500 dark:text-warm-400">Website</dt>
<dd className="truncate">
<a
href={websiteUrl}
target="_blank"
rel="noreferrer noopener"
className="pointer-events-auto text-teal-600 hover:underline dark:text-teal-400"
>
{websiteUrl.replace(/^https?:\/\//, '')}
</a>
</dd>
</>
)}
</dl>
);
}
export const PoiPopupCardContent = memo(function PoiPopupCardContent({
poi,
}: {
poi: PoiPopupCardData;
}) {
return (
<div className="px-3 py-2 max-w-[280px]">
<div className="flex items-center gap-2">
<img
src={getPoiIconUrl(poi.category, poi.emoji, poi.icon_category, poi.name)}
alt=""
aria-hidden="true"
loading="lazy"
referrerPolicy="no-referrer"
className="h-5 w-5 shrink-0 rounded-[4px] bg-white object-contain p-0.5"
/>
<div className="min-w-0">
<div className="font-semibold dark:text-warm-100">{poi.name}</div>
<div className="flex items-center gap-1.5 text-xs text-warm-500 dark:text-warm-400">
<span
className="inline-block w-2 h-2 rounded-full flex-shrink-0"
style={{
backgroundColor: `rgb(${getPoiGroupColor(poi.group).join(',')})`,
}}
/>
{ts(poi.category)}
</div>
</div>
</div>
{poi.school && renderSchoolMetadata(poi.school)}
</div>
);
});

View file

@ -0,0 +1,131 @@
import { useCallback, useRef, useState } from 'react';
import type { MutableRefObject } from 'react';
import type { MapFlyToOptions } from '../../../types';
import type { MapFlyTo } from './types';
export interface PendingFlyTo {
lat: number;
lng: number;
zoom: number;
}
/**
* Mobile drawer / bottom sheet state plus the fly-to plumbing that keeps a
* selected target visible above them. Fly-tos requested while the drawer panel
* hasn't measured itself yet are parked in refs and consumed once the panel
* rect arrives, so the camera lands in the area the drawer leaves uncovered.
*/
export function useMobileDrawer(isMobile: boolean, flyToRef: MutableRefObject<MapFlyTo | null>) {
const [mobileDrawerOpen, setMobileDrawerOpen] = useState(false);
const [mobileBottomSheetHeight, setMobileBottomSheetHeight] = useState(0);
const mobileDrawerPanelRectRef = useRef<DOMRectReadOnly | null>(null);
const pendingCurrentLocationFlyToRef = useRef<{ lat: number; lng: number } | null>(null);
const pendingLocationSearchFlyToRef = useRef<PendingFlyTo | null>(null);
const consumePendingLocationSearchFlyTo = useCallback(
(rect?: DOMRectReadOnly | null) => {
const pending = pendingLocationSearchFlyToRef.current;
const panelRect = rect ?? mobileDrawerPanelRectRef.current;
if (!pending || !panelRect) return;
const bottomInset = Math.max(0, window.innerHeight - panelRect.top);
const flyTo = flyToRef.current;
if (!flyTo) return;
flyTo(pending.lat, pending.lng, pending.zoom, {
visibleViewportArea: { bottom: bottomInset },
});
pendingLocationSearchFlyToRef.current = null;
},
[flyToRef]
);
const consumePendingCurrentLocationFlyTo = useCallback(
(rect?: DOMRectReadOnly | null) => {
const pending = pendingCurrentLocationFlyToRef.current;
const panelRect = rect ?? mobileDrawerPanelRectRef.current;
if (!pending || !panelRect) return;
const bottomInset = Math.max(0, window.innerHeight - panelRect.top);
const flyTo = flyToRef.current;
if (!flyTo) return;
flyTo(pending.lat, pending.lng, 17, {
visibleViewportArea: { bottom: bottomInset },
});
pendingCurrentLocationFlyToRef.current = null;
},
[flyToRef]
);
const openMobileDrawer = useCallback(() => {
setMobileDrawerOpen(true);
}, []);
/** Open the drawer and fly to the searched location once the panel rect is known. */
const openMobileDrawerForLocationSearch = useCallback(
(target: PendingFlyTo) => {
pendingLocationSearchFlyToRef.current = target;
setMobileDrawerOpen(true);
consumePendingLocationSearchFlyTo();
},
[consumePendingLocationSearchFlyTo]
);
const clearPendingLocationSearchFlyTo = useCallback(() => {
pendingLocationSearchFlyToRef.current = null;
}, []);
/** Park a current-location fly-to until the drawer panel has measured itself. */
const queueCurrentLocationFlyTo = useCallback(
(lat: number, lng: number) => {
pendingCurrentLocationFlyToRef.current = { lat, lng };
consumePendingCurrentLocationFlyTo();
},
[consumePendingCurrentLocationFlyTo]
);
const handleMobileDrawerPanelRectChange = useCallback(
(rect: DOMRectReadOnly) => {
mobileDrawerPanelRectRef.current = rect;
consumePendingCurrentLocationFlyTo(rect);
consumePendingLocationSearchFlyTo(rect);
},
[consumePendingCurrentLocationFlyTo, consumePendingLocationSearchFlyTo]
);
const handleMobileDrawerClose = useCallback(() => {
pendingCurrentLocationFlyToRef.current = null;
pendingLocationSearchFlyToRef.current = null;
mobileDrawerPanelRectRef.current = null;
setMobileDrawerOpen(false);
}, []);
const getMobileMapFlyToOptions = useCallback((): MapFlyToOptions | undefined => {
if (!isMobile) return undefined;
const panelRect = mobileDrawerPanelRectRef.current;
if (mobileDrawerOpen && panelRect) {
const bottomInset = Math.max(0, window.innerHeight - panelRect.top);
if (bottomInset > 0) {
return { visibleViewportArea: { bottom: bottomInset } };
}
}
return mobileBottomSheetHeight > 0
? { visibleArea: { bottom: mobileBottomSheetHeight } }
: undefined;
}, [isMobile, mobileBottomSheetHeight, mobileDrawerOpen]);
return {
mobileDrawerOpen,
mobileBottomSheetHeight,
setMobileBottomSheetHeight,
openMobileDrawer,
openMobileDrawerForLocationSearch,
clearPendingLocationSearchFlyTo,
queueCurrentLocationFlyTo,
handleMobileDrawerPanelRectChange,
handleMobileDrawerClose,
getMobileMapFlyToOptions,
};
}

View file

@ -7,11 +7,11 @@ interface SubNavProps {
export function SubNav({ tabs, activeTab, onTabChange }: SubNavProps) { export function SubNav({ tabs, activeTab, onTabChange }: SubNavProps) {
return ( return (
<div className="max-w-5xl mx-auto w-full px-6 pt-4"> <div className="max-w-5xl mx-auto w-full px-6 pt-4">
<div className="flex gap-2 border-b border-warm-200 dark:border-warm-700"> <div className="flex gap-2 overflow-x-auto border-b border-warm-200 dark:border-warm-700">
{tabs.map((tab) => ( {tabs.map((tab) => (
<button <button
key={tab.key} key={tab.key}
className={`cursor-pointer px-4 py-2 text-sm font-medium border-b-2 ${ className={`cursor-pointer shrink-0 whitespace-nowrap px-4 py-2 text-sm font-medium border-b-2 ${
activeTab === tab.key activeTab === tab.key
? 'border-teal-500 text-teal-700 dark:text-teal-400' ? 'border-teal-500 text-teal-700 dark:text-teal-400'
: 'border-transparent text-warm-500 dark:text-warm-400 hover:text-warm-700 dark:hover:text-warm-300' : 'border-transparent text-warm-500 dark:text-warm-400 hover:text-warm-700 dark:hover:text-warm-300'

View file

@ -0,0 +1,43 @@
import { useMemo } from 'react';
const DESKTOP_TOP_CARD_WIDTH = 300;
const DESKTOP_TOP_CARD_GAP = 8;
const DESKTOP_TOP_CARD_HORIZONTAL_INSET = 24;
const DESKTOP_TOP_CARDS_STACKED_MIN_MAP_WIDTH =
DESKTOP_TOP_CARD_WIDTH + DESKTOP_TOP_CARD_HORIZONTAL_INSET;
const DESKTOP_TOP_CARDS_ROW_MIN_MAP_WIDTH =
DESKTOP_TOP_CARD_WIDTH * 2 + DESKTOP_TOP_CARD_GAP + DESKTOP_TOP_CARD_HORIZONTAL_INSET;
interface UseMapCardLayoutOptions {
mapWidth: number;
hideTopCardsWhenNarrow: boolean;
hideLegend: boolean;
hideLocationSearch: boolean;
}
/**
* Desktop top-card layout for the map overlay area: hides the cards entirely
* when the map is too narrow for a single card, and stacks them vertically
* when there is room for one card but not for two side by side.
*/
export function useMapCardLayout({
mapWidth,
hideTopCardsWhenNarrow,
hideLegend,
hideLocationSearch,
}: UseMapCardLayoutOptions) {
return useMemo(() => {
const hideTopCardsForWidth =
hideTopCardsWhenNarrow && mapWidth > 0 && mapWidth < DESKTOP_TOP_CARDS_STACKED_MIN_MAP_WIDTH;
const stackTopCards =
hideTopCardsWhenNarrow &&
mapWidth >= DESKTOP_TOP_CARDS_STACKED_MIN_MAP_WIDTH &&
mapWidth < DESKTOP_TOP_CARDS_ROW_MIN_MAP_WIDTH;
return {
showLocationSearch: !hideLocationSearch && !hideTopCardsForWidth,
showLegend: !hideLegend && !hideTopCardsForWidth,
topCardsLayoutClass: stackTopCards ? 'flex-col items-start' : 'items-start justify-between',
};
}, [mapWidth, hideTopCardsWhenNarrow, hideLegend, hideLocationSearch]);
}

View file

@ -880,6 +880,7 @@ const de: Translations = {
walk: 'Zu Fuß', walk: 'Zu Fuß',
cycle: 'Fahrrad', cycle: 'Fahrrad',
nationalAvg: 'England-Schnitt', nationalAvg: 'England-Schnitt',
crimeDataEnds: 'Polizeidaten für dieses Gebiet enden {{year}}',
}, },
// ── Street View ──────────────────────────────────── // ── Street View ────────────────────────────────────

View file

@ -864,6 +864,7 @@ const en = {
walk: 'Walk', walk: 'Walk',
cycle: 'Cycle', cycle: 'Cycle',
nationalAvg: 'National avg', nationalAvg: 'National avg',
crimeDataEnds: 'Police data for this area ends {{year}}',
}, },
// ── Street View ──────────────────────────────────── // ── Street View ────────────────────────────────────

View file

@ -893,6 +893,7 @@ const fr: Translations = {
walk: 'Marche', walk: 'Marche',
cycle: 'Vélo', cycle: 'Vélo',
nationalAvg: 'Moyenne nationale', nationalAvg: 'Moyenne nationale',
crimeDataEnds: 'Les données de police pour cette zone s\'arrêtent en {{year}}',
}, },
// ── Street View ──────────────────────────────────── // ── Street View ────────────────────────────────────

View file

@ -852,6 +852,7 @@ const hi: Translations = {
walk: 'पैदल', walk: 'पैदल',
cycle: 'साइकिल', cycle: 'साइकिल',
nationalAvg: 'राष्ट्रीय औसत', nationalAvg: 'राष्ट्रीय औसत',
crimeDataEnds: 'इस क्षेत्र के लिए पुलिस डेटा {{year}} में समाप्त होता है',
}, },
streetView: { streetView: {

View file

@ -881,6 +881,7 @@ const hu: Translations = {
walk: 'Gyalog', walk: 'Gyalog',
cycle: 'Kerékpár', cycle: 'Kerékpár',
nationalAvg: 'Országos átlag', nationalAvg: 'Országos átlag',
crimeDataEnds: 'A körzet rendőrségi adatai {{year}}-ig érhetők el',
}, },
// ── Street View ──────────────────────────────────── // ── Street View ────────────────────────────────────

View file

@ -823,6 +823,7 @@ const zh: Translations = {
walk: '步行', walk: '步行',
cycle: '骑行', cycle: '骑行',
nationalAvg: '全国平均', nationalAvg: '全国平均',
crimeDataEnds: '该地区的警方数据截至{{year}}年',
}, },
// ── Street View ──────────────────────────────────── // ── Street View ────────────────────────────────────

View file

@ -303,6 +303,12 @@ export interface HexagonStatsResponse {
price_history?: PricePoint[]; price_history?: PricePoint[];
/** Per-crime-type per-year counts averaged across the selection. */ /** Per-crime-type per-year counts averaged across the selection. */
crime_by_year?: CrimeYearStats[]; crime_by_year?: CrimeYearStats[];
/**
* Latest year in the crime dataset as a whole. A selection whose series end
* earlier sits in a force-level publication gap (e.g. Greater Manchester
* since mid-2019) and its crime figures are captioned as stale.
*/
crime_latest_year?: number;
central_postcode?: string; central_postcode?: string;
filter_exclusions?: FilterExclusion[]; filter_exclusions?: FilterExclusion[];
} }

View file

@ -24,10 +24,11 @@ from pathlib import Path
import numpy as np import numpy as np
import polars as pl import polars as pl
from pipeline.utils.normalize import collapse_whitespace, replace_non_alnum_lower
_NOISE_WORDS = re.compile( _NOISE_WORDS = re.compile(
r"\b(the|of|and|c\s*of\s*e|cofe|ce|rc|voluntary|aided|controlled|va|vc)\b" r"\b(the|of|and|c\s*of\s*e|cofe|ce|rc|voluntary|aided|controlled|va|vc)\b"
) )
_NON_ALNUM = re.compile(r"[^a-z0-9 ]")
_SCHOOL_WORDS = re.compile( _SCHOOL_WORDS = re.compile(
r"\b(school|academy|primary|secondary|junior|infant|community|college|high)\b" r"\b(school|academy|primary|secondary|junior|infant|community|college|high)\b"
) )
@ -35,16 +36,16 @@ _SCHOOL_WORDS = re.compile(
def normalize_name(name: str, strip_school_words: bool = False) -> str: def normalize_name(name: str, strip_school_words: bool = False) -> str:
s = name.lower().replace("&", " and ").replace("st.", "st ").replace("'", "") s = name.lower().replace("&", " and ").replace("st.", "st ").replace("'", "")
s = _NON_ALNUM.sub(" ", s) s = replace_non_alnum_lower(s)
s = _NOISE_WORDS.sub(" ", s) s = _NOISE_WORDS.sub(" ", s)
if strip_school_words: if strip_school_words:
s = _SCHOOL_WORDS.sub(" ", s) s = _SCHOOL_WORDS.sub(" ", s)
return " ".join(s.split()) return collapse_whitespace(s)
def normalize_la(la: str) -> str: def normalize_la(la: str) -> str:
s = _NON_ALNUM.sub(" ", la.lower().replace("&", " and ")) s = replace_non_alnum_lower(la.lower().replace("&", " and "))
return " ".join(s.replace("city of", "").split()) return collapse_whitespace(s.replace("city of", ""))
def load_ground_truth(directory: Path) -> pl.DataFrame: def load_ground_truth(directory: Path) -> pl.DataFrame:

View file

@ -171,43 +171,88 @@ def parse_contained_range(contained_range: str) -> tuple[str, str] | None:
return start, end return start, end
def select_coverage_archives(archives: list[CrimeArchive]) -> list[CrimeArchive]: def _index_to_month(index: int) -> str:
"""Select non-overlapping snapshots that still cover the available history. year, month_num = divmod(index, 12)
if month_num == 0:
year -= 1
month_num = 12
return f"{year:04d}-{month_num:02d}"
def select_coverage_archives(
archives: list[CrimeArchive], *, allow_gaps: bool = False
) -> list[CrimeArchive]:
"""Select snapshots whose ranges chain together to cover the available history.
The source publishes rolling multi-year snapshots. Downloading every monthly The source publishes rolling multi-year snapshots. Downloading every monthly
snapshot mostly fetches duplicate data; for our aggregate LSOA counts we only snapshot mostly fetches duplicate data; for our aggregate LSOA counts we only
need continuous month coverage. need continuous month coverage. Greedy interval cover, newest first: anchor
on the snapshot with the latest end month, then repeatedly take the archive
reaching furthest back among those adjacent to or overlapping the covered
range. Accepting an overlapping snapshot (rather than only an exactly
adjacent one) matters when the adjacent snapshot is missing from the index:
skipping it would leave a multi-month hole, while overlap only costs
download time because extraction skips already-extracted months. A hole no
archive can bridge is a publication gap in the source a hard error unless
``allow_gaps``, since the run would otherwise be stamped complete with
artificial dips in every crime-over-time series.
""" """
selected: list[CrimeArchive] = [] selected: list[CrimeArchive] = []
earliest_covered_start: int | None = None ranged: list[tuple[int, int, CrimeArchive]] = []
for archive in archives:
def sort_key(archive: CrimeArchive) -> int:
parsed_range = parse_contained_range(archive.contained_range)
if parsed_range is not None:
return _month_to_index(parsed_range[1])
return _month_to_index(archive.month)
for archive in sorted(archives, key=sort_key, reverse=True):
parsed_range = parse_contained_range(archive.contained_range) parsed_range = parse_contained_range(archive.contained_range)
if parsed_range is None: if parsed_range is None:
selected.append(archive) selected.append(archive)
continue else:
ranged.append(
start, end = parsed_range (
start_index = _month_to_index(start) _month_to_index(parsed_range[0]),
end_index = _month_to_index(end) _month_to_index(parsed_range[1]),
if earliest_covered_start is None or end_index < earliest_covered_start: archive,
if (
earliest_covered_start is not None
and end_index < earliest_covered_start - 1
):
print(
"Warning: archive ranges are not adjacent; "
f"coverage gap before {archive.filename}",
file=sys.stderr,
) )
selected.append(archive) )
earliest_covered_start = start_index
earliest_covered_start: int | None = None
while True:
if earliest_covered_start is None:
eligible = ranged
else:
eligible = [item for item in ranged if item[0] < earliest_covered_start]
if not eligible:
break
if earliest_covered_start is None:
# Anchor: latest end month, reaching as far back as available.
start_index, _, archive = max(
eligible, key=lambda item: (item[1], -item[0])
)
else:
chained = [
item for item in eligible if item[1] >= earliest_covered_start - 1
]
if not chained:
hole_start = max(item[1] for item in eligible) + 1
message = (
"no archive covers "
f"{_index_to_month(hole_start)} to "
f"{_index_to_month(earliest_covered_start - 1)}"
)
if not allow_gaps:
raise RuntimeError(
f"Coverage gap: {message}. Rerun with --allow-gaps to "
"accept the hole."
)
print(f"Warning: coverage gap: {message}", file=sys.stderr)
chained = eligible
# Furthest backward reach; on a start tie prefer the newer
# snapshot, whose data for the months around the boundary carries
# the latest revisions.
start_index, _, archive = min(
chained, key=lambda item: (item[0], -item[1])
)
selected.append(archive)
earliest_covered_start = start_index
return selected return selected
@ -331,14 +376,24 @@ def extract_csvs(
*, *,
overwrite: bool = False, overwrite: bool = False,
street_only: bool = True, street_only: bool = True,
extracted_this_run: set[PurePosixPath] | None = None,
) -> tuple[int, int]: ) -> tuple[int, int]:
"""Extract CSVs from one ZIP. Returns (extracted, skipped).""" """Extract CSVs from one ZIP. Returns (extracted, skipped).
``extracted_this_run`` is shared across the archives of one run, processed
newest-snapshot first: a member already written by a newer snapshot is
skipped even with ``overwrite``, so an older overlapping archive can never
replace a month with a less-revised copy.
"""
extracted = 0 extracted = 0
skipped = 0 skipped = 0
with zipfile.ZipFile(zip_path) as archive: with zipfile.ZipFile(zip_path) as archive:
for info, rel_path in _safe_csv_members(archive, street_only=street_only): for info, rel_path in _safe_csv_members(archive, street_only=street_only):
dest = output_dir.joinpath(*rel_path.parts) dest = output_dir.joinpath(*rel_path.parts)
if extracted_this_run is not None and rel_path in extracted_this_run:
skipped += 1
continue
if dest.exists() and not overwrite: if dest.exists() and not overwrite:
skipped += 1 skipped += 1
continue continue
@ -347,6 +402,8 @@ def extract_csvs(
with archive.open(info) as source, dest.open("wb") as target: with archive.open(info) as source, dest.open("wb") as target:
shutil.copyfileobj(source, target) shutil.copyfileobj(source, target)
extracted += 1 extracted += 1
if extracted_this_run is not None:
extracted_this_run.add(rel_path)
return extracted, skipped return extracted, skipped
@ -489,8 +546,22 @@ def main() -> None:
) )
parser.add_argument( parser.add_argument(
"--overwrite-extracted", "--overwrite-extracted",
action=argparse.BooleanOptionalAction,
default=True,
help=(
"Replace previously extracted CSVs with this run's snapshot data "
"(police.uk revises the trailing 36 months in every release, so "
"keeping old extractions freezes stale revisions; within a run the "
"newest snapshot still wins for overlapping months)"
),
)
parser.add_argument(
"--allow-gaps",
action="store_true", action="store_true",
help="Overwrite CSVs when extracting overlapping archive snapshots", help=(
"Continue past months no archive covers instead of failing "
"(coverage strategy only)"
),
) )
parser.add_argument( parser.add_argument(
"--no-verify", "--no-verify",
@ -521,7 +592,7 @@ def main() -> None:
limit=args.limit, limit=args.limit,
) )
archives = ( archives = (
select_coverage_archives(available_archives) select_coverage_archives(available_archives, allow_gaps=args.allow_gaps)
if args.archive_strategy == "coverage" if args.archive_strategy == "coverage"
else available_archives else available_archives
) )
@ -570,6 +641,7 @@ def main() -> None:
total_extracted = 0 total_extracted = 0
total_skipped = 0 total_skipped = 0
extracted_this_run: set[PurePosixPath] = set()
for index, archive in enumerate(archives, start=1): for index, archive in enumerate(archives, start=1):
print(f"[{index}/{len(archives)}] {archive.label} ({archive.size})") print(f"[{index}/{len(archives)}] {archive.label} ({archive.size})")
zip_path = download_archive( zip_path = download_archive(
@ -585,6 +657,7 @@ def main() -> None:
args.output, args.output,
overwrite=args.overwrite_extracted, overwrite=args.overwrite_extracted,
street_only=street_only, street_only=street_only,
extracted_this_run=extracted_this_run,
) )
total_extracted += extracted total_extracted += extracted
total_skipped += skipped total_skipped += skipped

View file

@ -16,12 +16,12 @@ License: Open Government Licence v3.0
""" """
import argparse import argparse
from io import BytesIO
from pathlib import Path from pathlib import Path
import httpx
import polars as pl import polars as pl
from pipeline.utils import ENGLAND_LSOA_COUNT_2021, download_nomis_csv
pl.Config.set_tbl_cols(-1) pl.Config.set_tbl_cols(-1)
# NOMIS API: Census 2021 TS021 (ethnic group, 20 categories) by LSOA 2021 # NOMIS API: Census 2021 TS021 (ethnic group, 20 categories) by LSOA 2021
@ -35,7 +35,6 @@ BASE_URL = (
"&measures=20100" "&measures=20100"
"&select=GEOGRAPHY_CODE,C2021_ETH_20_NAME,OBS_VALUE" "&select=GEOGRAPHY_CODE,C2021_ETH_20_NAME,OBS_VALUE"
) )
PAGE_SIZE = 25000
# Map the 19 detailed NOMIS C2021_ETH_20 leaf categories to our 7 output groups. # Map the 19 detailed NOMIS C2021_ETH_20 leaf categories to our 7 output groups.
# The Asian split: # The Asian split:
@ -150,24 +149,7 @@ def _ethnicity_percentages(df: pl.DataFrame) -> pl.DataFrame:
def download_and_convert(output_path: Path) -> None: def download_and_convert(output_path: Path) -> None:
print("Downloading Census 2021 ethnic group (TS021) by LSOA from NOMIS...") print("Downloading Census 2021 ethnic group (TS021) by LSOA from NOMIS...")
frames = [] df = download_nomis_csv(BASE_URL)
offset = 0
while True:
url = f"{BASE_URL}&recordoffset={offset}"
response = httpx.get(url, follow_redirects=True, timeout=120)
response.raise_for_status()
if len(response.content) == 0:
break
chunk = pl.read_csv(BytesIO(response.content))
if chunk.height == 0:
break
frames.append(chunk)
print(f" Fetched {chunk.height} rows (offset={offset})")
if chunk.height < PAGE_SIZE:
break
offset += PAGE_SIZE
df = pl.concat(frames)
print(f"Total rows: {df.height}") print(f"Total rows: {df.height}")
# Filter to England only (E-prefixed LSOA codes); the merge joins on the # Filter to England only (E-prefixed LSOA codes); the merge joins on the
@ -177,6 +159,11 @@ def download_and_convert(output_path: Path) -> None:
wide = _ethnicity_percentages(df) wide = _ethnicity_percentages(df)
print(f"England LSOAs: {wide.height}") print(f"England LSOAs: {wide.height}")
if wide.height != ENGLAND_LSOA_COUNT_2021:
raise ValueError(
f"Expected {ENGLAND_LSOA_COUNT_2021} England LSOAs, "
f"got {wide.height}: truncated NOMIS download?"
)
print(f"Columns: {wide.columns}") print(f"Columns: {wide.columns}")
output_path.parent.mkdir(parents=True, exist_ok=True) output_path.parent.mkdir(parents=True, exist_ok=True)

View file

@ -241,9 +241,11 @@ def transform(zip_bytes: bytes) -> pl.DataFrame:
"""Convert the GIAS extract ZIP into a clean schools DataFrame.""" """Convert the GIAS extract ZIP into a clean schools DataFrame."""
raw = _read_csv_from_zip(zip_bytes) raw = _read_csv_from_zip(zip_bytes)
# Filter to currently-open establishments; the CSV also includes closed, # Filter to currently-open establishments; the CSV also includes closed and
# proposed-to-open, and proposed-to-close rows we do not want on a map. # proposed-to-open rows we do not want on a map. "Open, but proposed to
df = raw.filter(pl.col("EstablishmentStatus (name)") == "Open") # close" schools are open, operating establishments (GIAS can keep that
# status for years, e.g. pending amalgamations), so they must stay.
df = raw.filter(pl.col("EstablishmentStatus (name)").str.starts_with("Open"))
df = df.with_columns( df = df.with_columns(
pl.col("URN").cast(pl.Int64), pl.col("URN").cast(pl.Int64),

View file

@ -13,7 +13,7 @@ from pathlib import Path
import osmium import osmium
import polars as pl import polars as pl
from pyproj import Transformer from pyproj import Transformer
from shapely import wkb from shapely import make_valid, wkb
from shapely.errors import GEOSException from shapely.errors import GEOSException
from shapely.geometry import MultiPolygon, Polygon from shapely.geometry import MultiPolygon, Polygon
from tqdm import tqdm from tqdm import tqdm
@ -56,6 +56,22 @@ def _to_bng_polygon(geom):
return geom return geom
def _polygonal_part(geom):
"""The Polygon/MultiPolygon content of a geometry, or None if there is none."""
if geom.geom_type in ("Polygon", "MultiPolygon"):
return geom
if geom.geom_type == "GeometryCollection":
polygons = []
for part in geom.geoms:
if part.geom_type == "Polygon":
polygons.append(part)
elif part.geom_type == "MultiPolygon":
polygons.extend(part.geoms)
if polygons:
return MultiPolygon(polygons)
return None
def _matches_tags(tags): def _matches_tags(tags):
"""Check if an OSM element's tags match our greenspace/water criteria.""" """Check if an OSM element's tags match our greenspace/water criteria."""
for key, values in GREENSPACE_TAGS.items(): for key, values in GREENSPACE_TAGS.items():
@ -91,7 +107,13 @@ class GreenspaceHandler(osmium.SimpleHandler):
) )
return return
if geom.is_empty or not geom.is_valid: # Invalid geometries are often the largest, most complex park/water
# multipolygons (self-touching rings from OSM) — repair like pois.py
# rather than silently dropping them. make_valid may return a
# GeometryCollection with stray lines/points; keep only the polygons.
if not geom.is_valid:
geom = _polygonal_part(make_valid(geom))
if geom is None or geom.is_empty:
return return
# Reproject to BNG for area calculation # Reproject to BNG for area calculation

View file

@ -5,18 +5,14 @@ License: Open Government Licence v3.0
""" """
import argparse import argparse
import time
from pathlib import Path from pathlib import Path
import httpx from pipeline.utils import download_arcgis_hub_export
import pyogrio
URL = ( URL = (
"https://opendata-historicengland.hub.arcgis.com/api/download/v1/items/" "https://opendata-historicengland.hub.arcgis.com/api/download/v1/items/"
"767f279327a24845bf47dfe5eae9862b/geoPackage?layers=0" "767f279327a24845bf47dfe5eae9862b/geoPackage?layers=0"
) )
POLL_INTERVAL_S = 5
POLL_TIMEOUT_S = 600
def main() -> None: def main() -> None:
@ -28,37 +24,9 @@ def main() -> None:
) )
args = parser.parse_args() args = parser.parse_args()
args.output.parent.mkdir(parents=True, exist_ok=True) args.output.parent.mkdir(parents=True, exist_ok=True)
tmp_path = args.output.with_name(f"{args.output.stem}.tmp{args.output.suffix}")
print("Downloading Historic England listed-building points...") print("Downloading Historic England listed-building points...")
deadline = time.monotonic() + POLL_TIMEOUT_S features = download_arcgis_hub_export(URL, args.output, expected_geometry="Point")
with httpx.Client(follow_redirects=True, timeout=300) as client:
while True:
with client.stream("GET", URL) as response:
if response.status_code == 202:
response.read()
if time.monotonic() > deadline:
raise TimeoutError(
f"Export did not finish within {POLL_TIMEOUT_S}s: "
f"{response.text}"
)
time.sleep(POLL_INTERVAL_S)
continue
response.raise_for_status()
with tmp_path.open("wb") as fh:
for chunk in response.iter_bytes():
fh.write(chunk)
break
info = pyogrio.read_info(tmp_path)
features = info.get("features", 0)
geometry_type = str(info.get("geometry_type") or "")
if features <= 0:
raise ValueError("Downloaded listed-buildings file contains no features")
if "Point" not in geometry_type:
raise ValueError(f"Expected point geometry, got {geometry_type!r}")
tmp_path.replace(args.output)
size_mb = args.output.stat().st_size / (1024 * 1024) size_mb = args.output.stat().st_size / (1024 * 1024)
print( print(
f"Saved {features} listed-building points to {args.output} ({size_mb:.1f} MB)" f"Saved {features} listed-building points to {args.output} ({size_mb:.1f} MB)"

View file

@ -10,21 +10,19 @@ of the 0-4, 10-14 and 15-19 bands (one fifth per single year of age).
""" """
import argparse import argparse
from io import BytesIO
from pathlib import Path from pathlib import Path
import httpx
import polars as pl import polars as pl
from pipeline.utils import ENGLAND_LSOA_COUNT_2021, download_nomis_csv
# NOMIS API: Census 2021 TS007A (age, five-year bands) by LSOA 2021 (TYPE151). # NOMIS API: Census 2021 TS007A (age, five-year bands) by LSOA 2021 (TYPE151).
# c2021_age_19 codes: 1 = 0-4, 2 = 5-9, 3 = 10-14, 4 = 15-19. # c2021_age_19 codes: 1 = 0-4, 2 = 5-9, 3 = 10-14, 4 = 15-19.
# NOMIS paginates at 25,000 rows by default, so we paginate with recordoffset.
BASE_URL = ( BASE_URL = (
"https://www.nomisweb.co.uk/api/v01/dataset/NM_2020_1.data.csv" "https://www.nomisweb.co.uk/api/v01/dataset/NM_2020_1.data.csv"
"?date=latest&geography=TYPE151&measures=20100&c2021_age_19=1,2,3,4" "?date=latest&geography=TYPE151&measures=20100&c2021_age_19=1,2,3,4"
"&select=GEOGRAPHY_CODE,C2021_AGE_19,OBS_VALUE" "&select=GEOGRAPHY_CODE,C2021_AGE_19,OBS_VALUE"
) )
PAGE_SIZE = 25000
AGE_BAND_COLUMNS = { AGE_BAND_COLUMNS = {
1: "aged_0_4", 1: "aged_0_4",
@ -36,24 +34,7 @@ AGE_BAND_COLUMNS = {
def download_and_convert(output_path: Path) -> None: def download_and_convert(output_path: Path) -> None:
print("Downloading Census 2021 LSOA age bands from NOMIS...") print("Downloading Census 2021 LSOA age bands from NOMIS...")
frames = [] df = download_nomis_csv(BASE_URL)
offset = 0
while True:
url = f"{BASE_URL}&recordoffset={offset}"
response = httpx.get(url, follow_redirects=True, timeout=120)
response.raise_for_status()
if len(response.content) == 0:
break
chunk = pl.read_csv(BytesIO(response.content))
if chunk.height == 0:
break
frames.append(chunk)
print(f" Fetched {chunk.height} rows (offset={offset})")
if chunk.height < PAGE_SIZE:
break
offset += PAGE_SIZE
df = pl.concat(frames)
print(f"Total rows: {df.height}") print(f"Total rows: {df.height}")
result = ( result = (
@ -70,6 +51,11 @@ def download_and_convert(output_path: Path) -> None:
raise ValueError(f"NOMIS response missing age bands: {missing}") raise ValueError(f"NOMIS response missing age bands: {missing}")
print(f"England LSOAs: {result.height}") print(f"England LSOAs: {result.height}")
if result.height != ENGLAND_LSOA_COUNT_2021:
raise ValueError(
f"Expected {ENGLAND_LSOA_COUNT_2021} England LSOAs, "
f"got {result.height}: truncated NOMIS download?"
)
for name in AGE_BAND_COLUMNS.values(): for name in AGE_BAND_COLUMNS.values():
print(f" {name}: total {result[name].sum():,}") print(f" {name}: total {result[name].sum():,}")

View file

@ -5,39 +5,20 @@ License: Open Government Licence v3.0
""" """
import argparse import argparse
from io import BytesIO
from pathlib import Path from pathlib import Path
import httpx
import polars as pl import polars as pl
from pipeline.utils import ENGLAND_LSOA_COUNT_2021, download_nomis_csv
# NOMIS API: Census 2021 TS001 (usual residents) by LSOA 2021 (TYPE151) # NOMIS API: Census 2021 TS001 (usual residents) by LSOA 2021 (TYPE151)
# c2021_restype_3=0 selects "Total: All usual residents" # c2021_restype_3=0 selects "Total: All usual residents"
# NOMIS paginates at 25,000 rows by default, so we paginate with recordoffset.
BASE_URL = "https://www.nomisweb.co.uk/api/v01/dataset/NM_2021_1.data.csv?date=latest&geography=TYPE151&measures=20100&c2021_restype_3=0&select=GEOGRAPHY_CODE,OBS_VALUE" BASE_URL = "https://www.nomisweb.co.uk/api/v01/dataset/NM_2021_1.data.csv?date=latest&geography=TYPE151&measures=20100&c2021_restype_3=0&select=GEOGRAPHY_CODE,OBS_VALUE"
PAGE_SIZE = 25000
def download_and_convert(output_path: Path) -> None: def download_and_convert(output_path: Path) -> None:
print("Downloading Census 2021 LSOA population from NOMIS...") print("Downloading Census 2021 LSOA population from NOMIS...")
frames = [] df = download_nomis_csv(BASE_URL)
offset = 0
while True:
url = f"{BASE_URL}&recordoffset={offset}"
response = httpx.get(url, follow_redirects=True, timeout=120)
response.raise_for_status()
if len(response.content) == 0:
break
chunk = pl.read_csv(BytesIO(response.content))
if chunk.height == 0:
break
frames.append(chunk)
print(f" Fetched {chunk.height} rows (offset={offset})")
if chunk.height < PAGE_SIZE:
break
offset += PAGE_SIZE
df = pl.concat(frames)
print(f"Total rows: {df.height}") print(f"Total rows: {df.height}")
result = df.rename( result = df.rename(
@ -50,6 +31,11 @@ def download_and_convert(output_path: Path) -> None:
result = result.filter(pl.col("lsoa21").str.starts_with("E")) result = result.filter(pl.col("lsoa21").str.starts_with("E"))
print(f"England LSOAs: {result.height}") print(f"England LSOAs: {result.height}")
if result.height != ENGLAND_LSOA_COUNT_2021:
raise ValueError(
f"Expected {ENGLAND_LSOA_COUNT_2021} England LSOAs, "
f"got {result.height}: truncated NOMIS download?"
)
print( print(
f"Population range: {result['population'].min()} - {result['population'].max()}" f"Population range: {result['population'].min()} - {result['population'].max()}"
) )

View file

@ -3,6 +3,7 @@ import base64
import json import json
import re import re
import sys import sys
import time
import urllib.request import urllib.request
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from io import BytesIO from io import BytesIO
@ -120,18 +121,29 @@ def collect_twemoji_codes() -> list[str]:
return sorted({f"{ord(e[0]):x}" for e in emojis}) return sorted({f"{ord(e[0]):x}" for e in emojis})
DOWNLOAD_ATTEMPTS = 3
RETRY_BACKOFF_S = 2.0
def download_file(url: str, dest: Path) -> tuple[bool, str]: def download_file(url: str, dest: Path) -> tuple[bool, str]:
"""Download a single file. Returns (success, url).""" """Download a single file, retrying transient errors. Returns (success, url)."""
dest.parent.mkdir(parents=True, exist_ok=True) dest.parent.mkdir(parents=True, exist_ok=True)
try: for attempt in range(DOWNLOAD_ATTEMPTS):
urllib.request.urlretrieve(url, dest) if attempt:
return True, url time.sleep(RETRY_BACKOFF_S * 2 ** (attempt - 1))
except urllib.error.HTTPError as e: try:
print(f" {e.code} {url}", file=sys.stderr) urllib.request.urlretrieve(url, dest)
return False, url return True, url
except Exception as e: except urllib.error.HTTPError as e:
print(f" ERROR {url}: {e}", file=sys.stderr) # 4xx is a permanent answer (bad glyph range / missing emoji);
return False, url # retrying won't change it.
if 400 <= e.code < 500:
print(f" {e.code} {url}", file=sys.stderr)
return False, url
print(f" {e.code} {url} (attempt {attempt + 1})", file=sys.stderr)
except Exception as e:
print(f" ERROR {url}: {e} (attempt {attempt + 1})", file=sys.stderr)
return False, url
def download_text(url: str) -> str: def download_text(url: str) -> str:
@ -389,37 +401,38 @@ def main():
url = f"{POI_ICON_BASE}/{icon_path}" url = f"{POI_ICON_BASE}/{icon_path}"
tasks.append((url, poi_icons_dir / icon_path)) tasks.append((url, poi_icons_dir / icon_path))
# Skip already-downloaded files print(f"Downloading {len(tasks) + len(DERIVED_POI_ICON_PATHS)} assets")
remaining = [(url, dest) for url, dest in tasks]
print(f"Downloading {len(remaining) + len(DERIVED_POI_ICON_PATHS)} assets")
ok = 0 ok = 0
fail = 0 failed_urls: list[str] = []
with ThreadPoolExecutor(max_workers=20) as pool: with ThreadPoolExecutor(max_workers=20) as pool:
futures = { futures = {pool.submit(download_file, url, dest): url for url, dest in tasks}
pool.submit(download_file, url, dest): url for url, dest in remaining
}
for future in as_completed(futures): for future in as_completed(futures):
success, url = future.result() success, url = future.result()
if success: if success:
ok += 1 ok += 1
else: else:
fail += 1 failed_urls.append(url)
for kind, source_path, dest_path in DERIVED_POI_ICON_PATHS: for kind, source_path, dest_path in DERIVED_POI_ICON_PATHS:
success, _url = download_derived_poi_icon( success, url = download_derived_poi_icon(
kind, source_path, poi_icons_dir / dest_path kind, source_path, poi_icons_dir / dest_path
) )
if success: if success:
ok += 1 ok += 1
else: else:
fail += 1 failed_urls.append(url)
crop_poi_svg_icons(poi_icons_dir) crop_poi_svg_icons(poi_icons_dir)
inject_townhall_sprite(sprites_dir) inject_townhall_sprite(sprites_dir)
print(f"Done: {ok} downloaded, {fail} failed") print(f"Done: {ok} downloaded, {len(failed_urls)} failed")
if failed_urls:
# A partial asset bundle (missing font ranges, sprites, icons) renders
# broken labels at runtime but would otherwise satisfy the make stamp.
for url in failed_urls:
print(f" missing: {url}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -8,17 +8,16 @@ License: Open Government Licence v3.0
""" """
import argparse import argparse
from io import BytesIO
from pathlib import Path from pathlib import Path
import httpx
import polars as pl import polars as pl
from pipeline.utils import ENGLAND_LSOA_COUNT_2021, download_nomis_csv
# NOMIS API: Census 2021 TS007A (age by five-year bands) by LSOA 2021 (TYPE151) # NOMIS API: Census 2021 TS007A (age by five-year bands) by LSOA 2021 (TYPE151)
# c2021_age_19=1..18 selects 18 five-year bands (excluding 0 = Total) # c2021_age_19=1..18 selects 18 five-year bands (excluding 0 = Total)
# measures=20100 selects absolute count # measures=20100 selects absolute count
BASE_URL = "https://www.nomisweb.co.uk/api/v01/dataset/NM_2020_1.data.csv?date=latest&geography=TYPE151&c2021_age_19=1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18&measures=20100&select=GEOGRAPHY_CODE,C2021_AGE_19_NAME,OBS_VALUE" BASE_URL = "https://www.nomisweb.co.uk/api/v01/dataset/NM_2020_1.data.csv?date=latest&geography=TYPE151&c2021_age_19=1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18&measures=20100&select=GEOGRAPHY_CODE,C2021_AGE_19_NAME,OBS_VALUE"
PAGE_SIZE = 25000
# Five-year age bands in order, with lower bounds for interpolation. # Five-year age bands in order, with lower bounds for interpolation.
# The last band (85+) is open-ended — we treat it as 85-89 for median purposes. # The last band (85+) is open-ended — we treat it as 85-89 for median purposes.
@ -161,24 +160,7 @@ def _bands_to_median_table(pivoted: pl.DataFrame) -> pl.DataFrame:
def download_and_convert(output_path: Path) -> None: def download_and_convert(output_path: Path) -> None:
print("Downloading Census 2021 age by five-year bands from NOMIS...") print("Downloading Census 2021 age by five-year bands from NOMIS...")
frames = [] df = download_nomis_csv(BASE_URL)
offset = 0
while True:
url = f"{BASE_URL}&recordoffset={offset}"
response = httpx.get(url, follow_redirects=True, timeout=120)
response.raise_for_status()
if len(response.content) == 0:
break
chunk = pl.read_csv(BytesIO(response.content))
if chunk.height == 0:
break
frames.append(chunk)
print(f" Fetched {chunk.height} rows (offset={offset})")
if chunk.height < PAGE_SIZE:
break
offset += PAGE_SIZE
df = pl.concat(frames)
print(f"Total rows: {df.height}") print(f"Total rows: {df.height}")
# Filter to England only # Filter to England only
@ -194,6 +176,11 @@ def download_and_convert(output_path: Path) -> None:
result = _bands_to_median_table(pivoted) result = _bands_to_median_table(pivoted)
print(f"England LSOAs: {result.height}") print(f"England LSOAs: {result.height}")
if result.height != ENGLAND_LSOA_COUNT_2021:
raise ValueError(
f"Expected {ENGLAND_LSOA_COUNT_2021} England LSOAs, "
f"got {result.height}: truncated NOMIS download?"
)
print( print(
f"Median age range: {result['median_age'].min()} - {result['median_age'].max()}" f"Median age range: {result['median_age'].min()} - {result['median_age'].max()}"
) )

View file

@ -181,6 +181,27 @@ def canonical_station_name(name: str | None) -> str:
return " ".join(words) return " ".join(words)
_QUALIFIER_RE = re.compile(r"\(([^)]*)\)")
def station_name_qualifier(name: str | None) -> str:
"""The canonicalized parenthetical of a station name, e.g. "Edgware Road
(Bakerloo)" -> "bakerloo".
Genuinely distinct same-named stations (the two Edgware Roads ~150m apart,
Hammersmith's two stations) differ ONLY by this parenthetical, which
`canonical_station_name` strips; it must block their merge while still
letting unqualified entrance/variant rows collapse into either.
"""
if not name:
return ""
parts = _QUALIFIER_RE.findall(name)
if not parts:
return ""
text = " ".join(parts).lower().replace("&", " and ")
return re.sub(r"[^a-z0-9]+", " ", text).strip()
def canonical_station_name_expr(name_col: str = "name") -> pl.Expr: def canonical_station_name_expr(name_col: str = "name") -> pl.Expr:
"""Normalize station names so entrances/transport-mode variants collapse.""" """Normalize station names so entrances/transport-mode variants collapse."""
expr = pl.col(name_col).str.to_lowercase() expr = pl.col(name_col).str.to_lowercase()
@ -246,6 +267,7 @@ class StationAccumulator:
entrance: bool = False entrance: bool = False
is_lu: bool = False is_lu: bool = False
count: int = 1 count: int = 1
qualifier: str = ""
@property @property
def lat(self) -> float: def lat(self) -> float:
@ -260,6 +282,11 @@ class StationAccumulator:
dlng = (self.lng - lng) * math.cos(math.radians(self.lat)) dlng = (self.lng - lng) * math.cos(math.radians(self.lat))
return (dlat * dlat + dlng * dlng) <= TUBE_STATION_MERGE_RADIUS_DEGREES**2 return (dlat * dlat + dlng * dlng) <= TUBE_STATION_MERGE_RADIUS_DEGREES**2
def qualifier_compatible(self, qualifier: str) -> bool:
# Conflicting parentheticals mark distinct same-named stations; an
# unqualified row can join either group.
return not qualifier or not self.qualifier or qualifier == self.qualifier
def merge(self, row: dict[str, object]) -> None: def merge(self, row: dict[str, object]) -> None:
self.lat_sum += float(row["lat"]) self.lat_sum += float(row["lat"])
self.lng_sum += float(row["lng"]) self.lng_sum += float(row["lng"])
@ -267,14 +294,28 @@ class StationAccumulator:
self.is_lu = self.is_lu or bool(row.get("is_lu")) self.is_lu = self.is_lu or bool(row.get("is_lu"))
name = str(row["name"] or "") name = str(row["name"] or "")
row_qualifier = station_name_qualifier(name)
self.qualifier = self.qualifier or row_qualifier
entrance = bool(row.get("entrance")) entrance = bool(row.get("entrance"))
if station_name_score(name, entrance) < station_name_score( # Prefer a display name carrying the group's disambiguating
self.name, self.entrance # parenthetical: without it the two Edgware Roads would both render as
): # the bare "Edgware Road Underground Station".
candidate = (
self._qualifier_penalty(row_qualifier),
*station_name_score(name, entrance),
)
current = (
self._qualifier_penalty(station_name_qualifier(self.name)),
*station_name_score(self.name, self.entrance),
)
if candidate < current:
self.id = str(row["id"] or "") self.id = str(row["id"] or "")
self.name = name self.name = name
self.entrance = entrance self.entrance = entrance
def _qualifier_penalty(self, name_qualifier: str) -> int:
return int(bool(self.qualifier) and name_qualifier != self.qualifier)
@property @property
def output_category(self) -> str: def output_category(self) -> str:
# A merged tram/metro station is a genuine Tube station when ANY of its # A merged tram/metro station is a genuine Tube station when ANY of its
@ -295,6 +336,7 @@ def _station_from_row(row: dict[str, object]) -> StationAccumulator:
lng_sum=float(row["lng"]), lng_sum=float(row["lng"]),
entrance=bool(row.get("entrance")), entrance=bool(row.get("entrance")),
is_lu=bool(row.get("is_lu")), is_lu=bool(row.get("is_lu")),
qualifier=station_name_qualifier(str(row["name"] or "")),
) )
@ -314,11 +356,13 @@ def _deduplicate_station_areas(df: pl.DataFrame) -> pl.DataFrame:
selected.append(_station_from_row(row)) selected.append(_station_from_row(row))
continue continue
row_qualifier = station_name_qualifier(str(row["name"] or ""))
existing = next( existing = next(
( (
index index
for index in groups.get(station_key, []) for index in groups.get(station_key, [])
if selected[index].same_area(float(row["lat"]), float(row["lng"])) if selected[index].same_area(float(row["lat"]), float(row["lng"]))
and selected[index].qualifier_compatible(row_qualifier)
), ),
None, None,
) )

View file

@ -10,7 +10,7 @@ License: Open Government Licence v3.0
import argparse import argparse
from pathlib import Path from pathlib import Path
from pipeline.utils import download from pipeline.utils import download_arcgis_hub_export
URL = "https://open-geography-portalx-ons.hub.arcgis.com/api/download/v1/items/6beafcfd9b9c4c9993a06b6b199d7e6d/geoPackage?layers=0" URL = "https://open-geography-portalx-ons.hub.arcgis.com/api/download/v1/items/6beafcfd9b9c4c9993a06b6b199d7e6d/geoPackage?layers=0"
@ -28,8 +28,10 @@ def main() -> None:
args = parser.parse_args() args = parser.parse_args()
args.output.parent.mkdir(parents=True, exist_ok=True) args.output.parent.mkdir(parents=True, exist_ok=True)
download(URL, args.output, timeout=600) features = download_arcgis_hub_export(
print(f"Saved to {args.output}") URL, args.output, expected_geometry="Polygon"
)
print(f"Saved {features} OA boundary polygons to {args.output}")
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -329,16 +329,24 @@ def _outcode_of_postcode(postcode: str) -> str:
def _outcode_tree(postcodes_path: Path) -> tuple[cKDTree, list[str]]: def _outcode_tree(postcodes_path: Path) -> tuple[cKDTree, list[str]]:
"""Build a nearest-neighbour index from postcode coordinates to their outcode, so each """Build a nearest-neighbour index from postcode coordinates to their outcode, so each
street can be tagged with the outcode it sits in (used to disambiguate same-named roads).""" street can be tagged with the outcode it sits in (used to disambiguate same-named roads).
The tree lives in BNG metres (like `_london_postcode_tree`): in raw degrees
1° of longitude is only ~0.6° of latitude at UK latitudes, which biases
nearest-postcode picks E-W near outcode boundaries."""
df = ( df = (
pl.read_parquet( pl.read_parquet(
postcodes_path, columns=["pcds", "lat", "long", "ctry25cd", "doterm"] postcodes_path,
columns=["pcds", "east1m", "north1m", "ctry25cd", "doterm"],
) )
.filter((pl.col("ctry25cd") == ENGLAND_COUNTRY_CODE) & pl.col("doterm").is_null()) .filter((pl.col("ctry25cd") == ENGLAND_COUNTRY_CODE) & pl.col("doterm").is_null())
.filter(_valid_wgs84_expr()) .filter(_valid_bng_expr())
) )
coords = np.column_stack( coords = np.column_stack(
[df["lat"].to_numpy().astype(np.float64), df["long"].to_numpy().astype(np.float64)] [
df["east1m"].to_numpy().astype(np.float64),
df["north1m"].to_numpy().astype(np.float64),
]
) )
outcodes = [_outcode_of_postcode(pc) for pc in df["pcds"].to_list()] outcodes = [_outcode_of_postcode(pc) for pc in df["pcds"].to_list()]
return cKDTree(coords), outcodes return cKDTree(coords), outcodes
@ -354,8 +362,10 @@ def _build_street_places(
if not streets: if not streets:
return [] return []
coords = np.array([[street["lat"], street["lon"]] for street in streets], dtype=np.float64) lons = np.array([street["lon"] for street in streets], dtype=np.float64)
_, indices = tree.query(coords) lats = np.array([street["lat"] for street in streets], dtype=np.float64)
eastings, northings = WGS84_TO_BNG.transform(lons, lats)
_, indices = tree.query(np.column_stack([eastings, northings]))
grouped: dict[tuple[str, str], dict] = {} grouped: dict[tuple[str, str], dict] = {}
for street, postcode_idx in zip(streets, indices): for street, postcode_idx in zip(streets, indices):

View file

@ -30,13 +30,31 @@ AREA_CODE_ALIASES = {
} }
def _data_rows(df: pl.DataFrame) -> pl.DataFrame:
"""Rows below Table 1's header row.
The preamble length varies (title, optional "This worksheet contains..."
note, then the header row starting with "Time period"), so locate the
header by content instead of counting rows a fixed slice leaves the
header in the data whenever ONS adds or removes a note line.
"""
header_marker = (
pl.col("column_1").cast(pl.String).str.strip_chars().str.to_lowercase()
== "time period"
)
header_rows = df.with_row_index("_row").filter(header_marker)
if header_rows.is_empty():
raise ValueError("PIPR Table 1: no 'Time period' header row found")
return df.slice(int(header_rows["_row"][0]) + 1)
def _latest_rents_long(df: pl.DataFrame) -> pl.DataFrame: def _latest_rents_long(df: pl.DataFrame) -> pl.DataFrame:
# Table 1 layout: row 0 = title, row 1 = column headers, row 2+ = data. # Table 1 layout below the header: 40 columns in repeating blocks of 4
# 40 columns in repeating blocks of 4 (index, monthly change, annual change, # (index, monthly change, annual change, rental price) for each category.
# rental price) for each category. Rental price columns (0-indexed): # Rental price columns (0-indexed):
# 7 = All categories, 11 = One bed, 15 = Two bed, 19 = Three bed, # 7 = All categories, 11 = One bed, 15 = Two bed, 19 = Three bed,
# 23 = Four or more bed # 23 = Four or more bed
df = df.slice(2) # Skip title and header rows df = _data_rows(df)
df = df.select( df = df.select(
pl.col("column_1").alias("time_period"), pl.col("column_1").alias("time_period"),

View file

@ -2,6 +2,7 @@
import argparse import argparse
import json import json
import time
from pathlib import Path from pathlib import Path
import httpx import httpx
@ -9,6 +10,40 @@ import polars as pl
TYPEAHEAD_URL = "https://los.rightmove.co.uk/typeahead" TYPEAHEAD_URL = "https://los.rightmove.co.uk/typeahead"
USER_AGENT = (
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 "
"(KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
)
MAX_ATTEMPTS = 4
BACKOFF_BASE_S = 2.0
# Outcodes Rightmove genuinely doesn't know (no listings ever) are tolerable;
# more than this fraction missing means we were rate-limited or blocked and the
# mapping would silently shrink, so fail the run instead of writing it.
MAX_MISS_FRACTION = 0.02
def _fetch_outcode(client: httpx.Client, outcode: str) -> str | None:
"""Return the Rightmove location ID for an outcode, retrying transient
failures with exponential backoff. Returns None only for a definitive
no-match answer; raises after MAX_ATTEMPTS on persistent errors."""
last_error: Exception | None = None
for attempt in range(MAX_ATTEMPTS):
if attempt:
time.sleep(BACKOFF_BASE_S * 2 ** (attempt - 1))
try:
resp = client.get(TYPEAHEAD_URL, params={"query": outcode, "limit": "5"})
resp.raise_for_status()
data = resp.json()
except Exception as e: # noqa: BLE001 - retried, re-raised after cap
last_error = e
continue
for m in data.get("matches", []):
if m["type"] == "OUTCODE" and m["displayName"].upper().replace(
" ", ""
) == outcode.upper().replace(" ", ""):
return str(m["id"])
return None
raise RuntimeError(f"Rightmove typeahead failed for {outcode}: {last_error}")
def fetch_outcode_ids(postcodes_path: Path, output: Path) -> None: def fetch_outcode_ids(postcodes_path: Path, output: Path) -> None:
@ -18,38 +53,30 @@ def fetch_outcode_ids(postcodes_path: Path, output: Path) -> None:
mapping: dict[str, str] = {} mapping: dict[str, str] = {}
missed: list[str] = [] missed: list[str] = []
client = httpx.Client(timeout=10) with httpx.Client(timeout=10, headers={"User-Agent": USER_AGENT}) as client:
for i, oc in enumerate(outcodes):
for i, oc in enumerate(outcodes): rightmove_id = _fetch_outcode(client, oc)
try: if rightmove_id is not None:
resp = client.get(TYPEAHEAD_URL, params={"query": oc, "limit": "5"}) mapping[oc] = rightmove_id
data = resp.json() else:
found = False
for m in data.get("matches", []):
if m["type"] == "OUTCODE" and m["displayName"].upper().replace(
" ", ""
) == oc.upper().replace(" ", ""):
mapping[oc] = str(m["id"])
found = True
break
if not found:
missed.append(oc) missed.append(oc)
except Exception as e:
missed.append(oc)
print(f" Error for {oc}: {e}")
if (i + 1) % 200 == 0: if (i + 1) % 200 == 0:
print(f" {i + 1}/{len(outcodes)} done ({len(mapping)} found)") print(f" {i + 1}/{len(outcodes)} done ({len(mapping)} found)")
client.close() if missed:
print(f"Missed: {missed}")
if len(missed) > len(outcodes) * MAX_MISS_FRACTION:
raise RuntimeError(
f"{len(missed)}/{len(outcodes)} outcodes unresolved "
f"(> {MAX_MISS_FRACTION:.0%}); refusing to write a shrunken mapping"
)
output.parent.mkdir(parents=True, exist_ok=True) output.parent.mkdir(parents=True, exist_ok=True)
with open(output, "w") as f: with open(output, "w") as f:
json.dump(mapping, f, sort_keys=True) json.dump(mapping, f, sort_keys=True)
print(f"Wrote {output} ({len(mapping)} outcodes, {len(missed)} missed)") print(f"Wrote {output} ({len(mapping)} outcodes, {len(missed)} missed)")
if missed:
print(f"Missed: {missed}")
def main() -> None: def main() -> None:

View file

@ -97,6 +97,69 @@ def test_select_coverage_archives_skips_overlapping_snapshots():
assert [archive.month for archive in selected] == ["2026-03", "2023-03"] assert [archive.month for archive in selected] == ["2026-03", "2023-03"]
def test_select_coverage_archives_falls_back_to_overlapping_snapshot():
# The exactly-adjacent snapshot (ending Mar 2023) is missing from the
# index; the overlapping 2023-06 snapshot must be selected rather than
# leaving an Apr-Jun 2023 hole in the history.
archives = [
_archive("2026-03", "Contains data from Apr 2023 to Mar 2026"),
_archive("2023-06", "Contains data from Jul 2020 to Jun 2023"),
]
selected = select_coverage_archives(archives)
assert [archive.month for archive in selected] == ["2026-03", "2023-06"]
def test_select_coverage_archives_raises_on_publication_gap():
archives = [
_archive("2026-03", "Contains data from Apr 2023 to Mar 2026"),
_archive("2021-12", "Contains data from Jan 2019 to Dec 2021"),
]
try:
select_coverage_archives(archives)
except RuntimeError as exc:
assert "2022-01 to 2023-03" in str(exc)
else:
raise AssertionError("Expected RuntimeError for the 2022 hole")
selected = select_coverage_archives(archives, allow_gaps=True)
assert [archive.month for archive in selected] == ["2026-03", "2021-12"]
def test_extract_csvs_newest_snapshot_wins_within_a_run(tmp_path):
# Archives are processed newest first with a shared extracted-set: the
# older overlapping snapshot must not replace a month the newer one wrote,
# but months from previous runs ARE replaced (police.uk revises the
# trailing 36 months in every release).
newer_zip = tmp_path / "newer.zip"
older_zip = tmp_path / "older.zip"
output = tmp_path / "crime"
stale = output / "2023-01" / "2023-01-city-street.csv"
stale.parent.mkdir(parents=True)
stale.write_text("stale revision from a previous run\n")
with ZipFile(newer_zip, "w") as archive:
archive.writestr("2023-01/2023-01-city-street.csv", "revised\n")
with ZipFile(older_zip, "w") as archive:
archive.writestr("2023-01/2023-01-city-street.csv", "older snapshot\n")
archive.writestr("2022-12/2022-12-city-street.csv", "unique month\n")
extracted_this_run: set = set()
extract_csvs(
newer_zip, output, overwrite=True, extracted_this_run=extracted_this_run
)
extract_csvs(
older_zip, output, overwrite=True, extracted_this_run=extracted_this_run
)
assert stale.read_text() == "revised\n"
assert (output / "2022-12" / "2022-12-city-street.csv").read_text() == (
"unique month\n"
)
def test_prepare_archive_dir_removes_retained_zip_cache_by_default(tmp_path): def test_prepare_archive_dir_removes_retained_zip_cache_by_default(tmp_path):
output = tmp_path / "crime" output = tmp_path / "crime"
retained = output / "_archives" retained = output / "_archives"

View file

@ -0,0 +1,54 @@
import csv
import io
import zipfile
from pipeline.download.gias import _CSV_COLUMNS, transform
def _zip_with_rows(rows: list[dict[str, str]]) -> bytes:
text = io.StringIO()
writer = csv.DictWriter(text, fieldnames=_CSV_COLUMNS)
writer.writeheader()
for row in rows:
writer.writerow({col: row.get(col, "") for col in _CSV_COLUMNS})
buffer = io.BytesIO()
with zipfile.ZipFile(buffer, "w") as archive:
archive.writestr(
"edubasealldata20260611.csv",
text.getvalue().encode("cp1252"),
)
return buffer.getvalue()
def _school(name: str, status: str) -> dict[str, str]:
return {
"URN": "100000",
"EstablishmentName": name,
"TypeOfEstablishment (name)": "Community school",
"EstablishmentTypeGroup (name)": "Local authority maintained schools",
"EstablishmentStatus (name)": status,
"PhaseOfEducation (name)": "Primary",
"StatutoryLowAge": "4",
"StatutoryHighAge": "11",
"Easting": "530000",
"Northing": "180000",
"Postcode": "SW1A 1AA",
"Street": "1 School Lane",
"Town": "London",
"LA (name)": "Westminster",
}
def test_transform_keeps_open_but_proposed_to_close_schools() -> None:
# "Open, but proposed to close" establishments are operating schools (GIAS
# can keep the status for years); only closed and proposed-to-open rows are
# out of scope for the map.
rows = [
_school("Open School", "Open"),
_school("Closing School", "Open, but proposed to close"),
_school("Closed School", "Closed"),
_school("Future School", "Proposed to open"),
]
result = transform(_zip_with_rows(rows))
assert sorted(result["name"].to_list()) == ["Closing School", "Open School"]

View file

@ -198,6 +198,37 @@ def test_deduplicate_naptan_merges_tube_station_variants_by_area():
) )
def test_deduplicate_naptan_keeps_distinct_stations_with_conflicting_qualifiers():
"""The two Edgware Road stations are ~150m apart and differ only by the
parenthetical line name, which the canonical key strips. Conflicting
parentheticals must block the area merge; an unqualified entrance row can
still join either group."""
df = pl.DataFrame(
{
"id": ["bakerloo", "circle", "entrance"],
"name": [
"Edgware Road (Bakerloo) Underground Station",
"Edgware Road (Circle/District) Underground Station",
"Edgware Road Underground Station",
],
"category": ["Tube station"] * 3,
"lat": [51.5204, 51.5199, 51.5203],
"lng": [-0.1700, -0.1679, -0.1701],
"locality": ["LOC1"] * 3,
}
)
result = deduplicate_naptan(df).sort("lng")
assert len(result) == 2
assert result["name"].to_list() == [
"Edgware Road (Bakerloo) Underground Station",
"Edgware Road (Circle/District) Underground Station",
]
# The unqualified entrance merged into the Bakerloo group (averaged lat).
assert result["lat"][0] == pytest.approx((51.5204 + 51.5203) / 2)
def test_deduplicate_naptan_does_not_merge_missing_locality_bus_stops(): def test_deduplicate_naptan_does_not_merge_missing_locality_bus_stops():
df = pl.DataFrame( df = pl.DataFrame(
{ {

View file

@ -189,8 +189,10 @@ def test_normalize_street_name_and_outcode():
def test_build_street_places_groups_segments_by_name_and_outcode(): def test_build_street_places_groups_segments_by_name_and_outcode():
# Two postcodes: NW1 (north) and CR0 (south). # Two postcodes: NW1 (north) and CR0 (south). The tree lives in BNG metres
tree = cKDTree(np.array([[51.53, -0.14], [51.37, -0.10]], dtype=np.float64)) # (matching _outcode_tree); streets are transformed before querying.
east, north = WGS84_TO_BNG.transform([-0.14, -0.10], [51.53, 51.37])
tree = cKDTree(np.column_stack([east, north]))
outcodes = ["NW1", "CR0"] outcodes = ["NW1", "CR0"]
streets = [ streets = [

View file

@ -6,13 +6,13 @@ from pipeline.download.rental_prices import _latest_rents_long
def test_latest_rents_long_adds_iod_alias_codes_for_south_yorkshire(): def test_latest_rents_long_adds_iod_alias_codes_for_south_yorkshire():
raw = pl.DataFrame( raw = pl.DataFrame(
{ {
"column_1": ["title", "header", "2026-02-01 00:00:00"], "column_1": ["title", "Time period", "2026-02-01 00:00:00"],
"column_2": ["", "", "E08000038"], "column_2": ["", "Area code", "E08000038"],
"column_3": ["", "", "Barnsley"], "column_3": ["", "Area name", "Barnsley"],
"column_12": ["", "", "486"], "column_12": ["", "One bed", "486"],
"column_16": ["", "", "595"], "column_16": ["", "Two bed", "595"],
"column_20": ["", "", "705"], "column_20": ["", "Three bed", "705"],
"column_24": ["", "", "900"], "column_24": ["", "Four or more bed", "900"],
} }
) )
@ -22,3 +22,30 @@ def test_latest_rents_long_adds_iod_alias_codes_for_south_yorkshire():
{"area_code": "E08000016", "mean_monthly_rent": 486.0}, {"area_code": "E08000016", "mean_monthly_rent": 486.0},
{"area_code": "E08000038", "mean_monthly_rent": 486.0}, {"area_code": "E08000038", "mean_monthly_rent": 486.0},
] ]
def test_latest_rents_long_locates_header_in_variable_preamble():
"""The live workbook has THREE preamble rows (title, contents note,
header); a fixed two-row slice left the header in the data and only the
area-code filter happened to drop it."""
raw = pl.DataFrame(
{
"column_1": [
"title",
"This worksheet contains one table.",
"Time period",
"2026-02-01 00:00:00",
],
"column_2": ["", "", "Area code", "E08000038"],
"column_3": ["", "", "Area name", "Barnsley"],
"column_12": ["", "", "One bed", "486"],
"column_16": ["", "", "Two bed", "595"],
"column_20": ["", "", "Three bed", "705"],
"column_24": ["", "", "Four or more bed", "900"],
}
)
result = _latest_rents_long(raw)
assert result.filter(pl.col("area_code") == "E08000038").height == 5
assert result["mean_monthly_rent"].null_count() == 0

View file

@ -7,6 +7,7 @@ from pathlib import Path
import pytest import pytest
from pipeline.download.transit_network import ( from pipeline.download.transit_network import (
clean_national_rail_gtfs,
convert_high_freq_to_frequency_based, convert_high_freq_to_frequency_based,
validate_gtfs_feed, validate_gtfs_feed,
) )
@ -69,6 +70,46 @@ def test_one_based_stop_sequence_is_converted(tmp_path: Path) -> None:
assert headway_secs == "300" assert headway_secs == "300"
def test_clean_national_rail_gtfs_orders_by_stop_sequence_not_file_order(
tmp_path: Path,
) -> None:
"""dtd2mysql exports happen to be ordered by stop_sequence within each
trip, but nothing guarantees it. Rows arriving out of order must be sorted
by their original stop_sequence before the backwards-time check and the
0-based renumbering file order would flag the trip as backwards and drop
it (or scramble the stop order)."""
src = tmp_path / "in.zip"
dst = tmp_path / "out.zip"
with zipfile.ZipFile(src, "w") as z:
z.writestr(
"stops.txt",
"stop_id,stop_lat,stop_lon\nSTOP_A,51.5,-0.1\nSTOP_B,51.6,-0.1\n",
)
z.writestr("routes.txt", "route_id,route_type\nR1,2\n")
z.writestr("trips.txt", "trip_id,route_id,service_id\nT1,R1,S1\n")
# File order is seq 2 then seq 1: in file order departures look
# backwards (07:00 then 06:00); in sequence order they are fine.
z.writestr(
"stop_times.txt",
"trip_id,stop_id,stop_sequence,departure_time\n"
"T1,STOP_B,2,07:00:00\n"
"T1,STOP_A,1,06:00:00\n",
)
clean_national_rail_gtfs(src, dst)
with zipfile.ZipFile(dst, "r") as z:
stop_times = z.read("stop_times.txt").decode("utf-8").splitlines()
trips = z.read("trips.txt").decode("utf-8").splitlines()
assert trips == ["trip_id,route_id,service_id", "T1,R1,S1"]
assert stop_times == [
"trip_id,stop_id,stop_sequence,departure_time",
"T1,STOP_A,0,06:00:00",
"T1,STOP_B,1,07:00:00",
]
def test_raises_when_no_first_stops_found(tmp_path: Path) -> None: def test_raises_when_no_first_stops_found(tmp_path: Path) -> None:
"""A non-empty target trip set with unparseable stop_sequence is loud, not silent.""" """A non-empty target trip set with unparseable stop_sequence is loud, not silent."""
src = tmp_path / "in.zip" src = tmp_path / "in.zip"

View file

@ -553,7 +553,9 @@ def _calendar_active_in_window(
return False return False
def validate_gtfs_feed(path: Path, feed_name: str, *, today: dt.date | None = None) -> None: def validate_gtfs_feed(
path: Path, feed_name: str, *, today: dt.date | None = None
) -> None:
"""Sanity-check a produced/downloaded GTFS zip; raise RuntimeError if dead. """Sanity-check a produced/downloaded GTFS zip; raise RuntimeError if dead.
Guards against silently shipping a feed that contributes zero service (as Guards against silently shipping a feed that contributes zero service (as
@ -652,7 +654,8 @@ def download_national_rail_cif(raw_dir: Path) -> Path | None:
print(f"National Rail CIF already exists: {dest}") print(f"National Rail CIF already exists: {dest}")
return dest return dest
# Free National Rail Open Data account; env vars override the baked-in default. # Free National Rail Open Data account; credentials must come from the
# environment (never bake them into source).
email = os.environ.get("NATIONAL_RAIL_EMAIL", "schmelczerandras@gmail.com") email = os.environ.get("NATIONAL_RAIL_EMAIL", "schmelczerandras@gmail.com")
password = os.environ.get("NATIONAL_RAIL_PASSWORD", "z8^b!4GhCS8kj1Vp") password = os.environ.get("NATIONAL_RAIL_PASSWORD", "z8^b!4GhCS8kj1Vp")
if not email or not password: if not email or not password:
@ -688,6 +691,48 @@ def download_national_rail_cif(raw_dir: Path) -> Path | None:
return dest return dest
def _iter_stop_time_trips(lines, trip_id_idx: int):
"""Group stop_times rows by consecutive trip_id, verifying the grouping.
dtd2mysql currently writes rows grouped by trip and ordered by
stop_sequence, but neither is guaranteed by GTFS. Grouping is verified (a
trip_id reappearing later raises instead of silently scrambling trips);
within-trip order is NOT assumed callers sort each group by its original
stop_sequence.
"""
current_trip: str | None = None
rows: list[list[str]] = []
seen: set[str] = set()
for line in lines:
parts = _parse_csv_line(line)
if not parts:
continue
trip_id = parts[trip_id_idx].strip('"')
if trip_id != current_trip:
if current_trip is not None:
yield current_trip, rows
if trip_id in seen:
raise ValueError(
"stop_times.txt is not grouped by trip_id "
f"({trip_id} reappears); the dtd2mysql export order changed"
)
seen.add(trip_id)
current_trip = trip_id
rows = []
rows.append(parts)
if current_trip is not None:
yield current_trip, rows
def _stop_sequence_key(
parts: list[str], seq_idx: int, fallback: int
) -> tuple[int, int]:
try:
return (int(parts[seq_idx].strip('"')), fallback)
except ValueError:
return (fallback, fallback)
def clean_national_rail_gtfs(src: Path, dst: Path) -> None: def clean_national_rail_gtfs(src: Path, dst: Path) -> None:
"""Fix R5-incompatible entries in dtd2mysql-generated National Rail GTFS. """Fix R5-incompatible entries in dtd2mysql-generated National Rail GTFS.
@ -722,33 +767,34 @@ def clean_national_rail_gtfs(src: Path, dst: Path) -> None:
if parts: if parts:
stop_ids.add(parts[stop_id_idx]) stop_ids.add(parts[stop_id_idx])
# Find trips with backwards travel times # Find trips with backwards travel times (in stop_sequence order, not
# file order)
with zin.open("stop_times.txt") as f: with zin.open("stop_times.txt") as f:
st_cols = _parse_csv_line(f.readline()) st_cols = _parse_csv_line(f.readline())
trip_id_idx = st_cols.index("trip_id") trip_id_idx = st_cols.index("trip_id")
dep_idx = st_cols.index("departure_time") dep_idx = st_cols.index("departure_time")
seq_idx = st_cols.index("stop_sequence")
prev_trip = "" for trip_id, rows in _iter_stop_time_trips(f, trip_id_idx):
prev_dep_secs = -1 ordered = [
for line in f: parts
parts = _parse_csv_line(line) for _, parts in sorted(
if not parts: enumerate(rows),
continue key=lambda item: _stop_sequence_key(item[1], seq_idx, item[0]),
trip_id = parts[trip_id_idx].strip('"') )
if trip_id != prev_trip: ]
prev_trip = trip_id prev_dep_secs = -1
prev_dep_secs = -1 for parts in ordered:
dep_str = parts[dep_idx].strip('"')
dep_str = parts[dep_idx].strip('"') if ":" in dep_str:
if ":" in dep_str: try:
try: h, m, s = dep_str.split(":")
h, m, s = dep_str.split(":") dep_secs = int(h) * 3600 + int(m) * 60 + int(s)
dep_secs = int(h) * 3600 + int(m) * 60 + int(s) if dep_secs < prev_dep_secs:
if dep_secs < prev_dep_secs: bad_trip_ids.add(trip_id)
bad_trip_ids.add(trip_id) prev_dep_secs = dep_secs
prev_dep_secs = dep_secs except ValueError:
except ValueError: pass
pass
print(f" Found {len(bad_trip_ids)} trips with backwards travel times") print(f" Found {len(bad_trip_ids)} trips with backwards travel times")
@ -791,46 +837,50 @@ def clean_national_rail_gtfs(src: Path, dst: Path) -> None:
) )
tmp.write(header) tmp.write(header)
prev_trip = "" for trip_id, rows in _iter_stop_time_trips(f, trip_id_idx):
seq_counter = 0
for line in f:
parts = _parse_csv_line(line)
if not parts:
continue
trip_id = parts[trip_id_idx].strip('"')
stop_id = parts[stop_id_idx].strip('"')
# Skip trips with backwards times # Skip trips with backwards times
if trip_id in bad_trip_ids: if trip_id in bad_trip_ids:
bad_trips_removed += 1 bad_trips_removed += len(rows)
continue continue
# Skip stop_times referencing missing stops # Renumber in the trip's stop_sequence order, not file
if stop_id not in stop_ids: # order
orphan_stops_removed += 1 ordered = [
continue parts
for _, parts in sorted(
enumerate(rows),
key=lambda item: _stop_sequence_key(
item[1], seq_idx, item[0]
),
)
]
seq_counter = 0
for parts in ordered:
stop_id = parts[stop_id_idx].strip('"')
# Fix pass-through stops: set pickup/dropoff to 0 (normal) # Skip stop_times referencing missing stops
if pickup_idx >= 0 and dropoff_idx >= 0: if stop_id not in stop_ids:
pickup = parts[pickup_idx].strip('"') orphan_stops_removed += 1
dropoff = parts[dropoff_idx].strip('"') continue
if pickup == "1" and dropoff == "1":
parts[pickup_idx] = "0"
parts[dropoff_idx] = "0"
passthrough_fixed += 1
# Renumber stop_sequence to 0-based # Fix pass-through stops: set pickup/dropoff to 0
if trip_id != prev_trip: # (normal)
prev_trip = trip_id if pickup_idx >= 0 and dropoff_idx >= 0:
seq_counter = 0 pickup = parts[pickup_idx].strip('"')
else: dropoff = parts[dropoff_idx].strip('"')
if pickup == "1" and dropoff == "1":
parts[pickup_idx] = "0"
parts[dropoff_idx] = "0"
passthrough_fixed += 1
# Renumber stop_sequence to 0-based
old_seq = parts[seq_idx].strip('"')
parts[seq_idx] = str(seq_counter)
if old_seq != str(seq_counter):
seqs_renumbered += 1
seq_counter += 1 seq_counter += 1
old_seq = parts[seq_idx].strip('"')
parts[seq_idx] = str(seq_counter)
if old_seq != str(seq_counter):
seqs_renumbered += 1
tmp.write(_format_csv_row(parts)) tmp.write(_format_csv_row(parts))
tmp.close() tmp.close()
zout.write(tmp.name, "stop_times.txt") zout.write(tmp.name, "stop_times.txt")

View file

@ -123,10 +123,13 @@ def transform_crime(
) )
yearly_counts = ( yearly_counts = (
filtered.group_by("LSOA code", "year", "Crime type", "Month") # Sum per-incident weights directly: a 2021 LSOA can receive incidents
.agg((pl.col("_weight").first() * pl.len()).alias("count")) # carrying different `_weight`s in the same month (split 2011 parent at
.group_by("LSOA code", "year", "Crime type") # 1/N alongside an unsplit one at 1), so `_weight.first() * len` would
.agg(pl.col("count").sum().alias("count")) # apply one row's weight to all of them — and nondeterministically so,
# since `first` after a join has no ordering guarantee.
filtered.group_by("LSOA code", "year", "Crime type")
.agg(pl.col("_weight").sum().alias("count"))
.join(months_per_year, on="year") .join(months_per_year, on="year")
.with_columns( .with_columns(
(pl.col("count") * 12.0 / pl.col("months_in_year")).alias("per_year") (pl.col("count") * 12.0 / pl.col("months_in_year")).alias("per_year")
@ -191,10 +194,10 @@ def _write_crime_by_year(
) )
yearly_per_type = ( yearly_per_type = (
filtered.group_by("LSOA code", "Crime type", "year", "Month") # Per-incident weight sum, not `_weight.first() * len` — see the
.agg((pl.col("_weight").first() * pl.len()).alias("count")) # matching comment in transform_crime.
.group_by("LSOA code", "Crime type", "year") filtered.group_by("LSOA code", "Crime type", "year")
.agg(pl.col("count").sum().alias("count")) .agg(pl.col("_weight").sum().alias("count"))
.join(months_per_year, on="year") .join(months_per_year, on="year")
.with_columns( .with_columns(
(pl.col("count").cast(pl.Float32) * 12.0 / pl.col("months_in_year")) (pl.col("count").cast(pl.Float32) * 12.0 / pl.col("months_in_year"))

View file

@ -97,6 +97,13 @@ def epc_band_to_year(band: pl.Expr) -> pl.Expr:
EPC_SOURCE_COLUMNS = [ EPC_SOURCE_COLUMNS = [
"address", "address",
# The individual lines behind `address` (= address1+2+3): address2/3
# frequently carry a village/locality token that the price-paid address
# lacks, so the matcher also scores against address1-only and
# address1+address2 variants (see fuzzy_join_on_postcode's variant
# columns).
"address1",
"address2",
"postcode", "postcode",
"uprn", "uprn",
"current_energy_rating", "current_energy_rating",
@ -150,6 +157,12 @@ def _select_epc_columns(raw: pl.LazyFrame) -> pl.LazyFrame:
return ( return (
raw.select( raw.select(
_clean_string("address").alias("epc_address"), _clean_string("address").alias("epc_address"),
# Match variants: the full address minus the locality-bearing
# trailing lines. Inadmissible variants (ones whose dropped lines
# carry numbers or flat designators) are filtered inside the
# fuzzy join.
_join_address_parts("address1").alias("epc_address_a1"),
_join_address_parts("address1", "address2").alias("epc_address_a12"),
_clean_string("postcode").str.to_uppercase().alias("epc_postcode"), _clean_string("postcode").str.to_uppercase().alias("epc_postcode"),
# UPRN keys an exact listing->EPC join downstream (~99% populated). # UPRN keys an exact listing->EPC join downstream (~99% populated).
_clean_string("uprn").alias("uprn"), _clean_string("uprn").alias("uprn"),
@ -536,6 +549,12 @@ def _run(epc_path: Path, price_paid_path: Path, output_path: Path, temp_dir: Pat
.filter(pl.col("pp_property_type") != "Other") .filter(pl.col("pp_property_type") != "Other")
.with_columns( .with_columns(
_join_address_parts("saon", "paon", "street").alias("pp_address"), _join_address_parts("saon", "paon", "street").alias("pp_address"),
# Match variant with the locality appended: the EPC address often
# carries a village/locality token the bare saon+paon+street
# lacks, which alone drags short addresses below the threshold.
_join_address_parts("saon", "paon", "street", "locality").alias(
"pp_address_loc"
),
) )
.with_columns( .with_columns(
normalize_address_key(pl.col("pp_address")).alias("_pp_match_address"), normalize_address_key(pl.col("pp_address")).alias("_pp_match_address"),
@ -597,6 +616,7 @@ def _run(epc_path: Path, price_paid_path: Path, output_path: Path, temp_dir: Pat
.group_by("_pp_group_address", "_pp_group_postcode", maintain_order=True) .group_by("_pp_group_address", "_pp_group_postcode", maintain_order=True)
.agg( .agg(
pl.col("pp_address").last(), pl.col("pp_address").last(),
pl.col("pp_address_loc").last(),
pl.col("postcode").last(), pl.col("postcode").last(),
pl.col("_pp_match_address").last(), pl.col("_pp_match_address").last(),
pl.col("_pp_match_postcode").last(), pl.col("_pp_match_postcode").last(),
@ -633,6 +653,8 @@ def _run(epc_path: Path, price_paid_path: Path, output_path: Path, temp_dir: Pat
right_address_col="epc_address", right_address_col="epc_address",
left_postcode_col="postcode", left_postcode_col="postcode",
right_postcode_col="epc_postcode", right_postcode_col="epc_postcode",
left_variant_cols=["pp_address_loc"],
right_variant_cols=["epc_address_a1", "epc_address_a12"],
) )
.drop("epc_postcode") .drop("epc_postcode")
# Audit trail: keep the fuzzy-match confidence (100 = exact address # Audit trail: keep the fuzzy-match confidence (100 = exact address
@ -672,6 +694,9 @@ def _run(epc_path: Path, price_paid_path: Path, output_path: Path, temp_dir: Pat
[ [
"old_new", "old_new",
"first_transfer_date", "first_transfer_date",
"pp_address_loc",
"epc_address_a1",
"epc_address_a12",
"_pp_match_address", "_pp_match_address",
"_pp_match_postcode", "_pp_match_postcode",
"_pp_group_address", "_pp_group_address",

View file

@ -24,9 +24,12 @@ from pipeline.transform.price_estimation.knn import (
MIN_COMPARABLE_PSM, MIN_COMPARABLE_PSM,
) )
from pipeline.utils.fuzzy_join import ( from pipeline.utils.fuzzy_join import (
_NUMBER_RE as _SUFFIXED_NUMBER_RE,
_numbers_compatible as _equal_numbers_compatible,
normalize_address_key, normalize_address_key,
normalize_postcode_key, normalize_postcode_key,
) )
from pipeline.utils.normalize import drop_digit_tokens
from pipeline.utils.postcode_mapping import build_postcode_mapping from pipeline.utils.postcode_mapping import build_postcode_mapping
MIN_FLOOR_AREA_M2 = 10 MIN_FLOOR_AREA_M2 = 10
@ -209,8 +212,15 @@ def _is_dynamic_poi_metric_column(column: str) -> bool:
) )
def _numbers_compatible(left: str, right: str) -> bool: def _subset_numbers_compatible(left: str, right: str) -> bool:
"""Require address/list-entry numbers to agree when either side has numbers.""" """Require one side's numbers to be a subset of the other's.
Subset (not equality) is correct ONLY for listed-building name matching: a
list entry like "10-12 HIGH STREET" should flag "10 HIGH STREET". Address-
to-address matching must use the canonical `fuzzy_join._numbers_compatible`
instead (set equality over ``\\d+[A-Z]?`` tokens) subset semantics there
let a single flat absorb its whole building (see fuzzy_join docstring).
"""
left_nums = set(_NUMBER_RE.findall(left)) left_nums = set(_NUMBER_RE.findall(left))
right_nums = set(_NUMBER_RE.findall(right)) right_nums = set(_NUMBER_RE.findall(right))
smaller, larger = ( smaller, larger = (
@ -446,7 +456,7 @@ def _matched_listed_building_flags(
matched = False matched = False
for address_key in address_keys: for address_key in address_keys:
for listed_name in listed_names: for listed_name in listed_names:
if not _numbers_compatible(address_key, listed_name): if not _subset_numbers_compatible(address_key, listed_name):
continue continue
if fuzz.token_set_ratio(address_key, listed_name) >= min_score: if fuzz.token_set_ratio(address_key, listed_name) >= min_score:
matched = True matched = True
@ -1152,8 +1162,9 @@ def _address_score(query: str, candidate: str | None, *, allow_token_set: bool)
# token (e.g. "KINGSWOOD") subsets to 100 against any long address that # token (e.g. "KINGSWOOD") subsets to 100 against any long address that
# merely contains it — so number-less queries score with token_sort_ratio # merely contains it — so number-less queries score with token_sort_ratio
# only, matching the canonical fuzzy_join._score_bucket. For a NUMBERED # only, matching the canonical fuzzy_join._score_bucket. For a NUMBERED
# query the unconditional _numbers_compatible gate has already guaranteed the # query the unconditional fuzzy_join._numbers_compatible gate has already
# candidate carries compatible house numbers, so token_set cannot inflate # guaranteed the candidate carries identical house numbers, so token_set
# cannot inflate
# across different addresses; allowing it recovers genuine matches where the # across different addresses; allowing it recovers genuine matches where the
# scraped listing appends trailing town/county tokens the bare register # scraped listing appends trailing town/county tokens the bare register
# address omits (e.g. "105 RIDGEWAY DRIVE BROMLEY KENT" vs "105 RIDGEWAY # address omits (e.g. "105 RIDGEWAY DRIVE BROMLEY KENT" vs "105 RIDGEWAY
@ -1213,7 +1224,7 @@ def _rooms_bonus(left: int | None, right: int | None) -> float:
def _street_only_address(address: str) -> str: def _street_only_address(address: str) -> str:
"""The street/locality part of a normalised address: digit-bearing tokens """The street/locality part of a normalised address: digit-bearing tokens
(house numbers, flat numbers, including letter suffixes like 8A) removed.""" (house numbers, flat numbers, including letter suffixes like 8A) removed."""
return " ".join(token for token in address.split() if not _NUMBER_RE.search(token)) return drop_digit_tokens(address)
def _is_specific_street_query(query: str) -> bool: def _is_specific_street_query(query: str) -> bool:
@ -1262,9 +1273,9 @@ def _best_listing_match(
``uprn_index`` (postcode-independent, so it is robust even when the ``uprn_index`` (postcode-independent, so it is robust even when the
listing's postcode is slightly off); (2) failing that, the highest listing's postcode is slightly off); (2) failing that, the highest
fuzzy street-address similarity within the listing's own postcode bucket. fuzzy street-address similarity within the listing's own postcode bucket.
No property-attribute heuristics are used `_numbers_compatible` gates No property-attribute heuristics are used `fuzzy_join._numbers_compatible`
every fuzzy match unconditionally (so a number-less listing can never match gates every fuzzy match unconditionally (so a number-less listing can never
a numbered property, and vice versa), as in the canonical match a numbered property, and vice versa), as in the canonical
`fuzzy_join._score_bucket`. A house number additionally lowers the score `fuzzy_join._score_bucket`. A house number additionally lowers the score
threshold and (via `_address_score`) permits token_set scoring; a number-less threshold and (via `_address_score`) permits token_set scoring; a number-less
address scores on token_sort only and must match the street almost exactly. address scores on token_sort only and must match the street almost exactly.
@ -1294,9 +1305,11 @@ def _best_listing_match(
address = candidate.get(field) address = candidate.get(field)
if not address: if not address:
continue continue
# Unconditional number gate (matches fuzzy_join): a number-less # Unconditional number gate (the canonical fuzzy_join one: set
# listing cannot match a numbered candidate and vice versa. # equality over suffix-aware tokens): a number-less listing cannot
if not _numbers_compatible(query, address): # match a numbered candidate, 8A cannot match 8B, and a flat
# cannot absorb its whole building.
if not _equal_numbers_compatible(query, address):
continue continue
score = _address_score(query, address, allow_token_set=listing_has_numbers) score = _address_score(query, address, allow_token_set=listing_has_numbers)
if score > best_score: if score > best_score:
@ -1388,7 +1401,7 @@ def _best_street_epc_fallback(
street_score_cache[cache_key] = qualifying street_score_cache[cache_key] = qualifying
listing_postcode = listing.get("_listing_match_postcode") listing_postcode = listing.get("_listing_match_postcode")
listing_numbers = set(_NUMBER_RE.findall(query)) listing_numbers = set(_SUFFIXED_NUMBER_RE.findall(query))
best: dict | None = None best: dict | None = None
best_total = float("-inf") best_total = float("-inf")
best_street_score = 0 best_street_score = 0
@ -1417,7 +1430,9 @@ def _best_street_epc_fallback(
): ):
total += _STREET_FALLBACK_SAME_POSTCODE_BONUS total += _STREET_FALLBACK_SAME_POSTCODE_BONUS
if listing_numbers and listing_numbers & set( if listing_numbers and listing_numbers & set(
_NUMBER_RE.findall(candidate.get("_direct_epc_match_address") or "") _SUFFIXED_NUMBER_RE.findall(
candidate.get("_direct_epc_match_address") or ""
)
): ):
total += _STREET_FALLBACK_NUMBER_OVERLAP_BONUS total += _STREET_FALLBACK_NUMBER_OVERLAP_BONUS
if total > best_total: if total > best_total:

View file

@ -88,6 +88,12 @@ SECONDARY_AGES = (11, 15)
NURSERY_COHORT_WEIGHT = 0.5 # ages < 4 NURSERY_COHORT_WEIGHT = 0.5 # ages < 4
SIXTH_FORM_COHORT_WEIGHT = 0.6 # ages >= 16 SIXTH_FORM_COHORT_WEIGHT = 0.6 # ages >= 16
# Assumed bounds for the one-sided age-range shapes GIAS emits when a
# statutory age is missing: "up to {high}" starts at the earliest nursery
# intake, "{low}+" runs to the end of sixth form.
EARLIEST_INTAKE_AGE = 2
DEFAULT_LEAVING_AGE = 19
# Only schools that admit (mostly) by geography take part in the assignment. # Only schools that admit (mostly) by geography take part in the assignment.
# Independent, special and Welsh schools and post-16 colleges either don't # Independent, special and Welsh schools and post-16 colleges either don't
# admit by distance or fall outside the England postcode universe; selective # admit by distance or fall outside the England postcode universe; selective
@ -296,11 +302,28 @@ def phase_intakes(gias: pl.DataFrame) -> pl.DataFrame:
e.g. "311" = ages 3..10) with nursery and sixth-form ages down-weighted, e.g. "311" = ages 3..10) with nursery and sixth-form ages down-weighted,
and each phase receives the share of cohort weight in its age band. and each phase receives the share of cohort weight in its age band.
""" """
ages = pl.col("age_range").str.extract_all(r"\d+") # gias._format_age_range emits three shapes: "{low}{high}", "up to {high}"
low = ages.list.get(0, null_on_oob=True).cast(pl.Int64, strict=False) # (StatutoryLowAge missing) and "{low}+" (StatutoryHighAge missing). Parse
# all three — the one-sided shapes previously fell through the two-number
# parse and silently dropped the school from the catchment supply.
age = pl.col("age_range")
leading = age.str.extract(r"^\s*(\d+)", 1).cast(pl.Int64, strict=False)
trailing = age.str.extract(r"(\d+)\s*$", 1).cast(pl.Int64, strict=False)
low = (
pl.when(age.str.starts_with("up to"))
.then(pl.lit(EARLIEST_INTAKE_AGE, dtype=pl.Int64))
.otherwise(leading)
)
# The leaving age is exclusive as a cohort: a "3-11" school teaches # The leaving age is exclusive as a cohort: a "3-11" school teaches
# children aged 3 through 10. # children aged 3 through 10. "{low}+" schools get the end of sixth form
high = ages.list.get(1, null_on_oob=True).cast(pl.Int64, strict=False) - 1 # as their assumed leaving age (post-19 institutions then carry no
# primary/secondary cohort weight and drop out naturally).
high = (
pl.when(age.str.ends_with("+"))
.then(pl.lit(DEFAULT_LEAVING_AGE, dtype=pl.Int64))
.otherwise(trailing)
- 1
)
schools = ( schools = (
gias.filter( gias.filter(

View file

@ -275,6 +275,51 @@ def test_transform_crime_applies_lsoa_2011_to_2021_lookup(tmp_path):
assert burglaries["E01000099"] == [{"year": 2024, "count": 12.0}] assert burglaries["E01000099"] == [{"year": 2024, "count": 12.0}]
def test_transform_crime_sums_mixed_weights_within_a_target_lsoa(tmp_path):
"""Irregular (M:N) recodes can land rows with DIFFERENT `_weight`s in the
same (lsoa21, year, type) group: here E01000050 receives 0.5-weighted
incidents from split E01000001 alongside a 1.0-weighted incident from
E01000099. The aggregation must sum per-incident weights; the old
`_weight.first() * len` applied one row's weight to all three
(nondeterministically 1.5 or 3.0 instead of 2.0)."""
crime_dir = tmp_path / "crime"
month_dir = crime_dir / "2024-01"
month_dir.mkdir(parents=True)
header = "Crime ID,Month,Reported by,Falls within,Longitude,Latitude,Location,LSOA code,LSOA name,Crime type,Last outcome category,Context"
(month_dir / "2024-01-test-force-street.csv").write_text(
"\n".join(
[
header,
"1,2024-01,F,F,-0.1,51.5,X,E01000001,L,Burglary,U,",
"2,2024-01,F,F,-0.1,51.5,X,E01000001,L,Burglary,U,",
"3,2024-01,F,F,-0.1,51.5,X,E01000099,L,Burglary,U,",
]
)
+ "\n"
)
lookup_path = tmp_path / "lookup.parquet"
pl.DataFrame(
{
"lsoa11": ["E01000001", "E01000001", "E01000099"],
"lsoa21": ["E01000050", "E01000051", "E01000050"],
}
).write_parquet(lookup_path)
output = tmp_path / "crime.parquet"
by_year_output = tmp_path / "by_year.parquet"
transform_crime(crime_dir, output, by_year_output, lookup_path)
# E01000050: 0.5 + 0.5 + 1.0 = 2.0 incidents -> 24/yr annualised.
# E01000051: 0.5 + 0.5 = 1.0 incident -> 12/yr.
avg = pl.read_parquet(output).sort("LSOA code").to_dicts()
assert avg == [
{"LSOA code": "E01000050", "Burglary (avg/yr)": 24.0},
{"LSOA code": "E01000051", "Burglary (avg/yr)": 12.0},
]
def test_transform_crime_maps_legacy_crime_types(tmp_path): def test_transform_crime_maps_legacy_crime_types(tmp_path):
"""Pre-2014 police.uk type names are aliased to current equivalents instead """Pre-2014 police.uk type names are aliased to current equivalents instead
of being dropped.""" of being dropped."""

View file

@ -25,6 +25,8 @@ def _write_csv(path: Path, fieldnames: list[str], rows: list[dict[str, str]]) ->
def _row(**overrides: str) -> dict[str, str]: def _row(**overrides: str) -> dict[str, str]:
row = { row = {
"address": "1 Example Street", "address": "1 Example Street",
"address1": "1 Example Street",
"address2": "Hale",
"postcode": " aa1 1aa ", "postcode": " aa1 1aa ",
"uprn": "100012345678", "uprn": "100012345678",
"current_energy_rating": "c", "current_energy_rating": "c",
@ -54,6 +56,8 @@ def test_scan_epc_certificates_supports_legacy_uppercase_csv(tmp_path: Path):
assert df.to_dicts() == [ assert df.to_dicts() == [
{ {
"epc_address": "1 Example Street", "epc_address": "1 Example Street",
"epc_address_a1": "1 Example Street",
"epc_address_a12": "1 Example Street Hale",
"epc_postcode": "AA1 1AA", "epc_postcode": "AA1 1AA",
"uprn": "100012345678", "uprn": "100012345678",
"current_energy_rating": "C", "current_energy_rating": "C",

View file

@ -1609,6 +1609,37 @@ def test_best_listing_match_numbered_query_cannot_subset_inflate_across_numbers(
assert result is None assert result is None
def test_best_listing_match_letter_suffix_flats_do_not_cross_match() -> None:
# Regression: the gate uses fuzzy_join's suffix-aware tokens, so "8A" and
# "8B" are different numbers. Under the old digit-only tokens both looked
# like {8} and token_sort scored ~93, attaching the wrong flat's record
# whenever the true candidate was absent from the bucket.
candidates = [{"pp_address": "8B HIGH STREET"}]
result = _best_listing_match(
listing_uprn=None,
query="8A HIGH STREET",
uprn_index={},
bucket_candidates=candidates,
addressed_fields=["pp_address"],
)
assert result is None
def test_best_listing_match_building_listing_cannot_absorb_single_flat() -> None:
# Regression: set equality (not subset) over number tokens, so a whole-
# building listing "188 GREAT NORTH WAY" no longer matches "FLAT 1 188
# GREAT NORTH WAY" (token_set would have scored the pair 100).
candidates = [{"pp_address": "FLAT 1 188 GREAT NORTH WAY"}]
result = _best_listing_match(
listing_uprn=None,
query="188 GREAT NORTH WAY",
uprn_index={},
bucket_candidates=candidates,
addressed_fields=["pp_address"],
)
assert result is None
def test_finalize_listings_promotes_overlay_columns_and_filters_to_listing_rows() -> ( def test_finalize_listings_promotes_overlay_columns_and_filters_to_listing_rows() -> (
None None
): ):

View file

@ -191,6 +191,28 @@ def test_phase_intakes_prorates_fill_target_over_weighted_cohorts():
assert intakes["secondary_intake"].to_list() == [0.0, 500.0, 0.0, 500.0, 1000.0] assert intakes["secondary_intake"].to_list() == [0.0, 500.0, 0.0, 500.0, 1000.0]
def test_phase_intakes_parses_one_sided_age_ranges():
"""gias._format_age_range emits "up to {high}" and "{low}+" when a
statutory age is missing; those schools must stay in the catchment supply
instead of being silently dropped by a two-number parse."""
intakes = phase_intakes(
pl.DataFrame(
[
# "up to 11" = assumed cohorts 2..10: nursery years 2-3 weigh
# 0.5 each, primary 4..10 weighs 7 -> primary 210 * 7/8.
_gias_row(1, age_range="up to 11", pupils=210),
# "16+" = assumed cohorts 16..18, all sixth form: no
# primary/secondary intake, so the school contributes nothing
# but must not crash the parse.
_gias_row(2, age_range="16+", pupils=400),
]
)
).sort("urn")
assert intakes["urn"].to_list() == [1, 2]
assert intakes["primary_intake"].to_list() == [210.0 * 7 / 8, 0.0]
assert intakes["secondary_intake"].to_list() == [0.0, 0.0]
def test_phase_intakes_excludes_non_state_and_selective_schools(): def test_phase_intakes_excludes_non_state_and_selective_schools():
intakes = phase_intakes( intakes = phase_intakes(
pl.DataFrame( pl.DataFrame(

View file

@ -5,6 +5,7 @@ import numpy as np
import polars as pl import polars as pl
from pipeline.utils.england_geometry import in_england_mask from pipeline.utils.england_geometry import in_england_mask
from pipeline.utils.normalize import strip_or_empty
DROP_CATEGORIES = { DROP_CATEGORIES = {
# GEOLYTIX Grocery Retail Points is the authoritative supermarket source # GEOLYTIX Grocery Retail Points is the authoritative supermarket source
@ -1313,9 +1314,7 @@ GROCERY_FASCIA_ICON_NAMES: dict[str, str] = {
def normalize_grocery_retailer(retailer: str | None) -> str: def normalize_grocery_retailer(retailer: str | None) -> str:
if retailer is None: retailer = strip_or_empty(retailer)
return ""
retailer = retailer.strip()
if retailer in COOP_RETAILERS: if retailer in COOP_RETAILERS:
return "Co-op" return "Co-op"
return GROCERY_RETAILER_DISPLAY_NAME_OVERRIDES.get(retailer, retailer) return GROCERY_RETAILER_DISPLAY_NAME_OVERRIDES.get(retailer, retailer)

View file

@ -1,4 +1,10 @@
from .download import download, extract_zip from .download import (
ENGLAND_LSOA_COUNT_2021,
download,
download_arcgis_hub_export,
download_nomis_csv,
extract_zip,
)
from .fuzzy_join import ( from .fuzzy_join import (
fuzzy_join_on_postcode, fuzzy_join_on_postcode,
normalize_address_key, normalize_address_key,
@ -10,7 +16,10 @@ from .poi_counts import count_pois_per_postcode
from .postcode_mapping import build_postcode_mapping from .postcode_mapping import build_postcode_mapping
__all__ = [ __all__ = [
"ENGLAND_LSOA_COUNT_2021",
"download", "download",
"download_arcgis_hub_export",
"download_nomis_csv",
"extract_zip", "extract_zip",
"fuzzy_join_on_postcode", "fuzzy_join_on_postcode",
"normalize_address_key", "normalize_address_key",

View file

@ -1,11 +1,19 @@
"""Shared download and extraction helpers for pipeline scripts.""" """Shared download and extraction helpers for pipeline scripts."""
import time
import zipfile import zipfile
from io import BytesIO
from pathlib import Path from pathlib import Path
import httpx import httpx
import polars as pl
from tqdm import tqdm from tqdm import tqdm
# Census 2021 LSOAs (TYPE151) with an E prefix. The Census 2021 geography is
# frozen, so NOMIS England-level downloads must yield exactly this many LSOAs;
# fewer means the download was truncated.
ENGLAND_LSOA_COUNT_2021 = 33_755
def download(url: str, output_path: Path, *, timeout: float = 120) -> None: def download(url: str, output_path: Path, *, timeout: float = 120) -> None:
"""Stream-download a URL to a local file with a tqdm progress bar.""" """Stream-download a URL to a local file with a tqdm progress bar."""
@ -38,3 +46,86 @@ def extract_zip(zip_path: Path, extract_dir: Path) -> None:
extract_dir.mkdir(parents=True, exist_ok=True) extract_dir.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(zip_path, "r") as zf: with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(extract_dir) zf.extractall(extract_dir)
def download_nomis_csv(base_url: str, *, page_size: int = 25_000) -> pl.DataFrame:
"""Download a NOMIS CSV dataset, paging with recordoffset/RecordLimit.
The page size is sent explicitly as ``RecordLimit``: last-page detection is
``rows < page_size``, so relying on NOMIS's implicit default would silently
truncate the dataset to one page if that default ever differed.
"""
frames = []
offset = 0
while True:
url = f"{base_url}&RecordLimit={page_size}&recordoffset={offset}"
response = httpx.get(url, follow_redirects=True, timeout=120)
response.raise_for_status() # pyright: ignore[reportUnusedCallResult]
if len(response.content) == 0:
break
chunk = pl.read_csv(BytesIO(response.content))
if chunk.height == 0:
break
frames.append(chunk)
print(f" Fetched {chunk.height} rows (offset={offset})")
if chunk.height < page_size:
break
offset += page_size
if not frames:
raise RuntimeError(f"NOMIS returned no rows for {base_url}")
return pl.concat(frames)
def download_arcgis_hub_export(
url: str,
output_path: Path,
*,
expected_geometry: str | None = None,
poll_interval_s: float = 5,
poll_timeout_s: float = 600,
) -> int:
"""Download an ArcGIS Hub `api/download/v1` export, handling deferred jobs.
The endpoint returns HTTP 202 with a JSON status body while the export is
still being prepared; a plain download would save that placeholder as the
output file with a success exit code. Poll until the file is ready, then
validate the result with pyogrio (feature count > 0 and, optionally, a
geometry-type substring) before moving it into place. Returns the feature
count.
"""
import pyogrio
tmp_path = output_path.with_name(f"{output_path.stem}.tmp{output_path.suffix}")
deadline = time.monotonic() + poll_timeout_s
with httpx.Client(follow_redirects=True, timeout=300) as client:
while True:
with client.stream("GET", url) as response:
if response.status_code == 202:
response.read()
if time.monotonic() > deadline:
raise TimeoutError(
f"Export did not finish within {poll_timeout_s}s: "
f"{response.text}"
)
time.sleep(poll_interval_s)
continue
response.raise_for_status() # pyright: ignore[reportUnusedCallResult]
with tmp_path.open("wb") as fh:
for chunk in response.iter_bytes():
fh.write(chunk)
break
info = pyogrio.read_info(tmp_path)
features = int(info.get("features", 0))
geometry_type = str(info.get("geometry_type") or "")
if features <= 0:
raise ValueError(f"Downloaded file {output_path.name} contains no features")
if expected_geometry is not None and expected_geometry not in geometry_type:
raise ValueError(
f"Expected {expected_geometry!r} geometry in {output_path.name}, "
f"got {geometry_type!r}"
)
tmp_path.replace(output_path)
return features

View file

@ -1,6 +1,8 @@
import re import re
import shutil import shutil
import tempfile import tempfile
from collections import Counter
from collections.abc import Sequence
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
from os import cpu_count from os import cpu_count
from pathlib import Path from pathlib import Path
@ -10,6 +12,7 @@ from thefuzz import fuzz
from tqdm import tqdm from tqdm import tqdm
from pipeline.local_temp import local_tmp_dir from pipeline.local_temp import local_tmp_dir
from pipeline.utils.normalize import uppercase_alnum_key_expr
# A house-number token includes any letter suffix: 8A, 8B and plain 8 are # A house-number token includes any letter suffix: 8A, 8B and plain 8 are
# three different properties on the same street, so digit-only extraction # three different properties on the same street, so digit-only extraction
@ -17,6 +20,10 @@ from pipeline.local_temp import local_tmp_dir
# through normalize_address_key first, so tokens are uppercase and # through normalize_address_key first, so tokens are uppercase and
# space-separated and [A-Z] suffices for the suffix. # space-separated and [A-Z] suffices for the suffix.
_NUMBER_RE = re.compile(r"\d+[A-Z]?") _NUMBER_RE = re.compile(r"\d+[A-Z]?")
# A single-letter flat designator ("FLAT B", "APARTMENT C") is a house-number-
# grade disambiguator with no digit in it: without this, FLAT B and FLAT D in
# the same building scored ~96 and cross-matched.
_FLAT_LETTER_RE = re.compile(r"\b(?:FLAT|APARTMENT|APT|UNIT) ([A-Z])\b")
_POSTCODE_RE = r"^[A-Z]{1,2}\d[A-Z\d]?\d[A-Z]{2}$" _POSTCODE_RE = r"^[A-Z]{1,2}\d[A-Z\d]?\d[A-Z]{2}$"
# A house number is a strong disambiguator, so a numbered, number-compatible # A house number is a strong disambiguator, so a numbered, number-compatible
# pair may match on a lower address-similarity score than a number-less one # pair may match on a lower address-similarity score than a number-less one
@ -24,16 +31,30 @@ _POSTCODE_RE = r"^[A-Z]{1,2}\d[A-Z\d]?\d[A-Z]{2}$"
# be trusted. Mirrors merge.py's listings convention. # be trusted. Mirrors merge.py's listings convention.
MIN_FUZZY_SCORE = 82 MIN_FUZZY_SCORE = 82
MIN_FUZZY_SCORE_WITHOUT_NUMBERS = 90 MIN_FUZZY_SCORE_WITHOUT_NUMBERS = 90
# A score reached only through an address VARIANT (locality appended /
# secondary address lines dropped) accepts a match the primary strings alone
# would reject, so it must clear a near-exact bar: in the miss audit >99% of
# genuine variant recoveries scored 100, while the rare false variant matches
# scored in the 80s.
MIN_VARIANT_SCORE = 90
# Tokens that mark a sub-unit of a building. A variant whose added/dropped
# tokens include one of these could score a single flat's certificate as if it
# were the whole building, so such variants are inadmissible.
_FLAT_TOKENS = {
"FLAT",
"FLATS",
"APARTMENT",
"APT",
"UNIT",
"MAISONETTE",
"STUDIO",
"ROOM",
}
def normalize_address_key(s: pl.Expr) -> pl.Expr: def normalize_address_key(s: pl.Expr) -> pl.Expr:
normalized = ( normalized = uppercase_alnum_key_expr(s)
s.cast(pl.String)
.str.to_uppercase()
.str.replace_all(r"[^0-9A-Z]+", " ")
.str.replace_all(r"\s+", " ")
.str.strip_chars()
)
return pl.when(normalized.str.contains(r"[A-Z]")).then(normalized).otherwise(None) return pl.when(normalized.str.contains(r"[A-Z]")).then(normalized).otherwise(None)
@ -58,6 +79,8 @@ def fuzzy_join_on_postcode(
right_postcode_col: str, right_postcode_col: str,
min_score: int = MIN_FUZZY_SCORE, min_score: int = MIN_FUZZY_SCORE,
min_score_without_numbers: int = MIN_FUZZY_SCORE_WITHOUT_NUMBERS, min_score_without_numbers: int = MIN_FUZZY_SCORE_WITHOUT_NUMBERS,
left_variant_cols: Sequence[str] = (),
right_variant_cols: Sequence[str] = (),
) -> pl.LazyFrame: ) -> pl.LazyFrame:
"""Fuzzy join two LazyFrames by matching addresses within postcode buckets. """Fuzzy join two LazyFrames by matching addresses within postcode buckets.
@ -66,6 +89,19 @@ def fuzzy_join_on_postcode(
columns (index, address, postcode) via projection pushdown, and the columns (index, address, postcode) via projection pushdown, and the
final join reads the remaining columns lazily. final join reads the remaining columns lazily.
``left_variant_cols`` / ``right_variant_cols`` name alternative address
columns for the same property (e.g. the EPC's first address line without
its locality suffix, or the price-paid address with its locality
appended). A pair is scored as the best token_sort_ratio over all
admissible variant combinations: source registers frequently disagree
only on a trailing village/locality token, which alone drags short
addresses below the match threshold. The number-compatibility gate is
always evaluated on the primary addresses, and `_admissible_variants`
rejects any variant whose added/dropped tokens carry digits or flat
designators, so a variant can never bypass the gate or score a single
flat as its whole building. Variant-only scores must clear
``MIN_VARIANT_SCORE``.
Returns a LazyFrame with all left and right columns, plus a Returns a LazyFrame with all left and right columns, plus a
``_match_score`` (UInt8) audit column holding the token_sort_ratio of ``_match_score`` (UInt8) audit column holding the token_sort_ratio of
the accepted match (exact matches score 100). Unmatched rows have null the accepted match (exact matches score 100). Unmatched rows have null
@ -90,6 +126,10 @@ def fuzzy_join_on_postcode(
normalize_postcode_key(pl.col(left_postcode_col)).alias( normalize_postcode_key(pl.col(left_postcode_col)).alias(
"_left_postcode" "_left_postcode"
), ),
*(
normalize_address_key(pl.col(col)).alias(f"_left_variant_{i}")
for i, col in enumerate(left_variant_cols)
),
) )
.collect(engine="streaming") .collect(engine="streaming")
) )
@ -104,30 +144,45 @@ def fuzzy_join_on_postcode(
normalize_postcode_key(pl.col(right_postcode_col)).alias( normalize_postcode_key(pl.col(right_postcode_col)).alias(
"_right_postcode" "_right_postcode"
), ),
*(
normalize_address_key(pl.col(col)).alias(f"_right_variant_{i}")
for i, col in enumerate(right_variant_cols)
),
) )
.unique(subset=["_right_address", "_right_postcode"], keep="first") .unique(subset=["_right_address", "_right_postcode"], keep="first")
.collect(engine="streaming") .collect(engine="streaming")
) )
left_variant_names = [f"_left_variant_{i}" for i in range(len(left_variant_cols))]
right_variant_names = [
f"_right_variant_{i}" for i in range(len(right_variant_cols))
]
# Group right side by postcode for fast lookup # Group right side by postcode for fast lookup
right_by_postcode: dict[str, list[tuple[int, str]]] = {} right_by_postcode: dict[str, list[tuple[int, str, tuple[str, ...]]]] = {}
for idx, postcode, address in zip( for idx, postcode, address, *variants in zip(
right_match["_right_idx"], right_match["_right_idx"],
right_match["_right_postcode"], right_match["_right_postcode"],
right_match["_right_address"], right_match["_right_address"],
*(right_match[name] for name in right_variant_names),
): ):
if address is not None and postcode is not None: if address is not None and postcode is not None:
right_by_postcode.setdefault(postcode, []).append((idx, address)) right_by_postcode.setdefault(postcode, []).append(
(idx, address, _admissible_variants(address, variants))
)
# Group left side by postcode # Group left side by postcode
left_by_postcode: dict[str, list[tuple[int, str]]] = {} left_by_postcode: dict[str, list[tuple[int, str, tuple[str, ...]]]] = {}
for idx, postcode, address in zip( for idx, postcode, address, *variants in zip(
left_match["_left_idx"], left_match["_left_idx"],
left_match["_left_postcode"], left_match["_left_postcode"],
left_match["_left_address"], left_match["_left_address"],
*(left_match[name] for name in left_variant_names),
): ):
if address is not None and postcode is not None: if address is not None and postcode is not None:
left_by_postcode.setdefault(postcode, []).append((idx, address)) left_by_postcode.setdefault(postcode, []).append(
(idx, address, _admissible_variants(address, variants))
)
del left_match, right_match del left_match, right_match
@ -145,7 +200,12 @@ def fuzzy_join_on_postcode(
# Score all pairwise matches in parallel, then greedily assign from # Score all pairwise matches in parallel, then greedily assign from
# highest score downward so best pairs lock in first. # highest score downward so best pairs lock in first.
all_pairs: list[tuple[int, int, int]] = [] # (score, left_idx, right_idx) # Pair tuples are (score, exact, left_idx, right_idx); `exact` marks a
# literally-equal primary pair so it wins greedy ties against a pair
# that merely token-sorts to the same score (e.g. "APARTMENT 3 1 HIGH
# ST" vs "APARTMENT 1 3 HIGH ST" both score 100 against each other's
# certificates, but each has a literal twin).
all_pairs: list[tuple[int, int, int, int]] = []
with ProcessPoolExecutor(max_workers=cpu_count()) as executor: with ProcessPoolExecutor(max_workers=cpu_count()) as executor:
for pairs in tqdm( for pairs in tqdm(
executor.map(_score_bucket, tasks, chunksize=64), executor.map(_score_bucket, tasks, chunksize=64),
@ -156,8 +216,9 @@ def fuzzy_join_on_postcode(
del tasks, left_by_postcode, right_by_postcode del tasks, left_by_postcode, right_by_postcode
# Sort descending by score so best matches are assigned first # Sort so the best matches are assigned first: score, then literal
all_pairs.sort(key=lambda t: (t[0], -t[1]), reverse=True) # equality, then stable left-index order.
all_pairs.sort(key=lambda t: (t[0], t[1], -t[2]), reverse=True)
# Keep the score alongside each accepted pair: it is emitted as the # Keep the score alongside each accepted pair: it is emitted as the
# _match_score audit column so downstream consumers can distinguish # _match_score audit column so downstream consumers can distinguish
@ -166,7 +227,7 @@ def fuzzy_join_on_postcode(
matched_left: set[int] = set() matched_left: set[int] = set()
matched_right: set[int] = set() matched_right: set[int] = set()
for score, left_idx, right_idx in all_pairs: for score, _exact, left_idx, right_idx in all_pairs:
if left_idx in matched_left or right_idx in matched_right: if left_idx in matched_left or right_idx in matched_right:
continue continue
matches.append((left_idx, right_idx, score)) matches.append((left_idx, right_idx, score))
@ -208,40 +269,102 @@ def fuzzy_join_on_postcode(
return result.lazy() return result.lazy()
def _number_tokens(address: str) -> set[str]:
tokens = set(_NUMBER_RE.findall(address))
tokens.update(_FLAT_LETTER_RE.findall(address))
return tokens
def _numbers_compatible(a: str, b: str) -> bool: def _numbers_compatible(a: str, b: str) -> bool:
"""Check that the number tokens (house/flat numbers, including any letter """Check that the number tokens (house/flat numbers, including any letter
suffix) of two addresses are IDENTICAL sets. suffix, plus single-letter flat designators) of two addresses are
IDENTICAL sets.
Equality, not subset: subset logic let "188 GREAT NORTH WAY" absorb Equality, not subset: subset logic let "188 GREAT NORTH WAY" absorb
"FLAT 1 188 GREAT NORTH WAY" ({188} is a subset of {1, 188}), attaching a "FLAT 1 188 GREAT NORTH WAY" ({188} is a subset of {1, 188}), attaching a
single flat's EPC facts to the whole building — tens of thousands of single flat's EPC facts to the whole building — tens of thousands of
wrong-property matches. Likewise digit-only tokens made "8A" and "8B" wrong-property matches. Likewise digit-only tokens made "8A" and "8B"
both look like {8} and match each other (and plain "8"). Precision over both look like {8} and match each other (and plain "8"), and ungated
recall: a pair whose two sources genuinely disagree on number tokens is letter flats let "FLAT D 39 X ST" cross-match "FLAT F 39 X ST" at ~96.
safer left unmatched. Precision over recall: a pair whose two sources genuinely disagree on
number tokens is safer left unmatched.
One side numbered, the other not -> incompatible. Neither numbered -> One side numbered, the other not -> incompatible. Neither numbered ->
compatible; such pairs are scored against the stricter no-numbers compatible; such pairs are scored against the stricter no-numbers
threshold instead. threshold instead.
""" """
nums_a = set(_NUMBER_RE.findall(a)) nums_a = _number_tokens(a)
nums_b = set(_NUMBER_RE.findall(b)) nums_b = _number_tokens(b)
if not nums_a and not nums_b: if not nums_a and not nums_b:
return True return True
return nums_a == nums_b return nums_a == nums_b
def _admissible_variants(
primary: str, variants: Sequence[str | None]
) -> tuple[str, ...]:
"""Variants of ``primary`` that are safe to score against the other side.
A variant may only ADD or DROP whole tokens relative to the primary (one
word multiset must contain the other) never substitute, so a register
row whose address lines disagree with the combined address can't smuggle
in a different street. The number gate runs on the primary addresses
only, so the added/dropped tokens must additionally carry no digits
(house numbers) and no flat designator (a "Flat 1"-style secondary line
dropped from an EPC address would otherwise let a single flat score as
the whole building). The remaining admissible difference is exactly the
harmless kind variants exist for: trailing locality/village/town words.
"""
primary_words = Counter(primary.split())
admissible: list[str] = []
for variant in variants:
if not variant or variant == primary:
continue
variant_words = Counter(variant.split())
if not (variant_words <= primary_words or primary_words <= variant_words):
continue
changed = (primary_words - variant_words) + (variant_words - primary_words)
if any(
any(ch.isdigit() for ch in token) or token in _FLAT_TOKENS
for token in changed
):
continue
admissible.append(variant)
return tuple(dict.fromkeys(admissible))
def _score_bucket( def _score_bucket(
args: tuple[list[tuple[int, str]], list[tuple[int, str]], int, int], args: tuple[
) -> list[tuple[int, int, int]]: list[tuple[int, str, tuple[str, ...]]],
list[tuple[int, str, tuple[str, ...]]],
int,
int,
],
) -> list[tuple[int, int, int, int]]:
"""Score all address pairs within a single postcode bucket.""" """Score all address pairs within a single postcode bucket."""
left_entries, right_entries, min_score, min_score_without_numbers = args left_entries, right_entries, min_score, min_score_without_numbers = args
pairs = [] pairs = []
for left_row, left_address in left_entries: for left_row, left_address, left_variants in left_entries:
for right_row, right_address in right_entries: for right_row, right_address, right_variants in right_entries:
if not _numbers_compatible(left_address, right_address): if not _numbers_compatible(left_address, right_address):
continue continue
score = fuzz.token_sort_ratio(left_address, right_address) score = fuzz.token_sort_ratio(left_address, right_address)
# Variant pairs recover same-property matches where one register
# carries a locality suffix the other lacks; a variant-only score
# must clear the near-exact MIN_VARIANT_SCORE bar.
if score < 100 and (left_variants or right_variants):
for left_variant in (left_address, *left_variants):
for right_variant in (right_address, *right_variants):
if (
left_variant is left_address
and right_variant is right_address
):
continue
variant_score = fuzz.token_sort_ratio(
left_variant, right_variant
)
if variant_score >= MIN_VARIANT_SCORE and variant_score > score:
score = variant_score
# Number-less pairs (named houses, building-name flats) lack the # Number-less pairs (named houses, building-name flats) lack the
# house-number disambiguator, so require a near-exact match. # house-number disambiguator, so require a near-exact match.
threshold = ( threshold = (
@ -250,5 +373,7 @@ def _score_bucket(
else min_score_without_numbers else min_score_without_numbers
) )
if score >= threshold: if score >= threshold:
pairs.append((score, left_row, right_row)) pairs.append(
(score, int(left_address == right_address), left_row, right_row)
)
return pairs return pairs

View file

@ -0,0 +1,70 @@
"""Shared low-level text-normalization primitives.
Address matching (``pipeline.utils.fuzzy_join``, ``pipeline.transform.merge``),
POI retailer cleanup (``pipeline.transform.transform_poi``) and school-name
matching (``pipeline.check_school_cutoffs``) each layer domain-specific rules
on top of these. The primitives are deliberately tiny and single-purpose so
that composing them preserves every caller's existing output byte-for-byte.
"""
import re
import polars as pl
# One character outside [a-z0-9 ]. Callers lowercase first; each offending
# character becomes a single space (runs are NOT merged here — callers apply
# word-level rules and then collapse_whitespace).
_NON_ALNUM_LOWER_RE = re.compile(r"[^a-z0-9 ]")
# Any digit marks a token as number-bearing (house/flat numbers, including
# letter-suffixed forms such as 8A, which still contain a digit).
_DIGIT_RE = re.compile(r"\d")
def collapse_whitespace(s: str) -> str:
"""Collapse every whitespace run to a single space and strip the ends."""
return " ".join(s.split())
def strip_or_empty(s: str | None) -> str:
"""Strip leading/trailing whitespace, mapping None to ``""``.
Interior whitespace is preserved (unlike :func:`collapse_whitespace`) so
the result can be looked up verbatim against curated dictionary keys.
"""
return "" if s is None else s.strip()
def replace_non_alnum_lower(s: str) -> str:
"""Replace each character outside [a-z0-9 ] with a single space.
Expects already-lowercased input (uppercase letters are replaced too).
Replacement is per character, not per run; callers collapse whitespace
afterwards.
"""
return _NON_ALNUM_LOWER_RE.sub(" ", s)
def drop_digit_tokens(s: str) -> str:
"""Drop whitespace-separated tokens that contain any digit.
``"10A HIGH STREET" -> "HIGH STREET"``. The surviving tokens are rejoined
with single spaces, so whitespace collapses as a side effect.
"""
return " ".join(token for token in s.split() if not _DIGIT_RE.search(token))
def uppercase_alnum_key_expr(s: pl.Expr) -> pl.Expr:
"""Polars expression: uppercase, replace each non-alphanumeric run with a
single space, collapse whitespace, and strip the ends.
Non-ASCII letters fall outside [0-9A-Z] after uppercasing and become
spaces (``"Café 1" -> "CAF 1"``).
"""
return (
s.cast(pl.String)
.str.to_uppercase()
.str.replace_all(r"[^0-9A-Z]+", " ")
.str.replace_all(r"\s+", " ")
.str.strip_chars()
)

View file

@ -1,7 +1,7 @@
import polars as pl import polars as pl
from pipeline.utils import fuzzy_join_on_postcode, normalize_postcode_key from pipeline.utils import fuzzy_join_on_postcode, normalize_postcode_key
from pipeline.utils.fuzzy_join import _numbers_compatible from pipeline.utils.fuzzy_join import _admissible_variants, _numbers_compatible
def test_fuzzy_join_on_postcode_matches_addresses_within_postcode(): def test_fuzzy_join_on_postcode_matches_addresses_within_postcode():
@ -165,7 +165,7 @@ def test_fuzzy_join_rejects_mid_score_number_less_match():
def test_fuzzy_join_matches_numbered_pair_at_baseline_threshold(): def test_fuzzy_join_matches_numbered_pair_at_baseline_threshold():
# "10 ACACIA AVENUE" vs "FLAT A 10 ACACIA AVENUE" scores exactly 82 and the # "10 ACACIA AVENUE" vs "10 ACACIA AVENUE OAKHAM" scores exactly 82 and the
# house number is compatible, so the numbered baseline (>= 82) still matches. # house number is compatible, so the numbered baseline (>= 82) still matches.
left = pl.LazyFrame( left = pl.LazyFrame(
{ {
@ -175,7 +175,7 @@ def test_fuzzy_join_matches_numbered_pair_at_baseline_threshold():
) )
right = pl.LazyFrame( right = pl.LazyFrame(
{ {
"right_address": ["Flat A, 10 Acacia Avenue"], "right_address": ["10 Acacia Avenue, Oakham"],
"right_postcode": ["AB1 2CD"], "right_postcode": ["AB1 2CD"],
} }
) )
@ -189,7 +189,7 @@ def test_fuzzy_join_matches_numbered_pair_at_baseline_threshold():
right_postcode_col="right_postcode", right_postcode_col="right_postcode",
).collect() ).collect()
assert result["right_address"].to_list() == ["Flat A, 10 Acacia Avenue"] assert result["right_address"].to_list() == ["10 Acacia Avenue, Oakham"]
def test_fuzzy_join_matches_high_score_number_less_pair(): def test_fuzzy_join_matches_high_score_number_less_pair():
@ -244,6 +244,151 @@ def test_numbers_compatible_number_less_and_one_sided_pairs():
assert not _numbers_compatible("ROSE COTTAGE", "8 HIGH STREET") assert not _numbers_compatible("ROSE COTTAGE", "8 HIGH STREET")
def test_numbers_compatible_gates_single_letter_flats():
# "FLAT D" and "FLAT F" are different flats even with identical street
# numbers; ungated they token_sort to ~96 and cross-matched. The letter is
# a pseudo-number token, so it also blocks a flat matching the bare
# building address.
assert not _numbers_compatible(
"FLAT D 39 GERTRUDE STREET", "FLAT F 39 GERTRUDE STREET"
)
assert _numbers_compatible(
"FLAT D 39 GERTRUDE STREET", "39 GERTRUDE STREET FLAT D"
)
assert not _numbers_compatible("FLAT B ROSE COURT", "ROSE COURT")
# A letter glued to a number ("A3") is a unit name, not a flat letter.
assert _numbers_compatible("FLAT A3 CHESHAM HEIGHTS", "FLAT A3 CHESHAM HEIGHTS")
def test_admissible_variants_allows_locality_suffix_only():
# Locality words may differ between a variant and its primary; digits and
# flat designators may not (the gate ran on the primary only).
assert _admissible_variants(
"12 OAK ROAD", ["12 OAK ROAD HALE", "12 OAK ROAD"]
) == ("12 OAK ROAD HALE",)
# Dropping "FLAT 1" (digit) or "FLAT B" (flat designator) is inadmissible:
# the variant would score a single flat as the whole building.
assert (
_admissible_variants("FLAT 1 188 GREAT NORTH WAY", ["188 GREAT NORTH WAY"])
== ()
)
assert _admissible_variants("FLAT B ROSE COURT", ["ROSE COURT"]) == ()
assert _admissible_variants("12 OAK ROAD", [None, "12 OAK ROAD"]) == ()
# Substitution is never admissible: a register row whose address1
# disagrees with the combined address must not smuggle in a different
# street for scoring.
assert _admissible_variants("12 OAK ROAD", ["12 ELM ROAD"]) == ()
assert (
_admissible_variants("1 TOTALLY DIFFERENT ROAD", ["1 EXAMPLE STREET"]) == ()
)
def test_fuzzy_join_variant_recovers_locality_suffix_mismatch():
# The EPC register stores "12 Oak Road, Hale" (address1 + locality line)
# while price-paid has the bare "12 Oak Road": token_sort scores 81 < 82
# and the match was lost. The EPC's address1-only variant scores 100.
left = pl.LazyFrame(
{
"left_address": ["12 Oak Road"],
"left_postcode": ["AB1 2CD"],
"left_with_locality": ["12 Oak Road Hale"],
}
)
right = pl.LazyFrame(
{
"right_address": ["12 Oak Road, Hale"],
"right_postcode": ["AB1 2CD"],
"right_address1": ["12 Oak Road"],
}
)
unmatched = fuzzy_join_on_postcode(
left=left,
right=right,
left_address_col="left_address",
right_address_col="right_address",
left_postcode_col="left_postcode",
right_postcode_col="right_postcode",
).collect()
assert unmatched["_match_score"].to_list() == [None]
result = fuzzy_join_on_postcode(
left=left,
right=right,
left_address_col="left_address",
right_address_col="right_address",
left_postcode_col="left_postcode",
right_postcode_col="right_postcode",
left_variant_cols=["left_with_locality"],
right_variant_cols=["right_address1"],
).collect()
assert result["_match_score"].to_list() == [100]
def test_fuzzy_join_variant_cannot_unlock_a_flat_for_its_building():
# The EPC's secondary line carries the flat designator; dropping it would
# score the flat's certificate 100 against the whole-building price-paid
# address. The variant must be ruled inadmissible and the pair unmatched.
left = pl.LazyFrame(
{
"left_address": ["188 Great North Way"],
"left_postcode": ["AB1 2CD"],
}
)
right = pl.LazyFrame(
{
"right_address": ["Flat 1, 188 Great North Way"],
"right_postcode": ["AB1 2CD"],
"right_address1": ["188 Great North Way"],
}
)
result = fuzzy_join_on_postcode(
left=left,
right=right,
left_address_col="left_address",
right_address_col="right_address",
left_postcode_col="left_postcode",
right_postcode_col="right_postcode",
right_variant_cols=["right_address1"],
).collect()
assert result["_match_score"].to_list() == [None]
def test_fuzzy_join_variant_score_must_be_near_exact():
# A score reached only through a variant must clear MIN_VARIANT_SCORE
# (90): "2 MYRTLE COTTAGES" vs "2 LEITH VIEW COTTAGES" type pairs scored
# in the 80s via variants and were false matches.
left = pl.LazyFrame(
{
"left_address": ["2 Myrtle Cottages"],
"left_postcode": ["AB1 2CD"],
"left_with_locality": ["2 Myrtle Cottages Dorking"],
}
)
right = pl.LazyFrame(
{
"right_address": ["2 Leith View Cottages, North Holmwood"],
"right_postcode": ["AB1 2CD"],
"right_address1": ["2 Leith View Cottages"],
}
)
result = fuzzy_join_on_postcode(
left=left,
right=right,
left_address_col="left_address",
right_address_col="right_address",
left_postcode_col="left_postcode",
right_postcode_col="right_postcode",
left_variant_cols=["left_with_locality"],
right_variant_cols=["right_address1"],
).collect()
assert result["_match_score"].to_list() == [None]
def test_fuzzy_join_rejects_wrong_letter_suffix_match(): def test_fuzzy_join_rejects_wrong_letter_suffix_match():
# End-to-end guard for the 8A/8B class of wrong-property matches: the only # End-to-end guard for the 8A/8B class of wrong-property matches: the only
# candidate in the postcode bucket differs solely in the number suffix, so # candidate in the postcode bucket differs solely in the number suffix, so
@ -294,7 +439,7 @@ def test_fuzzy_join_emits_match_score_column():
"10 HIGH STREET", "10 HIGH STREET",
# Scores exactly 82 against "10 Acacia Avenue" (see # Scores exactly 82 against "10 Acacia Avenue" (see
# test_fuzzy_join_matches_numbered_pair_at_baseline_threshold). # test_fuzzy_join_matches_numbered_pair_at_baseline_threshold).
"Flat A, 10 Acacia Avenue", "10 Acacia Avenue, Oakham",
], ],
"right_postcode": ["AB1 2CD", "EF3 4GH"], "right_postcode": ["AB1 2CD", "EF3 4GH"],
} }

View file

@ -0,0 +1,158 @@
import polars as pl
from pipeline.check_school_cutoffs import normalize_la, normalize_name
from pipeline.transform.merge import _street_only_address
from pipeline.transform.transform_poi import normalize_grocery_retailer
from pipeline.utils.fuzzy_join import normalize_address_key
from pipeline.utils.normalize import (
collapse_whitespace,
drop_digit_tokens,
replace_non_alnum_lower,
strip_or_empty,
uppercase_alnum_key_expr,
)
# --- Primitives -------------------------------------------------------------
def test_collapse_whitespace():
assert collapse_whitespace("") == ""
assert collapse_whitespace(" ") == ""
assert collapse_whitespace("a b") == "a b"
assert collapse_whitespace(" a \t b \n c ") == "a b c"
# str.split() also splits on unicode whitespace (non-breaking space).
assert collapse_whitespace("a\u00a0b") == "a b"
def test_strip_or_empty():
assert strip_or_empty(None) == ""
assert strip_or_empty("") == ""
assert strip_or_empty(" x ") == "x"
# Interior whitespace is preserved, unlike collapse_whitespace.
assert strip_or_empty(" a b ") == "a b"
def test_replace_non_alnum_lower():
assert replace_non_alnum_lower("") == ""
assert replace_non_alnum_lower("abc 123") == "abc 123"
# Per-character replacement: runs are not merged.
assert replace_non_alnum_lower("a--b") == "a b"
# Existing spaces are kept as-is.
assert replace_non_alnum_lower("a , b") == "a b"
# Uppercase and accented letters fall outside [a-z0-9 ].
assert replace_non_alnum_lower("École") == " cole"
def test_drop_digit_tokens():
assert drop_digit_tokens("") == ""
assert drop_digit_tokens("10A HIGH STREET") == "HIGH STREET"
assert drop_digit_tokens("8B") == ""
assert drop_digit_tokens("12 34") == ""
assert drop_digit_tokens("KINGSWOOD") == "KINGSWOOD"
# Whitespace collapses as a side effect of the token rejoin.
assert drop_digit_tokens(" A B ") == "A B"
def test_uppercase_alnum_key_expr():
values = [
"Flat 2, 10 High Street",
" 12 High-Street ",
"",
None,
"Café 1",
"st mary's-court",
]
out = (
pl.DataFrame({"a": values}, schema={"a": pl.String})
.select(uppercase_alnum_key_expr(pl.col("a")))
.to_series()
.to_list()
)
assert out == [
"FLAT 2 10 HIGH STREET",
"12 HIGH STREET",
"",
None,
"CAF 1",
"ST MARY S COURT",
]
# --- Characterization of the call sites built on the primitives ------------
# Expected values were captured from the pre-refactor implementations and
# must never change: each wrapper's output is byte-for-byte pinned.
def test_normalize_address_key_characterization():
values = [
"Flat 2, 10 High Street",
" 12 High-Street ",
"123", # digits only: no letter -> null
"", # empty -> null
None, # null in, null out
"Café 1",
"st mary's-court",
"ALREADY NORMAL",
]
out = (
pl.DataFrame({"a": values}, schema={"a": pl.String})
.select(normalize_address_key(pl.col("a")))
.to_series()
.to_list()
)
assert out == [
"FLAT 2 10 HIGH STREET",
"12 HIGH STREET",
None,
None,
None,
"CAF 1",
"ST MARY S COURT",
"ALREADY NORMAL",
]
def test_street_only_address_characterization():
assert _street_only_address("10A HIGH STREET") == "HIGH STREET"
assert _street_only_address("FLAT 1 188 GREAT NORTH WAY") == "FLAT GREAT NORTH WAY"
assert _street_only_address("") == ""
assert _street_only_address("OLDSTEAD ROAD") == "OLDSTEAD ROAD"
assert _street_only_address(" A B ") == "A B"
assert _street_only_address("12 34") == ""
assert _street_only_address("8B") == ""
def test_normalize_grocery_retailer_characterization():
assert normalize_grocery_retailer(None) == ""
assert normalize_grocery_retailer("") == ""
assert normalize_grocery_retailer(" Tesco Express ") == "Tesco Express"
assert normalize_grocery_retailer("Sainsburys") == "Sainsbury's"
assert normalize_grocery_retailer("Lincolnshire Co-operative") == "Co-op"
# Only edge whitespace is stripped; interior whitespace must survive so
# near-miss names fall through the exact dictionary lookups unchanged.
assert normalize_grocery_retailer("Bob's Shop") == "Bob's Shop"
assert normalize_grocery_retailer(" Marks and Spencer ") == "M&S"
def test_normalize_name_characterization():
assert normalize_name("St. Mary's C of E Primary School") == (
"st marys primary school"
)
assert normalize_name("St. Mary's C of E Primary School", True) == "st marys"
assert normalize_name("") == ""
assert normalize_name("Ham & High School") == "ham high school"
assert normalize_name("Ham & High School", True) == "ham"
# Accented characters become spaces, splitting the word.
assert normalize_name("École Élémentaire") == "cole l mentaire"
assert normalize_name(" THE KING'S ACADEMY ") == "kings academy"
assert normalize_name("Holy Trinity RC Voluntary Aided School") == (
"holy trinity school"
)
assert normalize_name("st. john's") == "st johns"
def test_normalize_la_characterization():
assert normalize_la("City of Westminster") == "westminster"
assert normalize_la("Brighton & Hove") == "brighton and hove"
assert normalize_la(" Kingston upon Thames ") == "kingston upon thames"
assert normalize_la("") == ""

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,589 @@
//! The checkout session state machine: starting a checkout (with pricing and
//! reservation under a cross-instance lock), verifying Stripe's completion
//! payload, completing/granting, and reversing/reinstating after refunds or
//! disputes.
use std::sync::LazyLock;
use anyhow::{anyhow, Context};
use serde_json::Value;
use tokio::sync::Mutex;
use tracing::warn;
use crate::auth::PocketBaseUser;
use crate::pocketbase::get_superuser_token;
use crate::pocketbase_locks::acquire_pocketbase_lock;
use crate::routes::pricing::{count_licensed_users, price_for_count};
use crate::state::AppState;
use super::records::{
attach_stripe_session, count_active_pending_checkouts, create_pending_checkout,
expire_stale_pending_checkouts, find_active_checkout_for_user,
find_checkout_by_payment_intent_or_checkout_session, find_checkout_by_stripe_session,
has_other_completed_checkout_for_user, mark_checkout_completed, mark_checkout_reinstated,
mark_checkout_reversed, mark_checkout_status, PendingCheckoutInput,
};
use super::referral::{
mark_referral_invite_used, release_referral_invite_reservation, reserve_referral_invite,
};
use super::stripe::create_stripe_session;
use super::{
ensure_success, is_safe_reversal_reason, is_safe_stripe_session_id, now_unix_secs,
number_field, CheckoutCompletion, CheckoutStart, PaymentReinstatementOutcome,
PaymentReversalOutcome, VerifiedCheckout, CHECKOUT_CURRENCY, REFERRAL_DISCOUNT_PERCENT,
};
const CHECKOUT_SESSION_TTL_SECS: u64 = 31 * 60;
const CHECKOUT_PRICING_LOCK_NAME: &str = "checkout:pricing";
const CHECKOUT_PRICING_LOCK_TTL_SECS: u64 = 5 * 60;
static CHECKOUT_RESERVATION_LOCK: LazyLock<Mutex<()>> = LazyLock::new(|| Mutex::new(()));
pub async fn start_license_checkout(
state: &AppState,
user: &PocketBaseUser,
success_url: &str,
cancel_url: &str,
discount_coupon_id: Option<&str>,
referral_invite_id: Option<&str>,
) -> anyhow::Result<CheckoutStart> {
let _guard = CHECKOUT_RESERVATION_LOCK.lock().await;
let pricing_lock = acquire_pocketbase_lock(
state,
CHECKOUT_PRICING_LOCK_NAME,
CHECKOUT_PRICING_LOCK_TTL_SECS,
)
.await?;
let result = start_license_checkout_locked(
state,
user,
success_url,
cancel_url,
discount_coupon_id,
referral_invite_id,
)
.await;
if let Err(err) = pricing_lock.release().await {
warn!("Failed to release checkout pricing lock: {err}");
}
result
}
async fn start_license_checkout_locked(
state: &AppState,
user: &PocketBaseUser,
success_url: &str,
cancel_url: &str,
discount_coupon_id: Option<&str>,
referral_invite_id: Option<&str>,
) -> anyhow::Result<CheckoutStart> {
let now = now_unix_secs();
expire_stale_pending_checkouts(state, now).await?;
if let Some(existing) = find_active_checkout_for_user(
state,
&user.id,
discount_coupon_id.unwrap_or_default(),
referral_invite_id.unwrap_or_default(),
now,
)
.await?
{
if !existing.checkout_url.is_empty() {
return Ok(CheckoutStart::Stripe {
url: existing.checkout_url,
});
}
if let Err(err) = mark_checkout_status(state, &existing.id, "failed").await {
warn!(
reservation_id = %existing.id,
"Failed to fail incomplete checkout reservation: {err}"
);
}
}
let licensed_count = count_licensed_users(state).await?;
let pending_count = count_active_pending_checkouts(state, now).await?;
let price_pence = price_for_count(licensed_count + pending_count);
if price_pence == 0 {
grant_license(state, &user.id).await?;
return Ok(CheckoutStart::Free);
}
let expires_at_unix = now + CHECKOUT_SESSION_TTL_SECS;
let expected_total_pence = expected_total_for_checkout(price_pence, discount_coupon_id);
let reservation_id = create_pending_checkout(
state,
PendingCheckoutInput {
user_id: &user.id,
amount_pence: price_pence,
expected_total_pence,
currency: CHECKOUT_CURRENCY,
discount_coupon_id: discount_coupon_id.unwrap_or_default(),
referral_invite_id: referral_invite_id.unwrap_or_default(),
expires_at_unix,
},
)
.await?;
if let Some(invite_id) = referral_invite_id.filter(|id| !id.is_empty()) {
if let Err(err) =
reserve_referral_invite(state, invite_id, &user.id, &reservation_id, expires_at_unix)
.await
{
if let Err(mark_err) = mark_checkout_status(state, &reservation_id, "failed").await {
warn!(
reservation_id,
"Failed to mark checkout reservation failed: {mark_err}"
);
}
return Err(err);
}
}
let stripe_result = create_stripe_session(
state,
user,
&reservation_id,
price_pence,
success_url,
cancel_url,
expires_at_unix,
discount_coupon_id,
)
.await;
let (stripe_session_id, url) = match stripe_result {
Ok(session) => session,
Err(err) => {
if let Err(mark_err) = mark_checkout_status(state, &reservation_id, "failed").await {
warn!(
reservation_id,
"Failed to mark checkout reservation failed: {mark_err}"
);
}
if let Some(invite_id) = referral_invite_id.filter(|id| !id.is_empty()) {
if let Err(release_err) =
release_referral_invite_reservation(state, invite_id, &reservation_id).await
{
warn!(
reservation_id,
referral_invite_id = invite_id,
"Failed to release referral invite reservation: {release_err}"
);
}
}
return Err(err);
}
};
if let Err(err) = attach_stripe_session(state, &reservation_id, &stripe_session_id, &url).await
{
if let Err(mark_err) = mark_checkout_status(state, &reservation_id, "failed").await {
warn!(
reservation_id,
"Failed to mark checkout reservation failed: {mark_err}"
);
}
if let Some(invite_id) = referral_invite_id.filter(|id| !id.is_empty()) {
if let Err(release_err) =
release_referral_invite_reservation(state, invite_id, &reservation_id).await
{
warn!(
reservation_id,
referral_invite_id = invite_id,
"Failed to release referral invite reservation: {release_err}"
);
}
}
return Err(err);
}
Ok(CheckoutStart::Stripe { url })
}
pub async fn verify_checkout_completion(
state: &AppState,
session: &Value,
) -> anyhow::Result<CheckoutCompletion> {
let session_id = match session["id"].as_str() {
Some(id) if is_safe_stripe_session_id(id) => id,
_ => {
return Ok(CheckoutCompletion::Rejected(
"missing or invalid session id".into(),
))
}
};
let payment_intent_id = match session["payment_intent"].as_str() {
Some(id) if is_safe_stripe_session_id(id) => id,
_ => {
return Ok(CheckoutCompletion::Rejected(
"missing or invalid payment intent id".into(),
))
}
};
let checkout = match find_checkout_by_stripe_session(state, session_id).await? {
Some(checkout) => checkout,
None => {
return Ok(CheckoutCompletion::Rejected(
"checkout session has no reservation".into(),
))
}
};
let already_completed = checkout.status == "completed";
if !already_completed && checkout.status != "pending" && checkout.status != "expired" {
return Ok(CheckoutCompletion::Rejected(format!(
"checkout reservation is {}",
checkout.status
)));
}
if checkout.stripe_session_id != session_id {
mark_checkout_status(state, &checkout.id, "invalid").await?;
return Ok(CheckoutCompletion::Rejected(
"checkout reservation session id mismatch".into(),
));
}
let client_reference_id = session["client_reference_id"].as_str().unwrap_or_default();
if client_reference_id != checkout.user_id {
mark_checkout_status(state, &checkout.id, "invalid").await?;
return Ok(CheckoutCompletion::Rejected(
"checkout client_reference_id mismatch".into(),
));
}
let payment_status = session["payment_status"].as_str().unwrap_or_default();
if payment_status != "paid" {
return Ok(CheckoutCompletion::Rejected(format!(
"checkout payment_status is {payment_status}"
)));
}
let currency = session["currency"]
.as_str()
.unwrap_or_default()
.to_ascii_lowercase();
if currency != checkout.currency {
mark_checkout_status(state, &checkout.id, "invalid").await?;
return Ok(CheckoutCompletion::Rejected(
"checkout currency mismatch".into(),
));
}
let amount_subtotal = match number_field(session, "amount_subtotal") {
Some(amount) => amount,
None => {
mark_checkout_status(state, &checkout.id, "invalid").await?;
return Ok(CheckoutCompletion::Rejected(
"checkout amount_subtotal missing".into(),
));
}
};
if amount_subtotal != checkout.amount_pence {
mark_checkout_status(state, &checkout.id, "invalid").await?;
return Ok(CheckoutCompletion::Rejected(
"checkout amount_subtotal mismatch".into(),
));
}
let amount_total = match number_field(session, "amount_total") {
Some(amount) => amount,
None => {
mark_checkout_status(state, &checkout.id, "invalid").await?;
return Ok(CheckoutCompletion::Rejected(
"checkout amount_total missing".into(),
));
}
};
if amount_total != checkout.expected_total_pence {
mark_checkout_status(state, &checkout.id, "invalid").await?;
return Ok(CheckoutCompletion::Rejected(
"checkout amount_total mismatch".into(),
));
}
let verified = VerifiedCheckout {
reservation_id: checkout.id,
user_id: checkout.user_id,
stripe_session_id: session_id.to_string(),
payment_intent_id: payment_intent_id.to_string(),
paid_amount_pence: amount_total,
referral_invite_id: checkout.referral_invite_id,
};
if already_completed {
Ok(CheckoutCompletion::AlreadyHandled(verified))
} else {
Ok(CheckoutCompletion::Grant(verified))
}
}
pub async fn complete_verified_checkout(
state: &AppState,
checkout: &VerifiedCheckout,
) -> anyhow::Result<()> {
let _guard = CHECKOUT_RESERVATION_LOCK.lock().await;
let pricing_lock = acquire_pocketbase_lock(
state,
CHECKOUT_PRICING_LOCK_NAME,
CHECKOUT_PRICING_LOCK_TTL_SECS,
)
.await?;
let result = complete_verified_checkout_locked(state, checkout).await;
if let Err(err) = pricing_lock.release().await {
warn!("Failed to release checkout pricing lock: {err}");
}
result
}
async fn complete_verified_checkout_locked(
state: &AppState,
checkout: &VerifiedCheckout,
) -> anyhow::Result<()> {
let live_checkout = find_checkout_by_stripe_session(state, &checkout.stripe_session_id)
.await?
.ok_or_else(|| anyhow!("checkout reservation disappeared before completion"))?;
if live_checkout.status == "completed" {
if !checkout.referral_invite_id.is_empty() {
mark_referral_invite_used(
state,
&checkout.referral_invite_id,
&checkout.user_id,
&checkout.reservation_id,
)
.await?;
}
return Ok(());
}
if live_checkout.id != checkout.reservation_id
|| live_checkout.user_id != checkout.user_id
|| live_checkout.referral_invite_id != checkout.referral_invite_id
{
mark_checkout_status(state, &checkout.reservation_id, "invalid").await?;
return Err(anyhow!("checkout reservation changed before completion"));
}
if live_checkout.status != "pending" && live_checkout.status != "expired" {
return Err(anyhow!("checkout reservation is {}", live_checkout.status));
}
grant_license(state, &checkout.user_id).await?;
mark_checkout_completed(
state,
&checkout.reservation_id,
checkout.paid_amount_pence,
&checkout.payment_intent_id,
)
.await?;
if !checkout.referral_invite_id.is_empty() {
mark_referral_invite_used(
state,
&checkout.referral_invite_id,
&checkout.user_id,
&checkout.reservation_id,
)
.await?;
}
Ok(())
}
pub async fn grant_license_with_pricing_lock(
state: &AppState,
user_id: &str,
) -> anyhow::Result<()> {
let _guard = CHECKOUT_RESERVATION_LOCK.lock().await;
let pricing_lock = acquire_pocketbase_lock(
state,
CHECKOUT_PRICING_LOCK_NAME,
CHECKOUT_PRICING_LOCK_TTL_SECS,
)
.await?;
let result = grant_license(state, user_id).await;
if let Err(err) = pricing_lock.release().await {
warn!("Failed to release checkout pricing lock: {err}");
}
result
}
pub async fn grant_license(state: &AppState, user_id: &str) -> anyhow::Result<()> {
set_user_subscription(state, user_id, "licensed").await
}
pub async fn reverse_license_for_payment_intent(
state: &AppState,
payment_intent_id: &str,
reason: &str,
refunded_amount_pence: Option<u64>,
) -> anyhow::Result<PaymentReversalOutcome> {
if !is_safe_stripe_session_id(payment_intent_id) {
return Err(anyhow!("invalid Stripe payment intent id"));
}
if !is_safe_reversal_reason(reason) {
return Err(anyhow!("invalid Stripe reversal reason"));
}
let _guard = CHECKOUT_RESERVATION_LOCK.lock().await;
let checkout = match find_checkout_by_payment_intent_or_checkout_session(
state,
payment_intent_id,
)
.await?
{
Some(checkout) => checkout,
None => return Ok(PaymentReversalOutcome::NoMatchingCheckout),
};
let paid_amount_pence = checkout
.paid_amount_pence
.max(checkout.expected_total_pence);
if let Some(refunded_amount_pence) = refunded_amount_pence {
if refunded_amount_pence < paid_amount_pence {
return Ok(PaymentReversalOutcome::IgnoredPartialRefund {
user_id: checkout.user_id,
refunded_amount_pence,
paid_amount_pence,
});
}
}
if checkout.status == "reversed" {
return Ok(PaymentReversalOutcome::AlreadyHandled {
user_id: checkout.user_id,
});
}
if matches!(checkout.status.as_str(), "pending" | "expired" | "failed") {
mark_checkout_reversed(state, &checkout.id, reason, payment_intent_id).await?;
return Ok(PaymentReversalOutcome::Applied {
user_id: checkout.user_id,
});
}
if checkout.status != "completed" {
return Ok(PaymentReversalOutcome::NotReversible {
user_id: checkout.user_id,
status: checkout.status,
});
}
let has_other_license = has_other_completed_checkout_for_user(
state,
&checkout.user_id,
&checkout.id,
payment_intent_id,
)
.await?;
if !has_other_license {
revoke_license(state, &checkout.user_id).await?;
}
mark_checkout_reversed(state, &checkout.id, reason, payment_intent_id).await?;
Ok(PaymentReversalOutcome::Applied {
user_id: checkout.user_id,
})
}
pub async fn reinstate_license_for_payment_intent(
state: &AppState,
payment_intent_id: &str,
reason: &str,
) -> anyhow::Result<PaymentReinstatementOutcome> {
if !is_safe_stripe_session_id(payment_intent_id) {
return Err(anyhow!("invalid Stripe payment intent id"));
}
if !is_safe_reversal_reason(reason) {
return Err(anyhow!("invalid Stripe reinstatement reason"));
}
let _guard = CHECKOUT_RESERVATION_LOCK.lock().await;
let checkout = match find_checkout_by_payment_intent_or_checkout_session(
state,
payment_intent_id,
)
.await?
{
Some(checkout) => checkout,
None => return Ok(PaymentReinstatementOutcome::NoMatchingCheckout),
};
if checkout.status == "completed" {
return Ok(PaymentReinstatementOutcome::AlreadyHandled {
user_id: checkout.user_id,
});
}
if checkout.status != "reversed" {
return Ok(PaymentReinstatementOutcome::Ignored {
user_id: checkout.user_id,
reason: format!("checkout status is {}", checkout.status),
});
}
if !checkout.reversal_reason.starts_with("charge.dispute.") {
return Ok(PaymentReinstatementOutcome::Ignored {
user_id: checkout.user_id,
reason: format!("checkout was reversed by {}", checkout.reversal_reason),
});
}
grant_license(state, &checkout.user_id).await?;
mark_checkout_reinstated(state, &checkout.id, reason).await?;
Ok(PaymentReinstatementOutcome::Applied {
user_id: checkout.user_id,
})
}
async fn revoke_license(state: &AppState, user_id: &str) -> anyhow::Result<()> {
set_user_subscription(state, user_id, "free").await
}
async fn set_user_subscription(
state: &AppState,
user_id: &str,
subscription: &str,
) -> anyhow::Result<()> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({ "subscription": subscription }))
.send()
.await?;
ensure_success(resp)
.await
.context("PocketBase license update failed")?;
state.token_cache.invalidate_by_user_id(user_id);
Ok(())
}
pub(super) fn expected_total_for_checkout(
amount_pence: u64,
discount_coupon_id: Option<&str>,
) -> u64 {
if discount_coupon_id.is_some_and(|id| !id.is_empty()) {
return ((amount_pence * (100 - REFERRAL_DISCOUNT_PERCENT)) / 100).max(1);
}
amount_pence
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn expected_total_for_referral_discount_rounds_down_like_stripe_amount_math() {
assert_eq!(expected_total_for_checkout(999, Some("coupon_30")), 699);
assert_eq!(expected_total_for_checkout(1, Some("coupon_30")), 1);
assert_eq!(expected_total_for_checkout(999, None), 999);
}
}

View file

@ -0,0 +1,133 @@
//! Checkout sessions: Stripe-backed lifetime-license purchases reserved and
//! recorded in the PocketBase `checkout_sessions` collection.
//!
//! Split by concern:
//! - [`lifecycle`]: the session state machine (start, verify, complete,
//! reverse, reinstate) and license granting
//! - [`records`]: PocketBase `checkout_sessions` record handling
//! - [`referral`]: referral invite reservation/consumption bookkeeping
//! - [`stripe`]: Stripe API interaction (sessions, coupons, lookups)
mod lifecycle;
mod records;
mod referral;
mod stripe;
#[cfg(test)]
mod tests;
pub use lifecycle::{
complete_verified_checkout, grant_license_with_pricing_lock,
reinstate_license_for_payment_intent, reverse_license_for_payment_intent,
start_license_checkout, verify_checkout_completion,
};
pub use referral::active_referral_checkout_user;
use std::time::{SystemTime, UNIX_EPOCH};
use anyhow::anyhow;
use serde_json::Value;
pub const CHECKOUT_CURRENCY: &str = "gbp";
const CHECKOUT_COLLECTION: &str = "checkout_sessions";
const REFERRAL_DISCOUNT_PERCENT: u64 = 30;
pub enum CheckoutStart {
Free,
Stripe { url: String },
}
pub enum CheckoutCompletion {
Grant(VerifiedCheckout),
AlreadyHandled(VerifiedCheckout),
Rejected(String),
}
pub enum PaymentReversalOutcome {
Applied {
user_id: String,
},
AlreadyHandled {
user_id: String,
},
IgnoredPartialRefund {
user_id: String,
refunded_amount_pence: u64,
paid_amount_pence: u64,
},
NoMatchingCheckout,
NotReversible {
user_id: String,
status: String,
},
}
pub enum PaymentReinstatementOutcome {
Applied { user_id: String },
AlreadyHandled { user_id: String },
Ignored { user_id: String, reason: String },
NoMatchingCheckout,
}
pub struct VerifiedCheckout {
pub reservation_id: String,
pub user_id: String,
pub stripe_session_id: String,
pub payment_intent_id: String,
pub paid_amount_pence: u64,
pub referral_invite_id: String,
}
pub fn now_unix_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
fn number_field(value: &Value, field: &str) -> Option<u64> {
value[field].as_u64().or_else(|| {
value[field]
.as_f64()
.filter(|n| n.is_finite() && *n >= 0.0 && n.fract() == 0.0)
.map(|n| n as u64)
})
}
fn is_safe_stripe_session_id(id: &str) -> bool {
!id.is_empty()
&& id.len() <= 128
&& id
.bytes()
.all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'-')
}
fn is_safe_pocketbase_id(id: &str) -> bool {
!id.is_empty() && id.len() <= 32 && id.bytes().all(|b| b.is_ascii_alphanumeric())
}
fn is_safe_reversal_reason(reason: &str) -> bool {
!reason.is_empty()
&& reason.len() <= 128
&& reason
.bytes()
.all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'-' || b == b'.')
}
async fn ensure_success(resp: reqwest::Response) -> anyhow::Result<()> {
if resp.status().is_success() {
return Ok(());
}
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
Err(anyhow!("upstream returned {status}: {text}"))
}
async fn ensure_success_ref(resp: &reqwest::Response) -> anyhow::Result<()> {
if resp.status().is_success() {
return Ok(());
}
Err(anyhow!("upstream returned {}", resp.status()))
}

View file

@ -0,0 +1,564 @@
//! PocketBase `checkout_sessions` record handling: creating reservations,
//! status transitions, and lookups by Stripe session / payment intent.
use anyhow::{anyhow, Context};
use serde_json::Value;
use tracing::warn;
use crate::pocketbase::get_superuser_token;
use crate::state::AppState;
use super::referral::release_referral_invite_reservation;
use super::stripe::fetch_stripe_checkout_session_id_for_payment_intent;
use super::{
ensure_success, ensure_success_ref, is_safe_pocketbase_id, is_safe_stripe_session_id,
now_unix_secs, number_field, CHECKOUT_COLLECTION,
};
#[derive(Debug)]
pub(super) struct PendingCheckout {
pub(super) id: String,
pub(super) user_id: String,
pub(super) stripe_session_id: String,
pub(super) checkout_url: String,
pub(super) amount_pence: u64,
pub(super) expected_total_pence: u64,
pub(super) currency: String,
pub(super) referral_invite_id: String,
pub(super) status: String,
pub(super) payment_intent_id: String,
pub(super) paid_amount_pence: u64,
pub(super) reversal_reason: String,
}
pub async fn mark_checkout_completed(
state: &AppState,
reservation_id: &str,
paid_amount_pence: u64,
payment_intent_id: &str,
) -> anyhow::Result<()> {
if !is_safe_stripe_session_id(payment_intent_id) {
return Err(anyhow!("invalid Stripe payment intent id"));
}
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"status": "completed",
"paid_amount_pence": paid_amount_pence,
"completed_at_unix": now_unix_secs().to_string(),
"stripe_payment_intent_id": payment_intent_id,
}))
.send()
.await?;
ensure_success(resp)
.await
.context("PocketBase checkout completion update failed")
}
pub(super) async fn count_active_pending_checkouts(
state: &AppState,
now: u64,
) -> anyhow::Result<u64> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let filter = format!("status=\"pending\" && expires_at_unix>={now}");
let url = format!(
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
urlencoding::encode(&filter)
);
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
ensure_success_ref(&resp).await?;
let body: Value = resp.json().await?;
Ok(body["totalItems"].as_u64().unwrap_or(0))
}
pub(super) async fn find_active_checkout_for_user(
state: &AppState,
user_id: &str,
discount_coupon_id: &str,
referral_invite_id: &str,
now: u64,
) -> anyhow::Result<Option<PendingCheckout>> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let filter = active_checkout_filter(user_id, discount_coupon_id, referral_invite_id, now)?;
let url = format!(
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
urlencoding::encode(&filter)
);
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
ensure_success_ref(&resp).await?;
let body: Value = resp.json().await?;
let item = body["items"]
.as_array()
.and_then(|items| items.first())
.cloned();
item.map(parse_pending_checkout).transpose()
}
fn active_checkout_filter(
user_id: &str,
discount_coupon_id: &str,
referral_invite_id: &str,
now: u64,
) -> anyhow::Result<String> {
if !is_safe_pocketbase_id(user_id) {
return Err(anyhow!("invalid PocketBase user id"));
}
if !discount_coupon_id.is_empty() && !is_safe_stripe_session_id(discount_coupon_id) {
return Err(anyhow!("invalid Stripe coupon id"));
}
if !referral_invite_id.is_empty() && !is_safe_pocketbase_id(referral_invite_id) {
return Err(anyhow!("invalid PocketBase referral invite id"));
}
Ok(format!(
"status=\"pending\" && expires_at_unix>={now} && user=\"{user_id}\" && discount_coupon_id=\"{discount_coupon_id}\" && referral_invite_id=\"{referral_invite_id}\""
))
}
pub(super) async fn expire_stale_pending_checkouts(
state: &AppState,
now: u64,
) -> anyhow::Result<()> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let filter = format!("status=\"pending\" && expires_at_unix<{now}");
let url = format!(
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=50",
urlencoding::encode(&filter)
);
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
ensure_success_ref(&resp).await?;
let body: Value = resp.json().await?;
let Some(items) = body["items"].as_array() else {
return Ok(());
};
for item in items {
let Some(id) = item["id"].as_str() else {
continue;
};
if let Err(err) = mark_checkout_status(state, id, "expired").await {
warn!(
reservation_id = id,
"Failed to expire checkout reservation: {err}"
);
}
if let Some(invite_id) = item["referral_invite_id"]
.as_str()
.filter(|invite_id| !invite_id.is_empty())
{
if let Err(err) = release_referral_invite_reservation(state, invite_id, id).await {
warn!(
reservation_id = id,
referral_invite_id = invite_id,
"Failed to release expired referral invite reservation: {err}"
);
}
}
}
Ok(())
}
pub(super) struct PendingCheckoutInput<'a> {
pub(super) user_id: &'a str,
pub(super) amount_pence: u64,
pub(super) expected_total_pence: u64,
pub(super) currency: &'a str,
pub(super) discount_coupon_id: &'a str,
pub(super) referral_invite_id: &'a str,
pub(super) expires_at_unix: u64,
}
pub(super) async fn create_pending_checkout(
state: &AppState,
input: PendingCheckoutInput<'_>,
) -> anyhow::Result<String> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records");
let resp = state
.http_client
.post(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"user": input.user_id,
"stripe_session_id": "",
"stripe_payment_intent_id": "",
"checkout_url": "",
"amount_pence": input.amount_pence,
"expected_total_pence": input.expected_total_pence,
"currency": input.currency,
"discount_coupon_id": input.discount_coupon_id,
"referral_invite_id": input.referral_invite_id,
"status": "pending",
"expires_at_unix": input.expires_at_unix,
"paid_amount_pence": 0,
"completed_at_unix": "",
"reversal_reason": "",
}))
.send()
.await?;
ensure_success_ref(&resp).await?;
let body: Value = resp.json().await?;
body["id"]
.as_str()
.map(str::to_string)
.ok_or_else(|| anyhow!("PocketBase checkout reservation missing id"))
}
pub(super) async fn attach_stripe_session(
state: &AppState,
reservation_id: &str,
stripe_session_id: &str,
checkout_url: &str,
) -> anyhow::Result<()> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"stripe_session_id": stripe_session_id,
"checkout_url": checkout_url,
}))
.send()
.await?;
ensure_success(resp)
.await
.context("PocketBase checkout session attach failed")
}
pub(super) async fn mark_checkout_status(
state: &AppState,
reservation_id: &str,
status: &str,
) -> anyhow::Result<()> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({ "status": status }))
.send()
.await?;
ensure_success(resp)
.await
.with_context(|| format!("PocketBase checkout status update failed for {reservation_id}"))
}
pub(super) async fn mark_checkout_reversed(
state: &AppState,
reservation_id: &str,
reason: &str,
payment_intent_id: &str,
) -> anyhow::Result<()> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"status": "reversed",
"reversal_reason": reason,
"stripe_payment_intent_id": payment_intent_id,
}))
.send()
.await?;
ensure_success(resp)
.await
.with_context(|| format!("PocketBase checkout reversal update failed for {reservation_id}"))
}
pub(super) async fn mark_checkout_reinstated(
state: &AppState,
reservation_id: &str,
_reason: &str,
) -> anyhow::Result<()> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"status": "completed",
"reversal_reason": "",
}))
.send()
.await?;
ensure_success(resp).await.with_context(|| {
format!("PocketBase checkout reinstatement update failed for {reservation_id}")
})
}
pub(super) async fn find_checkout_by_stripe_session(
state: &AppState,
stripe_session_id: &str,
) -> anyhow::Result<Option<PendingCheckout>> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let filter = format!("stripe_session_id=\"{}\"", stripe_session_id);
let url = format!(
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
urlencoding::encode(&filter)
);
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
ensure_success_ref(&resp).await?;
let body: Value = resp.json().await?;
let item = body["items"]
.as_array()
.and_then(|items| items.first())
.cloned();
item.map(parse_pending_checkout).transpose()
}
async fn find_checkout_by_payment_intent(
state: &AppState,
payment_intent_id: &str,
) -> anyhow::Result<Option<PendingCheckout>> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let filter = format!("stripe_payment_intent_id=\"{}\"", payment_intent_id);
let url = format!(
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
urlencoding::encode(&filter)
);
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
ensure_success_ref(&resp).await?;
let body: Value = resp.json().await?;
let item = body["items"]
.as_array()
.and_then(|items| items.first())
.cloned();
item.map(parse_pending_checkout).transpose()
}
pub(super) async fn find_checkout_by_payment_intent_or_checkout_session(
state: &AppState,
payment_intent_id: &str,
) -> anyhow::Result<Option<PendingCheckout>> {
if let Some(checkout) = find_checkout_by_payment_intent(state, payment_intent_id).await? {
return Ok(Some(checkout));
}
let Some(session_id) =
fetch_stripe_checkout_session_id_for_payment_intent(state, payment_intent_id).await?
else {
return Ok(None);
};
let Some(mut checkout) = find_checkout_by_stripe_session(state, &session_id).await? else {
return Ok(None);
};
if checkout.payment_intent_id.is_empty() {
attach_payment_intent_to_checkout(state, &checkout.id, payment_intent_id).await?;
checkout.payment_intent_id = payment_intent_id.to_string();
} else if checkout.payment_intent_id != payment_intent_id {
mark_checkout_status(state, &checkout.id, "invalid").await?;
return Err(anyhow!(
"checkout reservation payment intent changed before reversal"
));
}
Ok(Some(checkout))
}
async fn attach_payment_intent_to_checkout(
state: &AppState,
reservation_id: &str,
payment_intent_id: &str,
) -> anyhow::Result<()> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"stripe_payment_intent_id": payment_intent_id,
}))
.send()
.await?;
ensure_success(resp)
.await
.context("PocketBase checkout payment intent attach failed")
}
pub(super) async fn has_other_completed_checkout_for_user(
state: &AppState,
user_id: &str,
reservation_id: &str,
payment_intent_id: &str,
) -> anyhow::Result<bool> {
if !is_safe_pocketbase_id(user_id) || !is_safe_pocketbase_id(reservation_id) {
return Err(anyhow!("invalid PocketBase id"));
}
if !is_safe_stripe_session_id(payment_intent_id) {
return Err(anyhow!("invalid Stripe payment intent id"));
}
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let filter = format!("user=\"{user_id}\" && status=\"completed\"");
let url = format!(
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=50",
urlencoding::encode(&filter)
);
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
ensure_success_ref(&resp).await?;
let body: Value = resp.json().await?;
let Some(items) = body["items"].as_array() else {
return Ok(false);
};
Ok(items.iter().any(|item| {
let other_id = item["id"].as_str().unwrap_or_default();
let other_payment_intent = item["stripe_payment_intent_id"]
.as_str()
.unwrap_or_default();
other_id != reservation_id && other_payment_intent != payment_intent_id
}))
}
fn parse_pending_checkout(item: Value) -> anyhow::Result<PendingCheckout> {
Ok(PendingCheckout {
id: item["id"]
.as_str()
.ok_or_else(|| anyhow!("checkout reservation missing id"))?
.to_string(),
user_id: item["user"]
.as_str()
.ok_or_else(|| anyhow!("checkout reservation missing user"))?
.to_string(),
stripe_session_id: item["stripe_session_id"]
.as_str()
.unwrap_or_default()
.to_string(),
checkout_url: item["checkout_url"]
.as_str()
.unwrap_or_default()
.to_string(),
amount_pence: number_field(&item, "amount_pence")
.ok_or_else(|| anyhow!("checkout reservation missing amount_pence"))?,
expected_total_pence: number_field(&item, "expected_total_pence")
.ok_or_else(|| anyhow!("checkout reservation missing expected_total_pence"))?,
currency: item["currency"]
.as_str()
.unwrap_or_default()
.to_ascii_lowercase(),
referral_invite_id: item["referral_invite_id"]
.as_str()
.unwrap_or_default()
.to_string(),
status: item["status"].as_str().unwrap_or_default().to_string(),
payment_intent_id: item["stripe_payment_intent_id"]
.as_str()
.unwrap_or_default()
.to_string(),
paid_amount_pence: number_field(&item, "paid_amount_pence").unwrap_or(0),
reversal_reason: item["reversal_reason"]
.as_str()
.unwrap_or_default()
.to_string(),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn active_checkout_filter_includes_empty_context_for_standard_checkout() {
let filter = active_checkout_filter("abc123", "", "", 42).unwrap();
assert_eq!(
filter,
"status=\"pending\" && expires_at_unix>=42 && user=\"abc123\" && discount_coupon_id=\"\" && referral_invite_id=\"\""
);
}
#[test]
fn active_checkout_filter_includes_referral_context() {
let filter = active_checkout_filter("user123", "coupon_30", "invite123", 99).unwrap();
assert_eq!(
filter,
"status=\"pending\" && expires_at_unix>=99 && user=\"user123\" && discount_coupon_id=\"coupon_30\" && referral_invite_id=\"invite123\""
);
}
#[test]
fn active_checkout_filter_rejects_unsafe_context_values() {
assert!(active_checkout_filter("user123", "bad\"coupon", "", 1).is_err());
assert!(active_checkout_filter("user123", "", "bad-invite", 1).is_err());
assert!(active_checkout_filter("bad-user", "", "", 1).is_err());
}
}

View file

@ -0,0 +1,312 @@
//! Referral invite bookkeeping: reserving an invite for an in-flight checkout,
//! releasing the reservation on failure/expiry, and recording final usage when
//! a verified payment completes.
use anyhow::{anyhow, Context};
use serde_json::Value;
use tracing::warn;
use crate::pocketbase::get_superuser_token;
use crate::state::AppState;
use super::{
ensure_success, ensure_success_ref, is_safe_pocketbase_id, now_unix_secs, number_field,
CHECKOUT_COLLECTION,
};
pub async fn mark_referral_invite_used(
state: &AppState,
invite_id: &str,
user_id: &str,
reservation_id: &str,
) -> anyhow::Result<()> {
if invite_id.is_empty() {
return Ok(());
}
if !is_safe_pocketbase_id(invite_id)
|| !is_safe_pocketbase_id(user_id)
|| !is_safe_pocketbase_id(reservation_id)
{
return Err(anyhow!("invalid PocketBase id"));
}
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let invite = fetch_invite_record(state, pb_url, &token, invite_id).await?;
// A verified Stripe payment must not lose entitlement just because local
// invite reservation bookkeeping expired or moved before webhook delivery.
match referral_invite_completion_action(&invite, user_id, reservation_id) {
ReferralInviteCompletionAction::AlreadyRecorded => return Ok(()),
ReferralInviteCompletionAction::AlreadyUsedByAnother => {
warn!(
invite_id,
user_id,
existing_used_by = invite["used_by_id"].as_str().unwrap_or_default(),
"Referral invite was already used by another account; preserving verified checkout entitlement"
);
return Ok(());
}
ReferralInviteCompletionAction::Record {
reservation_reassigned,
} => {
if reservation_reassigned {
warn!(
invite_id,
user_id,
reservation_id,
reserved_by_id = invite["reserved_by_id"].as_str().unwrap_or_default(),
reserved_checkout_id = invite["reserved_checkout_id"].as_str().unwrap_or_default(),
"Referral invite reservation moved before webhook completion; verified checkout will consume it"
);
}
}
}
let url = format!("{pb_url}/api/collections/invites/records/{invite_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"used_by_id": user_id,
"used_at": now_unix_secs().to_string(),
"reserved_by_id": "",
"reserved_checkout_id": "",
"reserved_until_unix": 0,
}))
.send()
.await?;
ensure_success(resp)
.await
.context("PocketBase invite usage update failed")
}
#[derive(Debug, PartialEq, Eq)]
enum ReferralInviteCompletionAction {
AlreadyRecorded,
AlreadyUsedByAnother,
Record { reservation_reassigned: bool },
}
fn referral_invite_completion_action(
invite: &Value,
user_id: &str,
reservation_id: &str,
) -> ReferralInviteCompletionAction {
let existing_used_by = invite["used_by_id"].as_str().unwrap_or_default();
if existing_used_by == user_id {
return ReferralInviteCompletionAction::AlreadyRecorded;
}
if !existing_used_by.is_empty() {
return ReferralInviteCompletionAction::AlreadyUsedByAnother;
}
let reserved_by_id = invite["reserved_by_id"].as_str().unwrap_or_default();
let reserved_checkout_id = invite["reserved_checkout_id"].as_str().unwrap_or_default();
let reservation_reassigned = (!reserved_by_id.is_empty() && reserved_by_id != user_id)
|| (!reserved_checkout_id.is_empty() && reserved_checkout_id != reservation_id);
ReferralInviteCompletionAction::Record {
reservation_reassigned,
}
}
async fn fetch_invite_record(
state: &AppState,
pb_url: &str,
token: &str,
invite_id: &str,
) -> anyhow::Result<Value> {
let url = format!("{pb_url}/api/collections/invites/records/{invite_id}");
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
ensure_success_ref(&resp).await?;
resp.json().await.map_err(Into::into)
}
pub(super) async fn reserve_referral_invite(
state: &AppState,
invite_id: &str,
user_id: &str,
reservation_id: &str,
reserved_until_unix: u64,
) -> anyhow::Result<()> {
if !is_safe_pocketbase_id(invite_id)
|| !is_safe_pocketbase_id(user_id)
|| !is_safe_pocketbase_id(reservation_id)
{
return Err(anyhow!("invalid PocketBase id"));
}
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let invite = fetch_invite_record(state, pb_url, &token, invite_id).await?;
let used_by = invite["used_by_id"].as_str().unwrap_or_default();
if !used_by.is_empty() {
return Err(anyhow!("referral invite already used"));
}
let now = now_unix_secs();
let reserved_by_id = invite["reserved_by_id"].as_str().unwrap_or_default();
let reserved_checkout_id = invite["reserved_checkout_id"].as_str().unwrap_or_default();
let existing_reserved_until = number_field(&invite, "reserved_until_unix").unwrap_or(0);
let reservation_is_live = existing_reserved_until >= now;
if reservation_is_live
&& !reserved_checkout_id.is_empty()
&& reserved_checkout_id != reservation_id
{
return Err(anyhow!("referral invite already has an active checkout"));
}
if reservation_is_live && !reserved_by_id.is_empty() && reserved_by_id != user_id {
return Err(anyhow!("referral invite reserved by another account"));
}
let url = format!("{pb_url}/api/collections/invites/records/{invite_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"reserved_by_id": user_id,
"reserved_checkout_id": reservation_id,
"reserved_until_unix": reserved_until_unix,
}))
.send()
.await?;
ensure_success(resp)
.await
.context("PocketBase invite reservation update failed")
}
pub(super) async fn release_referral_invite_reservation(
state: &AppState,
invite_id: &str,
reservation_id: &str,
) -> anyhow::Result<()> {
if !is_safe_pocketbase_id(invite_id) || !is_safe_pocketbase_id(reservation_id) {
return Err(anyhow!("invalid PocketBase id"));
}
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let invite = fetch_invite_record(state, pb_url, &token, invite_id).await?;
let used_by = invite["used_by_id"].as_str().unwrap_or_default();
let reserved_checkout_id = invite["reserved_checkout_id"].as_str().unwrap_or_default();
if !used_by.is_empty() || reserved_checkout_id != reservation_id {
return Ok(());
}
let url = format!("{pb_url}/api/collections/invites/records/{invite_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"reserved_by_id": "",
"reserved_checkout_id": "",
"reserved_until_unix": 0,
}))
.send()
.await?;
ensure_success(resp)
.await
.context("PocketBase invite reservation release failed")
}
pub async fn active_referral_checkout_user(
state: &AppState,
invite_id: &str,
) -> anyhow::Result<Option<String>> {
if !is_safe_pocketbase_id(invite_id) {
return Err(anyhow!("invalid PocketBase invite id"));
}
let now = now_unix_secs();
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let filter = format!(
"status=\"pending\" && expires_at_unix>={now} && referral_invite_id=\"{}\"",
invite_id
);
let url = format!(
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
urlencoding::encode(&filter)
);
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
ensure_success_ref(&resp).await?;
let body: Value = resp.json().await?;
Ok(body["items"]
.as_array()
.and_then(|items| items.first())
.and_then(|item| item["user"].as_str())
.map(str::to_string))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn referral_invite_completion_records_available_invite() {
let invite = serde_json::json!({
"used_by_id": "",
"reserved_by_id": "",
"reserved_checkout_id": "",
});
assert_eq!(
referral_invite_completion_action(&invite, "user123", "checkout123"),
ReferralInviteCompletionAction::Record {
reservation_reassigned: false
}
);
}
#[test]
fn referral_invite_completion_records_reassigned_reservation() {
let invite = serde_json::json!({
"used_by_id": "",
"reserved_by_id": "otheruser",
"reserved_checkout_id": "othercheckout",
});
assert_eq!(
referral_invite_completion_action(&invite, "user123", "checkout123"),
ReferralInviteCompletionAction::Record {
reservation_reassigned: true
}
);
}
#[test]
fn referral_invite_completion_detects_existing_usage() {
let used_by_same_user = serde_json::json!({ "used_by_id": "user123" });
let used_by_another_user = serde_json::json!({ "used_by_id": "otheruser" });
assert_eq!(
referral_invite_completion_action(&used_by_same_user, "user123", "checkout123"),
ReferralInviteCompletionAction::AlreadyRecorded
);
assert_eq!(
referral_invite_completion_action(&used_by_another_user, "user123", "checkout123"),
ReferralInviteCompletionAction::AlreadyUsedByAnother
);
}
}

View file

@ -0,0 +1,175 @@
//! Stripe API interaction: creating checkout sessions, verifying coupon
//! configuration, and looking up sessions by payment intent.
use anyhow::{anyhow, Context};
use serde_json::Value;
use crate::auth::PocketBaseUser;
use crate::state::AppState;
use super::lifecycle::expected_total_for_checkout;
use super::{
ensure_success_ref, is_safe_stripe_session_id, CHECKOUT_CURRENCY, REFERRAL_DISCOUNT_PERCENT,
};
const CHECKOUT_PRODUCT_NAME: &str = "Perfect Postcodes Lifetime License";
/// Fetch a Stripe coupon and ensure its `percent_off` matches the expected
/// referral discount AND that it has no `amount_off` override. This blocks a
/// misconfigured (or maliciously swapped) coupon ID from quietly granting a
/// larger discount than the server's pricing math assumed.
async fn verify_stripe_coupon_discount(state: &AppState, coupon_id: &str) -> anyhow::Result<()> {
if !is_safe_stripe_session_id(coupon_id) {
return Err(anyhow!("unsafe stripe coupon id"));
}
let url = format!(
"https://api.stripe.com/v1/coupons/{}",
urlencoding::encode(coupon_id)
);
let resp = state
.http_client
.get(&url)
.basic_auth(&state.stripe_secret_key, None::<&str>)
.send()
.await
.context("Stripe coupon fetch failed")?;
ensure_success_ref(&resp)
.await
.context("Stripe coupon fetch returned error")?;
let body: Value = resp
.json()
.await
.context("Failed to parse Stripe coupon response")?;
let valid = body["valid"].as_bool().unwrap_or(false);
if !valid {
return Err(anyhow!("stripe coupon is not valid"));
}
if body["amount_off"].is_number() {
return Err(anyhow!(
"stripe coupon uses amount_off; only percent_off is permitted"
));
}
let percent_off = body["percent_off"]
.as_f64()
.ok_or_else(|| anyhow!("stripe coupon missing percent_off"))?;
if percent_off.is_nan() || (percent_off - REFERRAL_DISCOUNT_PERCENT as f64).abs() > 0.001 {
return Err(anyhow!(
"stripe coupon percent_off ({percent_off}) does not match expected {REFERRAL_DISCOUNT_PERCENT}"
));
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub(super) async fn create_stripe_session(
state: &AppState,
user: &PocketBaseUser,
reservation_id: &str,
price_pence: u64,
success_url: &str,
cancel_url: &str,
expires_at_unix: u64,
discount_coupon_id: Option<&str>,
) -> anyhow::Result<(String, String)> {
if let Some(coupon_id) = discount_coupon_id.filter(|id| !id.is_empty()) {
verify_stripe_coupon_discount(state, coupon_id).await?;
}
let mut form_params = vec![
("mode", "payment".to_string()),
("payment_method_types[0]", "card".to_string()),
(
"line_items[0][price_data][unit_amount]",
price_pence.to_string(),
),
(
"line_items[0][price_data][currency]",
CHECKOUT_CURRENCY.to_string(),
),
(
"line_items[0][price_data][product_data][name]",
CHECKOUT_PRODUCT_NAME.to_string(),
),
("line_items[0][quantity]", "1".to_string()),
("success_url", success_url.to_string()),
("cancel_url", cancel_url.to_string()),
("expires_at", expires_at_unix.to_string()),
("client_reference_id", user.id.clone()),
("customer_email", user.email.clone()),
("metadata[pending_checkout_id]", reservation_id.to_string()),
("metadata[expected_amount_pence]", price_pence.to_string()),
(
"metadata[expected_total_pence]",
expected_total_for_checkout(price_pence, discount_coupon_id).to_string(),
),
("metadata[expected_currency]", CHECKOUT_CURRENCY.to_string()),
];
if let Some(coupon_id) = discount_coupon_id.filter(|id| !id.is_empty()) {
form_params.push(("discounts[0][coupon]", coupon_id.to_string()));
form_params.push(("metadata[discount_coupon_id]", coupon_id.to_string()));
}
let resp = state
.http_client
.post("https://api.stripe.com/v1/checkout/sessions")
.basic_auth(&state.stripe_secret_key, None::<&str>)
.form(&form_params)
.send()
.await
.context("Stripe checkout request failed")?;
ensure_success_ref(&resp)
.await
.context("Stripe checkout failed")?;
let body: Value = resp
.json()
.await
.context("Failed to parse Stripe response")?;
let session_id = body["id"]
.as_str()
.filter(|id| is_safe_stripe_session_id(id))
.map(str::to_string)
.ok_or_else(|| anyhow!("Stripe session missing valid id"))?;
let url = body["url"]
.as_str()
.map(str::to_string)
.filter(|url| !url.is_empty())
.ok_or_else(|| anyhow!("Stripe session missing URL"))?;
Ok((session_id, url))
}
pub(super) async fn fetch_stripe_checkout_session_id_for_payment_intent(
state: &AppState,
payment_intent_id: &str,
) -> anyhow::Result<Option<String>> {
let url = format!(
"https://api.stripe.com/v1/checkout/sessions?payment_intent={}&limit=1",
urlencoding::encode(payment_intent_id)
);
let resp = state
.http_client
.get(&url)
.basic_auth(&state.stripe_secret_key, None::<&str>)
.send()
.await
.context("Stripe checkout session lookup failed")?;
ensure_success_ref(&resp)
.await
.context("Stripe checkout session lookup returned error")?;
let body: Value = resp
.json()
.await
.context("Failed to parse Stripe checkout session lookup")?;
Ok(body["data"]
.as_array()
.and_then(|items| items.first())
.and_then(|item| item["id"].as_str())
.filter(|id| is_safe_stripe_session_id(id))
.map(str::to_string))
}

View file

@ -0,0 +1,688 @@
//! Integration-style tests for the money paths: Stripe webhook verification →
//! license granting, checkout reservation bookkeeping, and invite redemption.
//!
//! PocketBase (an external HTTP service in production) is replaced by an
//! in-process axum mock listening on an ephemeral local port. The mock keeps
//! records in memory, evaluates the small PocketBase filter subset the server
//! actually uses, and records every mutating request so tests can assert that
//! e.g. a replayed webhook does not grant a license twice.
//!
//! Stripe itself is NOT mocked (its API URL is hardcoded to
//! `https://api.stripe.com`), so tests that reach the Stripe call assert the
//! failure-cleanup behaviour instead: the reservation is marked `failed` and
//! referral invite reservations are released.
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{SystemTime, UNIX_EPOCH};
use axum::body::Bytes;
use axum::extract::{Path, Query, State};
use axum::http::{HeaderMap, HeaderValue, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use axum::{Extension, Json, Router};
use hmac::{Hmac, KeyInit, Mac};
use serde_json::{json, Value};
use sha2::Sha256;
use crate::auth::{OptionalUser, PocketBaseUser};
use crate::routes::{post_redeem_invite, post_stripe_webhook};
use crate::state::{AppState, SharedState};
use super::start_license_checkout;
// ---------------------------------------------------------------------------
// Mock PocketBase
// ---------------------------------------------------------------------------
#[derive(Default)]
struct MockPocketBase {
/// collection name → records (each a JSON object with an "id").
records: Mutex<HashMap<String, Vec<Value>>>,
next_id: AtomicUsize,
/// Every mutating request: (method, path, body).
log: Mutex<Vec<(String, String, Value)>>,
}
impl MockPocketBase {
fn seed(&self, collection: &str, mut record: Value) -> String {
let id = match record["id"].as_str() {
Some(id) if !id.is_empty() => id.to_string(),
_ => format!("mockid{:06}", self.next_id.fetch_add(1, Ordering::SeqCst)),
};
record["id"] = json!(id);
self.records
.lock()
.unwrap()
.entry(collection.to_string())
.or_default()
.push(record);
id
}
fn record(&self, collection: &str, id: &str) -> Option<Value> {
self.records
.lock()
.unwrap()
.get(collection)?
.iter()
.find(|record| record["id"].as_str() == Some(id))
.cloned()
}
fn records_in(&self, collection: &str) -> Vec<Value> {
self.records
.lock()
.unwrap()
.get(collection)
.cloned()
.unwrap_or_default()
}
/// Number of PATCHes that set the user's subscription to "licensed" —
/// i.e. how many times a license was granted.
fn license_grant_count(&self, user_id: &str) -> usize {
let path = format!("/api/collections/users/records/{user_id}");
self.log
.lock()
.unwrap()
.iter()
.filter(|(method, request_path, body)| {
method == "PATCH"
&& request_path == &path
&& body["subscription"].as_str() == Some("licensed")
})
.count()
}
}
/// Evaluate the PocketBase filter subset used by the server:
/// `a="x" && b>=N && c<N && (d="" || d="y")`.
fn record_matches(record: &Value, filter: &str) -> bool {
if filter.is_empty() {
return true;
}
filter
.split(" && ")
.all(|clause| clause_matches(record, clause.trim()))
}
fn clause_matches(record: &Value, clause: &str) -> bool {
if let Some(inner) = clause.strip_prefix('(').and_then(|c| c.strip_suffix(')')) {
return inner
.split(" || ")
.any(|alternative| clause_matches(record, alternative.trim()));
}
if let Some((field, value)) = clause.split_once(">=") {
return record_number(record, field) >= value.trim().parse::<i64>().unwrap_or(i64::MAX);
}
if let Some((field, value)) = clause.split_once('<') {
return record_number(record, field) < value.trim().parse::<i64>().unwrap_or(i64::MIN);
}
if let Some((field, value)) = clause.split_once('=') {
let expected = value.trim().trim_matches('"');
return record_string(record, field) == expected;
}
panic!("mock PocketBase cannot evaluate filter clause: {clause}");
}
fn record_number(record: &Value, field: &str) -> i64 {
let value = &record[field];
value
.as_i64()
.or_else(|| value.as_f64().map(|float| float as i64))
.or_else(|| value.as_str().and_then(|text| text.parse().ok()))
.unwrap_or(0)
}
fn record_string(record: &Value, field: &str) -> String {
let value = &record[field];
value.as_str().map(str::to_string).unwrap_or_else(|| {
if value.is_null() {
String::new()
} else {
value.to_string()
}
})
}
async fn auth_handler() -> Json<Value> {
Json(json!({ "token": "testsuperusertoken" }))
}
async fn list_records(
State(pb): State<Arc<MockPocketBase>>,
Path(collection): Path<String>,
Query(params): Query<HashMap<String, String>>,
) -> Json<Value> {
let filter = params.get("filter").map(String::as_str).unwrap_or("");
let matching: Vec<Value> = pb
.records
.lock()
.unwrap()
.get(&collection)
.map(|records| {
records
.iter()
.filter(|record| record_matches(record, filter))
.cloned()
.collect()
})
.unwrap_or_default();
let total = matching.len();
let per_page = params
.get("perPage")
.and_then(|raw| raw.parse::<usize>().ok())
.unwrap_or(30);
let items: Vec<Value> = matching.into_iter().take(per_page).collect();
Json(json!({ "items": items, "totalItems": total }))
}
async fn create_record(
State(pb): State<Arc<MockPocketBase>>,
Path(collection): Path<String>,
Json(body): Json<Value>,
) -> Response {
pb.log.lock().unwrap().push((
"POST".to_string(),
format!("/api/collections/{collection}/records"),
body.clone(),
));
// Emulate the unique `name` constraint on the distributed-lock collection
// so concurrent acquisitions conflict like they do against real PocketBase.
if collection == "checkout_locks" {
let exists = pb
.records
.lock()
.unwrap()
.get(&collection)
.is_some_and(|records| records.iter().any(|record| record["name"] == body["name"]));
if exists {
return (
StatusCode::BAD_REQUEST,
Json(json!({ "message": "name must be unique" })),
)
.into_response();
}
}
let id = pb.seed(&collection, body);
Json(pb.record(&collection, &id).expect("record just created")).into_response()
}
async fn get_record(
State(pb): State<Arc<MockPocketBase>>,
Path((collection, id)): Path<(String, String)>,
) -> Response {
match pb.record(&collection, &id) {
Some(record) => Json(record).into_response(),
None => StatusCode::NOT_FOUND.into_response(),
}
}
async fn patch_record(
State(pb): State<Arc<MockPocketBase>>,
Path((collection, id)): Path<(String, String)>,
Json(body): Json<Value>,
) -> Response {
pb.log.lock().unwrap().push((
"PATCH".to_string(),
format!("/api/collections/{collection}/records/{id}"),
body.clone(),
));
let mut records = pb.records.lock().unwrap();
let Some(record) = records.get_mut(&collection).and_then(|list| {
list.iter_mut()
.find(|record| record["id"].as_str() == Some(&id))
}) else {
return StatusCode::NOT_FOUND.into_response();
};
if let (Some(target), Some(updates)) = (record.as_object_mut(), body.as_object()) {
for (key, value) in updates {
target.insert(key.clone(), value.clone());
}
}
Json(record.clone()).into_response()
}
async fn delete_record(
State(pb): State<Arc<MockPocketBase>>,
Path((collection, id)): Path<(String, String)>,
) -> Response {
pb.log.lock().unwrap().push((
"DELETE".to_string(),
format!("/api/collections/{collection}/records/{id}"),
Value::Null,
));
let mut records = pb.records.lock().unwrap();
let Some(list) = records.get_mut(&collection) else {
return StatusCode::NOT_FOUND.into_response();
};
let before = list.len();
list.retain(|record| record["id"].as_str() != Some(&id));
if list.len() == before {
StatusCode::NOT_FOUND.into_response()
} else {
StatusCode::NO_CONTENT.into_response()
}
}
fn mock_pb_router(pb: Arc<MockPocketBase>) -> Router {
Router::new()
.route(
"/api/collections/_superusers/auth-with-password",
post(auth_handler),
)
.route(
"/api/collections/{collection}/records",
get(list_records).post(create_record),
)
.route(
"/api/collections/{collection}/records/{id}",
get(get_record).patch(patch_record).delete(delete_record),
)
.with_state(pb)
}
// ---------------------------------------------------------------------------
// Test harness
// ---------------------------------------------------------------------------
struct TestEnv {
shared: Arc<SharedState>,
pb: Arc<MockPocketBase>,
}
async fn setup() -> TestEnv {
let pb = Arc::new(MockPocketBase::default());
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind mock PocketBase listener");
let addr = listener.local_addr().expect("mock PocketBase address");
let router = mock_pb_router(pb.clone());
tokio::spawn(async move {
axum::serve(listener, router)
.await
.expect("mock PocketBase serve");
});
let state = AppState::for_tests(format!("http://{addr}"));
TestEnv {
shared: Arc::new(SharedState::new(state)),
pb,
}
}
fn now_unix() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("clock after epoch")
.as_secs()
}
fn test_user(id: &str) -> PocketBaseUser {
PocketBaseUser {
id: id.to_string(),
email: format!("{id}@test.example"),
is_admin: false,
subscription: "free".to_string(),
newsletter: false,
can_see_listings: false,
}
}
fn seed_user(pb: &MockPocketBase, id: &str) {
pb.seed(
"users",
json!({
"id": id,
"email": format!("{id}@test.example"),
"subscription": "free",
"is_admin": false,
}),
);
}
fn seed_pending_checkout(pb: &MockPocketBase, user_id: &str, session_id: &str) -> String {
pb.seed(
"checkout_sessions",
json!({
"user": user_id,
"stripe_session_id": session_id,
"stripe_payment_intent_id": "",
"checkout_url": "https://checkout.stripe.test/session",
"amount_pence": 999,
"expected_total_pence": 999,
"currency": "gbp",
"discount_coupon_id": "",
"referral_invite_id": "",
"status": "pending",
"expires_at_unix": now_unix() + 1800,
"paid_amount_pence": 0,
"completed_at_unix": "",
"reversal_reason": "",
}),
)
}
fn checkout_completed_event(session_id: &str, user_id: &str, amount_total: u64) -> Vec<u8> {
serde_json::to_vec(&json!({
"id": "evt_test_1",
"type": "checkout.session.completed",
"data": { "object": {
"id": session_id,
"payment_intent": "pi_test_1",
"client_reference_id": user_id,
"payment_status": "paid",
"currency": "gbp",
"amount_subtotal": 999,
"amount_total": amount_total,
}}
}))
.expect("event serializes")
}
/// Sign a payload the way Stripe does: HMAC-SHA256 over `{timestamp}.{payload}`
/// with the webhook secret, presented as `t=...,v1=...`.
fn stripe_signature_header(payload: &[u8], secret: &str) -> String {
let timestamp = now_unix();
let mut mac =
Hmac::<Sha256>::new_from_slice(secret.as_bytes()).expect("HMAC accepts any key length");
mac.update(format!("{timestamp}.").as_bytes());
mac.update(payload);
let signature = hex::encode(mac.finalize().into_bytes());
format!("t={timestamp},v1={signature}")
}
async fn deliver_webhook(env: &TestEnv, payload: Vec<u8>, signature: Option<&str>) -> Response {
let mut headers = HeaderMap::new();
if let Some(signature) = signature {
headers.insert(
"stripe-signature",
HeaderValue::from_str(signature).expect("signature header value"),
);
}
post_stripe_webhook(State(env.shared.clone()), headers, Bytes::from(payload)).await
}
async fn redeem_invite(env: &TestEnv, user: PocketBaseUser, code: &str) -> Response {
post_redeem_invite(
State(env.shared.clone()),
Extension(OptionalUser(Some(user))),
Json(serde_json::from_value(json!({ "code": code })).expect("redeem request deserializes")),
)
.await
}
async fn response_json(response: Response) -> Value {
let bytes = axum::body::to_bytes(response.into_body(), 1 << 20)
.await
.expect("response body");
serde_json::from_slice(&bytes).expect("response body is JSON")
}
// ---------------------------------------------------------------------------
// Stripe webhook → license granting
// ---------------------------------------------------------------------------
#[tokio::test]
async fn webhook_with_valid_signature_grants_license() {
let env = setup().await;
seed_user(&env.pb, "user1");
let reservation_id = seed_pending_checkout(&env.pb, "user1", "cs_test_abc");
let payload = checkout_completed_event("cs_test_abc", "user1", 999);
let signature = stripe_signature_header(&payload, "whsec_test_secret");
let response = deliver_webhook(&env, payload, Some(&signature)).await;
assert_eq!(response.status(), StatusCode::OK);
let user = env.pb.record("users", "user1").expect("user exists");
assert_eq!(user["subscription"], json!("licensed"));
let checkout = env
.pb
.record("checkout_sessions", &reservation_id)
.expect("checkout exists");
assert_eq!(checkout["status"], json!("completed"));
assert_eq!(checkout["paid_amount_pence"], json!(999));
assert_eq!(checkout["stripe_payment_intent_id"], json!("pi_test_1"));
assert_eq!(env.pb.license_grant_count("user1"), 1);
}
#[tokio::test]
async fn webhook_with_invalid_signature_is_rejected() {
let env = setup().await;
seed_user(&env.pb, "user1");
let reservation_id = seed_pending_checkout(&env.pb, "user1", "cs_test_bad");
let payload = checkout_completed_event("cs_test_bad", "user1", 999);
// Signed with the wrong secret.
let wrong_signature = stripe_signature_header(&payload, "whsec_wrong_secret");
let response = deliver_webhook(&env, payload.clone(), Some(&wrong_signature)).await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
// Missing signature header entirely.
let response = deliver_webhook(&env, payload, None).await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
let user = env.pb.record("users", "user1").expect("user exists");
assert_eq!(user["subscription"], json!("free"));
let checkout = env
.pb
.record("checkout_sessions", &reservation_id)
.expect("checkout exists");
assert_eq!(checkout["status"], json!("pending"));
assert_eq!(env.pb.license_grant_count("user1"), 0);
}
#[tokio::test]
async fn replayed_webhook_does_not_double_grant() {
let env = setup().await;
seed_user(&env.pb, "user1");
let reservation_id = seed_pending_checkout(&env.pb, "user1", "cs_test_replay");
let payload = checkout_completed_event("cs_test_replay", "user1", 999);
let signature = stripe_signature_header(&payload, "whsec_test_secret");
let first = deliver_webhook(&env, payload.clone(), Some(&signature)).await;
assert_eq!(first.status(), StatusCode::OK);
let replay = deliver_webhook(&env, payload, Some(&signature)).await;
assert_eq!(replay.status(), StatusCode::OK);
let user = env.pb.record("users", "user1").expect("user exists");
assert_eq!(user["subscription"], json!("licensed"));
let checkout = env
.pb
.record("checkout_sessions", &reservation_id)
.expect("checkout exists");
assert_eq!(checkout["status"], json!("completed"));
// The replay must be acknowledged without granting a second time.
assert_eq!(env.pb.license_grant_count("user1"), 1);
}
#[tokio::test]
async fn webhook_with_tampered_amount_is_rejected_and_marks_reservation_invalid() {
let env = setup().await;
seed_user(&env.pb, "user1");
let reservation_id = seed_pending_checkout(&env.pb, "user1", "cs_test_amount");
// Validly signed event whose amount_total does not match the reservation.
let payload = checkout_completed_event("cs_test_amount", "user1", 500);
let signature = stripe_signature_header(&payload, "whsec_test_secret");
let response = deliver_webhook(&env, payload, Some(&signature)).await;
// Rejections are acknowledged with 200 so Stripe stops retrying.
assert_eq!(response.status(), StatusCode::OK);
let user = env.pb.record("users", "user1").expect("user exists");
assert_eq!(user["subscription"], json!("free"));
let checkout = env
.pb
.record("checkout_sessions", &reservation_id)
.expect("checkout exists");
assert_eq!(checkout["status"], json!("invalid"));
assert_eq!(env.pb.license_grant_count("user1"), 0);
}
// ---------------------------------------------------------------------------
// Checkout session creation (up to the hardcoded Stripe API call)
// ---------------------------------------------------------------------------
#[tokio::test]
async fn checkout_start_reserves_then_marks_failed_when_stripe_is_unreachable() {
let env = setup().await;
seed_user(&env.pb, "user9");
let state = env.shared.load_state();
let user = test_user("user9");
// The Stripe API URL is hardcoded to https://api.stripe.com, so the
// session-creation call fails in tests (no network / dummy key). The
// reservation bookkeeping before and after that call is what we assert.
let result = start_license_checkout(
&state,
&user,
"https://x/success",
"https://x/cancel",
None,
None,
)
.await;
assert!(result.is_err(), "Stripe call must fail in tests");
let checkouts = env.pb.records_in("checkout_sessions");
assert_eq!(checkouts.len(), 1, "exactly one reservation created");
let checkout = &checkouts[0];
assert_eq!(checkout["user"], json!("user9"));
// 0 licensed users → public count 120 → second tier price (999p).
assert_eq!(checkout["amount_pence"], json!(999));
assert_eq!(checkout["expected_total_pence"], json!(999));
assert_eq!(checkout["currency"], json!("gbp"));
// The failed Stripe call must not leave a live pending reservation.
assert_eq!(checkout["status"], json!("failed"));
// The cross-instance pricing lock was released.
assert!(env.pb.records_in("checkout_locks").is_empty());
assert_eq!(env.pb.license_grant_count("user9"), 0);
}
// ---------------------------------------------------------------------------
// Invite redemption
// ---------------------------------------------------------------------------
#[tokio::test]
async fn admin_invite_redemption_grants_license() {
let env = setup().await;
seed_user(&env.pb, "user2");
let invite_id = env.pb.seed(
"invites",
json!({
"code": "admininvite1",
"invite_type": "admin",
"created_by": "adminuser1",
"used_by_id": "",
"used_at": "",
}),
);
let response = redeem_invite(&env, test_user("user2"), "admininvite1").await;
assert_eq!(response.status(), StatusCode::OK);
let body = response_json(response).await;
assert_eq!(body["result"], json!("licensed"));
let invite = env.pb.record("invites", &invite_id).expect("invite exists");
assert_eq!(invite["used_by_id"], json!("user2"));
let user = env.pb.record("users", "user2").expect("user exists");
assert_eq!(user["subscription"], json!("licensed"));
assert_eq!(env.pb.license_grant_count("user2"), 1);
}
#[tokio::test]
async fn invalid_and_oversized_invite_codes_are_rejected() {
let env = setup().await;
// Non-alphanumeric characters.
let response = redeem_invite(&env, test_user("user2"), "bad-code!").await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
// Longer than the 20-character limit.
let oversized = "a".repeat(21);
let response = redeem_invite(&env, test_user("user2"), &oversized).await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
// Empty code.
let response = redeem_invite(&env, test_user("user2"), "").await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
assert_eq!(env.pb.license_grant_count("user2"), 0);
}
#[tokio::test]
async fn already_used_invite_is_rejected() {
let env = setup().await;
seed_user(&env.pb, "user2");
env.pb.seed(
"invites",
json!({
"code": "usedinvite12",
"invite_type": "admin",
"created_by": "adminuser1",
"used_by_id": "otheruser9",
"used_at": "1700000000",
}),
);
let response = redeem_invite(&env, test_user("user2"), "usedinvite12").await;
assert_eq!(response.status(), StatusCode::NOT_FOUND);
let user = env.pb.record("users", "user2").expect("user exists");
assert_eq!(user["subscription"], json!("free"));
assert_eq!(env.pb.license_grant_count("user2"), 0);
}
#[tokio::test]
async fn referral_invite_redemption_releases_reservation_when_stripe_is_unreachable() {
let env = setup().await;
seed_user(&env.pb, "user3");
let invite_id = env.pb.seed(
"invites",
json!({
"code": "refcode12345",
"invite_type": "referral",
"created_by": "licenseduser1",
"used_by_id": "",
"used_at": "",
"reserved_by_id": "",
"reserved_checkout_id": "",
"reserved_until_unix": 0,
}),
);
// The redemption itself is valid; it fails only at the hardcoded Stripe
// call (coupon verification / session creation), which must roll back the
// reservation cleanly.
let response = redeem_invite(&env, test_user("user3"), "refcode12345").await;
assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
let checkouts = env.pb.records_in("checkout_sessions");
assert_eq!(checkouts.len(), 1, "referral reservation was created");
assert_eq!(checkouts[0]["referral_invite_id"], json!(invite_id));
assert_eq!(checkouts[0]["status"], json!("failed"));
// The invite reservation was released and the invite is still unused.
let invite = env.pb.record("invites", &invite_id).expect("invite exists");
assert_eq!(invite["used_by_id"], json!(""));
assert_eq!(invite["reserved_checkout_id"], json!(""));
assert_eq!(invite["reserved_by_id"], json!(""));
assert_eq!(env.pb.license_grant_count("user3"), 0);
}

View file

@ -182,8 +182,7 @@ impl CrimeByYearData {
// Force-coverage calendar (optional column: legacy parquets predate it; // Force-coverage calendar (optional column: legacy parquets predate it;
// their postcodes are treated as fully covered). A row with an empty // their postcodes are treated as fully covered). A row with an empty
// list is meaningful — zero covered years — so it IS inserted. // list is meaningful — zero covered years — so it IS inserted.
let mut covered_years_by_postcode: FxHashMap<String, Vec<i32>> = let mut covered_years_by_postcode: FxHashMap<String, Vec<i32>> = FxHashMap::default();
FxHashMap::default();
if let Ok(col) = df.column(COVERAGE_COLUMN) { if let Ok(col) = df.column(COVERAGE_COLUMN) {
let list_ca = col let list_ca = col
.list() .list()
@ -195,12 +194,12 @@ impl CrimeByYearData {
}; };
let mut years: Vec<i32> = Vec::with_capacity(inner.len()); let mut years: Vec<i32> = Vec::with_capacity(inner.len());
if !inner.is_empty() { if !inner.is_empty() {
let structs = inner.struct_().with_context(|| { let structs = inner
format!("Inner of '{COVERAGE_COLUMN}' is not a struct") .struct_()
})?; .with_context(|| format!("Inner of '{COVERAGE_COLUMN}' is not a struct"))?;
let year_field = structs.field_by_name("year").with_context(|| { let year_field = structs
format!("Missing 'year' field in '{COVERAGE_COLUMN}'") .field_by_name("year")
})?; .with_context(|| format!("Missing 'year' field in '{COVERAGE_COLUMN}'"))?;
for idx in 0..inner.len() { for idx in 0..inner.len() {
match year_field.get(idx).ok() { match year_field.get(idx).ok() {
Some(AnyValue::Int32(y)) => years.push(y), Some(AnyValue::Int32(y)) => years.push(y),

View file

@ -742,6 +742,29 @@ impl PlaceData {
} }
} }
#[cfg(test)]
impl PlaceData {
/// Minimal empty instance for integration tests that need an `AppState`
/// but never touch place data.
pub(crate) fn empty_for_tests() -> Self {
PlaceData {
name: Vec::new(),
name_lower: Vec::new(),
name_search: Vec::new(),
place_type: InternedColumn::build(&[]),
type_rank: Vec::new(),
population: Vec::new(),
lat: Vec::new(),
lon: Vec::new(),
city: Vec::new(),
travel_destination: Vec::new(),
token_index: FxHashMap::default(),
token_prefix_index: FxHashMap::default(),
fuzzy_trigram_index: FxHashMap::default(),
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View file

@ -588,6 +588,29 @@ impl POIData {
} }
} }
#[cfg(test)]
impl POIData {
/// Minimal empty instance for integration tests that need an `AppState`
/// but never touch POI data.
pub(crate) fn empty_for_tests() -> Self {
POIData {
id_buffer: String::new(),
id_offsets: Vec::new(),
id_lengths: Vec::new(),
group: InternedColumn::build(&[]),
category: InternedColumn::build(&[]),
icon_category: InternedColumn::build(&[]),
name: Vec::new(),
lat: Vec::new(),
lng: Vec::new(),
emoji: InternedColumn::build(&[]),
priority: Vec::new(),
school_meta_idx: Vec::new(),
school_meta: Vec::new(),
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,973 @@
//! Address search: tokenization, query parsing, inverted/prefix indexes and the
//! ranked per-row search over property addresses.
use rustc_hash::{FxHashMap, FxHashSet};
use super::PropertyData;
/// Upper bound on rows scored per query. Intersection keeps most candidate sets far below
/// this; only a single very common road word (e.g. "high") approaches it, and the in-area
/// priority sort keeps a refined query's matches ahead of the cut.
const ADDRESS_SEARCH_CANDIDATE_LIMIT: usize = 150_000;
const ADDRESS_SEARCH_PREFIX_MIN_LEN: usize = 4;
const ADDRESS_SEARCH_PREFIX_MAX_LEN: usize = 8;
#[derive(Clone, Debug)]
pub(super) struct AddressTermGroup {
alternatives: Vec<String>,
}
#[derive(Debug)]
pub(super) struct AddressQuery {
full_postcode: Option<String>,
/// Compact uppercase outward code (optionally with a sector digit) recovered when the
/// user appended a partial postcode like "NW1" or "NW1 6". Used as an additive ranking
/// bias, never as a hard filter — so the disambiguating hint is honoured without
/// excluding the same road in other areas.
postcode_area: Option<String>,
text_groups: Vec<AddressTermGroup>,
numeric_terms: Vec<String>,
candidate_terms: Vec<String>,
}
fn tokenize_address_text(text: &str) -> Vec<String> {
let mut tokens = Vec::new();
let mut current = String::new();
for ch in text.chars() {
if ch.is_ascii_alphanumeric() {
current.push(ch.to_ascii_lowercase());
} else if matches!(ch, '\'' | '' | '`') {
continue;
} else if !current.is_empty() {
tokens.push(std::mem::take(&mut current));
}
}
if !current.is_empty() {
tokens.push(current);
}
tokens
}
fn is_full_postcode_compact(compact: &str) -> bool {
let bytes = compact.as_bytes();
let len = bytes.len();
if !(5..=7).contains(&len) {
return false;
}
let inward = &bytes[len - 3..];
if !inward[0].is_ascii_digit()
|| !inward[1].is_ascii_alphabetic()
|| !inward[2].is_ascii_alphabetic()
{
return false;
}
let outward = &bytes[..len - 3];
if !(2..=4).contains(&outward.len()) {
return false;
}
outward[0].is_ascii_alphabetic()
&& outward.iter().all(u8::is_ascii_alphanumeric)
&& outward.iter().any(u8::is_ascii_digit)
}
fn canonical_postcode_from_compact(compact: &str) -> String {
let upper = compact.to_ascii_uppercase();
let split = upper.len() - 3;
format!("{} {}", &upper[..split], &upper[split..])
}
fn extract_full_postcode(tokens: &[String]) -> Option<(String, Vec<usize>)> {
for (idx, token) in tokens.iter().enumerate() {
let compact = token.to_ascii_uppercase();
if is_full_postcode_compact(&compact) {
return Some((canonical_postcode_from_compact(&compact), vec![idx]));
}
}
for idx in 0..tokens.len().saturating_sub(1) {
let compact = format!(
"{}{}",
tokens[idx].to_ascii_uppercase(),
tokens[idx + 1].to_ascii_uppercase()
);
if is_full_postcode_compact(&compact) {
return Some((
canonical_postcode_from_compact(&compact),
vec![idx, idx + 1],
));
}
}
None
}
fn looks_like_postcode_fragment(token: &str) -> bool {
(2..=4).contains(&token.len())
&& token
.chars()
.next()
.is_some_and(|ch| ch.is_ascii_alphabetic())
&& token.chars().any(|ch| ch.is_ascii_digit())
&& token.chars().all(|ch| ch.is_ascii_alphanumeric())
}
fn is_numeric_address_token(token: &str) -> bool {
token.chars().all(|ch| ch.is_ascii_digit())
}
fn address_token_aliases(token: &str) -> Vec<&'static str> {
match token {
"apt" => vec!["apt", "apartment"],
"apartment" => vec!["apartment", "apt"],
"ave" => vec!["ave", "avenue"],
"avenue" => vec!["avenue", "ave"],
"blvd" => vec!["blvd", "boulevard"],
"boulevard" => vec!["boulevard", "blvd"],
"cl" => vec!["cl", "close"],
"close" => vec!["close", "cl"],
"ct" => vec!["ct", "court"],
"court" => vec!["court", "ct"],
"cres" => vec!["cres", "crescent"],
"crescent" => vec!["crescent", "cres"],
"dr" => vec!["dr", "drive"],
"drive" => vec!["drive", "dr"],
"fl" => vec!["fl", "flat"],
"flat" => vec!["flat", "fl"],
"gdns" => vec!["gdns", "gardens", "garden"],
"garden" => vec!["garden", "gardens", "gdns"],
"gardens" => vec!["gardens", "garden", "gdns"],
"hse" => vec!["hse", "house"],
"house" => vec!["house", "hse"],
"ln" => vec!["ln", "lane"],
"lane" => vec!["lane", "ln"],
"rd" => vec!["rd", "road"],
"road" => vec!["road", "rd"],
"sq" => vec!["sq", "square"],
"square" => vec!["square", "sq"],
"st" => vec!["st", "street", "saint"],
"street" => vec!["street", "st"],
"saint" => vec!["saint", "st"],
"terr" => vec!["terr", "terrace"],
"terrace" => vec!["terrace", "terr"],
_ => Vec::new(),
}
}
fn is_address_stop_token(token: &str) -> bool {
matches!(
token,
"a" | "an"
| "and"
| "apartment"
| "apt"
| "avenue"
| "ave"
| "block"
| "building"
| "bungalow"
| "close"
| "cl"
| "court"
| "ct"
| "cres"
| "crescent"
| "drive"
| "dr"
| "estate"
| "flat"
| "fl"
| "floor"
| "garden"
| "gardens"
| "gdns"
| "grove"
| "house"
| "hse"
| "lane"
| "ln"
| "lodge"
| "mansions"
| "mews"
| "of"
| "park"
| "place"
| "road"
| "rd"
| "room"
| "row"
| "saint"
| "sq"
| "square"
| "st"
| "street"
| "terr"
| "terrace"
| "the"
| "unit"
| "view"
| "villas"
| "walk"
| "way"
| "yard"
)
}
fn address_term_group(token: &str) -> Option<AddressTermGroup> {
if token.len() < 3 || is_numeric_address_token(token) || looks_like_postcode_fragment(token) {
return None;
}
let mut alternatives = Vec::new();
alternatives.push(token.to_string());
for alias in address_token_aliases(token) {
if !alternatives.iter().any(|existing| existing == alias) {
alternatives.push(alias.to_string());
}
}
if alternatives
.iter()
.all(|alternative| is_address_stop_token(alternative))
{
return None;
}
Some(AddressTermGroup { alternatives })
}
pub(super) fn address_search_tokens(text: &str) -> Vec<String> {
let mut tokens: Vec<String> = tokenize_address_text(text)
.into_iter()
.filter(|token| is_address_search_token(token))
.collect();
tokens.sort_unstable();
tokens.dedup();
tokens
}
fn is_address_search_token(token: &str) -> bool {
if looks_like_postcode_fragment(token) {
return false;
}
if is_numeric_address_token(token) {
return true;
}
if token.chars().any(|ch| ch.is_ascii_digit()) {
return token.len() >= 2;
}
token.len() >= 3
}
pub(super) fn is_address_candidate_token(token: &str) -> bool {
!is_numeric_address_token(token)
&& !looks_like_postcode_fragment(token)
&& (token.chars().any(|ch| ch.is_ascii_digit())
|| (token.len() >= 3 && !is_address_stop_token(token)))
}
fn address_prefix_key(term: &str) -> &str {
if term.len() > ADDRESS_SEARCH_PREFIX_MAX_LEN {
&term[..ADDRESS_SEARCH_PREFIX_MAX_LEN]
} else {
term
}
}
pub(super) fn build_address_prefix_index(
address_token_index: &FxHashMap<String, Vec<u32>>,
) -> FxHashMap<String, Vec<String>> {
let mut prefix_index: FxHashMap<String, Vec<String>> = FxHashMap::default();
for token in address_token_index.keys() {
let max_prefix_len = token.len().min(ADDRESS_SEARCH_PREFIX_MAX_LEN);
for prefix_len in ADDRESS_SEARCH_PREFIX_MIN_LEN..=max_prefix_len {
prefix_index
.entry(token[..prefix_len].to_string())
.or_default()
.push(token.clone());
}
}
for tokens in prefix_index.values_mut() {
tokens.sort_unstable();
tokens.dedup();
}
prefix_index
}
/// Intersect two ascending-sorted row-id slices.
fn intersect_sorted(left: &[u32], right: &[u32]) -> Vec<u32> {
let mut out = Vec::new();
let (mut i, mut j) = (0, 0);
while i < left.len() && j < right.len() {
match left[i].cmp(&right[j]) {
std::cmp::Ordering::Less => i += 1,
std::cmp::Ordering::Greater => j += 1,
std::cmp::Ordering::Equal => {
out.push(left[i]);
i += 1;
j += 1;
}
}
}
out
}
/// Union two ascending-sorted row-id slices (deduplicated, stays sorted).
fn union_sorted(left: &[u32], right: &[u32]) -> Vec<u32> {
let mut out = Vec::with_capacity(left.len() + right.len());
let (mut i, mut j) = (0, 0);
while i < left.len() && j < right.len() {
match left[i].cmp(&right[j]) {
std::cmp::Ordering::Less => {
out.push(left[i]);
i += 1;
}
std::cmp::Ordering::Greater => {
out.push(right[j]);
j += 1;
}
std::cmp::Ordering::Equal => {
out.push(left[i]);
i += 1;
j += 1;
}
}
}
out.extend_from_slice(&left[i..]);
out.extend_from_slice(&right[j..]);
out
}
/// An ordinal like "1st", "2nd", "3rd", "21st" — part of the street name ("2nd Avenue"), not a
/// house-number prefix.
fn is_ordinal_token(token: &str) -> bool {
let split = token.len().saturating_sub(2);
let (digits, suffix) = token.split_at(split);
!digits.is_empty()
&& digits.chars().all(|ch| ch.is_ascii_digit())
&& matches!(suffix, "st" | "nd" | "rd" | "th")
}
/// Leading address tokens that denote a unit/house number rather than the street itself.
fn is_house_prefix_token(token: &str) -> bool {
if is_ordinal_token(token) {
return false;
}
matches!(
token,
"flat" | "fl" | "apartment" | "apt" | "unit" | "no" | "block" | "floor" | "room"
) || token.len() == 1
|| token.chars().all(|ch| ch.is_ascii_digit())
|| (token.chars().next().is_some_and(|ch| ch.is_ascii_digit())
&& token.chars().any(|ch| ch.is_ascii_alphabetic()))
}
/// Street-level key for an address: drops the leading house-number / flat prefix so that
/// "12 Baker Street" and "5 Baker Street" collapse to a single street entry.
fn street_key(address: &str) -> String {
let tokens = tokenize_address_text(address);
let mut start = 0;
while start < tokens.len() && is_house_prefix_token(&tokens[start]) {
start += 1;
}
if start >= tokens.len() {
return tokens.join(" ");
}
tokens[start..].join(" ")
}
/// Road-type words. Their presence (with no house number) marks a road browse, which we
/// collapse to one result per street.
const ROAD_TYPE_TOKENS: &[&str] = &[
"street",
"st",
"road",
"rd",
"lane",
"ln",
"avenue",
"ave",
"close",
"cl",
"drive",
"dr",
"way",
"court",
"ct",
"crescent",
"cres",
"place",
"terrace",
"terr",
"grove",
"gardens",
"gdns",
"walk",
"row",
"square",
"sq",
"hill",
"parade",
"mews",
"embankment",
"broadway",
"boulevard",
"blvd",
];
fn query_has_road_type(query: &str) -> bool {
tokenize_address_text(query)
.iter()
.any(|token| ROAD_TYPE_TOKENS.contains(&token.as_str()))
}
/// The outward code (everything before the space) of a canonical postcode.
fn outcode_of(postcode: &str) -> &str {
postcode.split(' ').next().unwrap_or(postcode)
}
fn parse_address_query(query: &str) -> AddressQuery {
let tokens = tokenize_address_text(query);
let (full_postcode, postcode_token_indices) = extract_full_postcode(&tokens)
.map(|(postcode, indices)| (Some(postcode), indices))
.unwrap_or((None, Vec::new()));
let skip_postcode_tokens: FxHashSet<usize> = postcode_token_indices.into_iter().collect();
// Recover an appended partial postcode (outcode, or outcode + sector digit) as a ranking
// bias rather than discarding it — but only from the TRAILING position, so a leading road
// designation like "A4 Great West Road" is not mistaken for an area refinement.
let mut postcode_area: Option<String> = None;
let mut consumed_partial_tokens: FxHashSet<usize> = FxHashSet::default();
if full_postcode.is_none() && !tokens.is_empty() {
let last = tokens.len() - 1;
if !skip_postcode_tokens.contains(&last) {
let sector_digit =
tokens[last].len() == 1 && tokens[last].chars().all(|ch| ch.is_ascii_digit());
if last >= 1
&& sector_digit
&& !skip_postcode_tokens.contains(&(last - 1))
&& looks_like_postcode_fragment(&tokens[last - 1])
{
postcode_area = Some(format!(
"{}{}",
tokens[last - 1].to_ascii_uppercase(),
tokens[last]
));
consumed_partial_tokens.insert(last);
consumed_partial_tokens.insert(last - 1);
} else if looks_like_postcode_fragment(&tokens[last]) {
postcode_area = Some(tokens[last].to_ascii_uppercase());
consumed_partial_tokens.insert(last);
}
}
}
let mut text_groups = Vec::new();
let mut numeric_terms = Vec::new();
let mut candidate_terms = Vec::new();
for (idx, token) in tokens.iter().enumerate() {
if skip_postcode_tokens.contains(&idx)
|| consumed_partial_tokens.contains(&idx)
|| looks_like_postcode_fragment(token)
{
continue;
}
if is_numeric_address_token(token) {
numeric_terms.push(token.clone());
continue;
}
if let Some(group) = address_term_group(token) {
for alternative in &group.alternatives {
if !is_address_stop_token(alternative)
&& !candidate_terms.iter().any(|term| term == alternative)
{
candidate_terms.push(alternative.clone());
}
}
text_groups.push(group);
} else if token.chars().any(|ch| ch.is_ascii_digit()) && token.len() >= 2 {
numeric_terms.push(token.clone());
if !candidate_terms.iter().any(|term| term == token) {
candidate_terms.push(token.clone());
}
}
}
text_groups.dedup_by(|left, right| left.alternatives == right.alternatives);
numeric_terms.sort_unstable();
numeric_terms.dedup();
AddressQuery {
full_postcode,
postcode_area,
text_groups,
numeric_terms,
candidate_terms,
}
}
fn token_matches_query_term(token: &str, query_term: &str) -> bool {
token == query_term || (query_term.len() >= 3 && token.starts_with(query_term))
}
fn token_matches_numeric_term(token: &str, query_term: &str) -> bool {
token == query_term || token.starts_with(query_term)
}
#[cfg(test)]
fn address_tokens_match_group(tokens: &[String], group: &AddressTermGroup) -> bool {
group.alternatives.iter().any(|alternative| {
tokens
.iter()
.any(|token| token_matches_query_term(token, alternative))
})
}
impl PropertyData {
fn row_address_search_tokens(&self, row: usize) -> &[lasso::Spur] {
let offset = self.address_search_token_offsets[row] as usize;
let length = self.address_search_token_lengths[row] as usize;
&self.address_search_token_keys[offset..offset + length]
}
/// Search individual property addresses, returning `(row, score)` ranked best-first.
///
/// Candidate rows come from intersecting the posting lists of the distinctive words the
/// user typed in full (so "Cherry Hinton Road" narrows to rows containing both), unioned
/// with the exact-postcode rows when a complete postcode is present (so a postcode is a
/// boost, not an all-or-nothing gate). An appended partial postcode keeps in-area rows
/// ahead of the candidate cut and adds a scoring bias. With a road-type word and no house
/// number, results collapse to one row per street.
pub fn search_addresses(&self, query: &str, limit: usize) -> Vec<(usize, i32)> {
if limit == 0 {
return Vec::new();
}
let parsed = parse_address_query(query);
if parsed.full_postcode.is_none()
&& parsed.text_groups.is_empty()
&& parsed.numeric_terms.is_empty()
{
return Vec::new();
}
let mut candidate_rows = self.address_candidate_rows(&parsed.candidate_terms);
// A complete postcode contributes its rows too, instead of replacing the road match.
if let Some(postcode) = parsed.full_postcode.as_deref() {
if let Some(rows) = self
.postcode_interner
.get(postcode)
.and_then(|key| self.postcode_row_index.get(&key))
{
candidate_rows = if candidate_rows.is_empty() {
rows.clone()
} else {
union_sorted(&candidate_rows, rows)
};
}
}
if candidate_rows.is_empty() {
return Vec::new();
}
// When the user appended a partial postcode, keep in-area rows ahead of the cut so the
// refinement still surfaces even for very common roads. Single pass (stable partition) so
// the postcode check — which allocates — runs exactly once per candidate.
if let Some(area) = parsed.postcode_area.as_deref() {
let mut in_area = Vec::new();
let mut others = Vec::new();
for &row in &candidate_rows {
if self.row_postcode_in_area(row as usize, area) {
in_area.push(row);
} else {
others.push(row);
}
}
in_area.extend(others);
candidate_rows = in_area;
}
candidate_rows.truncate(ADDRESS_SEARCH_CANDIDATE_LIMIT);
let mut scored: Vec<(i32, usize, usize)> = candidate_rows
.into_iter()
.filter_map(|row| {
let row = row as usize;
self.address_match_score(row, &parsed)
.map(|score| (score, self.address(row).len(), row))
})
.collect();
scored.sort_unstable_by(|left, right| {
right
.0
.cmp(&left.0)
.then(left.1.cmp(&right.1))
.then(left.2.cmp(&right.2))
});
// Collapse a road browse (road-type word, no house number) to one row per street.
let collapse_streets = parsed.numeric_terms.is_empty() && query_has_road_type(query);
let mut seen = FxHashSet::default();
let mut results = Vec::with_capacity(limit);
for (score, _, row) in scored {
let address = self.address(row).trim();
if address.is_empty() {
continue;
}
let key = if collapse_streets {
format!(
"{}\n{}",
street_key(address),
outcode_of(self.postcode(row))
)
} else {
format!("{}\n{}", address.to_ascii_lowercase(), self.postcode(row))
};
if !seen.insert(key) {
continue;
}
results.push((row, score));
if results.len() == limit {
break;
}
}
results
}
/// True when the row's postcode begins with the compact partial-postcode `area`
/// (e.g. "NW1" or "NW16" matches "NW1 6XE").
fn row_postcode_in_area(&self, row: usize, area: &str) -> bool {
let mut compact = String::new();
for ch in self.postcode(row).chars() {
if !ch.is_whitespace() {
compact.push(ch.to_ascii_uppercase());
}
}
compact.starts_with(area)
}
/// Candidate rows for the distinctive query words. Words typed in full intersect by their
/// exact posting lists (precise); a still-being-typed final word with no exact match seeds
/// from the smallest prefix-expanded posting list (so partial typing keeps working).
fn address_candidate_rows(&self, terms: &[String]) -> Vec<u32> {
let mut exact: Vec<&[u32]> = terms
.iter()
.filter_map(|term| self.address_token_index.get(term).map(Vec::as_slice))
.collect();
if !exact.is_empty() {
exact.sort_by_key(|rows| rows.len());
let mut acc = exact[0].to_vec();
for rows in &exact[1..] {
if acc.is_empty() {
break;
}
acc = intersect_sorted(&acc, rows);
}
return acc;
}
self.prefix_seed_rows(terms)
}
/// Seed rows from the smallest prefix-expanded term — used only when no word matched an
/// indexed token exactly (i.e. the user is still typing the final word).
fn prefix_seed_rows(&self, terms: &[String]) -> Vec<u32> {
let mut best: Option<Vec<u32>> = None;
for term in terms {
if term.len() < ADDRESS_SEARCH_PREFIX_MIN_LEN {
continue;
}
let Some(tokens) = self.address_prefix_index.get(address_prefix_key(term)) else {
continue;
};
let mut union: Vec<u32> = Vec::new();
for token in tokens {
if !token.starts_with(term) {
continue;
}
if let Some(rows) = self.address_token_index.get(token) {
union = if union.is_empty() {
rows.clone()
} else {
union_sorted(&union, rows)
};
}
}
if !union.is_empty()
&& best
.as_ref()
.is_none_or(|current| union.len() < current.len())
{
best = Some(union);
}
}
best.unwrap_or_default()
}
fn address_match_score(&self, row: usize, parsed: &AddressQuery) -> Option<i32> {
if self.address(row).trim().is_empty() {
return None;
}
let tokens = self.row_address_search_tokens(row);
if parsed
.text_groups
.iter()
.any(|group| !self.address_tokens_match_group(tokens, group))
{
return None;
}
let numeric_matches = parsed
.numeric_terms
.iter()
.filter(|term| {
tokens.iter().any(|token| {
token_matches_numeric_term(self.address_search_interner.resolve(token), term)
})
})
.count();
if !parsed.numeric_terms.is_empty() && numeric_matches == 0 {
return None;
}
let mut score = 0;
if parsed.full_postcode.is_some() {
score += 1_000;
}
score += (parsed.text_groups.len() as i32) * 200;
score += (numeric_matches as i32) * 90;
if numeric_matches == parsed.numeric_terms.len() && numeric_matches > 0 {
score += 50;
}
// Additive bias (never a filter) when the row sits in the appended partial postcode.
if let Some(area) = parsed.postcode_area.as_deref() {
if self.row_postcode_in_area(row, area) {
score += 400;
}
}
Some(score)
}
fn address_tokens_match_group(&self, tokens: &[lasso::Spur], group: &AddressTermGroup) -> bool {
group.alternatives.iter().any(|alternative| {
tokens.iter().any(|token| {
token_matches_query_term(self.address_search_interner.resolve(token), alternative)
})
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn full_postcode_detection_accepts_common_formats() {
assert!(is_full_postcode_compact("SW1A1AA"));
assert!(is_full_postcode_compact("E142DG"));
assert!(is_full_postcode_compact("M11AE"));
assert!(!is_full_postcode_compact("E14"));
assert!(!is_full_postcode_compact("DOWNING"));
assert!(!is_full_postcode_compact("10A"));
}
#[test]
fn address_query_parsing_skips_postcodes_and_street_suffixes() {
let parsed = parse_address_query("Flat 2, 10 Downing St, SW1A 2AA");
assert_eq!(parsed.full_postcode.as_deref(), Some("SW1A 2AA"));
assert_eq!(
parsed.numeric_terms,
vec!["10".to_string(), "2".to_string()]
);
assert_eq!(parsed.candidate_terms, vec!["downing".to_string()]);
assert_eq!(parsed.text_groups.len(), 1);
assert_eq!(
parsed.text_groups[0].alternatives,
vec!["downing".to_string()]
);
}
#[test]
fn address_query_parsing_handles_compact_postcodes() {
let parsed = parse_address_query("10 downing street sw1a1aa");
assert_eq!(parsed.full_postcode.as_deref(), Some("SW1A 1AA"));
assert_eq!(parsed.numeric_terms, vec!["10".to_string()]);
assert_eq!(parsed.candidate_terms, vec!["downing".to_string()]);
}
#[test]
fn address_query_recovers_appended_partial_postcode_as_bias() {
let parsed = parse_address_query("Baker Street NW1");
assert_eq!(parsed.full_postcode, None);
assert_eq!(parsed.postcode_area.as_deref(), Some("NW1"));
// The road words are still searchable; the postcode fragment did not consume them.
assert_eq!(parsed.candidate_terms, vec!["baker".to_string()]);
assert!(parsed.numeric_terms.is_empty());
}
#[test]
fn address_query_recovers_outcode_plus_sector_without_a_phantom_house_number() {
let parsed = parse_address_query("High Street CR0 2");
assert_eq!(parsed.postcode_area.as_deref(), Some("CR02"));
// The lone sector digit must not be treated as a house number.
assert!(parsed.numeric_terms.is_empty());
assert_eq!(parsed.candidate_terms, vec!["high".to_string()]);
}
#[test]
fn full_postcode_takes_precedence_over_partial_bias() {
let parsed = parse_address_query("Baker Street NW1 6XE");
assert_eq!(parsed.full_postcode.as_deref(), Some("NW1 6XE"));
assert_eq!(parsed.postcode_area, None);
}
#[test]
fn intersect_and_union_sorted_row_ids() {
assert_eq!(
intersect_sorted(&[1, 2, 3, 5], &[2, 3, 4, 5]),
vec![2, 3, 5]
);
assert_eq!(intersect_sorted(&[1, 2], &[3, 4]), Vec::<u32>::new());
assert_eq!(union_sorted(&[1, 3, 5], &[2, 3, 4]), vec![1, 2, 3, 4, 5]);
assert_eq!(union_sorted(&[], &[2, 4]), vec![2, 4]);
}
#[test]
fn street_key_collapses_house_numbers_and_flats() {
assert_eq!(street_key("12 Baker Street"), "baker street");
assert_eq!(street_key("5 Baker Street"), "baker street");
assert_eq!(street_key("Flat 2, 10 Downing Street"), "downing street");
assert_eq!(street_key("221B Baker Street"), "baker street");
}
#[test]
fn street_key_keeps_ordinal_street_names() {
// Ordinals are part of the street name, not a house-number prefix.
assert_eq!(street_key("2nd Avenue"), "2nd avenue");
assert_eq!(street_key("12 3rd Avenue"), "3rd avenue");
assert!(is_ordinal_token("21st"));
assert!(!is_ordinal_token("21"));
assert!(!is_ordinal_token("221b"));
}
#[test]
fn postcode_area_recovered_only_from_the_trailing_position() {
// A leading road designation must NOT be taken as an area refinement.
let parsed = parse_address_query("A4 Great West Road");
assert_eq!(parsed.postcode_area, None);
// A genuine trailing outcode still is.
let trailing = parse_address_query("Great West Road W4");
assert_eq!(trailing.postcode_area.as_deref(), Some("W4"));
}
#[test]
fn road_type_detection() {
assert!(query_has_road_type("high street"));
assert!(query_has_road_type("acacia avenue"));
assert!(!query_has_road_type("acacia"));
assert!(!query_has_road_type("london"));
}
#[test]
fn address_query_parsing_keeps_partial_terms_for_row_matching() {
let parsed = parse_address_query("settlers cour");
assert_eq!(parsed.full_postcode, None);
assert_eq!(parsed.numeric_terms, Vec::<String>::new());
assert_eq!(
parsed.candidate_terms,
vec!["settlers".to_string(), "cour".to_string()]
);
assert_eq!(parsed.text_groups.len(), 2);
assert_eq!(
parsed.text_groups[0].alternatives,
vec!["settlers".to_string()]
);
assert_eq!(parsed.text_groups[1].alternatives, vec!["cour".to_string()]);
}
#[test]
fn address_search_tokens_keep_actual_address_terms_for_scoring() {
let tokens = address_search_tokens("Flat 2, 10 Downing Cour");
assert_eq!(
tokens,
vec![
"10".to_string(),
"2".to_string(),
"cour".to_string(),
"downing".to_string(),
"flat".to_string()
]
);
}
#[test]
fn address_prefix_index_finds_partial_address_terms() {
let mut token_index: FxHashMap<String, Vec<u32>> = FxHashMap::default();
token_index.insert("downing".to_string(), vec![1]);
token_index.insert("downton".to_string(), vec![2]);
token_index.insert("market".to_string(), vec![3]);
let prefix_index = build_address_prefix_index(&token_index);
assert_eq!(
prefix_index.get("down").cloned().unwrap_or_default(),
vec!["downing".to_string(), "downton".to_string()]
);
assert_eq!(
prefix_index.get("downi").cloned().unwrap_or_default(),
vec!["downing".to_string()]
);
assert_eq!(
prefix_index.get("downt").cloned().unwrap_or_default(),
vec!["downton".to_string()]
);
assert!(!prefix_index.contains_key("do"));
}
#[test]
fn address_term_matching_allows_prefixes_and_aliases() {
let tokens = tokenize_address_text("10 Downing Street");
let prefix_group = address_term_group("down").expect("prefix term should be searchable");
let alias_group = AddressTermGroup {
alternatives: vec!["st".to_string(), "street".to_string()],
};
assert!(address_tokens_match_group(&tokens, &prefix_group));
assert!(address_tokens_match_group(&tokens, &alias_group));
}
#[test]
fn address_term_matching_uses_actual_token_prefixes() {
let tokens = tokenize_address_text("12 Settlers Court");
let prefix_group = address_term_group("cou").expect("partial term should be searchable");
assert!(address_tokens_match_group(&tokens, &prefix_group));
}
}

View file

@ -0,0 +1,34 @@
//! H3 spatial cell precomputation for property rows.
use anyhow::Context;
use rayon::prelude::*;
use crate::consts::H3_PRECOMPUTE_MAX;
/// Precompute H3 cell IDs for all rows at the maximum resolution only.
/// Parent cells for lower resolutions are derived on the fly via `CellIndex::parent()`.
pub fn precompute_h3(lat: &[f32], lon: &[f32]) -> anyhow::Result<Vec<u64>> {
let res = H3_PRECOMPUTE_MAX;
tracing::info!("Precomputing H3 cells at resolution {}", res);
let h3_res =
h3o::Resolution::try_from(res).with_context(|| format!("Invalid H3 resolution: {res}"))?;
let cells: Vec<u64> = lat
.par_iter()
.zip(lon.par_iter())
.enumerate()
.map(|(i, (&latitude, &longitude))| {
let coord = h3o::LatLng::new(latitude as f64, longitude as f64).unwrap_or_else(|err| {
panic!(
"Invalid coordinates at row {}: lat={}, lon={}: {}",
i, latitude, longitude, err
)
});
u64::from(coord.to_cell(h3_res))
})
.collect();
tracing::info!("H3 precomputation complete ({} cells)", cells.len());
Ok(cells)
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,238 @@
//! Property data: the row-major quantized feature matrix plus the side tables
//! (addresses, postcodes, renovation/price history, POI metrics) built from the
//! properties + postcode parquet files.
//!
//! Split by concern:
//! - [`loading`]: parquet ingestion, validation, spatial sort and matrix build
//! - [`stats`]: histograms, percentiles and slider-bound computation
//! - [`quant`]: u16 quantization encode/decode
//! - [`poi_metrics`]: postcode-level POI metric side table
//! - [`address_search`]: address tokenization, indexing and ranked search
//! - [`h3`]: H3 cell precomputation
mod address_search;
mod h3;
mod loading;
mod poi_metrics;
mod quant;
mod stats;
pub use h3::precompute_h3;
pub use poi_metrics::PostcodePoiMetrics;
pub use quant::QuantRef;
pub use stats::{FeatureStats, Histogram};
use rustc_hash::FxHashMap;
use serde::Serialize;
use crate::consts::NAN_U16;
#[derive(Serialize, Clone)]
pub struct RenovationEvent {
pub year: i32,
pub event: String,
}
#[derive(Serialize, Clone)]
pub struct HistoricalPrice {
pub year: i32,
pub month: u8,
pub price: i64,
}
pub struct PropertyData {
pub lat: Vec<f32>,
pub lon: Vec<f32>,
pub feature_names: Vec<String>,
pub num_features: usize,
/// Number of numeric features (enum features start at this index).
pub num_numeric: usize,
/// Row-major flat array: feature_data[row * num_features + feat_idx].
/// Quantized to u16. NaN sentinel = u16::MAX (65535).
/// Numeric features: encoded via (val - min) / range * 65534.
/// Enum features: stored directly as u16 cast of the f32 index.
pub feature_data: Vec<u16>,
/// Per-feature: range / QUANT_SCALE for fast decode.
dequant_a: Vec<f32>,
/// Per-feature: minimum value (offset for dequantization).
quant_min: Vec<f32>,
/// Per-feature: max - min (for encoding filter bounds).
quant_range: Vec<f32>,
pub feature_stats: Vec<FeatureStats>,
pub poi_metrics: PostcodePoiMetrics,
/// Unquantized last sale price used by the price-history chart.
last_known_price_raw: Vec<f32>,
/// Contiguous buffer holding all address strings end-to-end.
address_buffer: String,
/// Byte offset into `address_buffer` where each row's address starts.
address_offsets: Vec<u32>,
/// Length in bytes of each row's address.
address_lengths: Vec<u16>,
/// Interned postcodes: reader is thread-safe, keys index into it.
postcode_interner: lasso::RodeoReader,
postcode_keys: Vec<lasso::Spur>,
/// Rows for each postcode, keyed by the interned postcode key.
postcode_row_index: FxHashMap<lasso::Spur, Vec<u32>>,
/// Inverted index from address tokens to property rows.
address_token_index: FxHashMap<String, Vec<u32>>,
/// Prefix lookup from typed address-token prefix to indexed full address tokens.
address_prefix_index: FxHashMap<String, Vec<String>>,
/// Interned normalized address-search tokens used for per-row scoring.
address_search_interner: lasso::RodeoReader,
/// Flat per-row normalized address-search token keys.
address_search_token_keys: Vec<lasso::Spur>,
/// Offset into `address_search_token_keys` for each row.
address_search_token_offsets: Vec<u32>,
/// Number of normalized address-search token keys for each row.
address_search_token_lengths: Vec<u16>,
/// For enum features: maps feature index to list of possible string values.
/// Index in values list corresponds to the u16 value stored in feature_data.
pub enum_values: rustc_hash::FxHashMap<usize, Vec<String>>,
/// For enum features: maps feature index to per-value global counts (same order as enum_values).
pub enum_counts: rustc_hash::FxHashMap<usize, Vec<u64>>,
/// Per-row flag: true = construction date is approximate (from EPC band),
/// false = exact (from new-build transaction date).
/// Bit-packed: byte `row / 8`, bit `row % 8`. 8x smaller than Vec<bool>.
approx_build_date_bits: Vec<u8>,
/// Per-row renovation events. Keyed by (permuted) row index.
/// Only rows with events are present in the map.
renovation_history: FxHashMap<u32, Vec<RenovationEvent>>,
/// Per-row historical sale transactions (Land Registry price-paid).
/// Keyed by (permuted) row index. Only rows with prices are present.
historical_prices: FxHashMap<u32, Vec<HistoricalPrice>>,
property_sub_type: FxHashMap<u32, String>,
price_qualifier: FxHashMap<u32, String>,
}
impl PropertyData {
/// Get the address string for a given row.
pub fn address(&self, row: usize) -> &str {
let offset = self.address_offsets[row] as usize;
let length = self.address_lengths[row] as usize;
&self.address_buffer[offset..offset + length]
}
/// Get the postcode string for a given row.
pub fn postcode(&self, row: usize) -> &str {
self.postcode_interner.resolve(&self.postcode_keys[row])
}
/// Get postcode components for field-level borrowing (avoids conflicting borrows with feature_data).
pub fn postcode_parts(&self) -> (&lasso::RodeoReader, &[lasso::Spur]) {
(&self.postcode_interner, &self.postcode_keys)
}
/// Property rows for a given postcode string, or empty if unknown.
pub fn rows_for_postcode(&self, postcode: &str) -> &[u32] {
self.postcode_interner
.get(postcode)
.and_then(|key| self.postcode_row_index.get(&key))
.map(Vec::as_slice)
.unwrap_or(&[])
}
/// Get the is_approx_build_date flag for a given row (bit-packed).
pub fn is_approx_build_date(&self, row: usize) -> bool {
let byte = self.approx_build_date_bits[row / 8];
byte & (1 << (row % 8)) != 0
}
/// Get renovation events for a given row (empty slice if none).
pub fn renovation_history(&self, row: usize) -> &[RenovationEvent] {
self.renovation_history
.get(&(row as u32))
.map(|v| v.as_slice())
.unwrap_or(&[])
}
/// Get historical sale transactions for a given row (empty slice if none).
pub fn historical_prices(&self, row: usize) -> &[HistoricalPrice] {
self.historical_prices
.get(&(row as u32))
.map(|v| v.as_slice())
.unwrap_or(&[])
}
/// Get property sub-type for a given row.
pub fn property_sub_type(&self, row: usize) -> Option<&str> {
self.property_sub_type
.get(&(row as u32))
.map(String::as_str)
}
/// Get price qualifier for a given row.
pub fn price_qualifier(&self, row: usize) -> Option<&str> {
self.price_qualifier.get(&(row as u32)).map(String::as_str)
}
/// Get the unquantized last sale price for charting.
#[inline]
pub fn last_known_price_raw(&self, row: usize) -> f32 {
self.last_known_price_raw[row]
}
/// Decode a single feature value from quantized u16 storage.
#[inline]
pub fn get_feature(&self, row: usize, feat_idx: usize) -> f32 {
let raw = self.feature_data[row * self.num_features + feat_idx];
if raw == NAN_U16 {
return f32::NAN;
}
if feat_idx >= self.num_numeric {
raw as f32
} else {
raw as f32 * self.dequant_a[feat_idx] + self.quant_min[feat_idx]
}
}
/// Get a QuantRef for passing to aggregation/filter functions.
pub fn quant_ref(&self) -> QuantRef<'_> {
QuantRef {
dequant_a: &self.dequant_a,
quant_min: &self.quant_min,
quant_range: &self.quant_range,
num_numeric: self.num_numeric,
}
}
}
#[cfg(test)]
impl PropertyData {
/// Minimal empty instance for integration tests that need an `AppState`
/// but never touch property data (e.g. checkout/webhook/invite flows).
pub(crate) fn empty_for_tests() -> Self {
PropertyData {
lat: Vec::new(),
lon: Vec::new(),
feature_names: Vec::new(),
num_features: 0,
num_numeric: 0,
feature_data: Vec::new(),
dequant_a: Vec::new(),
quant_min: Vec::new(),
quant_range: Vec::new(),
feature_stats: Vec::new(),
poi_metrics: PostcodePoiMetrics::empty(0),
last_known_price_raw: Vec::new(),
address_buffer: String::new(),
address_offsets: Vec::new(),
address_lengths: Vec::new(),
postcode_interner: lasso::Rodeo::default().into_reader(),
postcode_keys: Vec::new(),
postcode_row_index: FxHashMap::default(),
address_token_index: FxHashMap::default(),
address_prefix_index: FxHashMap::default(),
address_search_interner: lasso::Rodeo::default().into_reader(),
address_search_token_keys: Vec::new(),
address_search_token_offsets: Vec::new(),
address_search_token_lengths: Vec::new(),
enum_values: rustc_hash::FxHashMap::default(),
enum_counts: rustc_hash::FxHashMap::default(),
approx_build_date_bits: Vec::new(),
renovation_history: FxHashMap::default(),
historical_prices: FxHashMap::default(),
property_sub_type: FxHashMap::default(),
price_qualifier: FxHashMap::default(),
}
}
}

View file

@ -0,0 +1,200 @@
//! Postcode-level POI metric side table: dynamic POI features are stored once
//! per postcode (not per property row) to keep the hot row-major feature matrix
//! narrow, with a per-property row mapping for lookups.
use anyhow::Context;
use polars::prelude::*;
use rayon::prelude::*;
use rustc_hash::FxHashMap;
use crate::consts::{NAN_U16, QUANT_SCALE};
use crate::features::{self, Bounds};
use super::quant::QuantRef;
use super::stats::{column_to_f32_vec, compute_feature_stats, FeatureStats};
pub(super) const NO_POI_METRIC_ROW: u32 = u32::MAX;
pub struct PostcodePoiMetrics {
pub feature_names: Vec<String>,
pub name_to_index: FxHashMap<String, usize>,
/// Metric-major storage: columns[metric_idx][postcode_metric_idx].
pub columns: Vec<Vec<u16>>,
pub feature_stats: Vec<FeatureStats>,
/// Per-property row lookup into the postcode metric table.
row_to_metric_idx: Vec<u32>,
dequant_a: Vec<f32>,
quant_min: Vec<f32>,
quant_range: Vec<f32>,
}
impl PostcodePoiMetrics {
pub(super) fn empty(row_count: usize) -> Self {
Self {
feature_names: Vec::new(),
name_to_index: FxHashMap::default(),
columns: Vec::new(),
feature_stats: Vec::new(),
row_to_metric_idx: vec![NO_POI_METRIC_ROW; row_count],
dequant_a: Vec::new(),
quant_min: Vec::new(),
quant_range: Vec::new(),
}
}
pub(super) fn from_postcode_df(
df: &DataFrame,
feature_names: Vec<String>,
) -> anyhow::Result<Self> {
if feature_names.is_empty() {
return Ok(Self::empty(0));
}
tracing::info!(
metrics = feature_names.len(),
postcodes = df.height(),
"Building postcode POI metric side table"
);
let col_major: Vec<Vec<f32>> = feature_names
.par_iter()
.map(|name| {
let column = df
.column(name.as_str())
.with_context(|| format!("Missing POI metric column '{name}'"))?;
column_to_f32_vec(column)
})
.collect::<anyhow::Result<Vec<_>>>()?;
let feature_stats: Vec<FeatureStats> = col_major
.par_iter()
.enumerate()
.map(|(metric_idx, vals)| {
let name = feature_names[metric_idx].as_str();
let bounds = features::bounds_for(name)
.with_context(|| format!("No bounds config for POI metric '{name}'"))?;
Ok(compute_feature_stats(
vals,
&bounds,
features::has_integer_bins(name),
))
})
.collect::<anyhow::Result<Vec<_>>>()?;
let mut quant_min = Vec::with_capacity(feature_names.len());
let mut quant_range = Vec::with_capacity(feature_names.len());
for (metric_idx, stats) in feature_stats.iter().enumerate() {
let (min, max) = match features::bounds_for(feature_names[metric_idx].as_str()) {
Some(Bounds::Fixed { min, max }) => (min, max),
_ => (stats.histogram.min, stats.histogram.max),
};
quant_min.push(min);
quant_range.push(if max > min { max - min } else { 0.0 });
}
let dequant_a: Vec<f32> = quant_range
.iter()
.map(|&range| {
if range > 0.0 {
range / QUANT_SCALE
} else {
0.0
}
})
.collect();
let columns: Vec<Vec<u16>> = col_major
.par_iter()
.enumerate()
.map(|(metric_idx, vals)| {
let range = quant_range[metric_idx];
let min = quant_min[metric_idx];
vals.iter()
.map(|&value| {
if !value.is_finite() {
NAN_U16
} else if range > 0.0 {
let normalized = (value - min) / range;
(normalized * QUANT_SCALE).round().clamp(0.0, QUANT_SCALE) as u16
} else {
0
}
})
.collect()
})
.collect();
let name_to_index = feature_names
.iter()
.enumerate()
.map(|(idx, name)| (name.clone(), idx))
.collect();
Ok(Self {
feature_names,
name_to_index,
columns,
feature_stats,
row_to_metric_idx: Vec::new(),
dequant_a,
quant_min,
quant_range,
})
}
pub(super) fn set_row_mapping(&mut self, row_to_metric_idx: Vec<u32>) {
self.row_to_metric_idx = row_to_metric_idx;
}
pub fn is_empty(&self) -> bool {
self.feature_names.is_empty()
}
pub fn num_features(&self) -> usize {
self.feature_names.len()
}
pub fn quant_ref(&self) -> QuantRef<'_> {
QuantRef {
dequant_a: &self.dequant_a,
quant_min: &self.quant_min,
quant_range: &self.quant_range,
num_numeric: self.feature_names.len(),
}
}
#[inline]
pub fn metric_row_for_property(&self, row: usize) -> Option<usize> {
self.row_to_metric_idx
.get(row)
.copied()
.filter(|&idx| idx != NO_POI_METRIC_ROW)
.map(|idx| idx as usize)
}
#[inline]
pub fn raw_for_metric_row(&self, metric_row: usize, metric_idx: usize) -> u16 {
self.columns[metric_idx][metric_row]
}
#[inline]
pub fn raw_for_property_row(&self, row: usize, metric_idx: usize) -> u16 {
let Some(metric_row) = self.metric_row_for_property(row) else {
return NAN_U16;
};
self.raw_for_metric_row(metric_row, metric_idx)
}
#[inline]
pub fn decode_raw(&self, metric_idx: usize, raw: u16) -> f32 {
if raw == NAN_U16 {
f32::NAN
} else {
raw as f32 * self.dequant_a[metric_idx] + self.quant_min[metric_idx]
}
}
#[inline]
pub fn get_for_property_row(&self, row: usize, metric_idx: usize) -> f32 {
self.decode_raw(metric_idx, self.raw_for_property_row(row, metric_idx))
}
}

View file

@ -0,0 +1,46 @@
//! u16 quantization: decoding stored feature values and encoding filter bounds.
use crate::consts::{NAN_U16, QUANT_SCALE};
/// Lightweight reference to quantization parameters for decoding u16 feature data.
pub struct QuantRef<'a> {
pub dequant_a: &'a [f32],
pub quant_min: &'a [f32],
pub quant_range: &'a [f32],
pub num_numeric: usize,
}
impl QuantRef<'_> {
/// Decode a raw u16 value back to f32.
#[inline]
pub fn decode(&self, feat_idx: usize, raw: u16) -> f32 {
if raw == NAN_U16 {
return f32::NAN;
}
if feat_idx >= self.num_numeric {
raw as f32
} else {
raw as f32 * self.dequant_a[feat_idx] + self.quant_min[feat_idx]
}
}
/// Encode a filter minimum bound to u16 (floors to include boundary values).
#[inline]
pub fn encode_min(&self, feat_idx: usize, value: f32) -> u16 {
if !value.is_finite() || self.quant_range[feat_idx] == 0.0 {
return 0;
}
let norm = (value - self.quant_min[feat_idx]) / self.quant_range[feat_idx];
(norm * QUANT_SCALE).floor().clamp(0.0, QUANT_SCALE) as u16
}
/// Encode a filter maximum bound to u16 (ceils to include boundary values).
#[inline]
pub fn encode_max(&self, feat_idx: usize, value: f32) -> u16 {
if !value.is_finite() || self.quant_range[feat_idx] == 0.0 {
return QUANT_SCALE as u16;
}
let norm = (value - self.quant_min[feat_idx]) / self.quant_range[feat_idx];
(norm * QUANT_SCALE).ceil().clamp(0.0, QUANT_SCALE) as u16
}
}

View file

@ -0,0 +1,544 @@
//! Feature statistics: outlier-bracketed histograms, percentile estimation and
//! slider-bound computation.
use anyhow::Context;
use polars::prelude::*;
use serde::Serialize;
use crate::consts::HISTOGRAM_BINS;
use crate::features::Bounds;
/// Histogram with outlier buckets at the edges.
/// - Bin 0: [min, p1) — low outliers
/// - Bins 1 to n-2: [p1, p99) — main distribution, evenly divided
/// - Bin n-1: [p99, max] — high outliers
#[derive(Serialize, Clone)]
pub struct Histogram {
pub min: f32,
pub max: f32,
/// 1st percentile (left edge of main distribution)
pub p1: f32,
/// 99th percentile (right edge of main distribution)
pub p99: f32,
pub counts: Vec<u64>,
}
impl Histogram {
/// Return the bin index for a given value using the outlier-bracket layout.
#[cfg(test)]
pub fn bin_for_value(&self, value: f32) -> usize {
let num_bins = self.counts.len();
if value < self.p1 {
0
} else if value >= self.p99 {
num_bins - 1
} else {
let middle_bins = num_bins.saturating_sub(2);
if middle_bins > 0 && self.p99 > self.p1 {
let width = (self.p99 - self.p1) / middle_bins as f32;
let middle_bin = ((value - self.p1) / width) as usize;
(1 + middle_bin).min(num_bins - 2)
} else {
num_bins / 2
}
}
}
/// Width of a single middle bin (bins 1..n-2).
#[cfg(test)]
pub fn middle_bin_width(&self) -> f32 {
let middle_bins = self.counts.len().saturating_sub(2);
if middle_bins > 0 && self.p99 > self.p1 {
(self.p99 - self.p1) / middle_bins as f32
} else {
0.0
}
}
}
pub struct FeatureStats {
pub slider_min: f32,
pub slider_max: f32,
pub histogram: Histogram,
}
/// Compute a percentile from a uniformly-binned histogram.
/// `prelim_counts` are uniform bins over [min, max].
fn percentile_from_uniform_histogram(
count: usize,
min: f32,
max: f32,
prelim_counts: &[u64],
percentile: f32,
) -> f32 {
if count == 0 || prelim_counts.is_empty() {
return min;
}
let target = (count as f64 * percentile as f64 / 100.0).floor() as u64;
let bin_width = (max - min) / prelim_counts.len() as f32;
let mut cumulative = 0u64;
for (i, &bin_count) in prelim_counts.iter().enumerate() {
let prev_cumulative = cumulative;
cumulative += bin_count;
if cumulative > target {
// Interpolate within this bin
let bin_start = min + i as f32 * bin_width;
let fraction = if bin_count > 0 {
(target - prev_cumulative) as f32 / bin_count as f32
} else {
0.0
};
return bin_start + fraction * bin_width;
}
}
max
}
/// Build a histogram and compute slider bounds based on the feature's Bounds config.
pub fn compute_feature_stats(vals: &[f32], bounds: &Bounds, integer_bins: bool) -> FeatureStats {
// Single pass: min, max, count (skipping NaN and infinity)
let mut min = f32::INFINITY;
let mut max = f32::NEG_INFINITY;
let mut count = 0usize;
for &value in vals {
if value.is_finite() {
if value < min {
min = value;
}
if value > max {
max = value;
}
count += 1;
}
}
if count == 0 {
let (slider_min, slider_max) = match bounds {
Bounds::Fixed {
min: fmin,
max: fmax,
} => (*fmin, *fmax),
Bounds::Percentile { .. } => (0.0, 0.0),
};
return FeatureStats {
slider_min,
slider_max,
histogram: Histogram {
min: 0.0,
max: 0.0,
p1: 0.0,
p99: 0.0,
counts: vec![0; HISTOGRAM_BINS],
},
};
}
// Build preliminary histogram with uniform bins to compute percentiles
// Use full HISTOGRAM_BINS for percentile precision
let range = if max == min { 1.0 } else { max - min };
let prelim_max = min + range * (1.0 + 1e-6);
let prelim_bin_width = (prelim_max - min) / HISTOGRAM_BINS as f32;
let mut prelim_counts = vec![0u64; HISTOGRAM_BINS];
for &value in vals {
if value.is_finite() {
let bin = ((value - min) / prelim_bin_width) as usize;
prelim_counts[bin.min(HISTOGRAM_BINS - 1)] += 1;
}
}
// Compute p1 and p99 from preliminary histogram
let mut p1 = percentile_from_uniform_histogram(count, min, max, &prelim_counts, 1.0);
let mut p99 = percentile_from_uniform_histogram(count, min, max, &prelim_counts, 99.0);
// Iterative refinement for outlier-dominated distributions.
// When extreme outliers (e.g. 317M sqm from web scraping) dominate the range,
// the uniform histogram puts all real data in one bin, making percentile
// estimation useless. Zoom into the estimated data region and recompute.
let mut refined_counts = prelim_counts;
let mut refined_count = count;
let mut refined_min = min;
let mut refined_max = max;
for _ in 0..3 {
let iqr = p99 - p1;
if iqr <= 0.0 || (refined_max - refined_min) <= 5.0 * iqr {
break;
}
let new_min = (p1 - iqr).max(min);
let new_max = p99 + iqr;
if new_max <= new_min {
break;
}
let bin_width = (new_max - new_min) / HISTOGRAM_BINS as f32;
let mut counts = vec![0u64; HISTOGRAM_BINS];
let mut cnt = 0usize;
for &value in vals {
if value.is_finite() && value >= new_min && value <= new_max {
let bin = ((value - new_min) / bin_width) as usize;
counts[bin.min(HISTOGRAM_BINS - 1)] += 1;
cnt += 1;
}
}
if cnt == 0 {
break;
}
p1 = percentile_from_uniform_histogram(cnt, new_min, new_max, &counts, 1.0);
p99 = percentile_from_uniform_histogram(cnt, new_min, new_max, &counts, 99.0);
refined_counts = counts;
refined_count = cnt;
refined_min = new_min;
refined_max = new_max;
}
// For integer-binned features, snap p1/p99 to integer boundaries
// so each middle bin is exactly 1 unit wide.
if integer_bins {
p1 = p1.floor();
p99 = p99.ceil();
}
// Determine number of histogram bins
let num_bins = if integer_bins && p99 > p1 {
// One middle bin per integer + 2 outlier bins
(p99 - p1) as usize + 2
} else {
// Count unique values within the p1p99 range to cap histogram bins.
// Using the full-range cardinality would over-allocate bins when outliers
// inflate it (e.g. bedrooms: 1137 unique values but only ~10 within p1p99).
let cardinality = {
let mut unique_set = rustc_hash::FxHashSet::default();
for &val in vals {
if val.is_finite() && val >= p1 && val <= p99 {
unique_set.insert(val.to_bits());
}
}
unique_set.len()
};
HISTOGRAM_BINS.min(cardinality).max(3)
};
// Build final histogram with outlier bins at edges:
// - Bin 0: [min, p1) — low outliers
// - Bins 1 to n-2: [p1, p99) — main distribution, evenly divided
// - Bin n-1: [p99, max] — high outliers
let mut counts = vec![0u64; num_bins];
let middle_bins = num_bins.saturating_sub(2);
let middle_width = if middle_bins > 0 && p99 > p1 {
(p99 - p1) / middle_bins as f32
} else {
0.0
};
for &value in vals {
if value.is_finite() {
let bin = if value < p1 {
0 // Low outlier bin
} else if value >= p99 {
num_bins - 1 // High outlier bin
} else if middle_width > 0.0 {
// Middle bins (1 to n-2)
let middle_bin = ((value - p1) / middle_width) as usize;
(1 + middle_bin).min(num_bins - 2)
} else {
num_bins / 2 // Fallback if p1 == p99
};
counts[bin] += 1;
}
}
let histogram = Histogram {
min: refined_min,
max: refined_max,
p1,
p99,
counts,
};
// Compute slider bounds (use refined histogram for accurate percentiles)
let (slider_min, slider_max) = match bounds {
Bounds::Fixed {
min: fmin,
max: fmax,
} => (*fmin, *fmax),
Bounds::Percentile { low, high } => {
let p_low = percentile_from_uniform_histogram(
refined_count,
refined_min,
refined_max,
&refined_counts,
*low as f32,
);
let p_high = percentile_from_uniform_histogram(
refined_count,
refined_min,
refined_max,
&refined_counts,
*high as f32,
);
(p_low, p_high)
}
};
FeatureStats {
slider_min,
slider_max,
histogram,
}
}
pub(super) fn column_to_f32_vec(column: &Column) -> anyhow::Result<Vec<f32>> {
let float_series = column
.cast(&DataType::Float32)
.context("Failed to cast column to Float32")?;
let chunked = float_series
.f32()
.context("Failed to get f32 chunked array")?;
Ok(chunked
.into_iter()
.map(|value| value.unwrap_or(f32::NAN))
.collect())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::consts::QUANT_SCALE;
use crate::features::Bounds;
fn make_fixed_bounds(min: f32, max: f32) -> Bounds {
Bounds::Fixed { min, max }
}
fn make_percentile_bounds(low: f64, high: f64) -> Bounds {
Bounds::Percentile { low, high }
}
#[test]
fn histogram_empty_data() {
let data: Vec<f32> = vec![];
let bounds = make_fixed_bounds(0.0, 100.0);
let stats = compute_feature_stats(&data, &bounds, false);
assert_eq!(stats.slider_min, 0.0);
assert_eq!(stats.slider_max, 100.0);
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 0);
}
#[test]
fn histogram_single_value() {
let data = vec![50.0_f32];
let bounds = make_fixed_bounds(0.0, 100.0);
let stats = compute_feature_stats(&data, &bounds, false);
assert_eq!(stats.histogram.min, 50.0);
assert_eq!(stats.histogram.max, 50.0);
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 1);
}
#[test]
fn histogram_uniform_distribution() {
let data: Vec<f32> = (0..100).map(|i| i as f32).collect();
let bounds = make_fixed_bounds(0.0, 100.0);
let stats = compute_feature_stats(&data, &bounds, false);
assert_eq!(stats.histogram.min, 0.0);
assert_eq!(stats.histogram.max, 99.0);
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 100);
}
#[test]
fn histogram_with_nan_values() {
let data = vec![10.0_f32, f32::NAN, 20.0, f32::NAN, 30.0];
let bounds = make_fixed_bounds(0.0, 100.0);
let stats = compute_feature_stats(&data, &bounds, false);
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 3);
assert_eq!(stats.histogram.min, 10.0);
assert_eq!(stats.histogram.max, 30.0);
}
#[test]
fn histogram_all_nan() {
let data = vec![f32::NAN, f32::NAN, f32::NAN];
let bounds = make_fixed_bounds(0.0, 100.0);
let stats = compute_feature_stats(&data, &bounds, false);
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 0);
}
#[test]
fn histogram_all_same_value() {
let data = vec![42.0_f32; 1000];
let bounds = make_fixed_bounds(0.0, 100.0);
let stats = compute_feature_stats(&data, &bounds, false);
assert_eq!(stats.histogram.min, 42.0);
assert_eq!(stats.histogram.max, 42.0);
assert_eq!(stats.histogram.p1, 42.0);
assert_eq!(stats.histogram.p99, 42.0);
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 1000);
}
#[test]
fn histogram_percentile_bounds() {
let mut data: Vec<f32> = vec![0.0]; // Low outlier
data.extend((1..99).map(|i| 50.0 + i as f32 * 0.01));
data.push(1000.0); // High outlier
let bounds = make_percentile_bounds(2.0, 98.0);
let stats = compute_feature_stats(&data, &bounds, false);
assert!(stats.slider_min > 0.0);
assert!(stats.slider_max < 1000.0);
}
#[test]
fn fixed_price_bounds_keep_slider_cap() {
let data = vec![400_000.0_f32, 2_500_000.0, 3_750_000.0];
let bounds = make_fixed_bounds(0.0, 2_500_000.0);
let stats = compute_feature_stats(&data, &bounds, false);
assert_eq!(stats.slider_min, 0.0);
assert_eq!(stats.slider_max, 2_500_000.0);
}
#[test]
fn histogram_bin_for_value() {
let hist = Histogram {
min: 0.0,
max: 100.0,
p1: 10.0,
p99: 90.0,
counts: vec![0; 10],
};
assert_eq!(hist.bin_for_value(5.0), 0); // Low outlier bin
assert_eq!(hist.bin_for_value(95.0), 9); // High outlier bin
let mid_value = 50.0;
let bin = hist.bin_for_value(mid_value);
assert!((1..=8).contains(&bin));
}
#[test]
fn histogram_middle_bin_width() {
let hist = Histogram {
min: 0.0,
max: 100.0,
p1: 10.0,
p99: 90.0,
counts: vec![0; 10],
};
let expected_width = (90.0 - 10.0) / 8.0;
assert!((hist.middle_bin_width() - expected_width).abs() < 0.001);
}
#[test]
fn histogram_cardinality_caps_bins() {
let data = vec![1.0_f32, 1.0, 2.0, 2.0, 3.0, 3.0];
let bounds = make_fixed_bounds(0.0, 100.0);
let stats = compute_feature_stats(&data, &bounds, false);
assert_eq!(stats.histogram.counts.len(), 3);
}
#[test]
fn min_max_skips_nan() {
let values = vec![10.0_f32, f32::NAN, 20.0, f32::NAN, 5.0];
let mut min = f32::INFINITY;
let mut max = f32::NEG_INFINITY;
for &v in &values {
if v.is_finite() {
if v < min {
min = v;
}
if v > max {
max = v;
}
}
}
assert_eq!(min, 5.0);
assert_eq!(max, 20.0);
}
#[test]
fn count_skips_nan() {
let values = [1.0_f32, f32::NAN, 2.0, f32::NAN, 3.0];
let count = values.iter().filter(|v| v.is_finite()).count();
assert_eq!(count, 3);
}
#[test]
fn infinity_values_excluded() {
let data = vec![f32::INFINITY, f32::NEG_INFINITY, 50.0];
let bounds = Bounds::Fixed {
min: 0.0,
max: 100.0,
};
let stats = compute_feature_stats(&data, &bounds, false);
assert_eq!(stats.histogram.min, 50.0);
assert_eq!(stats.histogram.max, 50.0);
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 1);
}
#[test]
fn only_finite_values() {
let data = vec![10.0_f32, 20.0, 30.0];
let bounds = Bounds::Fixed {
min: 0.0,
max: 100.0,
};
let stats = compute_feature_stats(&data, &bounds, false);
assert_eq!(stats.histogram.min, 10.0);
assert_eq!(stats.histogram.max, 30.0);
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 3);
}
#[test]
fn extreme_outlier_does_not_destroy_quantization() {
// Simulate floor area: 10k normal values (50-200 sqm) + one 317M outlier
let mut data: Vec<f32> = (0..10_000).map(|i| 50.0 + (i % 150) as f32).collect();
data.push(317_000_000.0); // Extreme outlier from web scraping
let bounds = make_percentile_bounds(0.0, 98.0);
let stats = compute_feature_stats(&data, &bounds, false);
// After refinement, histogram range should be much tighter than 317M
assert!(
stats.histogram.max < 1_000_000.0,
"histogram.max should be refined, got {}",
stats.histogram.max,
);
// p1 should be near 50, not millions
assert!(
stats.histogram.p1 < 100.0,
"p1 should be near real data, got {}",
stats.histogram.p1,
);
// Slider min should reflect actual data range
assert!(
stats.slider_min < 100.0,
"slider_min should be near real data, got {}",
stats.slider_min,
);
// Quantization using histogram.min/max should give usable range
let qmin = stats.histogram.min;
let qrange = stats.histogram.max - stats.histogram.min;
assert!(qrange > 0.0 && qrange < 1_000_000.0);
// A typical floor area (100 sqm) should be distinguishable from min
let normalized = (100.0 - qmin) / qrange;
let encoded = (normalized * QUANT_SCALE).round() as u16;
assert!(
encoded > 100,
"100 sqm should encode to a meaningful u16 value, got {}",
encoded,
);
}
}

View file

@ -272,6 +272,21 @@ pub fn slugify(name: &str) -> String {
result result
} }
#[cfg(test)]
impl TravelTimeStore {
/// Minimal empty instance for integration tests that need an `AppState`
/// but never touch travel time data.
pub(crate) fn empty_for_tests() -> Self {
Self {
base_dir: PathBuf::new(),
available_modes: Vec::new(),
destinations: FxHashMap::default(),
slug_to_file: FxHashMap::default(),
cache: Mutex::new(LruCache::new(1)),
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View file

@ -1042,7 +1042,44 @@ async fn main() -> anyhow::Result<()> {
listener, listener,
app.into_make_service_with_connect_info::<std::net::SocketAddr>(), app.into_make_service_with_connect_info::<std::net::SocketAddr>(),
) )
.with_graceful_shutdown(shutdown_signal())
.await .await
.context("Server error")?; .context("Server error")?;
info!("Server shut down cleanly");
Ok(()) Ok(())
} }
/// Resolves on SIGTERM or SIGINT so in-flight requests (exports, checkouts)
/// can drain before the process exits. The realtime SSE proxy connections
/// never complete, so a watchdog force-exits before Docker's default 10s
/// stop grace period elapses and it sends SIGKILL.
async fn shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("Failed to install SIGINT handler");
};
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("Failed to install SIGTERM handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
() = ctrl_c => {},
() = terminate => {},
}
info!("Shutdown signal received; draining in-flight requests");
tokio::spawn(async {
tokio::time::sleep(Duration::from_secs(8)).await;
tracing::warn!("Graceful shutdown drain timed out after 8s; forcing exit");
std::process::exit(0);
});
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,448 @@
//! The `POST /api/ai-filters` route handler: rate limiting, the Gemini
//! function-calling conversation loop, and zero-match refinement.
use std::sync::Arc;
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::Json;
use axum::Extension;
use metrics::counter;
use serde_json::{json, Value};
use tracing::{info, warn};
use crate::auth::OptionalUser;
use crate::consts::{AI_FILTERS_MAX_TOKENS, AI_FILTERS_TEMPERATURE, AI_FILTERS_WEEKLY_TOKEN_LIMIT};
use crate::pocketbase::log_ai_query;
use crate::state::SharedState;
use crate::utils::gemini_chat;
use super::matching::count_matching_rows;
use super::parsing::{
normalize_context_filters, strip_markdown_fences, validate_and_convert,
validate_travel_time_filters,
};
use super::tools::{build_tool_declarations, execute_destination_search};
use super::usage::{current_week_number, fetch_ai_usage, record_ai_request_usage};
use super::{AiFiltersRequest, AiFiltersResponse};
/// Budget limits for the Gemini conversation loop. Separate counters prevent
/// tool calls (destination searches) from starving JSON retries or zero-match
/// refinements.
const MAX_TOOL_CALLS: usize = 4;
const MAX_RETRIES: usize = 3;
const MAX_REFINEMENTS: u32 = 3;
const MAX_TOTAL_ROUNDS: usize = 10;
const MAX_AI_QUERY_CHARS: usize = 5000;
pub async fn post_ai_filters(
State(shared): State<Arc<SharedState>>,
Extension(user): Extension<OptionalUser>,
Json(req): Json<AiFiltersRequest>,
) -> Result<Json<AiFiltersResponse>, (StatusCode, String)> {
let state = shared.load_state();
// Auth check
let user = user
.0
.ok_or((StatusCode::UNAUTHORIZED, "Login required".into()))?;
if req.query.chars().count() > MAX_AI_QUERY_CHARS {
counter!("ai_requests_total", "status" => "query_too_long").increment(1);
return Err((
StatusCode::PAYLOAD_TOO_LARGE,
format!("Query too long (max {MAX_AI_QUERY_CHARS} chars)"),
));
}
// Check weekly token usage
let current_week = current_week_number();
let (stored_tokens, stored_week) = fetch_ai_usage(&state, &user.id).await?;
let tokens_used = if stored_week == current_week {
stored_tokens
} else {
0
};
if tokens_used >= AI_FILTERS_WEEKLY_TOKEN_LIMIT {
counter!("ai_requests_total", "status" => "rate_limited").increment(1);
return Err((
StatusCode::TOO_MANY_REQUESTS,
"Weekly AI usage limit reached. Resets next week.".into(),
));
}
info!(query = %req.query, user_id = %user.id, "POST /api/ai-filters");
let tools = build_tool_declarations(&state);
// Build user message with optional context for conversational refinement
let user_text = if let Some(ref ctx) = req.context {
let mut msg = String::new();
msg.push_str("Currently active filters:\n");
let normalized_filters = normalize_context_filters(&ctx.filters);
msg.push_str(&serde_json::to_string(&normalized_filters).unwrap_or_default());
if !ctx.travel_time.is_empty() {
msg.push_str("\nCurrently active travel time filters:\n");
for tt in &ctx.travel_time {
let bounds = match (tt.min, tt.max) {
(Some(min), Some(max)) => format!("{}-{} min", min, max),
(Some(min), None) => format!("min {} min", min),
(None, Some(max)) => format!("max {} min", max),
(None, None) => "no range".to_string(),
};
msg.push_str(&format!("- {} to {} ({})\n", tt.mode, tt.label, bounds));
}
}
msg.push_str(&format!("\nUser request: {}", req.query));
msg
} else {
req.query.clone()
};
let mut contents = vec![json!({
"role": "user",
"parts": [{ "text": user_text }]
})];
let mut total_tokens_accumulated: u64 = 0;
let mut tool_call_count = 0usize;
let mut retry_count = 0usize;
let mut refinement_attempts = 0u32;
// Function calling loop: model may call search_destinations, we execute and feed back
for round in 0..MAX_TOTAL_ROUNDS {
let body = json!({
"systemInstruction": {
"parts": [{ "text": state.ai_filters_system_prompt }]
},
"contents": contents,
"tools": tools,
"generationConfig": {
"temperature": AI_FILTERS_TEMPERATURE,
"maxOutputTokens": AI_FILTERS_MAX_TOKENS,
"thinkingConfig": { "thinkingLevel": "LOW" },
}
});
let json_resp = match gemini_chat(
&state.http_client,
&state.gemini_api_key,
&state.gemini_model,
&body,
)
.await
{
Ok(resp) => resp,
Err(err) => {
record_ai_request_usage(
&state,
&user.id,
tokens_used,
current_week,
total_tokens_accumulated,
"llm_error",
)
.await;
return Err(err);
}
};
// Accumulate token usage
total_tokens_accumulated += json_resp
.get("usageMetadata")
.and_then(|md| md.get("totalTokenCount"))
.and_then(|tc| tc.as_u64())
.unwrap_or(0);
let candidate = match json_resp
.get("candidates")
.and_then(|cs| cs.get(0))
.and_then(|c| c.get("content"))
{
Some(candidate) => candidate,
None => {
warn!("Malformed Gemini response: missing candidates[0].content");
record_ai_request_usage(
&state,
&user.id,
tokens_used,
current_week,
total_tokens_accumulated,
"malformed_response",
)
.await;
return Err((StatusCode::BAD_GATEWAY, "Malformed Gemini response".into()));
}
};
let parts = match candidate.get("parts").and_then(|p| p.as_array()) {
Some(parts) => parts,
None => {
warn!("Malformed Gemini response: missing parts array");
record_ai_request_usage(
&state,
&user.id,
tokens_used,
current_week,
total_tokens_accumulated,
"malformed_response",
)
.await;
return Err((StatusCode::BAD_GATEWAY, "Malformed Gemini response".into()));
}
};
// Check if the model made a function call.
// Find the full part (includes thoughtSignature required by Gemini 3 models).
if let Some(fc_part) = parts.iter().find(|part| part.get("functionCall").is_some()) {
let fc = fc_part.get("functionCall").unwrap();
let fn_name = fc.get("name").and_then(|n| n.as_str()).unwrap_or("");
let fn_args = fc.get("args").cloned().unwrap_or(json!({}));
tool_call_count += 1;
info!(
function = fn_name,
round = round,
tool_call = tool_call_count,
"AI called tool"
);
if tool_call_count > MAX_TOOL_CALLS {
warn!("Tool call budget exhausted, forcing text output");
contents.push(candidate.clone());
contents.push(json!({
"role": "user",
"parts": [{ "text": "Tool call limit reached. Output your best JSON now using the destinations you already found. Do not call any more tools." }]
}));
continue;
}
let fn_result = if fn_name == "search_destinations" {
let query = fn_args.get("query").and_then(|q| q.as_str()).unwrap_or("");
let mode = fn_args
.get("mode")
.and_then(|m| m.as_str())
.unwrap_or("transit");
execute_destination_search(&state, query, mode)
} else {
json!({"error": "unknown function"})
};
// Append the model's full response (preserves thoughtSignature) + our function result
contents.push(candidate.clone());
contents.push(json!({
"role": "user",
"parts": [{
"functionResponse": {
"name": fn_name,
"response": fn_result
}
}]
}));
// Continue the loop — model will process the results
continue;
}
// Model returned text — extract and parse as JSON
let text = parts
.iter()
.find_map(|part| part.get("text").and_then(|t| t.as_str()))
.unwrap_or("");
let text = strip_markdown_fences(text);
let text = text.trim();
if text.is_empty() {
retry_count += 1;
warn!(
"Gemini returned empty text content (round {}, retry {})",
round, retry_count
);
if retry_count > MAX_RETRIES {
record_ai_request_usage(
&state,
&user.id,
tokens_used,
current_week,
total_tokens_accumulated,
"empty_response",
)
.await;
return Err((
StatusCode::BAD_GATEWAY,
"AI returned empty responses".into(),
));
}
contents.push(candidate.clone());
contents.push(json!({
"role": "user",
"parts": [{ "text": "Your response was empty. Please output the JSON object." }]
}));
continue;
}
let raw: Value = match serde_json::from_str(text) {
Ok(val) => val,
Err(err) => {
retry_count += 1;
warn!(error = %err, round = round, retry = retry_count, "Failed to parse Gemini JSON output");
if retry_count > MAX_RETRIES {
record_ai_request_usage(
&state,
&user.id,
tokens_used,
current_week,
total_tokens_accumulated,
"invalid_json",
)
.await;
return Err((StatusCode::BAD_GATEWAY, "AI returned invalid JSON".into()));
}
contents.push(candidate.clone());
contents.push(json!({
"role": "user",
"parts": [{ "text": "That was not valid JSON. Please output ONLY the JSON object with numeric_filters, enum_filters, travel_time_filters, and notes." }]
}));
continue;
}
};
let filters = validate_and_convert(&raw, &state.features_response);
let travel_time_filters = validate_travel_time_filters(&raw, &state);
let notes = raw
.get("notes")
.and_then(|val| val.as_str())
.unwrap_or("")
.to_string();
// Count matching properties and refine if too restrictive
let (match_count, match_bounds) =
count_matching_rows(&state, &filters, &travel_time_filters);
info!(
match_count = match_count,
round = round,
"AI filter match count"
);
if match_count == 0 {
refinement_attempts += 1;
let total_rows = state.data.lat.len();
info!(
attempt = refinement_attempts,
"0 matches out of {total_rows} — asking AI to relax filters"
);
if refinement_attempts > MAX_REFINEMENTS {
warn!("Refinement budget exhausted, returning filters with 0 matches");
record_ai_request_usage(
&state,
&user.id,
tokens_used,
current_week,
total_tokens_accumulated,
"zero_matches",
)
.await;
let notes = if notes.is_empty() {
"No properties match these filters. Try relaxing some constraints.".to_string()
} else {
format!(
"{}. No properties match. Try relaxing some constraints.",
notes
)
};
return Ok(Json(AiFiltersResponse {
filters,
travel_time_filters,
notes,
match_count: 0,
match_bounds: None,
}));
}
let feedback = match refinement_attempts {
1 => format!(
"Your proposed filters matched 0 properties out of {total_rows} total. \
The combination is too restrictive. Please widen some numeric ranges \
or add more enum values while keeping the user's intent. \
Output the adjusted JSON."
),
2 => format!(
"Still 0 matches out of {total_rows}. Please widen ranges further. \
Output the adjusted JSON."
),
_ => format!(
"Still 0 matches out of {total_rows}. Please remove additional filters \
until some properties match, keeping the user's core priority. \
Output the adjusted JSON."
),
};
contents.push(candidate.clone());
contents.push(json!({
"role": "user",
"parts": [{ "text": feedback }]
}));
continue;
}
record_ai_request_usage(
&state,
&user.id,
tokens_used,
current_week,
total_tokens_accumulated,
"success",
)
.await;
// Log the query to PocketBase (fire-and-forget)
let filters_json = serde_json::to_string(&filters).unwrap_or_default();
let log_state = state.clone();
let log_user_id = user.id.clone();
let log_query = req.query.clone();
let log_notes = notes.clone();
let log_rounds = (round + 1) as u64;
tokio::spawn(async move {
log_ai_query(
&log_state,
&log_user_id,
&log_query,
&filters_json,
&log_notes,
total_tokens_accumulated,
log_rounds,
)
.await;
});
return Ok(Json(AiFiltersResponse {
filters,
travel_time_filters,
notes,
match_count,
match_bounds,
}));
}
// Exhausted total round budget without getting a valid response
warn!(
"AI exhausted {} total rounds without final response (tools={}, retries={}, refinements={})",
MAX_TOTAL_ROUNDS, tool_call_count, retry_count, refinement_attempts
);
record_ai_request_usage(
&state,
&user.id,
tokens_used,
current_week,
total_tokens_accumulated,
"incomplete",
)
.await;
Err((
StatusCode::BAD_GATEWAY,
"AI could not complete the request".into(),
))
}

View file

@ -0,0 +1,158 @@
//! Counting properties that match the AI-proposed property and travel time
//! filters, and computing a camera-friendly bounding box of the matches.
use serde_json::Value;
use tracing::warn;
use crate::data::travel_time::TravelData;
use crate::parsing::{parse_filters_with_poi, row_passes_filters, row_passes_poi_filters};
use crate::state::AppState;
use super::{MatchBounds, TravelTimeFilter};
/// Bounding box over matched coordinates, trimmed to the 5th95th percentile
/// per axis (when there are enough points) so a handful of remote outliers
/// doesn't zoom the camera out to all of England.
fn percentile_trimmed_bounds(mut lats: Vec<f32>, mut lons: Vec<f32>) -> Option<MatchBounds> {
if lats.is_empty() || lats.len() != lons.len() {
return None;
}
lats.sort_unstable_by(f32::total_cmp);
lons.sort_unstable_by(f32::total_cmp);
let last = lats.len() - 1;
let (lo, hi) = if lats.len() >= 20 {
let trim = lats.len() / 20;
(trim, last - trim)
} else {
(0, last)
};
Some(MatchBounds {
south: lats[lo],
north: lats[hi],
west: lons[lo],
east: lons[hi],
})
}
/// Convert validated filter JSON back to the `;;`-separated filter string format
/// that `parse_filters` expects.
///
/// Numeric: `{"name": [min, max]}` → `name:min:max`
/// Enum: `{"name": ["val1", "val2"]}` → `name:val1|val2`
fn filters_to_filter_string(filters: &Value) -> String {
let obj = match filters.as_object() {
Some(obj) => obj,
None => return String::new(),
};
let mut parts = Vec::new();
for (name, value) in obj {
if let Some(arr) = value.as_array() {
if arr.len() == 2 && arr[0].is_number() && arr[1].is_number() {
let min = arr[0].as_f64().unwrap_or(0.0);
let max = arr[1].as_f64().unwrap_or(0.0);
parts.push(format!("{name}:{min}:{max}"));
} else if !arr.is_empty() && arr[0].is_string() {
let values: Vec<&str> = arr.iter().filter_map(|v| v.as_str()).collect();
if !values.is_empty() {
parts.push(format!("{name}:{}", values.join("|")));
}
}
}
}
parts.join(";;")
}
/// Count how many rows in the property dataset pass the given property filters
/// AND travel time filters. Travel time data is loaded from the TravelTimeStore
/// and checked per-postcode (same logic as hexagons.rs).
pub(super) fn count_matching_rows(
state: &AppState,
filters: &Value,
travel_time_filters: &[TravelTimeFilter],
) -> (usize, Option<MatchBounds>) {
let filter_str = filters_to_filter_string(filters);
let quant = state.data.quant_ref();
let poi_quant = state.data.poi_metrics.quant_ref();
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = if filter_str.is_empty() {
(Vec::new(), Vec::new(), Vec::new())
} else {
match parse_filters_with_poi(
Some(&filter_str),
&state.feature_name_to_index,
&state.data.enum_values,
&quant,
&state.data.poi_metrics.name_to_index,
&poi_quant,
) {
Ok(f) => f,
Err(err) => {
warn!("Failed to parse filters for match count: {err}");
return (0, None);
}
}
};
// Load travel time data for each filter entry
let travel_data: Vec<(TravelData, Option<f32>, Option<f32>)> = travel_time_filters
.iter()
.filter_map(|ttf| {
let data = state.travel_time_store.get(&ttf.mode, &ttf.slug).ok()?;
Some((data, ttf.min, ttf.max))
})
.collect();
let has_travel = !travel_data.is_empty();
let feature_data = &state.data.feature_data;
let num_features = state.data.num_features;
let num_rows = state.data.lat.len();
let (pc_interner, pc_keys) = state.data.postcode_parts();
let has_poi_filters = !parsed_poi_filters.is_empty();
let mut count = 0usize;
let mut matched_lats: Vec<f32> = Vec::new();
let mut matched_lons: Vec<f32> = Vec::new();
for (row, pc_key) in pc_keys.iter().enumerate().take(num_rows) {
if !row_passes_filters(
row,
&parsed_filters,
&parsed_enum_filters,
feature_data,
num_features,
) {
continue;
}
if has_poi_filters
&& !row_passes_poi_filters(row, &parsed_poi_filters, &state.data.poi_metrics)
{
continue;
}
if has_travel {
let postcode = pc_interner.resolve(pc_key);
let mut passes_travel = true;
for (data, fmin, fmax) in &travel_data {
let pass = if let Some(mins) = data.get(postcode).map(|r| r.minutes as f32) {
fmin.is_none_or(|min| mins >= min) && fmax.is_none_or(|max| mins <= max)
} else {
false // no travel data → postcode not reachable
};
if !pass {
passes_travel = false;
break;
}
}
if !passes_travel {
continue;
}
}
count += 1;
matched_lats.push(state.data.lat[row]);
matched_lons.push(state.data.lon[row]);
}
(count, percentile_trimmed_bounds(matched_lats, matched_lons))
}

View file

@ -0,0 +1,81 @@
//! AI filters: translate a natural-language property query into validated
//! filter settings via Gemini.
//!
//! Split by concern:
//! - [`handler`]: the `POST /api/ai-filters` route handler and Gemini
//! conversation loop
//! - [`prompt`]: system prompt building (precomputed at startup)
//! - [`tools`]: the `search_destinations` tool declaration and execution
//! - [`parsing`]: LLM response parsing and validation against feature metadata
//! - [`matching`]: counting properties that match the proposed filters
//! - [`usage`]: weekly token usage tracking / rate limiting
mod handler;
mod matching;
mod parsing;
mod prompt;
mod tools;
mod usage;
pub use handler::post_ai_filters;
pub use prompt::build_system_prompt;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Deserialize)]
pub struct AiFiltersContext {
filters: Value,
#[serde(default)]
travel_time: Vec<AiTravelTimeContext>,
}
#[derive(Deserialize)]
pub struct AiTravelTimeContext {
mode: String,
label: String,
min: Option<f32>,
max: Option<f32>,
}
#[derive(Deserialize)]
pub struct AiFiltersRequest {
query: String,
/// Current filters for conversational refinement (e.g. "make it cheaper")
context: Option<AiFiltersContext>,
}
#[derive(Serialize)]
pub struct TravelTimeFilter {
mode: String,
slug: String,
label: String,
#[serde(skip_serializing_if = "Option::is_none")]
min: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
max: Option<f32>,
}
#[derive(Serialize)]
pub struct AiFiltersResponse {
filters: Value,
#[serde(skip_serializing_if = "Vec::is_empty")]
travel_time_filters: Vec<TravelTimeFilter>,
/// What the LLM couldn't map to existing filters (empty if everything matched)
#[serde(skip_serializing_if = "String::is_empty")]
notes: String,
/// Number of properties matching the proposed property and travel time filters.
match_count: usize,
/// Bounding box of the matching properties so the client can move the
/// camera to where matches actually are. Absent when nothing matches.
#[serde(skip_serializing_if = "Option::is_none")]
match_bounds: Option<MatchBounds>,
}
#[derive(Serialize)]
pub struct MatchBounds {
south: f32,
west: f32,
north: f32,
east: f32,
}

View file

@ -0,0 +1,385 @@
//! LLM response parsing: stripping markdown fences, normalizing frontend
//! synthetic filter keys, and validating proposed filters against feature
//! metadata and available travel destinations.
use serde_json::{json, Map, Value};
use tracing::warn;
use crate::routes::{FeatureInfo, FeaturesResponse};
use crate::state::AppState;
use super::TravelTimeFilter;
/// Strip markdown code fences (```json ... ``` or ``` ... ```) from LLM output.
/// Models occasionally wrap JSON in markdown fencing even when told not to.
pub(super) fn strip_markdown_fences(text: &str) -> &str {
let trimmed = text.trim();
// Try ```json\n...\n``` or ```\n...\n```
if let Some(rest) = trimmed.strip_prefix("```") {
// Skip optional language tag (e.g. "json")
let rest = if let Some(newline_pos) = rest.find('\n') {
&rest[newline_pos + 1..]
} else {
return trimmed;
};
if let Some(content) = rest.strip_suffix("```") {
return content.trim();
}
}
trimmed
}
fn school_feature_name_from_key(name: &str) -> Option<&'static str> {
let rest = name.strip_prefix("Schools:")?;
let mut parts = rest.split(':');
let phase = parts.next()?;
let rating = parts.next()?;
match (phase, rating) {
("primary", "good") => Some("Good+ primary school catchments"),
("secondary", "good") => Some("Good+ secondary school catchments"),
("primary", "outstanding") => Some("Outstanding primary school catchments"),
("secondary", "outstanding") => Some("Outstanding secondary school catchments"),
_ => None,
}
}
fn decode_synthetic_feature_key(name: &str, prefix: &str) -> Option<String> {
let rest = name.strip_prefix(prefix)?;
let (encoded, _id) = rest.rsplit_once(':')?;
urlencoding::decode(encoded)
.ok()
.map(|decoded| decoded.into_owned())
}
/// Convert frontend synthetic filter keys back to backend feature names.
///
/// The React filter UI stores configurable cards under keys such as
/// `Political vote share:%25%20Labour:0`. The LLM and backend validators need
/// the real feature name (`% Labour`) instead.
fn backend_filter_name(name: &str) -> Option<String> {
if let Some(feature_name) = school_feature_name_from_key(name) {
return Some(feature_name.to_string());
}
for prefix in [
"Specific crimes:",
"Political vote share:",
"Ethnicities:",
"Amenity distance:",
"Transport distance:",
"Amenities within 2km:",
"Amenities within 5km:",
] {
if let Some(feature_name) = decode_synthetic_feature_key(name, prefix) {
return Some(feature_name);
}
}
None
}
fn canonical_filter_name(name: &str) -> String {
backend_filter_name(name).unwrap_or_else(|| name.to_string())
}
pub(super) fn normalize_context_filters(filters: &Value) -> Value {
let Some(obj) = filters.as_object() else {
return filters.clone();
};
let mut normalized = Map::with_capacity(obj.len());
for (name, value) in obj {
normalized.insert(canonical_filter_name(name), value.clone());
}
Value::Object(normalized)
}
/// Maximum travel-time minutes the data can contain. Matches the Java pipeline's
/// MAX_TRIP_DURATION_MINUTES and the frontend's MAX_TRAVEL_MINUTES.
const TRAVEL_TIME_MAX_MINUTES: f64 = 90.0;
fn travel_time_minute_field(item: &Value, key: &str) -> Option<f32> {
item.get(key)
.and_then(|val| val.as_f64())
.filter(|val| val.is_finite())
.map(|val| val.clamp(0.0, TRAVEL_TIME_MAX_MINUTES) as f32)
}
fn parse_travel_time_bounds(item: &Value) -> (Option<f32>, Option<f32>) {
let explicit_min = travel_time_minute_field(item, "min");
let explicit_max = travel_time_minute_field(item, "max");
let (mut min, mut max) = if explicit_min.is_some() || explicit_max.is_some() {
(explicit_min, explicit_max)
} else {
let value = travel_time_minute_field(item, "value");
match (item.get("bound").and_then(|val| val.as_str()), value) {
(Some("min"), Some(val)) => (Some(val), None),
(Some("max"), Some(val)) => (None, Some(val)),
_ => (None, None),
}
};
if let (Some(min_val), Some(max_val)) = (min, max) {
if min_val > max_val {
min = Some(max_val);
max = Some(min_val);
}
}
(min, max)
}
/// Validate travel time filters from LLM output against available destinations.
pub(super) fn validate_travel_time_filters(raw: &Value, state: &AppState) -> Vec<TravelTimeFilter> {
let arr = match raw
.get("travel_time_filters")
.and_then(|val| val.as_array())
{
Some(arr) => arr,
None => return Vec::new(),
};
let tt_store = &state.travel_time_store;
let mut results = Vec::new();
for item in arr {
let mode = match item.get("mode").and_then(|val| val.as_str()) {
Some(mode) => mode,
None => continue,
};
let slug = match item.get("slug").and_then(|val| val.as_str()) {
Some(slug) => slug,
None => continue,
};
let label = item
.get("label")
.and_then(|val| val.as_str())
.unwrap_or(slug);
// Verify this destination actually exists
if !tt_store.has_destination(mode, slug) {
warn!(
mode = mode,
slug = slug,
"AI suggested non-existent destination"
);
continue;
}
let (min, max) = parse_travel_time_bounds(item);
// Only include if at least one bound is set
if min.is_some() || max.is_some() {
results.push(TravelTimeFilter {
mode: mode.to_string(),
slug: slug.to_string(),
label: label.to_string(),
min,
max,
});
}
}
results
}
/// Validate LLM output against feature metadata and convert to FeatureFilters format.
///
/// Input format (array-based, each numeric filter sets one bound):
/// ```json
/// {
/// "numeric_filters": [{"name": "Last known price", "bound": "max", "value": 300000}],
/// "enum_filters": [{"name": "Leasehold/Freehold", "values": ["Freehold"]}]
/// }
/// ```
///
/// Output format (FeatureFilters):
/// ```json
/// { "Last known price": [0, 300000], "Leasehold/Freehold": ["Freehold"] }
/// ```
pub(super) fn validate_and_convert(raw: &Value, features: &FeaturesResponse) -> Value {
let mut result = serde_json::Map::new();
// Build lookup maps from feature metadata.
// Store both slider bounds (min/max from percentiles) and true data bounds
// (histogram.min/max) so one-sided AI filters use the full data range.
let mut numeric_features: rustc_hash::FxHashMap<&str, (f32, f32, f32, f32)> =
rustc_hash::FxHashMap::default();
let mut enum_features: rustc_hash::FxHashMap<&str, &[String]> =
rustc_hash::FxHashMap::default();
for group in &features.groups {
for feature in &group.features {
match feature {
FeatureInfo::Numeric {
name,
min,
max,
histogram,
..
} => {
numeric_features.insert(name, (*min, *max, histogram.min, histogram.max));
}
FeatureInfo::Enum { name, values, .. } => {
enum_features.insert(name, values);
}
}
}
}
// Process numeric filters — each sets one bound (min or max).
// The unset side uses the true data min/max (from histogram), not
// the slider bounds (percentile-based), so a "max" filter for crime
// produces [0, value] rather than [2nd-percentile, value].
if let Some(arr) = raw.get("numeric_filters").and_then(|val| val.as_array()) {
for item in arr {
let raw_name = match item.get("name").and_then(|val| val.as_str()) {
Some(name) => name,
None => continue,
};
let name = canonical_filter_name(raw_name);
let (slider_min, slider_max, data_min, data_max) =
match numeric_features.get(name.as_str()) {
Some(range) => *range,
None => continue,
};
let bound = match item.get("bound").and_then(|val| val.as_str()) {
Some(b) => b,
None => continue,
};
// Clamp value to true data range (not slider range)
let value = match item.get("value").and_then(|val| val.as_f64()) {
Some(v) => v.max(data_min as f64).min(data_max as f64) as f32,
None => continue,
};
let (filter_min, filter_max) = match bound {
"min" => (value, data_max),
"max" => (data_min, value),
_ => continue,
};
// Only include if range is narrower than full slider range
if filter_min > slider_min || filter_max < slider_max {
result.insert(name, json!([filter_min, filter_max]));
}
}
}
// Process enum filters
if let Some(arr) = raw.get("enum_filters").and_then(|val| val.as_array()) {
for item in arr {
let raw_name = match item.get("name").and_then(|val| val.as_str()) {
Some(name) => name,
None => continue,
};
let name = canonical_filter_name(raw_name);
let valid_values = match enum_features.get(name.as_str()) {
Some(values) => *values,
None => continue,
};
if let Some(selected) = item.get("values").and_then(|val| val.as_array()) {
let valid: Vec<&str> = selected
.iter()
.filter_map(|item| item.as_str())
.filter(|str_val| valid_values.iter().any(|known| known == str_val))
.collect();
if !valid.is_empty() && valid.len() < valid_values.len() {
result.insert(name, json!(valid));
}
}
}
}
Value::Object(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn strip_fences_json_tag() {
let input = "```json\n{\"a\": 1}\n```";
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
}
#[test]
fn strip_fences_no_tag() {
let input = "```\n{\"a\": 1}\n```";
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
}
#[test]
fn strip_fences_passthrough() {
let input = "{\"a\": 1}";
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
}
#[test]
fn strip_fences_whitespace() {
let input = " ```json\n {\"a\": 1} \n``` ";
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
}
#[test]
fn synthetic_filter_keys_are_normalized_to_backend_names() {
assert_eq!(
canonical_filter_name("Schools:primary:good:0"),
"Good+ primary school catchments"
);
// Legacy keys still carry a distance segment; it is ignored.
assert_eq!(
canonical_filter_name("Schools:primary:good:2:0"),
"Good+ primary school catchments"
);
assert_eq!(
canonical_filter_name("Specific crimes:Burglary%20%28avg%2Fyr%29:1"),
"Burglary (avg/yr)"
);
assert_eq!(
canonical_filter_name("Political vote share:%25%20Labour:0"),
"% Labour"
);
assert_eq!(
canonical_filter_name(
"Transport distance:Distance%20to%20nearest%20amenity%20%28Bus%20stop%29%20%28km%29:0"
),
"Distance to nearest amenity (Bus stop) (km)"
);
}
#[test]
fn context_filters_are_normalized_before_prompting() {
let filters = json!({
"Political vote share:%25%20Green:0": [40, 100],
"Estimated current price": [0, 500000],
});
let normalized = normalize_context_filters(&filters);
assert_eq!(normalized["% Green"], json!([40, 100]));
assert_eq!(normalized["Estimated current price"], json!([0, 500000]));
}
#[test]
fn travel_time_bounds_accept_min_max_schema() {
let item = json!({ "min": 30, "max": 45 });
assert_eq!(parse_travel_time_bounds(&item), (Some(30.0), Some(45.0)));
}
#[test]
fn travel_time_bounds_accept_legacy_bound_value_schema() {
let item = json!({ "bound": "max", "value": 30 });
assert_eq!(parse_travel_time_bounds(&item), (None, Some(30.0)));
}
#[test]
fn travel_time_bounds_clamp_and_order_range() {
// Data ceiling is 90 (matches Java MAX_TRIP_DURATION_MINUTES).
// Inputs outside [0, 90] clamp; min/max ordering is preserved as-given here.
let item = json!({ "min": 150, "max": -10 });
assert_eq!(parse_travel_time_bounds(&item), (Some(0.0), Some(90.0)));
}
}

View file

@ -0,0 +1,282 @@
//! System prompt building for the AI filters assistant.
use crate::routes::{FeatureInfo, FeaturesResponse};
/// Build the complete system prompt for AI filters.
///
/// Contains: role instructions, feature catalogue, travel time info,
/// few-shot examples, output rules.
/// Precomputed at startup and cached in AppState.
pub fn build_system_prompt(
features: &FeaturesResponse,
mode_destinations: &[(String, usize)],
) -> String {
let mut parts = Vec::new();
parts.push(
"You are a UK property search assistant. \
The user describes their ideal property or area in natural language. \
Translate their description into filter settings using ONLY the features listed below.\n\
\n\
Rules:\n\
- ONLY set filters the user explicitly mentioned or clearly implied.\n\
- Leave out any filter the user did not mention. Empty arrays are fine.\n\
- Each numeric filter sets ONE bound only: \"min\" (at least this value) \
or \"max\" (at most this value). Never set two filters on the same feature.\n\
- Use EXACT feature names from the list spelling, capitalisation, and punctuation must match.\n\
- \"cheap\" / \"affordable\" = lower price range. \"expensive\" = higher price range.\n\
- \"low crime\" / \"safe\" = low values on the Serious crime (avg/yr) and Minor crime (avg/yr) \
features (area-normalised incident density near the postcode). Prefer these aggregates for broad \
area safety; use specific crime features only when the user names a crime type.\n\
- \"quiet\" = low Noise (dB). \"green\" / \"near parks\" = high Number of amenities (Park) within 2km \
or low Distance to nearest park (km), depending on wording.\n\
- \"good schools\" = Good+ school features. \"outstanding schools\" = Outstanding school features.\n\
- Amenities and transport stops are normal filters in the feature catalogue. \
For \"near a bus stop\", \"near a station\", \"near shops\", etc., use the exact \
Distance to nearest amenity (...) or Number of amenities (...) feature when available.\n\
- Politics/elections are normal filters in the Neighbours group. Use exact vote share \
features such as % Labour, % Conservative, % Liberal Democrat, % Reform UK, % Green, \
% Other parties, or Voter turnout (%) when the user asks for political character.\n\
- When the user says a number like \"under 400k\", interpret it as 400000.\n\
- When the user says \"3 bed\" or \"3 bedroom\", use Number of bedrooms & living rooms \
(note: this counts bedrooms + living rooms combined, so 3 bed ~ min 4).\n\
- If the user mentions something that has no matching filter, put it in \"notes\" \
as a short phrase (e.g. \"No filter for: garden, sea view\"). \
If everything was matched, set \"notes\" to an empty string.\n\
\n\
CONVERSATIONAL REFINEMENT:\n\
The user's message may include their currently active filters as context. \
When context is provided:\n\
- \"make it cheaper\" / \"lower the price\" = adjust the existing price filter down\n\
- \"also add ...\" / \"and good schools\" = keep existing filters and add new ones\n\
- \"remove the ...\" / \"drop the ...\" = return filters WITHOUT the mentioned one\n\
- If the request is a completely new search (not a refinement), ignore the context \
and build filters from scratch.\n\
- Always output the COMPLETE set of filters (existing + modified), not just the changes."
.to_string(),
);
// Travel time section with available modes
let modes_list = mode_destinations
.iter()
.map(|(mode, count)| format!("- {} ({} destinations available)", mode, count))
.collect::<Vec<_>>()
.join("\n");
parts.push(format!(
"\n--- TRAVEL TIME FILTERS ---\n\
You can add travel time filters when the user mentions commute times, \
proximity to places, or wanting to be near/within X minutes of somewhere.\n\
\n\
Available travel-time modes (only use modes that have destinations):\n\
{}\n\
- \"car\" / \"drive\" / \"driving\" = car mode\n\
- \"cycle\" / \"bike\" / \"cycling\" = bicycle mode\n\
- \"walk\" / \"walking\" / \"on foot\" = walking mode\n\
- \"train\" / \"tube\" / \"bus\" / \"public transport\" / \"commute\" = transit mode\n\
- \"without buses\" / \"no bus\" / \"rail only\" = transit-no-bus mode\n\
- \"no change\" / \"no transfer\" / \"direct\" / \"single bus/train\" = transit-no-change mode\n\
- \"no change and no bus\" / \"direct rail/tube\" = transit-no-change-no-bus mode\n\
- If a mode appears in the available mode list but is not named above, you may still \
use the exact mode string from the list.\n\
\n\
When the user mentions a specific place, you MUST call the search_destinations \
tool to find the exact slug. Use the name and slug from the search results.\n\
If search_destinations returns an empty array, the destination is not available \
mention it in \"notes\" (e.g. \"No travel data for: Gatwick Airport\") and do NOT \
include a travel_time_filter for it.\n\
\n\
Travel time values are in MINUTES (0-90 range; data is capped at 90 min).\n\
- \"within 30 minutes\" = set \"max\": 30\n\
- \"at least 10 minutes\" = set \"min\": 10\n\
- \"30-45 minute commute\" = set \"min\": 30 and \"max\": 45 on the same travel_time_filter\n\
- If only a max is given, omit min (and vice versa). Do not use bound/value for travel time.\n\
\n\
INFERRING TRANSPORT MODE (when the user does not specify one explicitly):\n\
- \"commute\" to a major city centre or station = transit\n\
- \"near\" / \"close to\" a city centre or station = transit\n\
- \"near\" / \"close to\" a smaller town, village, or rural area = car\n\
- \"drive\" / \"driving distance\" / \"driving time\" = always car\n\
- If multiple modes are plausible, prefer transit for urban destinations \
(London, Manchester, Birmingham, Leeds, etc.) and car for everything else.",
modes_list,
));
// Feature guidance
parts.push(
"\n--- DATA SOURCE ---\n\
The data is historical property sales from the Land Registry.\n\
\n\
Use these features for price queries:\n\
- For purchase price: use \"Estimated current price\" or \"Last known price\"\n\
- For price per sqm: use \"Est. price per sqm\"\n\
- For rent estimates: use \"Estimated monthly rent\""
.to_string(),
);
// Feature catalogue
parts.push("\n--- AVAILABLE FEATURES ---\n".to_string());
for group in &features.groups {
parts.push(format!("## {}", group.name));
for feature in &group.features {
match feature {
FeatureInfo::Numeric {
name,
min,
max,
description,
prefix,
suffix,
..
} => {
parts.push(format!(
"- \"{}\" (numeric, {}{:.0}{} to {}{:.0}{}): {}",
name, prefix, min, suffix, prefix, max, suffix, description
));
}
FeatureInfo::Enum {
name,
values,
description,
..
} => {
parts.push(format!(
"- \"{}\" (enum, values: [{}]): {}",
name,
values
.iter()
.map(|val| format!("\"{}\"", val))
.collect::<Vec<_>>()
.join(", "),
description
));
}
}
}
}
// Few-shot examples
parts.push("\n--- EXAMPLES ---\n".to_string());
parts.push(
"User: \"cheap freehold house under 400k\"\n\
Output: {\"numeric_filters\": [{\"name\": \"Last known price\", \"bound\": \"max\", \"value\": 400000}], \
\"enum_filters\": [{\"name\": \"Leasehold/Freehold\", \"values\": [\"Freehold\"]}, \
{\"name\": \"Property type\", \"values\": [\"Detached\", \"Semi-Detached\", \"Terraced\"]}], \
\"travel_time_filters\": [], \
\"notes\": \"\"}"
.to_string(),
);
parts.push(
"\nUser: \"safe quiet area with good schools and parks\"\n\
Output: {\"numeric_filters\": [\
{\"name\": \"Serious crime (avg/yr)\", \"bound\": \"max\", \"value\": 5}, \
{\"name\": \"Minor crime (avg/yr)\", \"bound\": \"max\", \"value\": 20}, \
{\"name\": \"Noise (dB)\", \"bound\": \"max\", \"value\": 55}, \
{\"name\": \"Good+ primary school catchments\", \"bound\": \"min\", \"value\": 2}, \
{\"name\": \"Good+ secondary school catchments\", \"bound\": \"min\", \"value\": 1}, \
{\"name\": \"Number of amenities (Park) within 2km\", \"bound\": \"min\", \"value\": 3}], \
\"enum_filters\": [], \"travel_time_filters\": [], \"notes\": \"\"}"
.to_string(),
);
parts.push(
"\nUser: \"quiet area with outstanding schools\"\n\
Output: {\"numeric_filters\": [\
{\"name\": \"Noise (dB)\", \"bound\": \"max\", \"value\": 55}, \
{\"name\": \"Outstanding primary school catchments\", \"bound\": \"min\", \"value\": 1}, \
{\"name\": \"Outstanding secondary school catchments\", \"bound\": \"min\", \"value\": 1}], \
\"enum_filters\": [], \"travel_time_filters\": [], \"notes\": \"\"}"
.to_string(),
);
parts.push(
"\nUser: \"3 bed flat under 300k with fast broadband near the beach\"\n\
Output: {\"numeric_filters\": [\
{\"name\": \"Last known price\", \"bound\": \"max\", \"value\": 300000}, \
{\"name\": \"Number of bedrooms & living rooms\", \"bound\": \"min\", \"value\": 4}], \
\"enum_filters\": [{\"name\": \"Property type\", \"values\": [\"Flats/Maisonettes\"]}, \
{\"name\": \"Max available download speed (Mbps)\", \"values\": [\"100\", \"300\", \"1000\"]}], \
\"travel_time_filters\": [], \
\"notes\": \"No filter for: beach proximity\"}"
.to_string(),
);
parts.push(
"\nUser: \"within 30 minutes commute of Kings Cross, under 500k\"\n\
(After calling search_destinations for \"Kings Cross\" with mode \"transit\" \
and getting [{\"name\": \"Kings Cross\", \"slug\": \"kings-cross\", \"place_type\": \"station\"}])\n\
Output: {\"numeric_filters\": [\
{\"name\": \"Last known price\", \"bound\": \"max\", \"value\": 500000}], \
\"enum_filters\": [], \
\"travel_time_filters\": [{\"mode\": \"transit\", \"slug\": \"kings-cross\", \
\"label\": \"Kings Cross\", \"max\": 30}], \
\"notes\": \"\"}"
.to_string(),
);
parts.push(
"\nUser: \"family home with garden, 45 min drive from Manchester, good schools\"\n\
(After calling search_destinations for \"Manchester\" with mode \"car\" \
and getting [{\"name\": \"Manchester\", \"slug\": \"manchester\", \"place_type\": \"city\"}])\n\
Output: {\"numeric_filters\": [\
{\"name\": \"Total floor area (sqm)\", \"bound\": \"min\", \"value\": 100}, \
{\"name\": \"Number of bedrooms & living rooms\", \"bound\": \"min\", \"value\": 5}, \
{\"name\": \"Good+ primary school catchments\", \"bound\": \"min\", \"value\": 2}, \
{\"name\": \"Good+ secondary school catchments\", \"bound\": \"min\", \"value\": 1}], \
\"enum_filters\": [{\"name\": \"Property type\", \
\"values\": [\"Detached\", \"Semi-Detached\"]}], \
\"travel_time_filters\": [{\"mode\": \"car\", \"slug\": \"manchester\", \
\"label\": \"Manchester\", \"max\": 45}], \
\"notes\": \"No filter for: garden\"}"
.to_string(),
);
parts.push(
"\nUser: \"Labour-voting area with low burglary and a station nearby\"\n\
Output: {\"numeric_filters\": [\
{\"name\": \"% Labour\", \"bound\": \"min\", \"value\": 40}, \
{\"name\": \"Burglary (avg/yr)\", \"bound\": \"max\", \"value\": 10}, \
{\"name\": \"Distance to nearest amenity (Rail station) (km)\", \"bound\": \"max\", \"value\": 1}], \
\"enum_filters\": [], \"travel_time_filters\": [], \"notes\": \"\"}"
.to_string(),
);
// Examples showing rent and price features
parts.push(
"\nUser: \"2 bed flat with rent under £1500/month\"\n\
Output: {\
\"numeric_filters\": [{\"name\": \"Estimated monthly rent\", \"bound\": \"max\", \"value\": 1500}], \
\"enum_filters\": [{\"name\": \"Property type\", \"values\": [\"Flats/Maisonettes\"]}], \
\"travel_time_filters\": [], \
\"notes\": \"\"}"
.to_string(),
);
parts.push(
"\nUser: \"3 bed house under 500k with good schools\"\n\
Output: {\
\"numeric_filters\": [{\"name\": \"Estimated current price\", \"bound\": \"max\", \"value\": 500000}, \
{\"name\": \"Good+ primary school catchments\", \"bound\": \"min\", \"value\": 2}], \
\"enum_filters\": [{\"name\": \"Property type\", \
\"values\": [\"Detached\", \"Semi-Detached\", \"Terraced\"]}], \
\"travel_time_filters\": [], \
\"notes\": \"\"}"
.to_string(),
);
// Output format reminder
parts.push(
"\n--- OUTPUT FORMAT ---\n\
{\"numeric_filters\": [...], \"enum_filters\": [...], \
\"travel_time_filters\": [{\"mode\": \"...\", \"slug\": \"...\", \"label\": \"...\", \
\"min\": N, \"max\": N}, ...], \"notes\": \"...\"}\n\
- travel_time_filters: min and max are both optional, but include at least one. \
Use ONLY slugs returned by search_destinations. If a place isn't found, mention it in notes.\n\
Respond with ONLY the JSON object. No explanation."
.to_string(),
);
parts.join("\n")
}

View file

@ -0,0 +1,188 @@
//! The `search_destinations` Gemini tool: declaration and execution against
//! PlaceData + TravelTimeStore.
use serde_json::{json, Value};
use tracing::info;
use crate::data::slugify;
use crate::state::AppState;
/// Build the Gemini tool declaration for destination search.
pub(super) fn build_tool_declarations(state: &AppState) -> Value {
let modes: Vec<&str> = state
.travel_time_store
.available_modes
.iter()
.map(|mode| mode.as_str())
.collect();
json!([{
"functionDeclarations": [{
"name": "search_destinations",
"description": "Search for available travel time destinations (cities, stations, towns) that have precomputed travel time data. Call this when the user mentions wanting to be near, close to, or within a certain travel time of a specific place.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Place name to search for (e.g. 'Manchester', 'Kings Cross', 'Heathrow')"
},
"mode": {
"type": "string",
"enum": modes,
"description": "Transport mode to search destinations for"
}
},
"required": ["query", "mode"]
}
}]
}])
}
/// Execute a destination search against PlaceData + TravelTimeStore.
/// Returns matching destinations as a JSON value with `results` and optional `message`.
///
/// Uses word-based matching: all words in the query must appear somewhere in the
/// place name (order-independent). Also matches against slugs for short queries.
pub(super) fn execute_destination_search(state: &AppState, query: &str, mode: &str) -> Value {
let query_lower = query.to_lowercase();
let query_words: Vec<&str> = query_lower.split_whitespace().collect();
let query_slug = slugify(query);
let tt_store = &state.travel_time_store;
let pd = &state.place_data;
let slug_set = match tt_store.destinations.get(mode) {
Some(slugs) => slugs,
None => {
return json!({ "results": [], "message": format!("No travel data available for mode '{}'", mode) })
}
};
// Find places matching the query that have travel time data.
// A place matches if ALL query words appear in its name, OR its slug matches the query slug.
let mut matches: Vec<(usize, String, u8, u32)> = pd
.name_lower
.iter()
.enumerate()
.filter_map(|(idx, name_lower)| {
if !pd.travel_destination[idx] {
return None;
}
let words_match = query_words.iter().all(|word| name_lower.contains(word));
let slug = slugify(&pd.name[idx]);
let slug_match = slug.contains(&query_slug) || query_slug.contains(&slug);
if !words_match && !slug_match {
return None;
}
if slug_set.contains(&slug) {
Some((idx, slug, pd.type_rank[idx], pd.population[idx]))
} else {
None
}
})
.collect();
// Sort: type rank asc, population desc
matches.sort_unstable_by(|a, b| a.2.cmp(&b.2).then(b.3.cmp(&a.3)));
matches.truncate(10);
if matches.is_empty() {
// Check if the query matched a city that lacks its own travel data.
// If so, return nearby stations within that city as suggestions.
let matched_city_name: Option<&str> =
pd.name_lower
.iter()
.enumerate()
.find_map(|(idx, name_lower)| {
if !pd.travel_destination[idx] {
return None;
}
let words_match = query_words.iter().all(|word| name_lower.contains(word));
let slug = slugify(&pd.name[idx]);
let slug_match = slug.contains(&query_slug) || query_slug.contains(&slug);
if (words_match || slug_match) && pd.type_rank[idx] == 0 {
Some(pd.name[idx].as_str())
} else {
None
}
});
if let Some(city_name) = matched_city_name {
let city_lower = city_name.to_lowercase();
let mut city_matches: Vec<(usize, String, u8, u32)> = pd
.city
.iter()
.enumerate()
.filter_map(|(idx, city_opt)| {
if !pd.travel_destination[idx] {
return None;
}
let city = city_opt.as_deref()?;
if city.to_lowercase() != city_lower {
return None;
}
let slug = slugify(&pd.name[idx]);
if slug_set.contains(&slug) {
Some((idx, slug, pd.type_rank[idx], pd.population[idx]))
} else {
None
}
})
.collect();
city_matches.sort_unstable_by(|a, b| a.2.cmp(&b.2).then(b.3.cmp(&a.3)));
city_matches.truncate(10);
if !city_matches.is_empty() {
let results: Vec<Value> = city_matches
.into_iter()
.map(|(idx, slug, ..)| {
json!({
"name": pd.name[idx],
"slug": slug,
"place_type": pd.place_type.get(idx).to_string(),
})
})
.collect();
info!(
query = query,
city = city_name,
results = results.len(),
"Destination search fell back to city stations"
);
return json!({
"results": results,
"message": format!(
"No travel data for '{}' directly. Pick one of these nearby stations:",
city_name
)
});
}
}
info!(
query = query,
mode = mode,
"Destination search returned no results"
);
return json!({
"results": [],
"message": format!("No travel time data available for '{}' by {}. This destination cannot be used as a travel time filter.", query, mode)
});
}
let results: Vec<Value> = matches
.into_iter()
.map(|(idx, slug, ..)| {
json!({
"name": pd.name[idx],
"slug": slug,
"place_type": pd.place_type.get(idx).to_string(),
})
})
.collect();
json!({ "results": results })
}

View file

@ -0,0 +1,119 @@
//! Weekly AI token usage tracking and rate limiting, persisted on the user's
//! PocketBase record.
use axum::http::StatusCode;
use metrics::counter;
use serde_json::{json, Value};
use tracing::warn;
use crate::pocketbase::get_superuser_token;
use crate::state::AppState;
/// Monotonically increasing week number derived from Unix epoch.
/// Resets every 7 days (604800 seconds). Used for weekly rate limiting.
pub(super) fn current_week_number() -> u64 {
let secs = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
// Only possible if the system clock is before 1970; fall back to
// week 0 rather than panicking inside a request handler.
.unwrap_or_default()
.as_secs();
secs / 604_800
}
/// Fetch the user's current AI token usage from PocketBase.
/// Returns `(tokens_used, week_number)`.
pub(super) async fn fetch_ai_usage(
state: &AppState,
user_id: &str,
) -> Result<(u64, u64), (StatusCode, String)> {
let token = get_superuser_token(state).await.map_err(|err| {
warn!("Failed to auth superuser for AI usage check: {err}");
(StatusCode::BAD_GATEWAY, "Internal error".into())
})?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await
.map_err(|err| {
warn!("Failed to fetch user record for AI usage: {err}");
(StatusCode::BAD_GATEWAY, "Internal error".into())
})?;
if !resp.status().is_success() {
let status = resp.status();
warn!("PocketBase user fetch failed ({status})");
return Err((StatusCode::BAD_GATEWAY, "Internal error".into()));
}
let body: Value = resp.json().await.map_err(|err| {
warn!("Failed to parse user record: {err}");
(StatusCode::BAD_GATEWAY, "Internal error".into())
})?;
let tokens_used = body
.get("ai_tokens_used")
.and_then(|val| val.as_u64())
.unwrap_or(0);
let week = body
.get("ai_tokens_week")
.and_then(|val| val.as_u64())
.unwrap_or(0);
Ok((tokens_used, week))
}
/// Update the user's AI token usage in PocketBase.
/// Best-effort — logs warnings on failure but does not propagate errors.
async fn update_ai_usage(state: &AppState, user_id: &str, tokens_used: u64, week: u64) {
let token = match get_superuser_token(state).await {
Ok(tk) => tk,
Err(err) => {
warn!("Failed to auth superuser for AI usage update: {err}");
return;
}
};
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
let res = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&json!({
"ai_tokens_used": tokens_used,
"ai_tokens_week": week,
}))
.send()
.await;
match res {
Ok(resp) if resp.status().is_success() => {}
Ok(resp) => {
let status = resp.status();
warn!("Failed to update AI usage ({status})");
}
Err(err) => warn!("Failed to update AI usage: {err}"),
}
}
pub(super) async fn record_ai_request_usage(
state: &AppState,
user_id: &str,
existing_tokens_used: u64,
week: u64,
request_tokens_used: u64,
status: &'static str,
) {
if request_tokens_used > 0 {
let new_total = existing_tokens_used.saturating_add(request_tokens_used);
update_ai_usage(state, user_id, new_total, week).await;
counter!("ai_tokens_total").increment(request_tokens_used);
}
counter!("ai_requests_total", "status" => status).increment(1);
}

View file

@ -780,31 +780,28 @@ pub async fn get_export(
// groups themselves; postcodes within a group are sorted alphabetically. // groups themselves; postcodes within a group are sorted alphabetically.
// Each group carries a rolled-up summary aggregate for its header row. // Each group carries a rolled-up summary aggregate for its header row.
let outcode_groups: Vec<OutcodeGroup> = { let outcode_groups: Vec<OutcodeGroup> = {
let mut order: Vec<String> = Vec::new(); let mut groups: Vec<OutcodeGroup> = Vec::new();
let mut by_outcode: FxHashMap<String, OutcodeGroup> = FxHashMap::default(); let mut idx_by_outcode: FxHashMap<String, usize> = FxHashMap::default();
for (i, (pc_idx, agg)) in postcode_aggs.iter().enumerate() { for (i, (pc_idx, agg)) in postcode_aggs.iter().enumerate() {
let outcode = outcode_of(&postcode_data.postcodes[*pc_idx]).to_string(); let outcode = outcode_of(&postcode_data.postcodes[*pc_idx]).to_string();
let group = by_outcode.entry(outcode.clone()).or_insert_with(|| { let idx = *idx_by_outcode.entry(outcode.clone()).or_insert_with(|| {
order.push(outcode.clone()); groups.push(OutcodeGroup {
OutcodeGroup { outcode,
outcode: outcode.clone(),
members: Vec::new(), members: Vec::new(),
summary: PostcodeExportAgg::new(total_export_features), summary: PostcodeExportAgg::new(total_export_features),
} });
groups.len() - 1
}); });
group.members.push(i); groups[idx].members.push(i);
group.summary.merge_from(agg); groups[idx].summary.merge_from(agg);
} }
for group in by_outcode.values_mut() { for group in &mut groups {
group.members.sort_by(|&a, &b| { group.members.sort_by(|&a, &b| {
postcode_data.postcodes[postcode_aggs[a].0] postcode_data.postcodes[postcode_aggs[a].0]
.cmp(&postcode_data.postcodes[postcode_aggs[b].0]) .cmp(&postcode_data.postcodes[postcode_aggs[b].0])
}); });
} }
order groups
.into_iter()
.map(|outcode| by_outcode.remove(&outcode).unwrap())
.collect()
}; };
// Build Excel workbook with two sheets // Build Excel workbook with two sheets

View file

@ -130,6 +130,11 @@ pub struct HexagonStatsResponse {
pub price_history: Vec<PricePoint>, pub price_history: Vec<PricePoint>,
#[serde(skip_serializing_if = "Vec::is_empty")] #[serde(skip_serializing_if = "Vec::is_empty")]
pub crime_by_year: Vec<CrimeYearStats>, pub crime_by_year: Vec<CrimeYearStats>,
/// Latest year in the crime dataset as a whole. When a selection's series
/// end earlier (force-level publication gap, e.g. Greater Manchester),
/// the client captions the data as stale.
#[serde(skip_serializing_if = "Option::is_none")]
pub crime_latest_year: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub central_postcode: Option<String>, pub central_postcode: Option<String>,
#[serde(skip_serializing_if = "Vec::is_empty")] #[serde(skip_serializing_if = "Vec::is_empty")]
@ -645,12 +650,19 @@ pub async fn get_hexagon_stats(
"GET /api/hexagon-stats" "GET /api/hexagon-stats"
); );
let crime_latest_year = if crime_by_year.is_empty() {
None
} else {
stats::crime_latest_available_year(&state.crime_by_year)
};
Ok(HexagonStatsResponse { Ok(HexagonStatsResponse {
count: total_count, count: total_count,
numeric_features, numeric_features,
enum_features: enum_features_out, enum_features: enum_features_out,
price_history, price_history,
crime_by_year, crime_by_year,
crime_latest_year,
central_postcode, central_postcode,
filter_exclusions, filter_exclusions,
}) })

View file

@ -36,8 +36,9 @@ fn is_allowed_pb_path(path: &str) -> bool {
/// Dedicated HTTP client for proxying — does not follow redirects so 3xx /// Dedicated HTTP client for proxying — does not follow redirects so 3xx
/// responses are passed through to the browser (needed for OAuth flows). /// responses are passed through to the browser (needed for OAuth flows).
/// No overall timeout because SSE (Server-Sent Events) connections used by /// No client-wide timeout because SSE (Server-Sent Events) connections used
/// PocketBase realtime/OAuth2 are long-lived streams. /// by PocketBase realtime/OAuth2 are long-lived streams; non-realtime
/// requests get a per-request timeout instead (see PROXY_REQUEST_TIMEOUT).
static PROXY_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| { static PROXY_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
reqwest::Client::builder() reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none()) .redirect(reqwest::redirect::Policy::none())
@ -47,6 +48,11 @@ static PROXY_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
.expect("Failed to build proxy HTTP client") .expect("Failed to build proxy HTTP client")
}); });
/// Timeout for proxied requests other than the realtime SSE stream, so a hung
/// PocketBase cannot pile up handlers indefinitely. Generous enough for file
/// uploads/downloads over slow links.
const PROXY_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
pub async fn proxy_to_pocketbase( pub async fn proxy_to_pocketbase(
State(shared): State<Arc<SharedState>>, State(shared): State<Arc<SharedState>>,
req: Request, req: Request,
@ -58,10 +64,7 @@ pub async fn proxy_to_pocketbase(
let target_path = path.strip_prefix("/pb").unwrap_or(path); let target_path = path.strip_prefix("/pb").unwrap_or(path);
if !is_allowed_pb_path(target_path) { if !is_allowed_pb_path(target_path) {
warn!(path = %target_path, "Rejected PocketBase proxy request to disallowed path"); warn!(path = %target_path, "Rejected PocketBase proxy request to disallowed path");
return Response::builder() return StatusCode::NOT_FOUND.into_response();
.status(StatusCode::NOT_FOUND)
.body(Body::empty())
.unwrap();
} }
let query = req let query = req
.uri() .uri()
@ -73,6 +76,12 @@ pub async fn proxy_to_pocketbase(
let method = req.method().clone(); let method = req.method().clone();
let mut builder = PROXY_CLIENT.request(method, &url); let mut builder = PROXY_CLIENT.request(method, &url);
// The realtime SSE stream is intentionally unbounded; everything else
// must complete within the timeout.
if target_path != "/api/realtime" {
builder = builder.timeout(PROXY_REQUEST_TIMEOUT);
}
// Forward only safe headers (allowlist) // Forward only safe headers (allowlist)
const ALLOWED_HEADERS: &[&str] = &[ const ALLOWED_HEADERS: &[&str] = &[
"content-type", "content-type",
@ -96,10 +105,7 @@ pub async fn proxy_to_pocketbase(
Ok(bytes) => bytes, Ok(bytes) => bytes,
Err(err) => { Err(err) => {
warn!("Failed to read request body: {err}"); warn!("Failed to read request body: {err}");
return Response::builder() return (StatusCode::BAD_REQUEST, "Failed to read request body").into_response();
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Failed to read request body"))
.unwrap();
} }
}; };
builder = builder.body(body_bytes); builder = builder.body(body_bytes);
@ -129,14 +135,14 @@ pub async fn proxy_to_pocketbase(
// realtime system and OAuth2 flow — buffering would hang forever // realtime system and OAuth2 flow — buffering would hang forever
// since SSE responses never complete. // since SSE responses never complete.
let body = Body::from_stream(upstream.bytes_stream()); let body = Body::from_stream(upstream.bytes_stream());
response.body(body).unwrap() response.body(body).unwrap_or_else(|err| {
warn!("Failed to build proxied response: {err}");
(StatusCode::BAD_GATEWAY, "Invalid upstream response").into_response()
})
} }
Err(err) => { Err(err) => {
warn!("PocketBase proxy error: {err}"); warn!("PocketBase proxy error: {err}");
Response::builder() (StatusCode::BAD_GATEWAY, "PocketBase unavailable").into_response()
.status(StatusCode::BAD_GATEWAY)
.body(Body::from("PocketBase unavailable"))
.unwrap()
} }
} }
} }

View file

@ -184,12 +184,19 @@ pub async fn get_postcode_stats(
"GET /api/postcode-stats" "GET /api/postcode-stats"
); );
let crime_latest_year = if crime_by_year.is_empty() {
None
} else {
stats::crime_latest_available_year(&state.crime_by_year)
};
Ok(HexagonStatsResponse { Ok(HexagonStatsResponse {
count: total_count, count: total_count,
numeric_features, numeric_features,
enum_features: enum_features_out, enum_features: enum_features_out,
price_history, price_history,
crime_by_year, crime_by_year,
crime_latest_year,
central_postcode: None, central_postcode: None,
filter_exclusions, filter_exclusions,
}) })

View file

@ -340,16 +340,14 @@ pub fn compute_crime_by_year(
let points: Vec<CrimeYearPoint> = years let points: Vec<CrimeYearPoint> = years
.iter() .iter()
.filter_map(|&year| { .filter_map(|&year| {
let denom = fully_covered_rows let denom = fully_covered_rows + covered_counts.get(&year).copied().unwrap_or(0);
+ covered_counts.get(&year).copied().unwrap_or(0);
if denom == 0 { if denom == 0 {
// No selected postcode has published data for this year. // No selected postcode has published data for this year.
return None; return None;
} }
Some(CrimeYearPoint { Some(CrimeYearPoint {
year, year,
count: (sums.get(&year).copied().unwrap_or(0.0) / denom as f64) count: (sums.get(&year).copied().unwrap_or(0.0) / denom as f64) as f32,
as f32,
}) })
}) })
.collect(); .collect();
@ -365,6 +363,19 @@ pub fn compute_crime_by_year(
out out
} }
/// Latest year present anywhere in the by-year crime dataset. The client
/// compares each selection's last charted year against this to caption
/// force-level publication gaps (e.g. Greater Manchester ends mid-2019) as
/// stale data instead of letting old numbers read as current.
pub fn crime_latest_available_year(crime_by_year: &CrimeByYearData) -> Option<i32> {
crime_by_year
.years_by_type
.iter()
.flatten()
.copied()
.max()
}
pub fn compute_poi_feature_stats( pub fn compute_poi_feature_stats(
matching_rows: &[usize], matching_rows: &[usize],
poi_metrics: &PostcodePoiMetrics, poi_metrics: &PostcodePoiMetrics,

View file

@ -87,6 +87,105 @@ pub struct AppState {
pub bugsink_frontend_config: Option<BugsinkFrontendConfig>, pub bugsink_frontend_config: Option<BugsinkFrontendConfig>,
} }
#[cfg(test)]
impl AppState {
/// Minimal AppState for integration tests of the PocketBase/Stripe money
/// paths (checkout, webhook, licensing, invites). All map/property data is
/// empty; only the HTTP-facing config (PocketBase URL, Stripe secrets,
/// caches) carries meaningful values.
pub(crate) fn for_tests(pocketbase_url: String) -> Self {
use std::time::Duration;
use crate::data::{
ActualListingData, CrimeByYearData, OutcodeData, POIData, PlaceData, PostcodeData,
PropertyData, TravelTimeStore,
};
use crate::utils::InternedColumn;
let http_client = reqwest::Client::builder()
.timeout(Duration::from_secs(5))
.connect_timeout(Duration::from_secs(2))
.build()
.expect("test HTTP client should build");
AppState {
data: PropertyData::empty_for_tests(),
grid: GridIndex::build(&[], &[], 0.01),
h3_cells: Vec::new(),
feature_name_to_index: FxHashMap::default(),
min_keys: Vec::new(),
max_keys: Vec::new(),
avg_keys: Vec::new(),
features_response: FeaturesResponse { groups: Vec::new() },
ai_filters_system_prompt: String::new(),
poi_data: Arc::new(POIData::empty_for_tests()),
poi_grid: Arc::new(GridIndex::build(&[], &[], 0.01)),
place_data: Arc::new(PlaceData::empty_for_tests()),
postcode_data: Arc::new(PostcodeData {
postcodes: Vec::new(),
centroids: Vec::new(),
aabbs: Vec::new(),
polygons: Vec::new(),
postcode_to_idx: FxHashMap::default(),
}),
outcode_data: Arc::new(OutcodeData {
names: Vec::new(),
name_lower: Vec::new(),
centroids: Vec::new(),
cities: Vec::new(),
}),
poi_category_groups: Arc::new(Vec::new()),
travel_time_store: Arc::new(TravelTimeStore::empty_for_tests()),
actual_listings: Arc::new(ActualListingData {
lat: Vec::new(),
lon: Vec::new(),
postcode: Vec::new(),
address: Vec::new(),
property_type: InternedColumn::build(&[]),
property_sub_type: InternedColumn::build(&[]),
leasehold_freehold: InternedColumn::build(&[]),
price_qualifier: InternedColumn::build(&[]),
bedrooms: Vec::new(),
bathrooms: Vec::new(),
rooms_total: Vec::new(),
floor_area_sqm: Vec::new(),
asking_price: Vec::new(),
asking_price_per_sqm: Vec::new(),
listing_url: Vec::new(),
listing_status: InternedColumn::build(&[]),
listing_date_iso: Vec::new(),
features: Vec::new(),
filter_feature_data: Vec::new(),
poi_filter_feature_data: Vec::new(),
grid: GridIndex::build(&[], &[], 0.01),
}),
crime_by_year: Arc::new(CrimeByYearData {
crime_types: Vec::new(),
years_by_type: Vec::new(),
series_by_postcode: FxHashMap::default(),
covered_years_by_postcode: FxHashMap::default(),
}),
token_cache: Arc::new(TokenCache::new()),
superuser_token_cache: Arc::new(SuperuserTokenCache::new()),
share_cache: Arc::new(ShareBoundsCache::new()),
screenshot_url: "http://127.0.0.1:1/screenshot".to_string(),
public_url: "https://test.example".to_string(),
is_dev: false,
http_client,
pocketbase_url,
pocketbase_admin_email: "admin@test.example".to_string(),
pocketbase_admin_password: "test-admin-password".to_string(),
gemini_api_key: "test-gemini-key".to_string(),
gemini_model: "test-model".to_string(),
google_maps_api_key: "test-maps-key".to_string(),
stripe_secret_key: "sk_test_dummy".to_string(),
stripe_webhook_secret: "whsec_test_secret".to_string(),
stripe_referral_coupon_id: "couponTest30".to_string(),
bugsink_frontend_config: None,
}
}
}
/// Wraps AppState for shared access across route handlers. /// Wraps AppState for shared access across route handlers.
/// Route handlers call `load_state()` to get the current snapshot. /// Route handlers call `load_state()` to get the current snapshot.
pub struct SharedState { pub struct SharedState {