SPlit up
This commit is contained in:
parent
cf39ad754e
commit
f59d01227b
91 changed files with 10370 additions and 7562 deletions
|
|
@ -619,7 +619,10 @@ export default function AreaPane({
|
|||
/>
|
||||
{crimeSeries && crimeSeries.points.length > 1 && (
|
||||
<div className="mt-2">
|
||||
<CrimeYearChart points={crimeSeries.points} />
|
||||
<CrimeYearChart
|
||||
points={crimeSeries.points}
|
||||
latestAvailableYear={stats?.crime_latest_year}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
|
@ -663,7 +666,10 @@ export default function AreaPane({
|
|||
}
|
||||
chart={
|
||||
crimeSeries && crimeSeries.points.length > 1 ? (
|
||||
<CrimeYearChart points={crimeSeries.points} />
|
||||
<CrimeYearChart
|
||||
points={crimeSeries.points}
|
||||
latestAvailableYear={stats?.crime_latest_year}
|
||||
/>
|
||||
) : (
|
||||
numericStats.histogram &&
|
||||
(globalHistogram ? (
|
||||
|
|
|
|||
|
|
@ -1,14 +1,22 @@
|
|||
import { useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { CrimeYearPoint } from '../../types';
|
||||
|
||||
interface CrimeYearChartProps {
|
||||
points: CrimeYearPoint[];
|
||||
/**
|
||||
* Latest year available in the crime dataset as a whole. When the series
|
||||
* ends earlier, the area's police force stopped publishing (e.g. Greater
|
||||
* Manchester since mid-2019) and the chart is captioned as stale.
|
||||
*/
|
||||
latestAvailableYear?: number;
|
||||
}
|
||||
|
||||
const PADDING = { top: 6, right: 4, bottom: 14, left: 4 };
|
||||
const HEIGHT = 48;
|
||||
|
||||
export default function CrimeYearChart({ points }: CrimeYearChartProps) {
|
||||
export default function CrimeYearChart({ points, latestAvailableYear }: CrimeYearChartProps) {
|
||||
const { t } = useTranslation();
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const [width, setWidth] = useState(0);
|
||||
|
||||
|
|
@ -97,6 +105,11 @@ export default function CrimeYearChart({ points }: CrimeYearChartProps) {
|
|||
</text>
|
||||
</svg>
|
||||
)}
|
||||
{latestAvailableYear != null && yearMax < latestAvailableYear && (
|
||||
<p className="mt-0.5 text-[10px] leading-snug text-amber-700 dark:text-amber-400">
|
||||
{t('areaPane.crimeDataEnds', { year: yearMax })}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
67
frontend/src/components/map/DeckOverlay.tsx
Normal file
67
frontend/src/components/map/DeckOverlay.tsx
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
import { useEffect } from 'react';
|
||||
import { useControl } from 'react-map-gl/maplibre';
|
||||
import { MapboxOverlay } from '@deck.gl/mapbox';
|
||||
|
||||
interface DeckWithPrivateDraw {
|
||||
_drawLayers?: (
|
||||
redrawReason: string,
|
||||
renderOptions?: { viewports?: unknown[]; [key: string]: unknown }
|
||||
) => unknown;
|
||||
__propertyMapNullViewportPatch?: boolean;
|
||||
}
|
||||
|
||||
function patchNullViewportDraw(overlay: MapboxOverlay) {
|
||||
const deck = (overlay as unknown as { _deck?: DeckWithPrivateDraw })._deck;
|
||||
if (!deck || deck.__propertyMapNullViewportPatch || typeof deck._drawLayers !== 'function') {
|
||||
return;
|
||||
}
|
||||
|
||||
const drawLayers = deck._drawLayers.bind(deck);
|
||||
deck._drawLayers = (redrawReason, renderOptions) => {
|
||||
const viewports = renderOptions?.viewports;
|
||||
if (viewports) {
|
||||
// Split-route startup can hand deck.gl a transient null viewport before MapLibre has sized the map.
|
||||
const nonNullViewports = viewports.filter(Boolean);
|
||||
if (nonNullViewports.length === 0) return;
|
||||
if (nonNullViewports.length !== viewports.length) {
|
||||
return drawLayers(redrawReason, { ...renderOptions, viewports: nonNullViewports });
|
||||
}
|
||||
}
|
||||
return drawLayers(redrawReason, renderOptions);
|
||||
};
|
||||
deck.__propertyMapNullViewportPatch = true;
|
||||
}
|
||||
|
||||
class SafeMapboxOverlay extends MapboxOverlay {
|
||||
onAdd(map: unknown) {
|
||||
const element = super.onAdd(map);
|
||||
patchNullViewportDraw(this);
|
||||
return element;
|
||||
}
|
||||
|
||||
setProps(props: Parameters<MapboxOverlay['setProps']>[0]) {
|
||||
super.setProps(props);
|
||||
patchNullViewportDraw(this);
|
||||
}
|
||||
}
|
||||
|
||||
export function DeckOverlay({
|
||||
layers,
|
||||
getTooltip,
|
||||
}: {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
layers: any[];
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
getTooltip: any;
|
||||
}) {
|
||||
const overlay = useControl(() => new SafeMapboxOverlay({ interleaved: true }));
|
||||
|
||||
useEffect(() => {
|
||||
overlay.setProps({
|
||||
layers: layers.filter(Boolean),
|
||||
getTooltip,
|
||||
});
|
||||
}, [overlay, layers, getTooltip]);
|
||||
|
||||
return null;
|
||||
}
|
||||
45
frontend/src/components/map/HoverCardOverlay.tsx
Normal file
45
frontend/src/components/map/HoverCardOverlay.tsx
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
import { memo } from 'react';
|
||||
|
||||
import type { FeatureFilters, FeatureMeta, HexagonData, PostcodeFeature } from '../../types';
|
||||
import HoverCard from './HoverCard';
|
||||
|
||||
interface HoverCardOverlayProps {
|
||||
x: number;
|
||||
y: number;
|
||||
id: string;
|
||||
usePostcodeView: boolean;
|
||||
data: HexagonData[];
|
||||
postcodeData: PostcodeFeature[];
|
||||
filters: FeatureFilters;
|
||||
features: FeatureMeta[];
|
||||
}
|
||||
|
||||
/** Resolves the hovered hexagon/postcode row from the loaded map data and
|
||||
* renders the hover card for it. Memoized so the row lookup only reruns when
|
||||
* the hover target or the underlying data actually changes. */
|
||||
export const HoverCardOverlay = memo(function HoverCardOverlay({
|
||||
x,
|
||||
y,
|
||||
id,
|
||||
usePostcodeView,
|
||||
data,
|
||||
postcodeData,
|
||||
filters,
|
||||
features,
|
||||
}: HoverCardOverlayProps) {
|
||||
return (
|
||||
<HoverCard
|
||||
x={x}
|
||||
y={y}
|
||||
id={id}
|
||||
isPostcode={usePostcodeView}
|
||||
data={
|
||||
usePostcodeView
|
||||
? postcodeData.find((f) => f.properties.postcode === id)?.properties || null
|
||||
: data.find((d) => d.h3 === id) || null
|
||||
}
|
||||
filters={filters}
|
||||
features={features}
|
||||
/>
|
||||
);
|
||||
});
|
||||
146
frontend/src/components/map/ListingPopups.tsx
Normal file
146
frontend/src/components/map/ListingPopups.tsx
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { TFunction } from 'i18next';
|
||||
|
||||
import type { ActualListing } from '../../types';
|
||||
|
||||
function formatListingPrice(price: number): string {
|
||||
return `£${price.toLocaleString()}`;
|
||||
}
|
||||
|
||||
function formatListingHeadline(listing: ActualListing, t: TFunction): string | null {
|
||||
const parts: string[] = [];
|
||||
if (listing.bedrooms != null) parts.push(t('common.bedsCount', { count: listing.bedrooms }));
|
||||
if (listing.bathrooms != null) parts.push(t('common.bathsCount', { count: listing.bathrooms }));
|
||||
if (listing.property_sub_type) parts.push(listing.property_sub_type);
|
||||
else if (listing.property_type) parts.push(listing.property_type);
|
||||
return parts.length > 0 ? parts.join(' · ') : null;
|
||||
}
|
||||
|
||||
export const ListingPopupSingleContent = memo(function ListingPopupSingleContent({
|
||||
listing,
|
||||
}: {
|
||||
listing: ActualListing;
|
||||
}) {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<a
|
||||
href={listing.listing_url}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="block px-3 py-2"
|
||||
>
|
||||
{listing.asking_price != null && (
|
||||
<div className="text-base font-bold text-teal-600 dark:text-teal-400">
|
||||
{formatListingPrice(listing.asking_price)}
|
||||
{listing.price_qualifier ? (
|
||||
<span className="ml-1 text-xs font-medium text-warm-500 dark:text-warm-400">
|
||||
{listing.price_qualifier}
|
||||
</span>
|
||||
) : null}
|
||||
</div>
|
||||
)}
|
||||
{formatListingHeadline(listing, t) && (
|
||||
<div className="text-xs text-warm-700 dark:text-warm-200 mt-0.5">
|
||||
{formatListingHeadline(listing, t)}
|
||||
</div>
|
||||
)}
|
||||
{listing.address && (
|
||||
<div className="text-xs text-warm-500 dark:text-warm-400 mt-0.5 line-clamp-2">
|
||||
{listing.address}
|
||||
</div>
|
||||
)}
|
||||
{listing.postcode && (
|
||||
<div className="text-[11px] text-warm-400 dark:text-warm-500 mt-0.5">
|
||||
{listing.postcode}
|
||||
</div>
|
||||
)}
|
||||
{listing.floor_area_sqm != null && (
|
||||
<div className="text-[11px] text-warm-500 dark:text-warm-400 mt-0.5">
|
||||
{Math.round(listing.floor_area_sqm)} sqm
|
||||
{listing.asking_price_per_sqm != null
|
||||
? ` · £${Math.round(listing.asking_price_per_sqm).toLocaleString()}/sqm`
|
||||
: ''}
|
||||
</div>
|
||||
)}
|
||||
{listing.features.length > 0 && (
|
||||
<ul className="mt-1.5 text-[11px] text-warm-600 dark:text-warm-300 list-disc pl-4 space-y-0.5">
|
||||
{listing.features.slice(0, 3).map((feature, idx) => (
|
||||
<li key={idx} className="line-clamp-1">
|
||||
{feature}
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
)}
|
||||
<div className="mt-1.5 text-[11px] text-teal-600 dark:text-teal-400 font-medium">
|
||||
Open listing ↗
|
||||
</div>
|
||||
</a>
|
||||
);
|
||||
});
|
||||
|
||||
export const ListingClusterPopupContent = memo(function ListingClusterPopupContent({
|
||||
count,
|
||||
listings,
|
||||
}: {
|
||||
count: number;
|
||||
listings: ActualListing[];
|
||||
}) {
|
||||
const { t } = useTranslation();
|
||||
const visibleCount = listings.length;
|
||||
return (
|
||||
<div>
|
||||
<div className="border-b border-warm-200 px-3 py-2 dark:border-warm-700">
|
||||
<div className="text-base font-bold text-red-600 dark:text-red-400">
|
||||
{count.toLocaleString()} listings
|
||||
</div>
|
||||
<div className="text-[11px] text-warm-500 dark:text-warm-400">
|
||||
{visibleCount > 0
|
||||
? `Showing ${visibleCount.toLocaleString()} of ${count.toLocaleString()}`
|
||||
: 'Grouped near this map position'}
|
||||
</div>
|
||||
</div>
|
||||
{visibleCount > 0 && (
|
||||
<div className="max-h-80 overflow-y-auto py-1">
|
||||
{listings.map((listing, idx) => {
|
||||
const headline = formatListingHeadline(listing, t);
|
||||
return (
|
||||
<a
|
||||
key={`${listing.listing_url}-${idx}`}
|
||||
href={listing.listing_url}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="block border-b border-warm-100 px-3 py-2 last:border-b-0 hover:bg-warm-50 dark:border-warm-700 dark:hover:bg-warm-700/60"
|
||||
>
|
||||
<div className="flex items-start justify-between gap-3">
|
||||
<div className="min-w-0">
|
||||
<div className="text-sm font-semibold text-teal-700 dark:text-teal-300">
|
||||
{listing.asking_price != null
|
||||
? formatListingPrice(listing.asking_price)
|
||||
: 'Listing'}
|
||||
</div>
|
||||
{headline && (
|
||||
<div className="mt-0.5 truncate text-xs text-warm-700 dark:text-warm-200">
|
||||
{headline}
|
||||
</div>
|
||||
)}
|
||||
{listing.address && (
|
||||
<div className="mt-0.5 line-clamp-1 text-[11px] text-warm-500 dark:text-warm-400">
|
||||
{listing.address}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{listing.postcode && (
|
||||
<div className="shrink-0 text-[11px] font-medium text-warm-400 dark:text-warm-500">
|
||||
{listing.postcode}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</a>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
});
|
||||
|
|
@ -1,10 +1,8 @@
|
|||
import { useCallback, useRef, useEffect, useState, useMemo, memo } from 'react';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { TFunction } from 'i18next';
|
||||
import { Layer, Map as MapGL, Source, useControl, ScaleControl } from 'react-map-gl/maplibre';
|
||||
import { Map as MapGL, ScaleControl } from 'react-map-gl/maplibre';
|
||||
import type { MapRef } from 'react-map-gl/maplibre';
|
||||
import { MapboxOverlay } from '@deck.gl/mapbox';
|
||||
import 'maplibre-gl/dist/maplibre-gl.css';
|
||||
import type {
|
||||
HexagonData,
|
||||
|
|
@ -17,7 +15,6 @@ import type {
|
|||
Bounds,
|
||||
MapFlyToOptions,
|
||||
ActualListing,
|
||||
SchoolMetadata,
|
||||
} from '../../types';
|
||||
|
||||
import {
|
||||
|
|
@ -26,28 +23,25 @@ import {
|
|||
getBoundsWithBottomScreenInset,
|
||||
getMapStyle,
|
||||
getMapDataBeforeId,
|
||||
getPoiIconUrl,
|
||||
getMapCenterForTargetScreenPoint,
|
||||
} from '../../lib/map-utils';
|
||||
import {
|
||||
MAP_MIN_ZOOM,
|
||||
MAP_BOUNDS,
|
||||
POI_GROUP_COLORS,
|
||||
POSTCODE_ZOOM_THRESHOLD,
|
||||
POI_AUTO_CARD_ZOOM_THRESHOLD,
|
||||
} from '../../lib/consts';
|
||||
import LocationSearch, { type SearchedLocation } from './LocationSearch';
|
||||
import MapLegend from './MapLegend';
|
||||
import HoverCard from './HoverCard';
|
||||
import { MAP_MIN_ZOOM, MAP_BOUNDS, POI_AUTO_CARD_ZOOM_THRESHOLD } from '../../lib/consts';
|
||||
import type { SearchedLocation } from './LocationSearch';
|
||||
import { LogoIcon } from '../ui/icons/LogoIcon';
|
||||
import { CloseIcon } from '../ui/icons/CloseIcon';
|
||||
import type { FeatureFilters } from '../../types';
|
||||
import { useDeckLayers } from '../../hooks/useDeckLayers';
|
||||
import { useTranslatedModes, type TravelTimeEntry } from '../../hooks/useTravelTime';
|
||||
import { ts } from '../../i18n/server';
|
||||
import { type OverlayId, OVERLAY_MIN_ZOOM } from '../../lib/overlays';
|
||||
import { useMapCardLayout } from '../../hooks/useMapCardLayout';
|
||||
import type { TravelTimeEntry } from '../../hooks/useTravelTime';
|
||||
import { type OverlayId } from '../../lib/overlays';
|
||||
import { CRIME_TYPE_VALUES } from '../../lib/crime-types';
|
||||
import type { BasemapId } from '../../lib/basemaps';
|
||||
import { DeckOverlay } from './DeckOverlay';
|
||||
import { OverlayTileLayers } from './OverlayTileLayers';
|
||||
import { MapTopCards } from './MapTopCards';
|
||||
import { PoiPopupCardContent } from './PoiPopupCard';
|
||||
import { ListingClusterPopupContent, ListingPopupSingleContent } from './ListingPopups';
|
||||
import { HoverCardOverlay } from './HoverCardOverlay';
|
||||
|
||||
interface MapProps {
|
||||
data: HexagonData[];
|
||||
|
|
@ -99,168 +93,11 @@ const EMPTY_ACTUAL_LISTINGS: ActualListing[] = [];
|
|||
const EMPTY_OVERLAYS = new Set<OverlayId>();
|
||||
const ALL_CRIME_TYPES = new Set<string>(CRIME_TYPE_VALUES);
|
||||
|
||||
function formatListingPrice(price: number): string {
|
||||
return `£${price.toLocaleString()}`;
|
||||
}
|
||||
|
||||
function formatListingHeadline(listing: ActualListing, t: TFunction): string | null {
|
||||
const parts: string[] = [];
|
||||
if (listing.bedrooms != null) parts.push(t('common.bedsCount', { count: listing.bedrooms }));
|
||||
if (listing.bathrooms != null) parts.push(t('common.bathsCount', { count: listing.bathrooms }));
|
||||
if (listing.property_sub_type) parts.push(listing.property_sub_type);
|
||||
else if (listing.property_type) parts.push(listing.property_type);
|
||||
return parts.length > 0 ? parts.join(' · ') : null;
|
||||
}
|
||||
|
||||
function ListingPopupSingleContent({ listing, t }: { listing: ActualListing; t: TFunction }) {
|
||||
return (
|
||||
<a
|
||||
href={listing.listing_url}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="block px-3 py-2"
|
||||
>
|
||||
{listing.asking_price != null && (
|
||||
<div className="text-base font-bold text-teal-600 dark:text-teal-400">
|
||||
{formatListingPrice(listing.asking_price)}
|
||||
{listing.price_qualifier ? (
|
||||
<span className="ml-1 text-xs font-medium text-warm-500 dark:text-warm-400">
|
||||
{listing.price_qualifier}
|
||||
</span>
|
||||
) : null}
|
||||
</div>
|
||||
)}
|
||||
{formatListingHeadline(listing, t) && (
|
||||
<div className="text-xs text-warm-700 dark:text-warm-200 mt-0.5">
|
||||
{formatListingHeadline(listing, t)}
|
||||
</div>
|
||||
)}
|
||||
{listing.address && (
|
||||
<div className="text-xs text-warm-500 dark:text-warm-400 mt-0.5 line-clamp-2">
|
||||
{listing.address}
|
||||
</div>
|
||||
)}
|
||||
{listing.postcode && (
|
||||
<div className="text-[11px] text-warm-400 dark:text-warm-500 mt-0.5">
|
||||
{listing.postcode}
|
||||
</div>
|
||||
)}
|
||||
{listing.floor_area_sqm != null && (
|
||||
<div className="text-[11px] text-warm-500 dark:text-warm-400 mt-0.5">
|
||||
{Math.round(listing.floor_area_sqm)} sqm
|
||||
{listing.asking_price_per_sqm != null
|
||||
? ` · £${Math.round(listing.asking_price_per_sqm).toLocaleString()}/sqm`
|
||||
: ''}
|
||||
</div>
|
||||
)}
|
||||
{listing.features.length > 0 && (
|
||||
<ul className="mt-1.5 text-[11px] text-warm-600 dark:text-warm-300 list-disc pl-4 space-y-0.5">
|
||||
{listing.features.slice(0, 3).map((feature, idx) => (
|
||||
<li key={idx} className="line-clamp-1">
|
||||
{feature}
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
)}
|
||||
<div className="mt-1.5 text-[11px] text-teal-600 dark:text-teal-400 font-medium">
|
||||
Open listing ↗
|
||||
</div>
|
||||
</a>
|
||||
);
|
||||
}
|
||||
|
||||
function ListingClusterPopupContent({
|
||||
count,
|
||||
listings,
|
||||
t,
|
||||
}: {
|
||||
count: number;
|
||||
listings: ActualListing[];
|
||||
t: TFunction;
|
||||
}) {
|
||||
const visibleCount = listings.length;
|
||||
return (
|
||||
<div>
|
||||
<div className="border-b border-warm-200 px-3 py-2 dark:border-warm-700">
|
||||
<div className="text-base font-bold text-red-600 dark:text-red-400">
|
||||
{count.toLocaleString()} listings
|
||||
</div>
|
||||
<div className="text-[11px] text-warm-500 dark:text-warm-400">
|
||||
{visibleCount > 0
|
||||
? `Showing ${visibleCount.toLocaleString()} of ${count.toLocaleString()}`
|
||||
: 'Grouped near this map position'}
|
||||
</div>
|
||||
</div>
|
||||
{visibleCount > 0 && (
|
||||
<div className="max-h-80 overflow-y-auto py-1">
|
||||
{listings.map((listing, idx) => {
|
||||
const headline = formatListingHeadline(listing, t);
|
||||
return (
|
||||
<a
|
||||
key={`${listing.listing_url}-${idx}`}
|
||||
href={listing.listing_url}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="block border-b border-warm-100 px-3 py-2 last:border-b-0 hover:bg-warm-50 dark:border-warm-700 dark:hover:bg-warm-700/60"
|
||||
>
|
||||
<div className="flex items-start justify-between gap-3">
|
||||
<div className="min-w-0">
|
||||
<div className="text-sm font-semibold text-teal-700 dark:text-teal-300">
|
||||
{listing.asking_price != null
|
||||
? formatListingPrice(listing.asking_price)
|
||||
: 'Listing'}
|
||||
</div>
|
||||
{headline && (
|
||||
<div className="mt-0.5 truncate text-xs text-warm-700 dark:text-warm-200">
|
||||
{headline}
|
||||
</div>
|
||||
)}
|
||||
{listing.address && (
|
||||
<div className="mt-0.5 line-clamp-1 text-[11px] text-warm-500 dark:text-warm-400">
|
||||
{listing.address}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{listing.postcode && (
|
||||
<div className="shrink-0 text-[11px] font-medium text-warm-400 dark:text-warm-500">
|
||||
{listing.postcode}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</a>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface PoiPopupCardData {
|
||||
name: string;
|
||||
category: string;
|
||||
icon_category?: string;
|
||||
group: string;
|
||||
emoji: string;
|
||||
school?: SchoolMetadata;
|
||||
}
|
||||
|
||||
interface Dimensions {
|
||||
width: number;
|
||||
height: number;
|
||||
}
|
||||
|
||||
const DESKTOP_TOP_CARD_WIDTH = 300;
|
||||
const DESKTOP_TOP_CARD_GAP = 8;
|
||||
const DESKTOP_TOP_CARD_HORIZONTAL_INSET = 24;
|
||||
const DESKTOP_TOP_CARDS_STACKED_MIN_MAP_WIDTH =
|
||||
DESKTOP_TOP_CARD_WIDTH + DESKTOP_TOP_CARD_HORIZONTAL_INSET;
|
||||
const DESKTOP_TOP_CARDS_ROW_MIN_MAP_WIDTH =
|
||||
DESKTOP_TOP_CARD_WIDTH * 2 + DESKTOP_TOP_CARD_GAP + DESKTOP_TOP_CARD_HORIZONTAL_INSET;
|
||||
const DESKTOP_TOP_CARD_CLASS = 'w-[300px]';
|
||||
const DESKTOP_LOCATION_SEARCH_INPUT_CLASS =
|
||||
'px-2 py-2 text-sm w-full border-none outline-none bg-transparent text-warm-700 dark:text-warm-200 placeholder-warm-400 dark:placeholder-warm-500';
|
||||
|
||||
type MapContainerStyle = CSSProperties & {
|
||||
'--map-mobile-bottom-inset'?: string;
|
||||
};
|
||||
|
|
@ -323,218 +160,6 @@ function getViewportRelativeVisibleAreaCenter(
|
|||
};
|
||||
}
|
||||
|
||||
interface DeckWithPrivateDraw {
|
||||
_drawLayers?: (
|
||||
redrawReason: string,
|
||||
renderOptions?: { viewports?: unknown[]; [key: string]: unknown }
|
||||
) => unknown;
|
||||
__propertyMapNullViewportPatch?: boolean;
|
||||
}
|
||||
|
||||
function patchNullViewportDraw(overlay: MapboxOverlay) {
|
||||
const deck = (overlay as unknown as { _deck?: DeckWithPrivateDraw })._deck;
|
||||
if (!deck || deck.__propertyMapNullViewportPatch || typeof deck._drawLayers !== 'function') {
|
||||
return;
|
||||
}
|
||||
|
||||
const drawLayers = deck._drawLayers.bind(deck);
|
||||
deck._drawLayers = (redrawReason, renderOptions) => {
|
||||
const viewports = renderOptions?.viewports;
|
||||
if (viewports) {
|
||||
// Split-route startup can hand deck.gl a transient null viewport before MapLibre has sized the map.
|
||||
const nonNullViewports = viewports.filter(Boolean);
|
||||
if (nonNullViewports.length === 0) return;
|
||||
if (nonNullViewports.length !== viewports.length) {
|
||||
return drawLayers(redrawReason, { ...renderOptions, viewports: nonNullViewports });
|
||||
}
|
||||
}
|
||||
return drawLayers(redrawReason, renderOptions);
|
||||
};
|
||||
deck.__propertyMapNullViewportPatch = true;
|
||||
}
|
||||
|
||||
class SafeMapboxOverlay extends MapboxOverlay {
|
||||
onAdd(map: unknown) {
|
||||
const element = super.onAdd(map);
|
||||
patchNullViewportDraw(this);
|
||||
return element;
|
||||
}
|
||||
|
||||
setProps(props: Parameters<MapboxOverlay['setProps']>[0]) {
|
||||
super.setProps(props);
|
||||
patchNullViewportDraw(this);
|
||||
}
|
||||
}
|
||||
|
||||
function getPoiGroupColor(group: string): [number, number, number] {
|
||||
const color = POI_GROUP_COLORS[group];
|
||||
if (!color) {
|
||||
throw new Error(`Missing POI group color for '${group}'`);
|
||||
}
|
||||
return color;
|
||||
}
|
||||
|
||||
/** Best-effort web URL from a free-text website field — GIAS stores some with
|
||||
* "http://", some without, and some as bare hostnames. */
|
||||
function normalizeSchoolWebsiteUrl(raw: string): string | null {
|
||||
const trimmed = raw.trim();
|
||||
if (!trimmed) return null;
|
||||
if (/^https?:\/\//i.test(trimmed)) return trimmed;
|
||||
if (/^[\w.-]+\.[a-z]{2,}/i.test(trimmed)) return `http://${trimmed}`;
|
||||
return null;
|
||||
}
|
||||
|
||||
function renderSchoolMetadata(school: SchoolMetadata) {
|
||||
// First line collects the headline classification (phase, type, religious
|
||||
// character) so the popup is scannable even when most fields are absent.
|
||||
const headline: string[] = [];
|
||||
if (school.phase) headline.push(school.phase);
|
||||
if (school.type) headline.push(school.type);
|
||||
|
||||
const pupilsLine =
|
||||
school.pupils !== undefined && school.capacity !== undefined
|
||||
? `${school.pupils.toLocaleString()} / ${school.capacity.toLocaleString()} pupils`
|
||||
: school.pupils !== undefined
|
||||
? `${school.pupils.toLocaleString()} pupils`
|
||||
: school.capacity !== undefined
|
||||
? `Capacity ${school.capacity.toLocaleString()}`
|
||||
: null;
|
||||
|
||||
const websiteUrl = school.website ? normalizeSchoolWebsiteUrl(school.website) : null;
|
||||
|
||||
return (
|
||||
<dl className="mt-2 grid grid-cols-[auto_1fr] gap-x-2 gap-y-0.5 text-xs text-warm-600 dark:text-warm-300">
|
||||
{headline.length > 0 && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Type</dt>
|
||||
<dd className="dark:text-warm-200">{headline.join(' · ')}</dd>
|
||||
</>
|
||||
)}
|
||||
{school.age_range && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Ages</dt>
|
||||
<dd className="dark:text-warm-200">{school.age_range}</dd>
|
||||
</>
|
||||
)}
|
||||
{school.gender && school.gender !== 'Mixed' && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Gender</dt>
|
||||
<dd className="dark:text-warm-200">{school.gender}</dd>
|
||||
</>
|
||||
)}
|
||||
{pupilsLine && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Pupils</dt>
|
||||
<dd className="dark:text-warm-200">{pupilsLine}</dd>
|
||||
</>
|
||||
)}
|
||||
{school.fsm_percent !== undefined && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Free meal</dt>
|
||||
<dd className="dark:text-warm-200">{school.fsm_percent.toFixed(1)}%</dd>
|
||||
</>
|
||||
)}
|
||||
{school.ofsted_rating && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Ofsted</dt>
|
||||
<dd className="dark:text-warm-200">{school.ofsted_rating}</dd>
|
||||
</>
|
||||
)}
|
||||
{school.sixth_form === 'Has a sixth form' && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Sixth form</dt>
|
||||
<dd className="dark:text-warm-200">Yes</dd>
|
||||
</>
|
||||
)}
|
||||
{school.religious_character &&
|
||||
school.religious_character !== 'Does not apply' &&
|
||||
school.religious_character !== 'None' && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Religion</dt>
|
||||
<dd className="dark:text-warm-200">{school.religious_character}</dd>
|
||||
</>
|
||||
)}
|
||||
{school.admissions_policy && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Admissions</dt>
|
||||
<dd className="dark:text-warm-200">{school.admissions_policy}</dd>
|
||||
</>
|
||||
)}
|
||||
{school.trust && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Trust</dt>
|
||||
<dd className="dark:text-warm-200">{school.trust}</dd>
|
||||
</>
|
||||
)}
|
||||
{(school.address || school.postcode) && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Address</dt>
|
||||
<dd className="dark:text-warm-200">
|
||||
{[school.address, school.postcode].filter(Boolean).join(', ')}
|
||||
</dd>
|
||||
</>
|
||||
)}
|
||||
{school.local_authority && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">LA</dt>
|
||||
<dd className="dark:text-warm-200">{school.local_authority}</dd>
|
||||
</>
|
||||
)}
|
||||
{school.head_name && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Head</dt>
|
||||
<dd className="dark:text-warm-200">{school.head_name}</dd>
|
||||
</>
|
||||
)}
|
||||
{websiteUrl && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Website</dt>
|
||||
<dd className="truncate">
|
||||
<a
|
||||
href={websiteUrl}
|
||||
target="_blank"
|
||||
rel="noreferrer noopener"
|
||||
className="pointer-events-auto text-teal-600 hover:underline dark:text-teal-400"
|
||||
>
|
||||
{websiteUrl.replace(/^https?:\/\//, '')}
|
||||
</a>
|
||||
</dd>
|
||||
</>
|
||||
)}
|
||||
</dl>
|
||||
);
|
||||
}
|
||||
|
||||
function PoiPopupCardContent({ poi }: { poi: PoiPopupCardData }) {
|
||||
return (
|
||||
<div className="px-3 py-2 max-w-[280px]">
|
||||
<div className="flex items-center gap-2">
|
||||
<img
|
||||
src={getPoiIconUrl(poi.category, poi.emoji, poi.icon_category, poi.name)}
|
||||
alt=""
|
||||
aria-hidden="true"
|
||||
loading="lazy"
|
||||
referrerPolicy="no-referrer"
|
||||
className="h-5 w-5 shrink-0 rounded-[4px] bg-white object-contain p-0.5"
|
||||
/>
|
||||
<div className="min-w-0">
|
||||
<div className="font-semibold dark:text-warm-100">{poi.name}</div>
|
||||
<div className="flex items-center gap-1.5 text-xs text-warm-500 dark:text-warm-400">
|
||||
<span
|
||||
className="inline-block w-2 h-2 rounded-full flex-shrink-0"
|
||||
style={{
|
||||
backgroundColor: `rgb(${getPoiGroupColor(poi.group).join(',')})`,
|
||||
}}
|
||||
/>
|
||||
{ts(poi.category)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{poi.school && renderSchoolMetadata(poi.school)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function getRenderedViewState(map: MapRef | null): ViewState | null {
|
||||
if (!map) return null;
|
||||
|
||||
|
|
@ -565,186 +190,6 @@ function getRenderedVisibleCenter(
|
|||
};
|
||||
}
|
||||
|
||||
function DeckOverlay({
|
||||
layers,
|
||||
getTooltip,
|
||||
}: {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
layers: any[];
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
getTooltip: any;
|
||||
}) {
|
||||
const overlay = useControl(() => new SafeMapboxOverlay({ interleaved: true }));
|
||||
|
||||
useEffect(() => {
|
||||
overlay.setProps({
|
||||
layers: layers.filter(Boolean),
|
||||
getTooltip,
|
||||
});
|
||||
}, [overlay, layers, getTooltip]);
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
function overlayTileUrl(path: string): string {
|
||||
return `${window.location.origin}/api/overlays/${path}/{z}/{x}/{y}`;
|
||||
}
|
||||
|
||||
function OverlayTileLayers({
|
||||
activeOverlays,
|
||||
activeCrimeTypes,
|
||||
zoom,
|
||||
}: {
|
||||
activeOverlays: Set<OverlayId>;
|
||||
activeCrimeTypes: Set<string>;
|
||||
zoom: number;
|
||||
}) {
|
||||
if (zoom < POSTCODE_ZOOM_THRESHOLD || activeOverlays.size === 0) return null;
|
||||
|
||||
const showNoise = activeOverlays.has('noise');
|
||||
const showCrime = activeOverlays.has('crime-hotspots');
|
||||
const showTrees = activeOverlays.has('trees-outside-woodlands');
|
||||
const showPropertyBorders = activeOverlays.has('property-borders');
|
||||
|
||||
// Restrict the heatmap to the selected crime types. This must always be a
|
||||
// concrete expression: passing `filter={undefined}` makes react-map-gl call
|
||||
// map.addLayer({filter: undefined}), which MapLibre rejects at validation
|
||||
// ("filter: array expected, undefined found"), so the layer is never created
|
||||
// and the heatmap stays blank until a later setFilter call. An `in` over the
|
||||
// selected types matches everything when all 14 are selected.
|
||||
const crimeFilter = ['in', ['get', 'crime_type'], ['literal', Array.from(activeCrimeTypes)]];
|
||||
|
||||
return (
|
||||
<>
|
||||
{showNoise && (
|
||||
<Source
|
||||
id="overlay-noise-source"
|
||||
type="raster"
|
||||
tiles={[overlayTileUrl('noise')]}
|
||||
tileSize={256}
|
||||
minzoom={OVERLAY_MIN_ZOOM.noise}
|
||||
maxzoom={14}
|
||||
>
|
||||
<Layer
|
||||
id="overlay-noise"
|
||||
type="raster"
|
||||
minzoom={POSTCODE_ZOOM_THRESHOLD}
|
||||
paint={{
|
||||
'raster-opacity': 0.68,
|
||||
'raster-fade-duration': 120,
|
||||
}}
|
||||
/>
|
||||
</Source>
|
||||
)}
|
||||
|
||||
{showCrime && (
|
||||
<Source
|
||||
id="overlay-crime-source"
|
||||
type="vector"
|
||||
tiles={[overlayTileUrl('crime-hotspots')]}
|
||||
minzoom={OVERLAY_MIN_ZOOM['crime-hotspots']}
|
||||
maxzoom={15}
|
||||
>
|
||||
<Layer
|
||||
id="overlay-crime-heatmap"
|
||||
type="heatmap"
|
||||
source-layer="crime_hotspots"
|
||||
minzoom={POSTCODE_ZOOM_THRESHOLD}
|
||||
filter={crimeFilter as never}
|
||||
paint={
|
||||
{
|
||||
'heatmap-weight': [
|
||||
'interpolate',
|
||||
['linear'],
|
||||
['coalesce', ['get', 'count'], ['get', 'weight'], 1],
|
||||
0,
|
||||
0,
|
||||
10,
|
||||
1,
|
||||
],
|
||||
'heatmap-intensity': ['interpolate', ['linear'], ['zoom'], 15, 0.8, 18, 2.2],
|
||||
'heatmap-radius': ['interpolate', ['linear'], ['zoom'], 15, 18, 18, 30],
|
||||
'heatmap-opacity': 0.72,
|
||||
'heatmap-color': [
|
||||
'interpolate',
|
||||
['linear'],
|
||||
['heatmap-density'],
|
||||
0,
|
||||
'rgba(0, 0, 0, 0)',
|
||||
0.2,
|
||||
'rgb(253, 224, 71)',
|
||||
0.45,
|
||||
'rgb(249, 115, 22)',
|
||||
0.75,
|
||||
'rgb(220, 38, 38)',
|
||||
1,
|
||||
'rgb(127, 29, 29)',
|
||||
],
|
||||
} as never
|
||||
}
|
||||
/>
|
||||
</Source>
|
||||
)}
|
||||
|
||||
{showTrees && (
|
||||
<Source
|
||||
id="overlay-trees-source"
|
||||
type="vector"
|
||||
tiles={[overlayTileUrl('trees-outside-woodlands')]}
|
||||
minzoom={OVERLAY_MIN_ZOOM['trees-outside-woodlands']}
|
||||
maxzoom={16}
|
||||
>
|
||||
<Layer
|
||||
id="overlay-tree-polygons"
|
||||
type="fill"
|
||||
source-layer="trees_outside_woodlands"
|
||||
minzoom={POSTCODE_ZOOM_THRESHOLD}
|
||||
paint={
|
||||
{
|
||||
'fill-color': '#1f9d55',
|
||||
'fill-opacity': [
|
||||
'interpolate',
|
||||
['linear'],
|
||||
['coalesce', ['get', 'area_sqm'], 0],
|
||||
0,
|
||||
0.28,
|
||||
250,
|
||||
0.62,
|
||||
],
|
||||
'fill-outline-color': 'rgba(15, 81, 50, 0.65)',
|
||||
} as never
|
||||
}
|
||||
/>
|
||||
</Source>
|
||||
)}
|
||||
|
||||
{showPropertyBorders && (
|
||||
<Source
|
||||
id="overlay-property-borders-source"
|
||||
type="vector"
|
||||
tiles={[overlayTileUrl('property-borders')]}
|
||||
minzoom={OVERLAY_MIN_ZOOM['property-borders']}
|
||||
maxzoom={16}
|
||||
>
|
||||
<Layer
|
||||
id="overlay-property-borders"
|
||||
type="line"
|
||||
source-layer="property_borders"
|
||||
minzoom={POSTCODE_ZOOM_THRESHOLD}
|
||||
paint={
|
||||
{
|
||||
'line-color': '#b45309',
|
||||
'line-opacity': ['interpolate', ['linear'], ['zoom'], 15, 0.35, 18, 0.85],
|
||||
'line-width': ['interpolate', ['linear'], ['zoom'], 15, 0.4, 18, 1.4],
|
||||
} as never
|
||||
}
|
||||
/>
|
||||
</Source>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default memo(function Map({
|
||||
data,
|
||||
postcodeData,
|
||||
|
|
@ -790,7 +235,6 @@ export default memo(function Map({
|
|||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const mapRef = useRef<MapRef | null>(null);
|
||||
const { t } = useTranslation();
|
||||
const modes = useTranslatedModes();
|
||||
const densityLabel = densityLabelProp ?? t('mapLegend.numberOfProperties');
|
||||
const [internalViewState, setInternalViewState] = useState<ViewState>(initialViewState);
|
||||
const [dimensions, setDimensions] = useState<Dimensions>({ width: 0, height: 0 });
|
||||
|
|
@ -941,23 +385,16 @@ export default memo(function Map({
|
|||
() => (bottomScreenInset > 0 ? { '--map-mobile-bottom-inset': `${bottomScreenInset}px` } : {}),
|
||||
[bottomScreenInset]
|
||||
);
|
||||
const hideDesktopTopCardsForWidth =
|
||||
hideTopCardsWhenNarrow &&
|
||||
dimensions.width > 0 &&
|
||||
dimensions.width < DESKTOP_TOP_CARDS_STACKED_MIN_MAP_WIDTH;
|
||||
const stackDesktopTopCards =
|
||||
hideTopCardsWhenNarrow &&
|
||||
dimensions.width >= DESKTOP_TOP_CARDS_STACKED_MIN_MAP_WIDTH &&
|
||||
dimensions.width < DESKTOP_TOP_CARDS_ROW_MIN_MAP_WIDTH;
|
||||
const showLocationSearch = !hideLocationSearch && !hideDesktopTopCardsForWidth;
|
||||
const showLegend = !hideLegend && !hideDesktopTopCardsForWidth;
|
||||
const { showLocationSearch, showLegend, topCardsLayoutClass } = useMapCardLayout({
|
||||
mapWidth: dimensions.width,
|
||||
hideTopCardsWhenNarrow,
|
||||
hideLegend,
|
||||
hideLocationSearch,
|
||||
});
|
||||
const getViewportCenter = useCallback(() => {
|
||||
const center = mapRef.current?.getCenter();
|
||||
return center ? { lat: center.lat, lng: center.lng } : null;
|
||||
}, []);
|
||||
const desktopTopCardsLayoutClass = stackDesktopTopCards
|
||||
? 'flex-col items-start'
|
||||
: 'items-start justify-between';
|
||||
|
||||
const {
|
||||
layers,
|
||||
|
|
@ -1108,79 +545,29 @@ export default memo(function Map({
|
|||
) : (
|
||||
<>
|
||||
{(showLocationSearch || showLegend) && (
|
||||
<div
|
||||
className={`absolute top-3 left-3 right-3 z-20 flex gap-2 pointer-events-none ${desktopTopCardsLayoutClass}`}
|
||||
>
|
||||
{showLocationSearch && (
|
||||
<LocationSearch
|
||||
<MapTopCards
|
||||
layoutClass={topCardsLayoutClass}
|
||||
showLocationSearch={showLocationSearch}
|
||||
showLegend={showLegend}
|
||||
onFlyTo={handleFlyTo}
|
||||
onLocationSearched={onLocationSearched}
|
||||
onCurrentLocationFound={onCurrentLocationFound}
|
||||
onMouseEnter={handleMouseLeave}
|
||||
onLocationSearchMouseEnter={handleMouseLeave}
|
||||
getViewportCenter={getViewportCenter}
|
||||
className={DESKTOP_TOP_CARD_CLASS}
|
||||
inputClassName={DESKTOP_LOCATION_SEARCH_INPUT_CLASS}
|
||||
/>
|
||||
)}
|
||||
{showLegend &&
|
||||
(viewFeature && colorRange ? (
|
||||
viewFeature.startsWith('tt_') ? (
|
||||
<MapLegend
|
||||
featureLabel={t('travel.travelTime', {
|
||||
mode: modes.label(
|
||||
viewFeature.split('_')[1] as 'car' | 'bicycle' | 'walking' | 'transit'
|
||||
),
|
||||
})}
|
||||
range={colorRange}
|
||||
showCancel={viewSource === 'eye'}
|
||||
onCancel={onCancelPin}
|
||||
onResetScale={viewSource === 'eye' ? onResetPreviewScale : undefined}
|
||||
resetScaleDisabled={!canResetPreviewScale}
|
||||
mode="feature"
|
||||
theme={theme}
|
||||
suffix=" min"
|
||||
className={DESKTOP_TOP_CARD_CLASS}
|
||||
/>
|
||||
) : colorFeatureMeta ? (
|
||||
<MapLegend
|
||||
featureLabel={
|
||||
viewSource === 'eye'
|
||||
? t('mapLegend.previewing', { name: ts(colorFeatureMeta.name) })
|
||||
: ts(colorFeatureMeta.name)
|
||||
}
|
||||
range={colorRange}
|
||||
showCancel={viewSource === 'eye'}
|
||||
onCancel={onCancelPin}
|
||||
onResetScale={viewSource === 'eye' ? onResetPreviewScale : undefined}
|
||||
resetScaleDisabled={!canResetPreviewScale}
|
||||
mode="feature"
|
||||
enumValues={
|
||||
colorFeatureMeta.type === 'enum' ? colorFeatureMeta.values : undefined
|
||||
}
|
||||
featureName={colorFeatureMeta.name}
|
||||
theme={theme}
|
||||
suffix={colorFeatureMeta.suffix}
|
||||
raw={colorFeatureMeta.raw}
|
||||
className={DESKTOP_TOP_CARD_CLASS}
|
||||
/>
|
||||
) : null
|
||||
) : (
|
||||
<MapLegend
|
||||
featureLabel={densityLabel}
|
||||
range={
|
||||
usePostcodeView
|
||||
? [postcodeCountRange.min, postcodeCountRange.max]
|
||||
: [countRange.min, countRange.max]
|
||||
}
|
||||
viewFeature={viewFeature}
|
||||
colorRange={colorRange}
|
||||
viewSource={viewSource}
|
||||
onCancelPin={onCancelPin}
|
||||
onResetPreviewScale={onResetPreviewScale}
|
||||
canResetPreviewScale={canResetPreviewScale}
|
||||
colorFeatureMeta={colorFeatureMeta}
|
||||
usePostcodeView={usePostcodeView}
|
||||
countRange={countRange}
|
||||
postcodeCountRange={postcodeCountRange}
|
||||
densityLabel={densityLabel}
|
||||
totalCount={totalCountProp}
|
||||
showCancel={false}
|
||||
onCancel={onCancelPin}
|
||||
mode="density"
|
||||
theme={theme}
|
||||
className={DESKTOP_TOP_CARD_CLASS}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
{autoPoiCards.map(({ poi, x, y }) => (
|
||||
<div
|
||||
|
|
@ -1247,28 +634,23 @@ export default memo(function Map({
|
|||
<CloseIcon className="w-3 h-3" />
|
||||
</button>
|
||||
{listingPopup.mode === 'single' ? (
|
||||
<ListingPopupSingleContent listing={listingPopup.listing} t={t} />
|
||||
<ListingPopupSingleContent listing={listingPopup.listing} />
|
||||
) : (
|
||||
<ListingClusterPopupContent
|
||||
count={listingPopup.count}
|
||||
listings={listingPopup.listings}
|
||||
t={t}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
{hoverPosition && hoveredHexagonId && hoveredHexagonId !== selectedHexagonId && (
|
||||
<HoverCard
|
||||
<HoverCardOverlay
|
||||
x={hoverPosition.x}
|
||||
y={hoverPosition.y}
|
||||
id={hoveredHexagonId}
|
||||
isPostcode={usePostcodeView}
|
||||
data={
|
||||
usePostcodeView
|
||||
? postcodeData.find((f) => f.properties.postcode === hoveredHexagonId)
|
||||
?.properties || null
|
||||
: data.find((d) => d.h3 === hoveredHexagonId) || null
|
||||
}
|
||||
usePostcodeView={usePostcodeView}
|
||||
data={data}
|
||||
postcodeData={postcodeData}
|
||||
filters={filters}
|
||||
features={features}
|
||||
/>
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import { Suspense, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { Trans, useTranslation } from 'react-i18next';
|
||||
|
||||
import type { ActualListing, MapFlyToOptions, PostcodeGeometry } from '../../types';
|
||||
import type { ActualListing, PostcodeGeometry } from '../../types';
|
||||
import type { SearchedLocation } from './LocationSearch';
|
||||
import { useMapData } from '../../hooks/useMapData';
|
||||
import { usePOIData } from '../../hooks/usePOIData';
|
||||
|
|
@ -67,11 +67,11 @@ import {
|
|||
useMobileBackNavigationGuard,
|
||||
useScreenshotReadySignal,
|
||||
} from './map-page/effects';
|
||||
import { useMobileDrawer } from './map-page/useMobileDrawer';
|
||||
import type { MapFlyTo, MapPageProps } from './map-page/types';
|
||||
|
||||
export type { ExportState } from './map-page/types';
|
||||
|
||||
type PendingFlyTo = { lat: number; lng: number; zoom: number };
|
||||
const EMPTY_ACTUAL_LISTINGS: ActualListing[] = [];
|
||||
|
||||
export default function MapPage({
|
||||
|
|
@ -127,10 +127,11 @@ export default function MapPage({
|
|||
);
|
||||
const [leftPaneWidth, leftPaneHandlers] = usePaneResize(384, 200, 0.45, 'left');
|
||||
const [rightPaneWidth, rightPaneHandlers] = usePaneResize(384, 200, 0.45, 'right');
|
||||
const [mobileDrawerOpen, setMobileDrawerOpen] = useState(false);
|
||||
const [mobileBottomSheetHeight, setMobileBottomSheetHeight] = useState(0);
|
||||
const [poiPaneOpen, setPoiPaneOpen] = useState(false);
|
||||
const [overlayPaneOpen, setOverlayPaneOpen] = useState(false);
|
||||
// The POI and overlay panes are mutually exclusive, so a single state tracks
|
||||
// which one (if any) is open.
|
||||
const [openMapPane, setOpenMapPane] = useState<'poi' | 'overlay' | null>(null);
|
||||
const poiPaneOpen = openMapPane === 'poi';
|
||||
const overlayPaneOpen = openMapPane === 'overlay';
|
||||
const [currentLocation, setCurrentLocation] = useState<{ lat: number; lng: number } | null>(null);
|
||||
const [listingsToggleEnabled, setListingsToggleEnabled] = useState(true);
|
||||
const [pendingInitialPostcode, setPendingInitialPostcode] = useState<string | null>(
|
||||
|
|
@ -184,27 +185,21 @@ export default function MapPage({
|
|||
} = useTravelTime(initialTravelTime);
|
||||
|
||||
const mapFlyToRef = useRef<MapFlyTo | null>(null);
|
||||
const pendingCurrentLocationFlyToRef = useRef<{ lat: number; lng: number } | null>(null);
|
||||
const pendingLocationSearchFlyToRef = useRef<PendingFlyTo | null>(null);
|
||||
const mobileDrawerPanelRectRef = useRef<DOMRectReadOnly | null>(null);
|
||||
const areaPaneScrollTopRef = useRef(0);
|
||||
const propertiesPaneScrollTopRef = useRef(0);
|
||||
|
||||
const getMobileMapFlyToOptions = useCallback((): MapFlyToOptions | undefined => {
|
||||
if (!isMobile) return undefined;
|
||||
|
||||
const panelRect = mobileDrawerPanelRectRef.current;
|
||||
if (mobileDrawerOpen && panelRect) {
|
||||
const bottomInset = Math.max(0, window.innerHeight - panelRect.top);
|
||||
if (bottomInset > 0) {
|
||||
return { visibleViewportArea: { bottom: bottomInset } };
|
||||
}
|
||||
}
|
||||
|
||||
return mobileBottomSheetHeight > 0
|
||||
? { visibleArea: { bottom: mobileBottomSheetHeight } }
|
||||
: undefined;
|
||||
}, [isMobile, mobileBottomSheetHeight, mobileDrawerOpen]);
|
||||
const {
|
||||
mobileDrawerOpen,
|
||||
mobileBottomSheetHeight,
|
||||
setMobileBottomSheetHeight,
|
||||
openMobileDrawer,
|
||||
openMobileDrawerForLocationSearch,
|
||||
clearPendingLocationSearchFlyTo,
|
||||
queueCurrentLocationFlyTo,
|
||||
handleMobileDrawerPanelRectChange,
|
||||
handleMobileDrawerClose,
|
||||
getMobileMapFlyToOptions,
|
||||
} = useMobileDrawer(isMobile, mapFlyToRef);
|
||||
|
||||
const mapData = useMapData({
|
||||
filters,
|
||||
|
|
@ -217,6 +212,12 @@ export default function MapPage({
|
|||
shareCode,
|
||||
});
|
||||
|
||||
// Read the zoom through a ref inside handleAiFilterSubmit so panning/zooming
|
||||
// doesn't recreate the callback (it sits in the Filters pane's dependency
|
||||
// chain, which would otherwise re-render on every camera move).
|
||||
const currentViewZoomRef = useRef<number | undefined>(undefined);
|
||||
currentViewZoomRef.current = mapData.currentView?.zoom;
|
||||
|
||||
const handleAiFilterSubmit = useCallback(
|
||||
async (query: string) => {
|
||||
const context = {
|
||||
|
|
@ -283,7 +284,7 @@ export default function MapPage({
|
|||
mapFlyToRef.current?.(
|
||||
destination.lat,
|
||||
destination.lon,
|
||||
mapData.currentView?.zoom ?? INITIAL_VIEW_STATE.zoom,
|
||||
currentViewZoomRef.current ?? INITIAL_VIEW_STATE.zoom,
|
||||
getMobileMapFlyToOptions()
|
||||
);
|
||||
}
|
||||
|
|
@ -298,7 +299,6 @@ export default function MapPage({
|
|||
getMobileMapFlyToOptions,
|
||||
handleSetEntries,
|
||||
handleSetFilters,
|
||||
mapData.currentView?.zoom,
|
||||
]
|
||||
);
|
||||
|
||||
|
|
@ -395,20 +395,6 @@ export default function MapPage({
|
|||
journeyDest,
|
||||
});
|
||||
|
||||
const consumePendingLocationSearchFlyTo = useCallback((rect?: DOMRectReadOnly | null) => {
|
||||
const pending = pendingLocationSearchFlyToRef.current;
|
||||
const panelRect = rect ?? mobileDrawerPanelRectRef.current;
|
||||
if (!pending || !panelRect) return;
|
||||
|
||||
const bottomInset = Math.max(0, window.innerHeight - panelRect.top);
|
||||
const flyTo = mapFlyToRef.current;
|
||||
if (!flyTo) return;
|
||||
flyTo(pending.lat, pending.lng, pending.zoom, {
|
||||
visibleViewportArea: { bottom: bottomInset },
|
||||
});
|
||||
pendingLocationSearchFlyToRef.current = null;
|
||||
}, []);
|
||||
|
||||
const handleLocationSearchResult = useCallback(
|
||||
(result: SearchedLocation | null) => {
|
||||
if (result) {
|
||||
|
|
@ -428,68 +414,41 @@ export default function MapPage({
|
|||
result.focusAddress
|
||||
);
|
||||
if (isMobile) {
|
||||
pendingLocationSearchFlyToRef.current = {
|
||||
openMobileDrawerForLocationSearch({
|
||||
lat: markerLat ?? result.latitude,
|
||||
lng: markerLng ?? result.longitude,
|
||||
zoom: result.zoom,
|
||||
};
|
||||
setMobileDrawerOpen(true);
|
||||
consumePendingLocationSearchFlyTo();
|
||||
});
|
||||
}
|
||||
} else {
|
||||
setCurrentLocation(null);
|
||||
pendingLocationSearchFlyToRef.current = null;
|
||||
clearPendingLocationSearchFlyTo();
|
||||
handleCloseSelection();
|
||||
}
|
||||
},
|
||||
[consumePendingLocationSearchFlyTo, handleCloseSelection, handleLocationSearch, isMobile]
|
||||
[
|
||||
clearPendingLocationSearchFlyTo,
|
||||
handleCloseSelection,
|
||||
handleLocationSearch,
|
||||
isMobile,
|
||||
openMobileDrawerForLocationSearch,
|
||||
]
|
||||
);
|
||||
|
||||
const consumePendingCurrentLocationFlyTo = useCallback((rect?: DOMRectReadOnly | null) => {
|
||||
const pending = pendingCurrentLocationFlyToRef.current;
|
||||
const panelRect = rect ?? mobileDrawerPanelRectRef.current;
|
||||
if (!pending || !panelRect) return;
|
||||
|
||||
const bottomInset = Math.max(0, window.innerHeight - panelRect.top);
|
||||
const flyTo = mapFlyToRef.current;
|
||||
if (!flyTo) return;
|
||||
flyTo(pending.lat, pending.lng, 17, {
|
||||
visibleViewportArea: { bottom: bottomInset },
|
||||
});
|
||||
pendingCurrentLocationFlyToRef.current = null;
|
||||
}, []);
|
||||
|
||||
const handleCurrentLocationFound = useCallback(
|
||||
(lat: number, lng: number) => {
|
||||
if (isMobile) {
|
||||
pendingCurrentLocationFlyToRef.current = { lat, lng };
|
||||
consumePendingCurrentLocationFlyTo();
|
||||
queueCurrentLocationFlyTo(lat, lng);
|
||||
} else {
|
||||
mapFlyToRef.current?.(lat, lng, 17);
|
||||
}
|
||||
setCurrentLocation({ lat, lng });
|
||||
handleCurrentLocationSearch(lat, lng);
|
||||
if (isMobile) setMobileDrawerOpen(true);
|
||||
if (isMobile) openMobileDrawer();
|
||||
},
|
||||
[consumePendingCurrentLocationFlyTo, handleCurrentLocationSearch, isMobile]
|
||||
[handleCurrentLocationSearch, isMobile, openMobileDrawer, queueCurrentLocationFlyTo]
|
||||
);
|
||||
|
||||
const handleMobileDrawerPanelRectChange = useCallback(
|
||||
(rect: DOMRectReadOnly) => {
|
||||
mobileDrawerPanelRectRef.current = rect;
|
||||
consumePendingCurrentLocationFlyTo(rect);
|
||||
consumePendingLocationSearchFlyTo(rect);
|
||||
},
|
||||
[consumePendingCurrentLocationFlyTo, consumePendingLocationSearchFlyTo]
|
||||
);
|
||||
|
||||
const handleMobileDrawerClose = useCallback(() => {
|
||||
pendingCurrentLocationFlyToRef.current = null;
|
||||
pendingLocationSearchFlyToRef.current = null;
|
||||
mobileDrawerPanelRectRef.current = null;
|
||||
setMobileDrawerOpen(false);
|
||||
}, []);
|
||||
|
||||
const shareReturnViewRef = useRef(shareCode ? initialViewState : null);
|
||||
// Hide the upgrade modal as soon as the user dismisses it. We can't rely on
|
||||
// the camera fly alone to close it: flying back to the free/shared zone only
|
||||
|
|
@ -555,11 +514,7 @@ export default function MapPage({
|
|||
isMobile,
|
||||
flyTo: mapFlyToRef,
|
||||
onLocationSearch: handleLocationSearch,
|
||||
onOpenMobileDrawer: (target) => {
|
||||
pendingLocationSearchFlyToRef.current = target;
|
||||
setMobileDrawerOpen(true);
|
||||
consumePendingLocationSearchFlyTo();
|
||||
},
|
||||
onOpenMobileDrawer: openMobileDrawerForLocationSearch,
|
||||
onSettled: () => setPendingInitialPostcode(null),
|
||||
});
|
||||
useHorizontalSwipeNavigationGuard();
|
||||
|
|
@ -578,10 +533,10 @@ export default function MapPage({
|
|||
(id: string, isPostcode?: boolean, geometry?: PostcodeGeometry) => {
|
||||
handleHexagonClick(id, isPostcode, geometry);
|
||||
if (id) {
|
||||
setMobileDrawerOpen(true);
|
||||
openMobileDrawer();
|
||||
}
|
||||
},
|
||||
[handleHexagonClick]
|
||||
[handleHexagonClick, openMobileDrawer]
|
||||
);
|
||||
|
||||
const hexagonLocation = useHexagonLocation(
|
||||
|
|
@ -641,15 +596,20 @@ export default function MapPage({
|
|||
shareAndSaveView,
|
||||
]
|
||||
);
|
||||
// dashboardParams changes on every camera move; read it through a ref so the
|
||||
// save/update handlers (and the Filters pane depending on them) stay stable
|
||||
// while panning. The ref always holds the params of the latest render.
|
||||
const dashboardParamsRef = useRef(dashboardParams);
|
||||
dashboardParamsRef.current = dashboardParams;
|
||||
const handleSaveSearch = useCallback(
|
||||
async (name: string) => {
|
||||
await onSaveSearch?.(name, dashboardParams);
|
||||
await onSaveSearch?.(name, dashboardParamsRef.current);
|
||||
},
|
||||
[dashboardParams, onSaveSearch]
|
||||
[onSaveSearch]
|
||||
);
|
||||
const handleUpdateEditInPlaceWithParams = useCallback(async () => {
|
||||
await onUpdateEditInPlace?.(dashboardParams);
|
||||
}, [dashboardParams, onUpdateEditInPlace]);
|
||||
await onUpdateEditInPlace?.(dashboardParamsRef.current);
|
||||
}, [onUpdateEditInPlace]);
|
||||
const checkoutReturnPath = useMemo(
|
||||
() => `/dashboard${dashboardParams ? `?${dashboardParams}` : ''}`,
|
||||
[dashboardParams]
|
||||
|
|
@ -686,27 +646,37 @@ export default function MapPage({
|
|||
}
|
||||
}, [mapData.licenseRequired]);
|
||||
|
||||
if (screenshotMode) {
|
||||
return (
|
||||
<ScreenshotMapPage
|
||||
mapData={mapData}
|
||||
mapViewFeature={mapViewFeature}
|
||||
filterRange={filterRange}
|
||||
viewSource={viewSource}
|
||||
features={features}
|
||||
initialViewState={initialViewState}
|
||||
theme={theme}
|
||||
ogMode={ogMode}
|
||||
travelTimeEntries={entries}
|
||||
activeOverlays={activeOverlays}
|
||||
activeCrimeTypes={crimeTypes}
|
||||
basemap={basemap}
|
||||
colorOpacity={colorOpacity}
|
||||
/>
|
||||
);
|
||||
const handleUpgradeClick = useCallback(() => {
|
||||
onNavigateTo('pricing');
|
||||
}, [onNavigateTo]);
|
||||
const handleTogglePoiPane = useCallback(() => {
|
||||
setOpenMapPane((pane) => (pane === 'poi' ? null : 'poi'));
|
||||
}, []);
|
||||
const handleToggleOverlayPane = useCallback(() => {
|
||||
setOpenMapPane((pane) => (pane === 'overlay' ? null : 'overlay'));
|
||||
}, []);
|
||||
const handleClosePoiPane = useCallback(() => {
|
||||
setOpenMapPane((pane) => (pane === 'poi' ? null : pane));
|
||||
}, []);
|
||||
const handleCloseOverlayPane = useCallback(() => {
|
||||
setOpenMapPane((pane) => (pane === 'overlay' ? null : pane));
|
||||
}, []);
|
||||
const handleAreaTabClick = useCallback(() => {
|
||||
setRightPaneTab('area');
|
||||
}, [setRightPaneTab]);
|
||||
const handleMobileDrawerTabChange = useCallback(
|
||||
(tab: 'area' | 'properties') => {
|
||||
if (tab === 'properties') {
|
||||
handlePropertiesTabClick();
|
||||
} else {
|
||||
setRightPaneTab(tab);
|
||||
}
|
||||
},
|
||||
[handlePropertiesTabClick, setRightPaneTab]
|
||||
);
|
||||
|
||||
const renderAreaPane = () => (
|
||||
const renderAreaPane = useCallback(
|
||||
() => (
|
||||
<Suspense fallback={<PaneFallback />}>
|
||||
<AreaPane
|
||||
stats={areaStats}
|
||||
|
|
@ -724,13 +694,32 @@ export default function MapPage({
|
|||
isGroupExpanded={isAreaGroupExpanded}
|
||||
onToggleGroup={toggleAreaGroup}
|
||||
scrollTopRef={areaPaneScrollTopRef}
|
||||
scrollRestoreKey={selectedHexagon ? `${selectedHexagon.type}:${selectedHexagon.id}` : null}
|
||||
scrollRestoreKey={
|
||||
selectedHexagon ? `${selectedHexagon.type}:${selectedHexagon.id}` : null
|
||||
}
|
||||
scrollSaveDisabled={loadingAreaStats && areaStats == null}
|
||||
/>
|
||||
</Suspense>
|
||||
),
|
||||
[
|
||||
activeEntries,
|
||||
areaStats,
|
||||
areaStatsUseFilters,
|
||||
features,
|
||||
filters,
|
||||
hexagonLocation,
|
||||
isAreaGroupExpanded,
|
||||
loadingAreaStats,
|
||||
selectedHexagon,
|
||||
setAreaStatsUseFilters,
|
||||
shareCode,
|
||||
toggleAreaGroup,
|
||||
unfilteredAreaCount,
|
||||
]
|
||||
);
|
||||
|
||||
const renderPropertiesPane = () => (
|
||||
const renderPropertiesPane = useCallback(
|
||||
() => (
|
||||
<Suspense fallback={<PaneFallback />}>
|
||||
<PropertiesPane
|
||||
properties={properties}
|
||||
|
|
@ -739,25 +728,33 @@ export default function MapPage({
|
|||
hexagonId={selectedHexagon?.id || null}
|
||||
onLoadMore={handleLoadMoreProperties}
|
||||
scrollTopRef={propertiesPaneScrollTopRef}
|
||||
scrollRestoreKey={selectedHexagon ? `${selectedHexagon.type}:${selectedHexagon.id}` : null}
|
||||
scrollRestoreKey={
|
||||
selectedHexagon ? `${selectedHexagon.type}:${selectedHexagon.id}` : null
|
||||
}
|
||||
scrollSaveDisabled={loadingProperties && properties.length === 0}
|
||||
/>
|
||||
</Suspense>
|
||||
),
|
||||
[handleLoadMoreProperties, loadingProperties, properties, propertiesTotal, selectedHexagon]
|
||||
);
|
||||
|
||||
const renderPOIPane = () => (
|
||||
const poiPane = useMemo(
|
||||
() => (
|
||||
<Suspense fallback={<PaneFallback />}>
|
||||
<POIPane
|
||||
groups={poiCategoryGroups}
|
||||
selectedCategories={selectedPOICategories}
|
||||
onCategoriesChange={setSelectedPOICategories}
|
||||
poiCount={pois.length}
|
||||
onClose={() => setPoiPaneOpen(false)}
|
||||
onClose={handleClosePoiPane}
|
||||
/>
|
||||
</Suspense>
|
||||
),
|
||||
[handleClosePoiPane, poiCategoryGroups, pois.length, selectedPOICategories]
|
||||
);
|
||||
|
||||
const renderOverlayPane = () => (
|
||||
const overlayPane = useMemo(
|
||||
() => (
|
||||
<Suspense fallback={<PaneFallback />}>
|
||||
<OverlayPane
|
||||
selectedOverlays={activeOverlays}
|
||||
|
|
@ -769,12 +766,15 @@ export default function MapPage({
|
|||
colorOpacity={colorOpacity}
|
||||
onColorOpacityChange={setColorOpacity}
|
||||
zoomedIn={overlaysZoomedIn}
|
||||
onClose={() => setOverlayPaneOpen(false)}
|
||||
onClose={handleCloseOverlayPane}
|
||||
/>
|
||||
</Suspense>
|
||||
),
|
||||
[activeOverlays, basemap, colorOpacity, crimeTypes, handleCloseOverlayPane, overlaysZoomedIn]
|
||||
);
|
||||
|
||||
const renderFilters = (options?: { destinationDropdownPortal?: boolean }) => (
|
||||
const filtersPane = useMemo(
|
||||
() => (
|
||||
<Suspense fallback={<PaneFallback />}>
|
||||
<Filters
|
||||
features={features}
|
||||
|
|
@ -810,7 +810,7 @@ export default function MapPage({
|
|||
isLoggedIn={!!user}
|
||||
onLoginRequired={onRegisterClick}
|
||||
isLicensed={user?.subscription === 'licensed'}
|
||||
onUpgradeClick={() => onNavigateTo('pricing')}
|
||||
onUpgradeClick={handleUpgradeClick}
|
||||
onResetTutorial={!isMobile ? tutorial.resetTutorial : undefined}
|
||||
filterImpacts={filterCounts.impacts}
|
||||
onClearAll={handleClearAll}
|
||||
|
|
@ -821,31 +821,117 @@ export default function MapPage({
|
|||
editingSearch && onUpdateEditInPlace ? handleUpdateEditInPlaceWithParams : undefined
|
||||
}
|
||||
onExitEditing={onCancelEdit}
|
||||
destinationDropdownPortal={options?.destinationDropdownPortal}
|
||||
destinationDropdownPortal={isMobile ? false : undefined}
|
||||
/>
|
||||
</Suspense>
|
||||
),
|
||||
[
|
||||
activeFeature,
|
||||
aiFilterError,
|
||||
aiFilterErrorType,
|
||||
aiFilterLoading,
|
||||
aiFilterNotes,
|
||||
aiFilterSummary,
|
||||
dragValue,
|
||||
editingSearch,
|
||||
enabledFeatures,
|
||||
entries,
|
||||
features,
|
||||
filterCounts.impacts,
|
||||
filters,
|
||||
handleAddEntry,
|
||||
handleAddFilter,
|
||||
handleAiFilterSubmit,
|
||||
handleClearAll,
|
||||
handleDragChange,
|
||||
handleDragEnd,
|
||||
handleDragStart,
|
||||
handleFilterChange,
|
||||
handleRemoveFilter,
|
||||
handleSaveSearch,
|
||||
handleTimeRangeChange,
|
||||
handleToggleBest,
|
||||
handleToggleNoBuses,
|
||||
handleToggleNoChange,
|
||||
handleTogglePin,
|
||||
handleTravelTimeDragEnd,
|
||||
handleTravelTimeRemoveEntry,
|
||||
handleTravelTimeSetDestination,
|
||||
handleUpdateEditInPlaceWithParams,
|
||||
handleUpgradeClick,
|
||||
isMobile,
|
||||
onCancelEdit,
|
||||
onClearPendingInfoFeature,
|
||||
onRegisterClick,
|
||||
onSaveSearch,
|
||||
onUpdateEditInPlace,
|
||||
pendingInfoFeature,
|
||||
pinnedFeature,
|
||||
savingSearch,
|
||||
tutorial.resetTutorial,
|
||||
user,
|
||||
]
|
||||
);
|
||||
|
||||
const handleTogglePoiPane = () => {
|
||||
setOverlayPaneOpen(false);
|
||||
setPoiPaneOpen((open) => !open);
|
||||
};
|
||||
const handleToggleOverlayPane = () => {
|
||||
setPoiPaneOpen(false);
|
||||
setOverlayPaneOpen((open) => !open);
|
||||
};
|
||||
const handleMobileDrawerTabChange = (tab: 'area' | 'properties') => {
|
||||
if (tab === 'properties') {
|
||||
handlePropertiesTabClick();
|
||||
} else {
|
||||
setRightPaneTab(tab);
|
||||
const mobileLegend = useMemo(
|
||||
() => (
|
||||
<MobileMapLegend
|
||||
mapViewFeature={mapViewFeature}
|
||||
colorRange={mapData.colorRange}
|
||||
viewSource={viewSource}
|
||||
mobileLegendMeta={mobileLegendMeta}
|
||||
densityLabel={densityLabel}
|
||||
densityRange={mobileDensityRange}
|
||||
theme={theme}
|
||||
canResetPreviewScale={mapData.canResetPreviewScale}
|
||||
onCancelPin={handleCancelPin}
|
||||
onResetPreviewScale={mapData.handleResetPreviewScale}
|
||||
/>
|
||||
),
|
||||
[
|
||||
densityLabel,
|
||||
handleCancelPin,
|
||||
mapData.canResetPreviewScale,
|
||||
mapData.colorRange,
|
||||
mapData.handleResetPreviewScale,
|
||||
mapViewFeature,
|
||||
mobileDensityRange,
|
||||
mobileLegendMeta,
|
||||
theme,
|
||||
viewSource,
|
||||
]
|
||||
);
|
||||
|
||||
const toasts = useMemo(
|
||||
() => (
|
||||
<ExportToast
|
||||
notice={exportNotice}
|
||||
closeLabel={t('common.close')}
|
||||
onClose={clearExportNotice}
|
||||
/>
|
||||
),
|
||||
[clearExportNotice, exportNotice, t]
|
||||
);
|
||||
|
||||
if (screenshotMode) {
|
||||
return (
|
||||
<ScreenshotMapPage
|
||||
mapData={mapData}
|
||||
mapViewFeature={mapViewFeature}
|
||||
filterRange={filterRange}
|
||||
viewSource={viewSource}
|
||||
features={features}
|
||||
initialViewState={initialViewState}
|
||||
theme={theme}
|
||||
ogMode={ogMode}
|
||||
travelTimeEntries={entries}
|
||||
activeOverlays={activeOverlays}
|
||||
activeCrimeTypes={crimeTypes}
|
||||
basemap={basemap}
|
||||
colorOpacity={colorOpacity}
|
||||
/>
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
const exportToast = (
|
||||
<ExportToast notice={exportNotice} closeLabel={t('common.close')} onClose={clearExportNotice} />
|
||||
);
|
||||
const toasts = exportToast;
|
||||
|
||||
const editingBar =
|
||||
editingSearch && isMobile ? (
|
||||
|
|
@ -940,25 +1026,12 @@ export default function MapPage({
|
|||
poiPaneOpen={poiPaneOpen}
|
||||
onTogglePoiPane={handleTogglePoiPane}
|
||||
poiButtonLabel={t('poiPane.pointsOfInterest')}
|
||||
poiPane={renderPOIPane()}
|
||||
poiPane={poiPane}
|
||||
overlayPaneOpen={overlayPaneOpen}
|
||||
onToggleOverlayPane={handleToggleOverlayPane}
|
||||
overlayPane={renderOverlayPane()}
|
||||
filtersPane={renderFilters({ destinationDropdownPortal: false })}
|
||||
mobileLegend={
|
||||
<MobileMapLegend
|
||||
mapViewFeature={mapViewFeature}
|
||||
colorRange={mapData.colorRange}
|
||||
viewSource={viewSource}
|
||||
mobileLegendMeta={mobileLegendMeta}
|
||||
densityLabel={densityLabel}
|
||||
densityRange={mobileDensityRange}
|
||||
theme={theme}
|
||||
canResetPreviewScale={mapData.canResetPreviewScale}
|
||||
onCancelPin={handleCancelPin}
|
||||
onResetPreviewScale={mapData.handleResetPreviewScale}
|
||||
/>
|
||||
}
|
||||
overlayPane={overlayPane}
|
||||
filtersPane={filtersPane}
|
||||
mobileLegend={mobileLegend}
|
||||
renderAreaPane={renderAreaPane}
|
||||
renderPropertiesPane={renderPropertiesPane}
|
||||
toasts={toasts}
|
||||
|
|
@ -975,7 +1048,7 @@ export default function MapPage({
|
|||
tutorialTheme={tutorialTheme}
|
||||
leftPaneWidth={leftPaneWidth}
|
||||
leftPaneHandlers={leftPaneHandlers}
|
||||
filtersPane={renderFilters()}
|
||||
filtersPane={filtersPane}
|
||||
mapData={mapData}
|
||||
pois={pois}
|
||||
activeOverlays={activeOverlays}
|
||||
|
|
@ -1008,15 +1081,15 @@ export default function MapPage({
|
|||
totalCount={filterCounts.total ?? undefined}
|
||||
poiPaneOpen={poiPaneOpen}
|
||||
onTogglePoiPane={handleTogglePoiPane}
|
||||
poiPane={renderPOIPane()}
|
||||
poiPane={poiPane}
|
||||
overlayPaneOpen={overlayPaneOpen}
|
||||
onToggleOverlayPane={handleToggleOverlayPane}
|
||||
overlayPane={renderOverlayPane()}
|
||||
overlayPane={overlayPane}
|
||||
showSelectionPane={!!selectedHexagon}
|
||||
rightPaneWidth={rightPaneWidth}
|
||||
rightPaneHandlers={rightPaneHandlers}
|
||||
rightPaneTab={rightPaneTab}
|
||||
onAreaTabClick={() => setRightPaneTab('area')}
|
||||
onAreaTabClick={handleAreaTabClick}
|
||||
onPropertiesTabClick={handlePropertiesTabClick}
|
||||
onCloseSelection={handleCloseSelection}
|
||||
renderAreaPane={renderAreaPane}
|
||||
|
|
|
|||
138
frontend/src/components/map/MapTopCards.tsx
Normal file
138
frontend/src/components/map/MapTopCards.tsx
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { FeatureMeta, MapFlyToOptions } from '../../types';
|
||||
import { useTranslatedModes } from '../../hooks/useTravelTime';
|
||||
import { ts } from '../../i18n/server';
|
||||
import LocationSearch, { type SearchedLocation } from './LocationSearch';
|
||||
import MapLegend from './MapLegend';
|
||||
|
||||
const DESKTOP_TOP_CARD_CLASS = 'w-[300px]';
|
||||
const DESKTOP_LOCATION_SEARCH_INPUT_CLASS =
|
||||
'px-2 py-2 text-sm w-full border-none outline-none bg-transparent text-warm-700 dark:text-warm-200 placeholder-warm-400 dark:placeholder-warm-500';
|
||||
|
||||
interface MapTopCardsProps {
|
||||
layoutClass: string;
|
||||
showLocationSearch: boolean;
|
||||
showLegend: boolean;
|
||||
onFlyTo: (lat: number, lng: number, zoom: number, options?: MapFlyToOptions) => void;
|
||||
onLocationSearched?: (location: SearchedLocation | null) => void;
|
||||
onCurrentLocationFound?: (lat: number, lng: number) => void;
|
||||
onLocationSearchMouseEnter: () => void;
|
||||
getViewportCenter: () => { lat: number; lng: number } | null;
|
||||
viewFeature: string | null;
|
||||
colorRange: [number, number] | null;
|
||||
viewSource: 'drag' | 'eye' | null;
|
||||
onCancelPin: () => void;
|
||||
onResetPreviewScale?: () => void;
|
||||
canResetPreviewScale: boolean;
|
||||
colorFeatureMeta: FeatureMeta | null;
|
||||
usePostcodeView: boolean;
|
||||
countRange: { min: number; max: number };
|
||||
postcodeCountRange: { min: number; max: number };
|
||||
densityLabel: string;
|
||||
totalCount?: number;
|
||||
theme: 'light' | 'dark';
|
||||
}
|
||||
|
||||
/** Desktop top-card overlay area: the location search box and the map legend. */
|
||||
export const MapTopCards = memo(function MapTopCards({
|
||||
layoutClass,
|
||||
showLocationSearch,
|
||||
showLegend,
|
||||
onFlyTo,
|
||||
onLocationSearched,
|
||||
onCurrentLocationFound,
|
||||
onLocationSearchMouseEnter,
|
||||
getViewportCenter,
|
||||
viewFeature,
|
||||
colorRange,
|
||||
viewSource,
|
||||
onCancelPin,
|
||||
onResetPreviewScale,
|
||||
canResetPreviewScale,
|
||||
colorFeatureMeta,
|
||||
usePostcodeView,
|
||||
countRange,
|
||||
postcodeCountRange,
|
||||
densityLabel,
|
||||
totalCount,
|
||||
theme,
|
||||
}: MapTopCardsProps) {
|
||||
const { t } = useTranslation();
|
||||
const modes = useTranslatedModes();
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`absolute top-3 left-3 right-3 z-20 flex gap-2 pointer-events-none ${layoutClass}`}
|
||||
>
|
||||
{showLocationSearch && (
|
||||
<LocationSearch
|
||||
onFlyTo={onFlyTo}
|
||||
onLocationSearched={onLocationSearched}
|
||||
onCurrentLocationFound={onCurrentLocationFound}
|
||||
onMouseEnter={onLocationSearchMouseEnter}
|
||||
getViewportCenter={getViewportCenter}
|
||||
className={DESKTOP_TOP_CARD_CLASS}
|
||||
inputClassName={DESKTOP_LOCATION_SEARCH_INPUT_CLASS}
|
||||
/>
|
||||
)}
|
||||
{showLegend &&
|
||||
(viewFeature && colorRange ? (
|
||||
viewFeature.startsWith('tt_') ? (
|
||||
<MapLegend
|
||||
featureLabel={t('travel.travelTime', {
|
||||
mode: modes.label(
|
||||
viewFeature.split('_')[1] as 'car' | 'bicycle' | 'walking' | 'transit'
|
||||
),
|
||||
})}
|
||||
range={colorRange}
|
||||
showCancel={viewSource === 'eye'}
|
||||
onCancel={onCancelPin}
|
||||
onResetScale={viewSource === 'eye' ? onResetPreviewScale : undefined}
|
||||
resetScaleDisabled={!canResetPreviewScale}
|
||||
mode="feature"
|
||||
theme={theme}
|
||||
suffix=" min"
|
||||
className={DESKTOP_TOP_CARD_CLASS}
|
||||
/>
|
||||
) : colorFeatureMeta ? (
|
||||
<MapLegend
|
||||
featureLabel={
|
||||
viewSource === 'eye'
|
||||
? t('mapLegend.previewing', { name: ts(colorFeatureMeta.name) })
|
||||
: ts(colorFeatureMeta.name)
|
||||
}
|
||||
range={colorRange}
|
||||
showCancel={viewSource === 'eye'}
|
||||
onCancel={onCancelPin}
|
||||
onResetScale={viewSource === 'eye' ? onResetPreviewScale : undefined}
|
||||
resetScaleDisabled={!canResetPreviewScale}
|
||||
mode="feature"
|
||||
enumValues={colorFeatureMeta.type === 'enum' ? colorFeatureMeta.values : undefined}
|
||||
featureName={colorFeatureMeta.name}
|
||||
theme={theme}
|
||||
suffix={colorFeatureMeta.suffix}
|
||||
raw={colorFeatureMeta.raw}
|
||||
className={DESKTOP_TOP_CARD_CLASS}
|
||||
/>
|
||||
) : null
|
||||
) : (
|
||||
<MapLegend
|
||||
featureLabel={densityLabel}
|
||||
range={
|
||||
usePostcodeView
|
||||
? [postcodeCountRange.min, postcodeCountRange.max]
|
||||
: [countRange.min, countRange.max]
|
||||
}
|
||||
totalCount={totalCount}
|
||||
showCancel={false}
|
||||
onCancel={onCancelPin}
|
||||
mode="density"
|
||||
theme={theme}
|
||||
className={DESKTOP_TOP_CARD_CLASS}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
});
|
||||
163
frontend/src/components/map/OverlayTileLayers.tsx
Normal file
163
frontend/src/components/map/OverlayTileLayers.tsx
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
import { Layer, Source } from 'react-map-gl/maplibre';
|
||||
|
||||
import { POSTCODE_ZOOM_THRESHOLD } from '../../lib/consts';
|
||||
import { type OverlayId, OVERLAY_MIN_ZOOM } from '../../lib/overlays';
|
||||
|
||||
function overlayTileUrl(path: string): string {
|
||||
return `${window.location.origin}/api/overlays/${path}/{z}/{x}/{y}`;
|
||||
}
|
||||
|
||||
export function OverlayTileLayers({
|
||||
activeOverlays,
|
||||
activeCrimeTypes,
|
||||
zoom,
|
||||
}: {
|
||||
activeOverlays: Set<OverlayId>;
|
||||
activeCrimeTypes: Set<string>;
|
||||
zoom: number;
|
||||
}) {
|
||||
if (zoom < POSTCODE_ZOOM_THRESHOLD || activeOverlays.size === 0) return null;
|
||||
|
||||
const showNoise = activeOverlays.has('noise');
|
||||
const showCrime = activeOverlays.has('crime-hotspots');
|
||||
const showTrees = activeOverlays.has('trees-outside-woodlands');
|
||||
const showPropertyBorders = activeOverlays.has('property-borders');
|
||||
|
||||
// Restrict the heatmap to the selected crime types. This must always be a
|
||||
// concrete expression: passing `filter={undefined}` makes react-map-gl call
|
||||
// map.addLayer({filter: undefined}), which MapLibre rejects at validation
|
||||
// ("filter: array expected, undefined found"), so the layer is never created
|
||||
// and the heatmap stays blank until a later setFilter call. An `in` over the
|
||||
// selected types matches everything when all 14 are selected.
|
||||
const crimeFilter = ['in', ['get', 'crime_type'], ['literal', Array.from(activeCrimeTypes)]];
|
||||
|
||||
return (
|
||||
<>
|
||||
{showNoise && (
|
||||
<Source
|
||||
id="overlay-noise-source"
|
||||
type="raster"
|
||||
tiles={[overlayTileUrl('noise')]}
|
||||
tileSize={256}
|
||||
minzoom={OVERLAY_MIN_ZOOM.noise}
|
||||
maxzoom={14}
|
||||
>
|
||||
<Layer
|
||||
id="overlay-noise"
|
||||
type="raster"
|
||||
minzoom={POSTCODE_ZOOM_THRESHOLD}
|
||||
paint={{
|
||||
'raster-opacity': 0.68,
|
||||
'raster-fade-duration': 120,
|
||||
}}
|
||||
/>
|
||||
</Source>
|
||||
)}
|
||||
|
||||
{showCrime && (
|
||||
<Source
|
||||
id="overlay-crime-source"
|
||||
type="vector"
|
||||
tiles={[overlayTileUrl('crime-hotspots')]}
|
||||
minzoom={OVERLAY_MIN_ZOOM['crime-hotspots']}
|
||||
maxzoom={15}
|
||||
>
|
||||
<Layer
|
||||
id="overlay-crime-heatmap"
|
||||
type="heatmap"
|
||||
source-layer="crime_hotspots"
|
||||
minzoom={POSTCODE_ZOOM_THRESHOLD}
|
||||
filter={crimeFilter as never}
|
||||
paint={
|
||||
{
|
||||
'heatmap-weight': [
|
||||
'interpolate',
|
||||
['linear'],
|
||||
['coalesce', ['get', 'count'], ['get', 'weight'], 1],
|
||||
0,
|
||||
0,
|
||||
10,
|
||||
1,
|
||||
],
|
||||
'heatmap-intensity': ['interpolate', ['linear'], ['zoom'], 15, 0.8, 18, 2.2],
|
||||
'heatmap-radius': ['interpolate', ['linear'], ['zoom'], 15, 18, 18, 30],
|
||||
'heatmap-opacity': 0.72,
|
||||
'heatmap-color': [
|
||||
'interpolate',
|
||||
['linear'],
|
||||
['heatmap-density'],
|
||||
0,
|
||||
'rgba(0, 0, 0, 0)',
|
||||
0.2,
|
||||
'rgb(253, 224, 71)',
|
||||
0.45,
|
||||
'rgb(249, 115, 22)',
|
||||
0.75,
|
||||
'rgb(220, 38, 38)',
|
||||
1,
|
||||
'rgb(127, 29, 29)',
|
||||
],
|
||||
} as never
|
||||
}
|
||||
/>
|
||||
</Source>
|
||||
)}
|
||||
|
||||
{showTrees && (
|
||||
<Source
|
||||
id="overlay-trees-source"
|
||||
type="vector"
|
||||
tiles={[overlayTileUrl('trees-outside-woodlands')]}
|
||||
minzoom={OVERLAY_MIN_ZOOM['trees-outside-woodlands']}
|
||||
maxzoom={16}
|
||||
>
|
||||
<Layer
|
||||
id="overlay-tree-polygons"
|
||||
type="fill"
|
||||
source-layer="trees_outside_woodlands"
|
||||
minzoom={POSTCODE_ZOOM_THRESHOLD}
|
||||
paint={
|
||||
{
|
||||
'fill-color': '#1f9d55',
|
||||
'fill-opacity': [
|
||||
'interpolate',
|
||||
['linear'],
|
||||
['coalesce', ['get', 'area_sqm'], 0],
|
||||
0,
|
||||
0.28,
|
||||
250,
|
||||
0.62,
|
||||
],
|
||||
'fill-outline-color': 'rgba(15, 81, 50, 0.65)',
|
||||
} as never
|
||||
}
|
||||
/>
|
||||
</Source>
|
||||
)}
|
||||
|
||||
{showPropertyBorders && (
|
||||
<Source
|
||||
id="overlay-property-borders-source"
|
||||
type="vector"
|
||||
tiles={[overlayTileUrl('property-borders')]}
|
||||
minzoom={OVERLAY_MIN_ZOOM['property-borders']}
|
||||
maxzoom={16}
|
||||
>
|
||||
<Layer
|
||||
id="overlay-property-borders"
|
||||
type="line"
|
||||
source-layer="property_borders"
|
||||
minzoom={POSTCODE_ZOOM_THRESHOLD}
|
||||
paint={
|
||||
{
|
||||
'line-color': '#b45309',
|
||||
'line-opacity': ['interpolate', ['linear'], ['zoom'], 15, 0.35, 18, 0.85],
|
||||
'line-width': ['interpolate', ['linear'], ['zoom'], 15, 0.4, 18, 1.4],
|
||||
} as never
|
||||
}
|
||||
/>
|
||||
</Source>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
188
frontend/src/components/map/PoiPopupCard.tsx
Normal file
188
frontend/src/components/map/PoiPopupCard.tsx
Normal file
|
|
@ -0,0 +1,188 @@
|
|||
import { memo } from 'react';
|
||||
|
||||
import type { SchoolMetadata } from '../../types';
|
||||
import { POI_GROUP_COLORS } from '../../lib/consts';
|
||||
import { getPoiIconUrl } from '../../lib/map-utils';
|
||||
import { ts } from '../../i18n/server';
|
||||
|
||||
export interface PoiPopupCardData {
|
||||
name: string;
|
||||
category: string;
|
||||
icon_category?: string;
|
||||
group: string;
|
||||
emoji: string;
|
||||
school?: SchoolMetadata;
|
||||
}
|
||||
|
||||
function getPoiGroupColor(group: string): [number, number, number] {
|
||||
const color = POI_GROUP_COLORS[group];
|
||||
if (!color) {
|
||||
throw new Error(`Missing POI group color for '${group}'`);
|
||||
}
|
||||
return color;
|
||||
}
|
||||
|
||||
/** Best-effort web URL from a free-text website field — GIAS stores some with
|
||||
* "http://", some without, and some as bare hostnames. */
|
||||
function normalizeSchoolWebsiteUrl(raw: string): string | null {
|
||||
const trimmed = raw.trim();
|
||||
if (!trimmed) return null;
|
||||
if (/^https?:\/\//i.test(trimmed)) return trimmed;
|
||||
if (/^[\w.-]+\.[a-z]{2,}/i.test(trimmed)) return `http://${trimmed}`;
|
||||
return null;
|
||||
}
|
||||
|
||||
function renderSchoolMetadata(school: SchoolMetadata) {
|
||||
// First line collects the headline classification (phase, type, religious
|
||||
// character) so the popup is scannable even when most fields are absent.
|
||||
const headline: string[] = [];
|
||||
if (school.phase) headline.push(school.phase);
|
||||
if (school.type) headline.push(school.type);
|
||||
|
||||
const pupilsLine =
|
||||
school.pupils !== undefined && school.capacity !== undefined
|
||||
? `${school.pupils.toLocaleString()} / ${school.capacity.toLocaleString()} pupils`
|
||||
: school.pupils !== undefined
|
||||
? `${school.pupils.toLocaleString()} pupils`
|
||||
: school.capacity !== undefined
|
||||
? `Capacity ${school.capacity.toLocaleString()}`
|
||||
: null;
|
||||
|
||||
const websiteUrl = school.website ? normalizeSchoolWebsiteUrl(school.website) : null;
|
||||
|
||||
return (
|
||||
<dl className="mt-2 grid grid-cols-[auto_1fr] gap-x-2 gap-y-0.5 text-xs text-warm-600 dark:text-warm-300">
|
||||
{headline.length > 0 && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Type</dt>
|
||||
<dd className="dark:text-warm-200">{headline.join(' · ')}</dd>
|
||||
</>
|
||||
)}
|
||||
{school.age_range && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Ages</dt>
|
||||
<dd className="dark:text-warm-200">{school.age_range}</dd>
|
||||
</>
|
||||
)}
|
||||
{school.gender && school.gender !== 'Mixed' && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Gender</dt>
|
||||
<dd className="dark:text-warm-200">{school.gender}</dd>
|
||||
</>
|
||||
)}
|
||||
{pupilsLine && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Pupils</dt>
|
||||
<dd className="dark:text-warm-200">{pupilsLine}</dd>
|
||||
</>
|
||||
)}
|
||||
{school.fsm_percent !== undefined && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Free meal</dt>
|
||||
<dd className="dark:text-warm-200">{school.fsm_percent.toFixed(1)}%</dd>
|
||||
</>
|
||||
)}
|
||||
{school.ofsted_rating && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Ofsted</dt>
|
||||
<dd className="dark:text-warm-200">{school.ofsted_rating}</dd>
|
||||
</>
|
||||
)}
|
||||
{school.sixth_form === 'Has a sixth form' && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Sixth form</dt>
|
||||
<dd className="dark:text-warm-200">Yes</dd>
|
||||
</>
|
||||
)}
|
||||
{school.religious_character &&
|
||||
school.religious_character !== 'Does not apply' &&
|
||||
school.religious_character !== 'None' && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Religion</dt>
|
||||
<dd className="dark:text-warm-200">{school.religious_character}</dd>
|
||||
</>
|
||||
)}
|
||||
{school.admissions_policy && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Admissions</dt>
|
||||
<dd className="dark:text-warm-200">{school.admissions_policy}</dd>
|
||||
</>
|
||||
)}
|
||||
{school.trust && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Trust</dt>
|
||||
<dd className="dark:text-warm-200">{school.trust}</dd>
|
||||
</>
|
||||
)}
|
||||
{(school.address || school.postcode) && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Address</dt>
|
||||
<dd className="dark:text-warm-200">
|
||||
{[school.address, school.postcode].filter(Boolean).join(', ')}
|
||||
</dd>
|
||||
</>
|
||||
)}
|
||||
{school.local_authority && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">LA</dt>
|
||||
<dd className="dark:text-warm-200">{school.local_authority}</dd>
|
||||
</>
|
||||
)}
|
||||
{school.head_name && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Head</dt>
|
||||
<dd className="dark:text-warm-200">{school.head_name}</dd>
|
||||
</>
|
||||
)}
|
||||
{websiteUrl && (
|
||||
<>
|
||||
<dt className="text-warm-500 dark:text-warm-400">Website</dt>
|
||||
<dd className="truncate">
|
||||
<a
|
||||
href={websiteUrl}
|
||||
target="_blank"
|
||||
rel="noreferrer noopener"
|
||||
className="pointer-events-auto text-teal-600 hover:underline dark:text-teal-400"
|
||||
>
|
||||
{websiteUrl.replace(/^https?:\/\//, '')}
|
||||
</a>
|
||||
</dd>
|
||||
</>
|
||||
)}
|
||||
</dl>
|
||||
);
|
||||
}
|
||||
|
||||
export const PoiPopupCardContent = memo(function PoiPopupCardContent({
|
||||
poi,
|
||||
}: {
|
||||
poi: PoiPopupCardData;
|
||||
}) {
|
||||
return (
|
||||
<div className="px-3 py-2 max-w-[280px]">
|
||||
<div className="flex items-center gap-2">
|
||||
<img
|
||||
src={getPoiIconUrl(poi.category, poi.emoji, poi.icon_category, poi.name)}
|
||||
alt=""
|
||||
aria-hidden="true"
|
||||
loading="lazy"
|
||||
referrerPolicy="no-referrer"
|
||||
className="h-5 w-5 shrink-0 rounded-[4px] bg-white object-contain p-0.5"
|
||||
/>
|
||||
<div className="min-w-0">
|
||||
<div className="font-semibold dark:text-warm-100">{poi.name}</div>
|
||||
<div className="flex items-center gap-1.5 text-xs text-warm-500 dark:text-warm-400">
|
||||
<span
|
||||
className="inline-block w-2 h-2 rounded-full flex-shrink-0"
|
||||
style={{
|
||||
backgroundColor: `rgb(${getPoiGroupColor(poi.group).join(',')})`,
|
||||
}}
|
||||
/>
|
||||
{ts(poi.category)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{poi.school && renderSchoolMetadata(poi.school)}
|
||||
</div>
|
||||
);
|
||||
});
|
||||
131
frontend/src/components/map/map-page/useMobileDrawer.ts
Normal file
131
frontend/src/components/map/map-page/useMobileDrawer.ts
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
import { useCallback, useRef, useState } from 'react';
|
||||
import type { MutableRefObject } from 'react';
|
||||
|
||||
import type { MapFlyToOptions } from '../../../types';
|
||||
import type { MapFlyTo } from './types';
|
||||
|
||||
export interface PendingFlyTo {
|
||||
lat: number;
|
||||
lng: number;
|
||||
zoom: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Mobile drawer / bottom sheet state plus the fly-to plumbing that keeps a
|
||||
* selected target visible above them. Fly-tos requested while the drawer panel
|
||||
* hasn't measured itself yet are parked in refs and consumed once the panel
|
||||
* rect arrives, so the camera lands in the area the drawer leaves uncovered.
|
||||
*/
|
||||
export function useMobileDrawer(isMobile: boolean, flyToRef: MutableRefObject<MapFlyTo | null>) {
|
||||
const [mobileDrawerOpen, setMobileDrawerOpen] = useState(false);
|
||||
const [mobileBottomSheetHeight, setMobileBottomSheetHeight] = useState(0);
|
||||
const mobileDrawerPanelRectRef = useRef<DOMRectReadOnly | null>(null);
|
||||
const pendingCurrentLocationFlyToRef = useRef<{ lat: number; lng: number } | null>(null);
|
||||
const pendingLocationSearchFlyToRef = useRef<PendingFlyTo | null>(null);
|
||||
|
||||
const consumePendingLocationSearchFlyTo = useCallback(
|
||||
(rect?: DOMRectReadOnly | null) => {
|
||||
const pending = pendingLocationSearchFlyToRef.current;
|
||||
const panelRect = rect ?? mobileDrawerPanelRectRef.current;
|
||||
if (!pending || !panelRect) return;
|
||||
|
||||
const bottomInset = Math.max(0, window.innerHeight - panelRect.top);
|
||||
const flyTo = flyToRef.current;
|
||||
if (!flyTo) return;
|
||||
flyTo(pending.lat, pending.lng, pending.zoom, {
|
||||
visibleViewportArea: { bottom: bottomInset },
|
||||
});
|
||||
pendingLocationSearchFlyToRef.current = null;
|
||||
},
|
||||
[flyToRef]
|
||||
);
|
||||
|
||||
const consumePendingCurrentLocationFlyTo = useCallback(
|
||||
(rect?: DOMRectReadOnly | null) => {
|
||||
const pending = pendingCurrentLocationFlyToRef.current;
|
||||
const panelRect = rect ?? mobileDrawerPanelRectRef.current;
|
||||
if (!pending || !panelRect) return;
|
||||
|
||||
const bottomInset = Math.max(0, window.innerHeight - panelRect.top);
|
||||
const flyTo = flyToRef.current;
|
||||
if (!flyTo) return;
|
||||
flyTo(pending.lat, pending.lng, 17, {
|
||||
visibleViewportArea: { bottom: bottomInset },
|
||||
});
|
||||
pendingCurrentLocationFlyToRef.current = null;
|
||||
},
|
||||
[flyToRef]
|
||||
);
|
||||
|
||||
const openMobileDrawer = useCallback(() => {
|
||||
setMobileDrawerOpen(true);
|
||||
}, []);
|
||||
|
||||
/** Open the drawer and fly to the searched location once the panel rect is known. */
|
||||
const openMobileDrawerForLocationSearch = useCallback(
|
||||
(target: PendingFlyTo) => {
|
||||
pendingLocationSearchFlyToRef.current = target;
|
||||
setMobileDrawerOpen(true);
|
||||
consumePendingLocationSearchFlyTo();
|
||||
},
|
||||
[consumePendingLocationSearchFlyTo]
|
||||
);
|
||||
|
||||
const clearPendingLocationSearchFlyTo = useCallback(() => {
|
||||
pendingLocationSearchFlyToRef.current = null;
|
||||
}, []);
|
||||
|
||||
/** Park a current-location fly-to until the drawer panel has measured itself. */
|
||||
const queueCurrentLocationFlyTo = useCallback(
|
||||
(lat: number, lng: number) => {
|
||||
pendingCurrentLocationFlyToRef.current = { lat, lng };
|
||||
consumePendingCurrentLocationFlyTo();
|
||||
},
|
||||
[consumePendingCurrentLocationFlyTo]
|
||||
);
|
||||
|
||||
const handleMobileDrawerPanelRectChange = useCallback(
|
||||
(rect: DOMRectReadOnly) => {
|
||||
mobileDrawerPanelRectRef.current = rect;
|
||||
consumePendingCurrentLocationFlyTo(rect);
|
||||
consumePendingLocationSearchFlyTo(rect);
|
||||
},
|
||||
[consumePendingCurrentLocationFlyTo, consumePendingLocationSearchFlyTo]
|
||||
);
|
||||
|
||||
const handleMobileDrawerClose = useCallback(() => {
|
||||
pendingCurrentLocationFlyToRef.current = null;
|
||||
pendingLocationSearchFlyToRef.current = null;
|
||||
mobileDrawerPanelRectRef.current = null;
|
||||
setMobileDrawerOpen(false);
|
||||
}, []);
|
||||
|
||||
const getMobileMapFlyToOptions = useCallback((): MapFlyToOptions | undefined => {
|
||||
if (!isMobile) return undefined;
|
||||
|
||||
const panelRect = mobileDrawerPanelRectRef.current;
|
||||
if (mobileDrawerOpen && panelRect) {
|
||||
const bottomInset = Math.max(0, window.innerHeight - panelRect.top);
|
||||
if (bottomInset > 0) {
|
||||
return { visibleViewportArea: { bottom: bottomInset } };
|
||||
}
|
||||
}
|
||||
|
||||
return mobileBottomSheetHeight > 0
|
||||
? { visibleArea: { bottom: mobileBottomSheetHeight } }
|
||||
: undefined;
|
||||
}, [isMobile, mobileBottomSheetHeight, mobileDrawerOpen]);
|
||||
|
||||
return {
|
||||
mobileDrawerOpen,
|
||||
mobileBottomSheetHeight,
|
||||
setMobileBottomSheetHeight,
|
||||
openMobileDrawer,
|
||||
openMobileDrawerForLocationSearch,
|
||||
clearPendingLocationSearchFlyTo,
|
||||
queueCurrentLocationFlyTo,
|
||||
handleMobileDrawerPanelRectChange,
|
||||
handleMobileDrawerClose,
|
||||
getMobileMapFlyToOptions,
|
||||
};
|
||||
}
|
||||
|
|
@ -7,11 +7,11 @@ interface SubNavProps {
|
|||
export function SubNav({ tabs, activeTab, onTabChange }: SubNavProps) {
|
||||
return (
|
||||
<div className="max-w-5xl mx-auto w-full px-6 pt-4">
|
||||
<div className="flex gap-2 border-b border-warm-200 dark:border-warm-700">
|
||||
<div className="flex gap-2 overflow-x-auto border-b border-warm-200 dark:border-warm-700">
|
||||
{tabs.map((tab) => (
|
||||
<button
|
||||
key={tab.key}
|
||||
className={`cursor-pointer px-4 py-2 text-sm font-medium border-b-2 ${
|
||||
className={`cursor-pointer shrink-0 whitespace-nowrap px-4 py-2 text-sm font-medium border-b-2 ${
|
||||
activeTab === tab.key
|
||||
? 'border-teal-500 text-teal-700 dark:text-teal-400'
|
||||
: 'border-transparent text-warm-500 dark:text-warm-400 hover:text-warm-700 dark:hover:text-warm-300'
|
||||
|
|
|
|||
43
frontend/src/hooks/useMapCardLayout.ts
Normal file
43
frontend/src/hooks/useMapCardLayout.ts
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
import { useMemo } from 'react';
|
||||
|
||||
const DESKTOP_TOP_CARD_WIDTH = 300;
|
||||
const DESKTOP_TOP_CARD_GAP = 8;
|
||||
const DESKTOP_TOP_CARD_HORIZONTAL_INSET = 24;
|
||||
const DESKTOP_TOP_CARDS_STACKED_MIN_MAP_WIDTH =
|
||||
DESKTOP_TOP_CARD_WIDTH + DESKTOP_TOP_CARD_HORIZONTAL_INSET;
|
||||
const DESKTOP_TOP_CARDS_ROW_MIN_MAP_WIDTH =
|
||||
DESKTOP_TOP_CARD_WIDTH * 2 + DESKTOP_TOP_CARD_GAP + DESKTOP_TOP_CARD_HORIZONTAL_INSET;
|
||||
|
||||
interface UseMapCardLayoutOptions {
|
||||
mapWidth: number;
|
||||
hideTopCardsWhenNarrow: boolean;
|
||||
hideLegend: boolean;
|
||||
hideLocationSearch: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Desktop top-card layout for the map overlay area: hides the cards entirely
|
||||
* when the map is too narrow for a single card, and stacks them vertically
|
||||
* when there is room for one card but not for two side by side.
|
||||
*/
|
||||
export function useMapCardLayout({
|
||||
mapWidth,
|
||||
hideTopCardsWhenNarrow,
|
||||
hideLegend,
|
||||
hideLocationSearch,
|
||||
}: UseMapCardLayoutOptions) {
|
||||
return useMemo(() => {
|
||||
const hideTopCardsForWidth =
|
||||
hideTopCardsWhenNarrow && mapWidth > 0 && mapWidth < DESKTOP_TOP_CARDS_STACKED_MIN_MAP_WIDTH;
|
||||
const stackTopCards =
|
||||
hideTopCardsWhenNarrow &&
|
||||
mapWidth >= DESKTOP_TOP_CARDS_STACKED_MIN_MAP_WIDTH &&
|
||||
mapWidth < DESKTOP_TOP_CARDS_ROW_MIN_MAP_WIDTH;
|
||||
|
||||
return {
|
||||
showLocationSearch: !hideLocationSearch && !hideTopCardsForWidth,
|
||||
showLegend: !hideLegend && !hideTopCardsForWidth,
|
||||
topCardsLayoutClass: stackTopCards ? 'flex-col items-start' : 'items-start justify-between',
|
||||
};
|
||||
}, [mapWidth, hideTopCardsWhenNarrow, hideLegend, hideLocationSearch]);
|
||||
}
|
||||
|
|
@ -880,6 +880,7 @@ const de: Translations = {
|
|||
walk: 'Zu Fuß',
|
||||
cycle: 'Fahrrad',
|
||||
nationalAvg: 'England-Schnitt',
|
||||
crimeDataEnds: 'Polizeidaten für dieses Gebiet enden {{year}}',
|
||||
},
|
||||
|
||||
// ── Street View ────────────────────────────────────
|
||||
|
|
|
|||
|
|
@ -864,6 +864,7 @@ const en = {
|
|||
walk: 'Walk',
|
||||
cycle: 'Cycle',
|
||||
nationalAvg: 'National avg',
|
||||
crimeDataEnds: 'Police data for this area ends {{year}}',
|
||||
},
|
||||
|
||||
// ── Street View ────────────────────────────────────
|
||||
|
|
|
|||
|
|
@ -893,6 +893,7 @@ const fr: Translations = {
|
|||
walk: 'Marche',
|
||||
cycle: 'Vélo',
|
||||
nationalAvg: 'Moyenne nationale',
|
||||
crimeDataEnds: 'Les données de police pour cette zone s\'arrêtent en {{year}}',
|
||||
},
|
||||
|
||||
// ── Street View ────────────────────────────────────
|
||||
|
|
|
|||
|
|
@ -852,6 +852,7 @@ const hi: Translations = {
|
|||
walk: 'पैदल',
|
||||
cycle: 'साइकिल',
|
||||
nationalAvg: 'राष्ट्रीय औसत',
|
||||
crimeDataEnds: 'इस क्षेत्र के लिए पुलिस डेटा {{year}} में समाप्त होता है',
|
||||
},
|
||||
|
||||
streetView: {
|
||||
|
|
|
|||
|
|
@ -881,6 +881,7 @@ const hu: Translations = {
|
|||
walk: 'Gyalog',
|
||||
cycle: 'Kerékpár',
|
||||
nationalAvg: 'Országos átlag',
|
||||
crimeDataEnds: 'A körzet rendőrségi adatai {{year}}-ig érhetők el',
|
||||
},
|
||||
|
||||
// ── Street View ────────────────────────────────────
|
||||
|
|
|
|||
|
|
@ -823,6 +823,7 @@ const zh: Translations = {
|
|||
walk: '步行',
|
||||
cycle: '骑行',
|
||||
nationalAvg: '全国平均',
|
||||
crimeDataEnds: '该地区的警方数据截至{{year}}年',
|
||||
},
|
||||
|
||||
// ── Street View ────────────────────────────────────
|
||||
|
|
|
|||
|
|
@ -303,6 +303,12 @@ export interface HexagonStatsResponse {
|
|||
price_history?: PricePoint[];
|
||||
/** Per-crime-type per-year counts averaged across the selection. */
|
||||
crime_by_year?: CrimeYearStats[];
|
||||
/**
|
||||
* Latest year in the crime dataset as a whole. A selection whose series end
|
||||
* earlier sits in a force-level publication gap (e.g. Greater Manchester
|
||||
* since mid-2019) and its crime figures are captioned as stale.
|
||||
*/
|
||||
crime_latest_year?: number;
|
||||
central_postcode?: string;
|
||||
filter_exclusions?: FilterExclusion[];
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,10 +24,11 @@ from pathlib import Path
|
|||
import numpy as np
|
||||
import polars as pl
|
||||
|
||||
from pipeline.utils.normalize import collapse_whitespace, replace_non_alnum_lower
|
||||
|
||||
_NOISE_WORDS = re.compile(
|
||||
r"\b(the|of|and|c\s*of\s*e|cofe|ce|rc|voluntary|aided|controlled|va|vc)\b"
|
||||
)
|
||||
_NON_ALNUM = re.compile(r"[^a-z0-9 ]")
|
||||
_SCHOOL_WORDS = re.compile(
|
||||
r"\b(school|academy|primary|secondary|junior|infant|community|college|high)\b"
|
||||
)
|
||||
|
|
@ -35,16 +36,16 @@ _SCHOOL_WORDS = re.compile(
|
|||
|
||||
def normalize_name(name: str, strip_school_words: bool = False) -> str:
|
||||
s = name.lower().replace("&", " and ").replace("st.", "st ").replace("'", "")
|
||||
s = _NON_ALNUM.sub(" ", s)
|
||||
s = replace_non_alnum_lower(s)
|
||||
s = _NOISE_WORDS.sub(" ", s)
|
||||
if strip_school_words:
|
||||
s = _SCHOOL_WORDS.sub(" ", s)
|
||||
return " ".join(s.split())
|
||||
return collapse_whitespace(s)
|
||||
|
||||
|
||||
def normalize_la(la: str) -> str:
|
||||
s = _NON_ALNUM.sub(" ", la.lower().replace("&", " and "))
|
||||
return " ".join(s.replace("city of", "").split())
|
||||
s = replace_non_alnum_lower(la.lower().replace("&", " and "))
|
||||
return collapse_whitespace(s.replace("city of", ""))
|
||||
|
||||
|
||||
def load_ground_truth(directory: Path) -> pl.DataFrame:
|
||||
|
|
|
|||
|
|
@ -171,41 +171,86 @@ def parse_contained_range(contained_range: str) -> tuple[str, str] | None:
|
|||
return start, end
|
||||
|
||||
|
||||
def select_coverage_archives(archives: list[CrimeArchive]) -> list[CrimeArchive]:
|
||||
"""Select non-overlapping snapshots that still cover the available history.
|
||||
def _index_to_month(index: int) -> str:
|
||||
year, month_num = divmod(index, 12)
|
||||
if month_num == 0:
|
||||
year -= 1
|
||||
month_num = 12
|
||||
return f"{year:04d}-{month_num:02d}"
|
||||
|
||||
|
||||
def select_coverage_archives(
|
||||
archives: list[CrimeArchive], *, allow_gaps: bool = False
|
||||
) -> list[CrimeArchive]:
|
||||
"""Select snapshots whose ranges chain together to cover the available history.
|
||||
|
||||
The source publishes rolling multi-year snapshots. Downloading every monthly
|
||||
snapshot mostly fetches duplicate data; for our aggregate LSOA counts we only
|
||||
need continuous month coverage.
|
||||
need continuous month coverage. Greedy interval cover, newest first: anchor
|
||||
on the snapshot with the latest end month, then repeatedly take the archive
|
||||
reaching furthest back among those adjacent to or overlapping the covered
|
||||
range. Accepting an overlapping snapshot (rather than only an exactly
|
||||
adjacent one) matters when the adjacent snapshot is missing from the index:
|
||||
skipping it would leave a multi-month hole, while overlap only costs
|
||||
download time because extraction skips already-extracted months. A hole no
|
||||
archive can bridge is a publication gap in the source — a hard error unless
|
||||
``allow_gaps``, since the run would otherwise be stamped complete with
|
||||
artificial dips in every crime-over-time series.
|
||||
"""
|
||||
selected: list[CrimeArchive] = []
|
||||
earliest_covered_start: int | None = None
|
||||
|
||||
def sort_key(archive: CrimeArchive) -> int:
|
||||
parsed_range = parse_contained_range(archive.contained_range)
|
||||
if parsed_range is not None:
|
||||
return _month_to_index(parsed_range[1])
|
||||
return _month_to_index(archive.month)
|
||||
|
||||
for archive in sorted(archives, key=sort_key, reverse=True):
|
||||
ranged: list[tuple[int, int, CrimeArchive]] = []
|
||||
for archive in archives:
|
||||
parsed_range = parse_contained_range(archive.contained_range)
|
||||
if parsed_range is None:
|
||||
selected.append(archive)
|
||||
continue
|
||||
|
||||
start, end = parsed_range
|
||||
start_index = _month_to_index(start)
|
||||
end_index = _month_to_index(end)
|
||||
if earliest_covered_start is None or end_index < earliest_covered_start:
|
||||
if (
|
||||
earliest_covered_start is not None
|
||||
and end_index < earliest_covered_start - 1
|
||||
):
|
||||
print(
|
||||
"Warning: archive ranges are not adjacent; "
|
||||
f"coverage gap before {archive.filename}",
|
||||
file=sys.stderr,
|
||||
else:
|
||||
ranged.append(
|
||||
(
|
||||
_month_to_index(parsed_range[0]),
|
||||
_month_to_index(parsed_range[1]),
|
||||
archive,
|
||||
)
|
||||
)
|
||||
|
||||
earliest_covered_start: int | None = None
|
||||
while True:
|
||||
if earliest_covered_start is None:
|
||||
eligible = ranged
|
||||
else:
|
||||
eligible = [item for item in ranged if item[0] < earliest_covered_start]
|
||||
if not eligible:
|
||||
break
|
||||
|
||||
if earliest_covered_start is None:
|
||||
# Anchor: latest end month, reaching as far back as available.
|
||||
start_index, _, archive = max(
|
||||
eligible, key=lambda item: (item[1], -item[0])
|
||||
)
|
||||
else:
|
||||
chained = [
|
||||
item for item in eligible if item[1] >= earliest_covered_start - 1
|
||||
]
|
||||
if not chained:
|
||||
hole_start = max(item[1] for item in eligible) + 1
|
||||
message = (
|
||||
"no archive covers "
|
||||
f"{_index_to_month(hole_start)} to "
|
||||
f"{_index_to_month(earliest_covered_start - 1)}"
|
||||
)
|
||||
if not allow_gaps:
|
||||
raise RuntimeError(
|
||||
f"Coverage gap: {message}. Rerun with --allow-gaps to "
|
||||
"accept the hole."
|
||||
)
|
||||
print(f"Warning: coverage gap: {message}", file=sys.stderr)
|
||||
chained = eligible
|
||||
# Furthest backward reach; on a start tie prefer the newer
|
||||
# snapshot, whose data for the months around the boundary carries
|
||||
# the latest revisions.
|
||||
start_index, _, archive = min(
|
||||
chained, key=lambda item: (item[0], -item[1])
|
||||
)
|
||||
|
||||
selected.append(archive)
|
||||
earliest_covered_start = start_index
|
||||
|
||||
|
|
@ -331,14 +376,24 @@ def extract_csvs(
|
|||
*,
|
||||
overwrite: bool = False,
|
||||
street_only: bool = True,
|
||||
extracted_this_run: set[PurePosixPath] | None = None,
|
||||
) -> tuple[int, int]:
|
||||
"""Extract CSVs from one ZIP. Returns (extracted, skipped)."""
|
||||
"""Extract CSVs from one ZIP. Returns (extracted, skipped).
|
||||
|
||||
``extracted_this_run`` is shared across the archives of one run, processed
|
||||
newest-snapshot first: a member already written by a newer snapshot is
|
||||
skipped even with ``overwrite``, so an older overlapping archive can never
|
||||
replace a month with a less-revised copy.
|
||||
"""
|
||||
extracted = 0
|
||||
skipped = 0
|
||||
|
||||
with zipfile.ZipFile(zip_path) as archive:
|
||||
for info, rel_path in _safe_csv_members(archive, street_only=street_only):
|
||||
dest = output_dir.joinpath(*rel_path.parts)
|
||||
if extracted_this_run is not None and rel_path in extracted_this_run:
|
||||
skipped += 1
|
||||
continue
|
||||
if dest.exists() and not overwrite:
|
||||
skipped += 1
|
||||
continue
|
||||
|
|
@ -347,6 +402,8 @@ def extract_csvs(
|
|||
with archive.open(info) as source, dest.open("wb") as target:
|
||||
shutil.copyfileobj(source, target)
|
||||
extracted += 1
|
||||
if extracted_this_run is not None:
|
||||
extracted_this_run.add(rel_path)
|
||||
|
||||
return extracted, skipped
|
||||
|
||||
|
|
@ -489,8 +546,22 @@ def main() -> None:
|
|||
)
|
||||
parser.add_argument(
|
||||
"--overwrite-extracted",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help=(
|
||||
"Replace previously extracted CSVs with this run's snapshot data "
|
||||
"(police.uk revises the trailing 36 months in every release, so "
|
||||
"keeping old extractions freezes stale revisions; within a run the "
|
||||
"newest snapshot still wins for overlapping months)"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow-gaps",
|
||||
action="store_true",
|
||||
help="Overwrite CSVs when extracting overlapping archive snapshots",
|
||||
help=(
|
||||
"Continue past months no archive covers instead of failing "
|
||||
"(coverage strategy only)"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-verify",
|
||||
|
|
@ -521,7 +592,7 @@ def main() -> None:
|
|||
limit=args.limit,
|
||||
)
|
||||
archives = (
|
||||
select_coverage_archives(available_archives)
|
||||
select_coverage_archives(available_archives, allow_gaps=args.allow_gaps)
|
||||
if args.archive_strategy == "coverage"
|
||||
else available_archives
|
||||
)
|
||||
|
|
@ -570,6 +641,7 @@ def main() -> None:
|
|||
|
||||
total_extracted = 0
|
||||
total_skipped = 0
|
||||
extracted_this_run: set[PurePosixPath] = set()
|
||||
for index, archive in enumerate(archives, start=1):
|
||||
print(f"[{index}/{len(archives)}] {archive.label} ({archive.size})")
|
||||
zip_path = download_archive(
|
||||
|
|
@ -585,6 +657,7 @@ def main() -> None:
|
|||
args.output,
|
||||
overwrite=args.overwrite_extracted,
|
||||
street_only=street_only,
|
||||
extracted_this_run=extracted_this_run,
|
||||
)
|
||||
total_extracted += extracted
|
||||
total_skipped += skipped
|
||||
|
|
|
|||
|
|
@ -16,12 +16,12 @@ License: Open Government Licence v3.0
|
|||
"""
|
||||
|
||||
import argparse
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import polars as pl
|
||||
|
||||
from pipeline.utils import ENGLAND_LSOA_COUNT_2021, download_nomis_csv
|
||||
|
||||
pl.Config.set_tbl_cols(-1)
|
||||
|
||||
# NOMIS API: Census 2021 TS021 (ethnic group, 20 categories) by LSOA 2021
|
||||
|
|
@ -35,7 +35,6 @@ BASE_URL = (
|
|||
"&measures=20100"
|
||||
"&select=GEOGRAPHY_CODE,C2021_ETH_20_NAME,OBS_VALUE"
|
||||
)
|
||||
PAGE_SIZE = 25000
|
||||
|
||||
# Map the 19 detailed NOMIS C2021_ETH_20 leaf categories to our 7 output groups.
|
||||
# The Asian split:
|
||||
|
|
@ -150,24 +149,7 @@ def _ethnicity_percentages(df: pl.DataFrame) -> pl.DataFrame:
|
|||
|
||||
def download_and_convert(output_path: Path) -> None:
|
||||
print("Downloading Census 2021 ethnic group (TS021) by LSOA from NOMIS...")
|
||||
frames = []
|
||||
offset = 0
|
||||
while True:
|
||||
url = f"{BASE_URL}&recordoffset={offset}"
|
||||
response = httpx.get(url, follow_redirects=True, timeout=120)
|
||||
response.raise_for_status()
|
||||
if len(response.content) == 0:
|
||||
break
|
||||
chunk = pl.read_csv(BytesIO(response.content))
|
||||
if chunk.height == 0:
|
||||
break
|
||||
frames.append(chunk)
|
||||
print(f" Fetched {chunk.height} rows (offset={offset})")
|
||||
if chunk.height < PAGE_SIZE:
|
||||
break
|
||||
offset += PAGE_SIZE
|
||||
|
||||
df = pl.concat(frames)
|
||||
df = download_nomis_csv(BASE_URL)
|
||||
print(f"Total rows: {df.height}")
|
||||
|
||||
# Filter to England only (E-prefixed LSOA codes); the merge joins on the
|
||||
|
|
@ -177,6 +159,11 @@ def download_and_convert(output_path: Path) -> None:
|
|||
wide = _ethnicity_percentages(df)
|
||||
|
||||
print(f"England LSOAs: {wide.height}")
|
||||
if wide.height != ENGLAND_LSOA_COUNT_2021:
|
||||
raise ValueError(
|
||||
f"Expected {ENGLAND_LSOA_COUNT_2021} England LSOAs, "
|
||||
f"got {wide.height}: truncated NOMIS download?"
|
||||
)
|
||||
print(f"Columns: {wide.columns}")
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
|
|
|||
|
|
@ -241,9 +241,11 @@ def transform(zip_bytes: bytes) -> pl.DataFrame:
|
|||
"""Convert the GIAS extract ZIP into a clean schools DataFrame."""
|
||||
raw = _read_csv_from_zip(zip_bytes)
|
||||
|
||||
# Filter to currently-open establishments; the CSV also includes closed,
|
||||
# proposed-to-open, and proposed-to-close rows we do not want on a map.
|
||||
df = raw.filter(pl.col("EstablishmentStatus (name)") == "Open")
|
||||
# Filter to currently-open establishments; the CSV also includes closed and
|
||||
# proposed-to-open rows we do not want on a map. "Open, but proposed to
|
||||
# close" schools are open, operating establishments (GIAS can keep that
|
||||
# status for years, e.g. pending amalgamations), so they must stay.
|
||||
df = raw.filter(pl.col("EstablishmentStatus (name)").str.starts_with("Open"))
|
||||
|
||||
df = df.with_columns(
|
||||
pl.col("URN").cast(pl.Int64),
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from pathlib import Path
|
|||
import osmium
|
||||
import polars as pl
|
||||
from pyproj import Transformer
|
||||
from shapely import wkb
|
||||
from shapely import make_valid, wkb
|
||||
from shapely.errors import GEOSException
|
||||
from shapely.geometry import MultiPolygon, Polygon
|
||||
from tqdm import tqdm
|
||||
|
|
@ -56,6 +56,22 @@ def _to_bng_polygon(geom):
|
|||
return geom
|
||||
|
||||
|
||||
def _polygonal_part(geom):
|
||||
"""The Polygon/MultiPolygon content of a geometry, or None if there is none."""
|
||||
if geom.geom_type in ("Polygon", "MultiPolygon"):
|
||||
return geom
|
||||
if geom.geom_type == "GeometryCollection":
|
||||
polygons = []
|
||||
for part in geom.geoms:
|
||||
if part.geom_type == "Polygon":
|
||||
polygons.append(part)
|
||||
elif part.geom_type == "MultiPolygon":
|
||||
polygons.extend(part.geoms)
|
||||
if polygons:
|
||||
return MultiPolygon(polygons)
|
||||
return None
|
||||
|
||||
|
||||
def _matches_tags(tags):
|
||||
"""Check if an OSM element's tags match our greenspace/water criteria."""
|
||||
for key, values in GREENSPACE_TAGS.items():
|
||||
|
|
@ -91,7 +107,13 @@ class GreenspaceHandler(osmium.SimpleHandler):
|
|||
)
|
||||
return
|
||||
|
||||
if geom.is_empty or not geom.is_valid:
|
||||
# Invalid geometries are often the largest, most complex park/water
|
||||
# multipolygons (self-touching rings from OSM) — repair like pois.py
|
||||
# rather than silently dropping them. make_valid may return a
|
||||
# GeometryCollection with stray lines/points; keep only the polygons.
|
||||
if not geom.is_valid:
|
||||
geom = _polygonal_part(make_valid(geom))
|
||||
if geom is None or geom.is_empty:
|
||||
return
|
||||
|
||||
# Reproject to BNG for area calculation
|
||||
|
|
|
|||
|
|
@ -5,18 +5,14 @@ License: Open Government Licence v3.0
|
|||
"""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import pyogrio
|
||||
from pipeline.utils import download_arcgis_hub_export
|
||||
|
||||
URL = (
|
||||
"https://opendata-historicengland.hub.arcgis.com/api/download/v1/items/"
|
||||
"767f279327a24845bf47dfe5eae9862b/geoPackage?layers=0"
|
||||
)
|
||||
POLL_INTERVAL_S = 5
|
||||
POLL_TIMEOUT_S = 600
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
|
@ -28,37 +24,9 @@ def main() -> None:
|
|||
)
|
||||
args = parser.parse_args()
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = args.output.with_name(f"{args.output.stem}.tmp{args.output.suffix}")
|
||||
|
||||
print("Downloading Historic England listed-building points...")
|
||||
deadline = time.monotonic() + POLL_TIMEOUT_S
|
||||
with httpx.Client(follow_redirects=True, timeout=300) as client:
|
||||
while True:
|
||||
with client.stream("GET", URL) as response:
|
||||
if response.status_code == 202:
|
||||
response.read()
|
||||
if time.monotonic() > deadline:
|
||||
raise TimeoutError(
|
||||
f"Export did not finish within {POLL_TIMEOUT_S}s: "
|
||||
f"{response.text}"
|
||||
)
|
||||
time.sleep(POLL_INTERVAL_S)
|
||||
continue
|
||||
response.raise_for_status()
|
||||
with tmp_path.open("wb") as fh:
|
||||
for chunk in response.iter_bytes():
|
||||
fh.write(chunk)
|
||||
break
|
||||
|
||||
info = pyogrio.read_info(tmp_path)
|
||||
features = info.get("features", 0)
|
||||
geometry_type = str(info.get("geometry_type") or "")
|
||||
if features <= 0:
|
||||
raise ValueError("Downloaded listed-buildings file contains no features")
|
||||
if "Point" not in geometry_type:
|
||||
raise ValueError(f"Expected point geometry, got {geometry_type!r}")
|
||||
|
||||
tmp_path.replace(args.output)
|
||||
features = download_arcgis_hub_export(URL, args.output, expected_geometry="Point")
|
||||
size_mb = args.output.stat().st_size / (1024 * 1024)
|
||||
print(
|
||||
f"Saved {features} listed-building points to {args.output} ({size_mb:.1f} MB)"
|
||||
|
|
|
|||
|
|
@ -10,21 +10,19 @@ of the 0-4, 10-14 and 15-19 bands (one fifth per single year of age).
|
|||
"""
|
||||
|
||||
import argparse
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import polars as pl
|
||||
|
||||
from pipeline.utils import ENGLAND_LSOA_COUNT_2021, download_nomis_csv
|
||||
|
||||
# NOMIS API: Census 2021 TS007A (age, five-year bands) by LSOA 2021 (TYPE151).
|
||||
# c2021_age_19 codes: 1 = 0-4, 2 = 5-9, 3 = 10-14, 4 = 15-19.
|
||||
# NOMIS paginates at 25,000 rows by default, so we paginate with recordoffset.
|
||||
BASE_URL = (
|
||||
"https://www.nomisweb.co.uk/api/v01/dataset/NM_2020_1.data.csv"
|
||||
"?date=latest&geography=TYPE151&measures=20100&c2021_age_19=1,2,3,4"
|
||||
"&select=GEOGRAPHY_CODE,C2021_AGE_19,OBS_VALUE"
|
||||
)
|
||||
PAGE_SIZE = 25000
|
||||
|
||||
AGE_BAND_COLUMNS = {
|
||||
1: "aged_0_4",
|
||||
|
|
@ -36,24 +34,7 @@ AGE_BAND_COLUMNS = {
|
|||
|
||||
def download_and_convert(output_path: Path) -> None:
|
||||
print("Downloading Census 2021 LSOA age bands from NOMIS...")
|
||||
frames = []
|
||||
offset = 0
|
||||
while True:
|
||||
url = f"{BASE_URL}&recordoffset={offset}"
|
||||
response = httpx.get(url, follow_redirects=True, timeout=120)
|
||||
response.raise_for_status()
|
||||
if len(response.content) == 0:
|
||||
break
|
||||
chunk = pl.read_csv(BytesIO(response.content))
|
||||
if chunk.height == 0:
|
||||
break
|
||||
frames.append(chunk)
|
||||
print(f" Fetched {chunk.height} rows (offset={offset})")
|
||||
if chunk.height < PAGE_SIZE:
|
||||
break
|
||||
offset += PAGE_SIZE
|
||||
|
||||
df = pl.concat(frames)
|
||||
df = download_nomis_csv(BASE_URL)
|
||||
print(f"Total rows: {df.height}")
|
||||
|
||||
result = (
|
||||
|
|
@ -70,6 +51,11 @@ def download_and_convert(output_path: Path) -> None:
|
|||
raise ValueError(f"NOMIS response missing age bands: {missing}")
|
||||
|
||||
print(f"England LSOAs: {result.height}")
|
||||
if result.height != ENGLAND_LSOA_COUNT_2021:
|
||||
raise ValueError(
|
||||
f"Expected {ENGLAND_LSOA_COUNT_2021} England LSOAs, "
|
||||
f"got {result.height}: truncated NOMIS download?"
|
||||
)
|
||||
for name in AGE_BAND_COLUMNS.values():
|
||||
print(f" {name}: total {result[name].sum():,}")
|
||||
|
||||
|
|
|
|||
|
|
@ -5,39 +5,20 @@ License: Open Government Licence v3.0
|
|||
"""
|
||||
|
||||
import argparse
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import polars as pl
|
||||
|
||||
from pipeline.utils import ENGLAND_LSOA_COUNT_2021, download_nomis_csv
|
||||
|
||||
# NOMIS API: Census 2021 TS001 (usual residents) by LSOA 2021 (TYPE151)
|
||||
# c2021_restype_3=0 selects "Total: All usual residents"
|
||||
# NOMIS paginates at 25,000 rows by default, so we paginate with recordoffset.
|
||||
BASE_URL = "https://www.nomisweb.co.uk/api/v01/dataset/NM_2021_1.data.csv?date=latest&geography=TYPE151&measures=20100&c2021_restype_3=0&select=GEOGRAPHY_CODE,OBS_VALUE"
|
||||
PAGE_SIZE = 25000
|
||||
|
||||
|
||||
def download_and_convert(output_path: Path) -> None:
|
||||
print("Downloading Census 2021 LSOA population from NOMIS...")
|
||||
frames = []
|
||||
offset = 0
|
||||
while True:
|
||||
url = f"{BASE_URL}&recordoffset={offset}"
|
||||
response = httpx.get(url, follow_redirects=True, timeout=120)
|
||||
response.raise_for_status()
|
||||
if len(response.content) == 0:
|
||||
break
|
||||
chunk = pl.read_csv(BytesIO(response.content))
|
||||
if chunk.height == 0:
|
||||
break
|
||||
frames.append(chunk)
|
||||
print(f" Fetched {chunk.height} rows (offset={offset})")
|
||||
if chunk.height < PAGE_SIZE:
|
||||
break
|
||||
offset += PAGE_SIZE
|
||||
|
||||
df = pl.concat(frames)
|
||||
df = download_nomis_csv(BASE_URL)
|
||||
print(f"Total rows: {df.height}")
|
||||
|
||||
result = df.rename(
|
||||
|
|
@ -50,6 +31,11 @@ def download_and_convert(output_path: Path) -> None:
|
|||
result = result.filter(pl.col("lsoa21").str.starts_with("E"))
|
||||
|
||||
print(f"England LSOAs: {result.height}")
|
||||
if result.height != ENGLAND_LSOA_COUNT_2021:
|
||||
raise ValueError(
|
||||
f"Expected {ENGLAND_LSOA_COUNT_2021} England LSOAs, "
|
||||
f"got {result.height}: truncated NOMIS download?"
|
||||
)
|
||||
print(
|
||||
f"Population range: {result['population'].min()} - {result['population'].max()}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import base64
|
|||
import json
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import urllib.request
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from io import BytesIO
|
||||
|
|
@ -120,17 +121,28 @@ def collect_twemoji_codes() -> list[str]:
|
|||
return sorted({f"{ord(e[0]):x}" for e in emojis})
|
||||
|
||||
|
||||
DOWNLOAD_ATTEMPTS = 3
|
||||
RETRY_BACKOFF_S = 2.0
|
||||
|
||||
|
||||
def download_file(url: str, dest: Path) -> tuple[bool, str]:
|
||||
"""Download a single file. Returns (success, url)."""
|
||||
"""Download a single file, retrying transient errors. Returns (success, url)."""
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
for attempt in range(DOWNLOAD_ATTEMPTS):
|
||||
if attempt:
|
||||
time.sleep(RETRY_BACKOFF_S * 2 ** (attempt - 1))
|
||||
try:
|
||||
urllib.request.urlretrieve(url, dest)
|
||||
return True, url
|
||||
except urllib.error.HTTPError as e:
|
||||
# 4xx is a permanent answer (bad glyph range / missing emoji);
|
||||
# retrying won't change it.
|
||||
if 400 <= e.code < 500:
|
||||
print(f" {e.code} {url}", file=sys.stderr)
|
||||
return False, url
|
||||
print(f" {e.code} {url} (attempt {attempt + 1})", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f" ERROR {url}: {e}", file=sys.stderr)
|
||||
print(f" ERROR {url}: {e} (attempt {attempt + 1})", file=sys.stderr)
|
||||
return False, url
|
||||
|
||||
|
||||
|
|
@ -389,37 +401,38 @@ def main():
|
|||
url = f"{POI_ICON_BASE}/{icon_path}"
|
||||
tasks.append((url, poi_icons_dir / icon_path))
|
||||
|
||||
# Skip already-downloaded files
|
||||
remaining = [(url, dest) for url, dest in tasks]
|
||||
|
||||
print(f"Downloading {len(remaining) + len(DERIVED_POI_ICON_PATHS)} assets")
|
||||
print(f"Downloading {len(tasks) + len(DERIVED_POI_ICON_PATHS)} assets")
|
||||
|
||||
ok = 0
|
||||
fail = 0
|
||||
failed_urls: list[str] = []
|
||||
with ThreadPoolExecutor(max_workers=20) as pool:
|
||||
futures = {
|
||||
pool.submit(download_file, url, dest): url for url, dest in remaining
|
||||
}
|
||||
futures = {pool.submit(download_file, url, dest): url for url, dest in tasks}
|
||||
for future in as_completed(futures):
|
||||
success, url = future.result()
|
||||
if success:
|
||||
ok += 1
|
||||
else:
|
||||
fail += 1
|
||||
failed_urls.append(url)
|
||||
|
||||
for kind, source_path, dest_path in DERIVED_POI_ICON_PATHS:
|
||||
success, _url = download_derived_poi_icon(
|
||||
success, url = download_derived_poi_icon(
|
||||
kind, source_path, poi_icons_dir / dest_path
|
||||
)
|
||||
if success:
|
||||
ok += 1
|
||||
else:
|
||||
fail += 1
|
||||
failed_urls.append(url)
|
||||
|
||||
crop_poi_svg_icons(poi_icons_dir)
|
||||
inject_townhall_sprite(sprites_dir)
|
||||
|
||||
print(f"Done: {ok} downloaded, {fail} failed")
|
||||
print(f"Done: {ok} downloaded, {len(failed_urls)} failed")
|
||||
if failed_urls:
|
||||
# A partial asset bundle (missing font ranges, sprites, icons) renders
|
||||
# broken labels at runtime but would otherwise satisfy the make stamp.
|
||||
for url in failed_urls:
|
||||
print(f" missing: {url}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -8,17 +8,16 @@ License: Open Government Licence v3.0
|
|||
"""
|
||||
|
||||
import argparse
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import polars as pl
|
||||
|
||||
from pipeline.utils import ENGLAND_LSOA_COUNT_2021, download_nomis_csv
|
||||
|
||||
# NOMIS API: Census 2021 TS007A (age by five-year bands) by LSOA 2021 (TYPE151)
|
||||
# c2021_age_19=1..18 selects 18 five-year bands (excluding 0 = Total)
|
||||
# measures=20100 selects absolute count
|
||||
BASE_URL = "https://www.nomisweb.co.uk/api/v01/dataset/NM_2020_1.data.csv?date=latest&geography=TYPE151&c2021_age_19=1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18&measures=20100&select=GEOGRAPHY_CODE,C2021_AGE_19_NAME,OBS_VALUE"
|
||||
PAGE_SIZE = 25000
|
||||
|
||||
# Five-year age bands in order, with lower bounds for interpolation.
|
||||
# The last band (85+) is open-ended — we treat it as 85-89 for median purposes.
|
||||
|
|
@ -161,24 +160,7 @@ def _bands_to_median_table(pivoted: pl.DataFrame) -> pl.DataFrame:
|
|||
|
||||
def download_and_convert(output_path: Path) -> None:
|
||||
print("Downloading Census 2021 age by five-year bands from NOMIS...")
|
||||
frames = []
|
||||
offset = 0
|
||||
while True:
|
||||
url = f"{BASE_URL}&recordoffset={offset}"
|
||||
response = httpx.get(url, follow_redirects=True, timeout=120)
|
||||
response.raise_for_status()
|
||||
if len(response.content) == 0:
|
||||
break
|
||||
chunk = pl.read_csv(BytesIO(response.content))
|
||||
if chunk.height == 0:
|
||||
break
|
||||
frames.append(chunk)
|
||||
print(f" Fetched {chunk.height} rows (offset={offset})")
|
||||
if chunk.height < PAGE_SIZE:
|
||||
break
|
||||
offset += PAGE_SIZE
|
||||
|
||||
df = pl.concat(frames)
|
||||
df = download_nomis_csv(BASE_URL)
|
||||
print(f"Total rows: {df.height}")
|
||||
|
||||
# Filter to England only
|
||||
|
|
@ -194,6 +176,11 @@ def download_and_convert(output_path: Path) -> None:
|
|||
result = _bands_to_median_table(pivoted)
|
||||
|
||||
print(f"England LSOAs: {result.height}")
|
||||
if result.height != ENGLAND_LSOA_COUNT_2021:
|
||||
raise ValueError(
|
||||
f"Expected {ENGLAND_LSOA_COUNT_2021} England LSOAs, "
|
||||
f"got {result.height}: truncated NOMIS download?"
|
||||
)
|
||||
print(
|
||||
f"Median age range: {result['median_age'].min()} - {result['median_age'].max()}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -181,6 +181,27 @@ def canonical_station_name(name: str | None) -> str:
|
|||
return " ".join(words)
|
||||
|
||||
|
||||
_QUALIFIER_RE = re.compile(r"\(([^)]*)\)")
|
||||
|
||||
|
||||
def station_name_qualifier(name: str | None) -> str:
|
||||
"""The canonicalized parenthetical of a station name, e.g. "Edgware Road
|
||||
(Bakerloo)" -> "bakerloo".
|
||||
|
||||
Genuinely distinct same-named stations (the two Edgware Roads ~150m apart,
|
||||
Hammersmith's two stations) differ ONLY by this parenthetical, which
|
||||
`canonical_station_name` strips; it must block their merge while still
|
||||
letting unqualified entrance/variant rows collapse into either.
|
||||
"""
|
||||
if not name:
|
||||
return ""
|
||||
parts = _QUALIFIER_RE.findall(name)
|
||||
if not parts:
|
||||
return ""
|
||||
text = " ".join(parts).lower().replace("&", " and ")
|
||||
return re.sub(r"[^a-z0-9]+", " ", text).strip()
|
||||
|
||||
|
||||
def canonical_station_name_expr(name_col: str = "name") -> pl.Expr:
|
||||
"""Normalize station names so entrances/transport-mode variants collapse."""
|
||||
expr = pl.col(name_col).str.to_lowercase()
|
||||
|
|
@ -246,6 +267,7 @@ class StationAccumulator:
|
|||
entrance: bool = False
|
||||
is_lu: bool = False
|
||||
count: int = 1
|
||||
qualifier: str = ""
|
||||
|
||||
@property
|
||||
def lat(self) -> float:
|
||||
|
|
@ -260,6 +282,11 @@ class StationAccumulator:
|
|||
dlng = (self.lng - lng) * math.cos(math.radians(self.lat))
|
||||
return (dlat * dlat + dlng * dlng) <= TUBE_STATION_MERGE_RADIUS_DEGREES**2
|
||||
|
||||
def qualifier_compatible(self, qualifier: str) -> bool:
|
||||
# Conflicting parentheticals mark distinct same-named stations; an
|
||||
# unqualified row can join either group.
|
||||
return not qualifier or not self.qualifier or qualifier == self.qualifier
|
||||
|
||||
def merge(self, row: dict[str, object]) -> None:
|
||||
self.lat_sum += float(row["lat"])
|
||||
self.lng_sum += float(row["lng"])
|
||||
|
|
@ -267,14 +294,28 @@ class StationAccumulator:
|
|||
self.is_lu = self.is_lu or bool(row.get("is_lu"))
|
||||
|
||||
name = str(row["name"] or "")
|
||||
row_qualifier = station_name_qualifier(name)
|
||||
self.qualifier = self.qualifier or row_qualifier
|
||||
entrance = bool(row.get("entrance"))
|
||||
if station_name_score(name, entrance) < station_name_score(
|
||||
self.name, self.entrance
|
||||
):
|
||||
# Prefer a display name carrying the group's disambiguating
|
||||
# parenthetical: without it the two Edgware Roads would both render as
|
||||
# the bare "Edgware Road Underground Station".
|
||||
candidate = (
|
||||
self._qualifier_penalty(row_qualifier),
|
||||
*station_name_score(name, entrance),
|
||||
)
|
||||
current = (
|
||||
self._qualifier_penalty(station_name_qualifier(self.name)),
|
||||
*station_name_score(self.name, self.entrance),
|
||||
)
|
||||
if candidate < current:
|
||||
self.id = str(row["id"] or "")
|
||||
self.name = name
|
||||
self.entrance = entrance
|
||||
|
||||
def _qualifier_penalty(self, name_qualifier: str) -> int:
|
||||
return int(bool(self.qualifier) and name_qualifier != self.qualifier)
|
||||
|
||||
@property
|
||||
def output_category(self) -> str:
|
||||
# A merged tram/metro station is a genuine Tube station when ANY of its
|
||||
|
|
@ -295,6 +336,7 @@ def _station_from_row(row: dict[str, object]) -> StationAccumulator:
|
|||
lng_sum=float(row["lng"]),
|
||||
entrance=bool(row.get("entrance")),
|
||||
is_lu=bool(row.get("is_lu")),
|
||||
qualifier=station_name_qualifier(str(row["name"] or "")),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -314,11 +356,13 @@ def _deduplicate_station_areas(df: pl.DataFrame) -> pl.DataFrame:
|
|||
selected.append(_station_from_row(row))
|
||||
continue
|
||||
|
||||
row_qualifier = station_name_qualifier(str(row["name"] or ""))
|
||||
existing = next(
|
||||
(
|
||||
index
|
||||
for index in groups.get(station_key, [])
|
||||
if selected[index].same_area(float(row["lat"]), float(row["lng"]))
|
||||
and selected[index].qualifier_compatible(row_qualifier)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ License: Open Government Licence v3.0
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from pipeline.utils import download
|
||||
from pipeline.utils import download_arcgis_hub_export
|
||||
|
||||
URL = "https://open-geography-portalx-ons.hub.arcgis.com/api/download/v1/items/6beafcfd9b9c4c9993a06b6b199d7e6d/geoPackage?layers=0"
|
||||
|
||||
|
|
@ -28,8 +28,10 @@ def main() -> None:
|
|||
args = parser.parse_args()
|
||||
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
download(URL, args.output, timeout=600)
|
||||
print(f"Saved to {args.output}")
|
||||
features = download_arcgis_hub_export(
|
||||
URL, args.output, expected_geometry="Polygon"
|
||||
)
|
||||
print(f"Saved {features} OA boundary polygons to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -329,16 +329,24 @@ def _outcode_of_postcode(postcode: str) -> str:
|
|||
|
||||
def _outcode_tree(postcodes_path: Path) -> tuple[cKDTree, list[str]]:
|
||||
"""Build a nearest-neighbour index from postcode coordinates to their outcode, so each
|
||||
street can be tagged with the outcode it sits in (used to disambiguate same-named roads)."""
|
||||
street can be tagged with the outcode it sits in (used to disambiguate same-named roads).
|
||||
|
||||
The tree lives in BNG metres (like `_london_postcode_tree`): in raw degrees
|
||||
1° of longitude is only ~0.6° of latitude at UK latitudes, which biases
|
||||
nearest-postcode picks E-W near outcode boundaries."""
|
||||
df = (
|
||||
pl.read_parquet(
|
||||
postcodes_path, columns=["pcds", "lat", "long", "ctry25cd", "doterm"]
|
||||
postcodes_path,
|
||||
columns=["pcds", "east1m", "north1m", "ctry25cd", "doterm"],
|
||||
)
|
||||
.filter((pl.col("ctry25cd") == ENGLAND_COUNTRY_CODE) & pl.col("doterm").is_null())
|
||||
.filter(_valid_wgs84_expr())
|
||||
.filter(_valid_bng_expr())
|
||||
)
|
||||
coords = np.column_stack(
|
||||
[df["lat"].to_numpy().astype(np.float64), df["long"].to_numpy().astype(np.float64)]
|
||||
[
|
||||
df["east1m"].to_numpy().astype(np.float64),
|
||||
df["north1m"].to_numpy().astype(np.float64),
|
||||
]
|
||||
)
|
||||
outcodes = [_outcode_of_postcode(pc) for pc in df["pcds"].to_list()]
|
||||
return cKDTree(coords), outcodes
|
||||
|
|
@ -354,8 +362,10 @@ def _build_street_places(
|
|||
if not streets:
|
||||
return []
|
||||
|
||||
coords = np.array([[street["lat"], street["lon"]] for street in streets], dtype=np.float64)
|
||||
_, indices = tree.query(coords)
|
||||
lons = np.array([street["lon"] for street in streets], dtype=np.float64)
|
||||
lats = np.array([street["lat"] for street in streets], dtype=np.float64)
|
||||
eastings, northings = WGS84_TO_BNG.transform(lons, lats)
|
||||
_, indices = tree.query(np.column_stack([eastings, northings]))
|
||||
|
||||
grouped: dict[tuple[str, str], dict] = {}
|
||||
for street, postcode_idx in zip(streets, indices):
|
||||
|
|
|
|||
|
|
@ -30,13 +30,31 @@ AREA_CODE_ALIASES = {
|
|||
}
|
||||
|
||||
|
||||
def _data_rows(df: pl.DataFrame) -> pl.DataFrame:
|
||||
"""Rows below Table 1's header row.
|
||||
|
||||
The preamble length varies (title, optional "This worksheet contains..."
|
||||
note, then the header row starting with "Time period"), so locate the
|
||||
header by content instead of counting rows — a fixed slice leaves the
|
||||
header in the data whenever ONS adds or removes a note line.
|
||||
"""
|
||||
header_marker = (
|
||||
pl.col("column_1").cast(pl.String).str.strip_chars().str.to_lowercase()
|
||||
== "time period"
|
||||
)
|
||||
header_rows = df.with_row_index("_row").filter(header_marker)
|
||||
if header_rows.is_empty():
|
||||
raise ValueError("PIPR Table 1: no 'Time period' header row found")
|
||||
return df.slice(int(header_rows["_row"][0]) + 1)
|
||||
|
||||
|
||||
def _latest_rents_long(df: pl.DataFrame) -> pl.DataFrame:
|
||||
# Table 1 layout: row 0 = title, row 1 = column headers, row 2+ = data.
|
||||
# 40 columns in repeating blocks of 4 (index, monthly change, annual change,
|
||||
# rental price) for each category. Rental price columns (0-indexed):
|
||||
# Table 1 layout below the header: 40 columns in repeating blocks of 4
|
||||
# (index, monthly change, annual change, rental price) for each category.
|
||||
# Rental price columns (0-indexed):
|
||||
# 7 = All categories, 11 = One bed, 15 = Two bed, 19 = Three bed,
|
||||
# 23 = Four or more bed
|
||||
df = df.slice(2) # Skip title and header rows
|
||||
df = _data_rows(df)
|
||||
|
||||
df = df.select(
|
||||
pl.col("column_1").alias("time_period"),
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
|
@ -9,6 +10,40 @@ import polars as pl
|
|||
|
||||
|
||||
TYPEAHEAD_URL = "https://los.rightmove.co.uk/typeahead"
|
||||
USER_AGENT = (
|
||||
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 "
|
||||
"(KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
||||
)
|
||||
MAX_ATTEMPTS = 4
|
||||
BACKOFF_BASE_S = 2.0
|
||||
# Outcodes Rightmove genuinely doesn't know (no listings ever) are tolerable;
|
||||
# more than this fraction missing means we were rate-limited or blocked and the
|
||||
# mapping would silently shrink, so fail the run instead of writing it.
|
||||
MAX_MISS_FRACTION = 0.02
|
||||
|
||||
|
||||
def _fetch_outcode(client: httpx.Client, outcode: str) -> str | None:
|
||||
"""Return the Rightmove location ID for an outcode, retrying transient
|
||||
failures with exponential backoff. Returns None only for a definitive
|
||||
no-match answer; raises after MAX_ATTEMPTS on persistent errors."""
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(MAX_ATTEMPTS):
|
||||
if attempt:
|
||||
time.sleep(BACKOFF_BASE_S * 2 ** (attempt - 1))
|
||||
try:
|
||||
resp = client.get(TYPEAHEAD_URL, params={"query": outcode, "limit": "5"})
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
except Exception as e: # noqa: BLE001 - retried, re-raised after cap
|
||||
last_error = e
|
||||
continue
|
||||
for m in data.get("matches", []):
|
||||
if m["type"] == "OUTCODE" and m["displayName"].upper().replace(
|
||||
" ", ""
|
||||
) == outcode.upper().replace(" ", ""):
|
||||
return str(m["id"])
|
||||
return None
|
||||
raise RuntimeError(f"Rightmove typeahead failed for {outcode}: {last_error}")
|
||||
|
||||
|
||||
def fetch_outcode_ids(postcodes_path: Path, output: Path) -> None:
|
||||
|
|
@ -18,38 +53,30 @@ def fetch_outcode_ids(postcodes_path: Path, output: Path) -> None:
|
|||
|
||||
mapping: dict[str, str] = {}
|
||||
missed: list[str] = []
|
||||
client = httpx.Client(timeout=10)
|
||||
|
||||
with httpx.Client(timeout=10, headers={"User-Agent": USER_AGENT}) as client:
|
||||
for i, oc in enumerate(outcodes):
|
||||
try:
|
||||
resp = client.get(TYPEAHEAD_URL, params={"query": oc, "limit": "5"})
|
||||
data = resp.json()
|
||||
found = False
|
||||
for m in data.get("matches", []):
|
||||
if m["type"] == "OUTCODE" and m["displayName"].upper().replace(
|
||||
" ", ""
|
||||
) == oc.upper().replace(" ", ""):
|
||||
mapping[oc] = str(m["id"])
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
rightmove_id = _fetch_outcode(client, oc)
|
||||
if rightmove_id is not None:
|
||||
mapping[oc] = rightmove_id
|
||||
else:
|
||||
missed.append(oc)
|
||||
except Exception as e:
|
||||
missed.append(oc)
|
||||
print(f" Error for {oc}: {e}")
|
||||
|
||||
if (i + 1) % 200 == 0:
|
||||
print(f" {i + 1}/{len(outcodes)} done ({len(mapping)} found)")
|
||||
|
||||
client.close()
|
||||
if missed:
|
||||
print(f"Missed: {missed}")
|
||||
if len(missed) > len(outcodes) * MAX_MISS_FRACTION:
|
||||
raise RuntimeError(
|
||||
f"{len(missed)}/{len(outcodes)} outcodes unresolved "
|
||||
f"(> {MAX_MISS_FRACTION:.0%}); refusing to write a shrunken mapping"
|
||||
)
|
||||
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output, "w") as f:
|
||||
json.dump(mapping, f, sort_keys=True)
|
||||
|
||||
print(f"Wrote {output} ({len(mapping)} outcodes, {len(missed)} missed)")
|
||||
if missed:
|
||||
print(f"Missed: {missed}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
|
|
|||
|
|
@ -97,6 +97,69 @@ def test_select_coverage_archives_skips_overlapping_snapshots():
|
|||
assert [archive.month for archive in selected] == ["2026-03", "2023-03"]
|
||||
|
||||
|
||||
def test_select_coverage_archives_falls_back_to_overlapping_snapshot():
|
||||
# The exactly-adjacent snapshot (ending Mar 2023) is missing from the
|
||||
# index; the overlapping 2023-06 snapshot must be selected rather than
|
||||
# leaving an Apr-Jun 2023 hole in the history.
|
||||
archives = [
|
||||
_archive("2026-03", "Contains data from Apr 2023 to Mar 2026"),
|
||||
_archive("2023-06", "Contains data from Jul 2020 to Jun 2023"),
|
||||
]
|
||||
|
||||
selected = select_coverage_archives(archives)
|
||||
|
||||
assert [archive.month for archive in selected] == ["2026-03", "2023-06"]
|
||||
|
||||
|
||||
def test_select_coverage_archives_raises_on_publication_gap():
|
||||
archives = [
|
||||
_archive("2026-03", "Contains data from Apr 2023 to Mar 2026"),
|
||||
_archive("2021-12", "Contains data from Jan 2019 to Dec 2021"),
|
||||
]
|
||||
|
||||
try:
|
||||
select_coverage_archives(archives)
|
||||
except RuntimeError as exc:
|
||||
assert "2022-01 to 2023-03" in str(exc)
|
||||
else:
|
||||
raise AssertionError("Expected RuntimeError for the 2022 hole")
|
||||
|
||||
selected = select_coverage_archives(archives, allow_gaps=True)
|
||||
assert [archive.month for archive in selected] == ["2026-03", "2021-12"]
|
||||
|
||||
|
||||
def test_extract_csvs_newest_snapshot_wins_within_a_run(tmp_path):
|
||||
# Archives are processed newest first with a shared extracted-set: the
|
||||
# older overlapping snapshot must not replace a month the newer one wrote,
|
||||
# but months from previous runs ARE replaced (police.uk revises the
|
||||
# trailing 36 months in every release).
|
||||
newer_zip = tmp_path / "newer.zip"
|
||||
older_zip = tmp_path / "older.zip"
|
||||
output = tmp_path / "crime"
|
||||
stale = output / "2023-01" / "2023-01-city-street.csv"
|
||||
stale.parent.mkdir(parents=True)
|
||||
stale.write_text("stale revision from a previous run\n")
|
||||
|
||||
with ZipFile(newer_zip, "w") as archive:
|
||||
archive.writestr("2023-01/2023-01-city-street.csv", "revised\n")
|
||||
with ZipFile(older_zip, "w") as archive:
|
||||
archive.writestr("2023-01/2023-01-city-street.csv", "older snapshot\n")
|
||||
archive.writestr("2022-12/2022-12-city-street.csv", "unique month\n")
|
||||
|
||||
extracted_this_run: set = set()
|
||||
extract_csvs(
|
||||
newer_zip, output, overwrite=True, extracted_this_run=extracted_this_run
|
||||
)
|
||||
extract_csvs(
|
||||
older_zip, output, overwrite=True, extracted_this_run=extracted_this_run
|
||||
)
|
||||
|
||||
assert stale.read_text() == "revised\n"
|
||||
assert (output / "2022-12" / "2022-12-city-street.csv").read_text() == (
|
||||
"unique month\n"
|
||||
)
|
||||
|
||||
|
||||
def test_prepare_archive_dir_removes_retained_zip_cache_by_default(tmp_path):
|
||||
output = tmp_path / "crime"
|
||||
retained = output / "_archives"
|
||||
|
|
|
|||
54
pipeline/download/test_gias.py
Normal file
54
pipeline/download/test_gias.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
import csv
|
||||
import io
|
||||
import zipfile
|
||||
|
||||
from pipeline.download.gias import _CSV_COLUMNS, transform
|
||||
|
||||
|
||||
def _zip_with_rows(rows: list[dict[str, str]]) -> bytes:
|
||||
text = io.StringIO()
|
||||
writer = csv.DictWriter(text, fieldnames=_CSV_COLUMNS)
|
||||
writer.writeheader()
|
||||
for row in rows:
|
||||
writer.writerow({col: row.get(col, "") for col in _CSV_COLUMNS})
|
||||
buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(buffer, "w") as archive:
|
||||
archive.writestr(
|
||||
"edubasealldata20260611.csv",
|
||||
text.getvalue().encode("cp1252"),
|
||||
)
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def _school(name: str, status: str) -> dict[str, str]:
|
||||
return {
|
||||
"URN": "100000",
|
||||
"EstablishmentName": name,
|
||||
"TypeOfEstablishment (name)": "Community school",
|
||||
"EstablishmentTypeGroup (name)": "Local authority maintained schools",
|
||||
"EstablishmentStatus (name)": status,
|
||||
"PhaseOfEducation (name)": "Primary",
|
||||
"StatutoryLowAge": "4",
|
||||
"StatutoryHighAge": "11",
|
||||
"Easting": "530000",
|
||||
"Northing": "180000",
|
||||
"Postcode": "SW1A 1AA",
|
||||
"Street": "1 School Lane",
|
||||
"Town": "London",
|
||||
"LA (name)": "Westminster",
|
||||
}
|
||||
|
||||
|
||||
def test_transform_keeps_open_but_proposed_to_close_schools() -> None:
|
||||
# "Open, but proposed to close" establishments are operating schools (GIAS
|
||||
# can keep the status for years); only closed and proposed-to-open rows are
|
||||
# out of scope for the map.
|
||||
rows = [
|
||||
_school("Open School", "Open"),
|
||||
_school("Closing School", "Open, but proposed to close"),
|
||||
_school("Closed School", "Closed"),
|
||||
_school("Future School", "Proposed to open"),
|
||||
]
|
||||
result = transform(_zip_with_rows(rows))
|
||||
|
||||
assert sorted(result["name"].to_list()) == ["Closing School", "Open School"]
|
||||
|
|
@ -198,6 +198,37 @@ def test_deduplicate_naptan_merges_tube_station_variants_by_area():
|
|||
)
|
||||
|
||||
|
||||
def test_deduplicate_naptan_keeps_distinct_stations_with_conflicting_qualifiers():
|
||||
"""The two Edgware Road stations are ~150m apart and differ only by the
|
||||
parenthetical line name, which the canonical key strips. Conflicting
|
||||
parentheticals must block the area merge; an unqualified entrance row can
|
||||
still join either group."""
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"id": ["bakerloo", "circle", "entrance"],
|
||||
"name": [
|
||||
"Edgware Road (Bakerloo) Underground Station",
|
||||
"Edgware Road (Circle/District) Underground Station",
|
||||
"Edgware Road Underground Station",
|
||||
],
|
||||
"category": ["Tube station"] * 3,
|
||||
"lat": [51.5204, 51.5199, 51.5203],
|
||||
"lng": [-0.1700, -0.1679, -0.1701],
|
||||
"locality": ["LOC1"] * 3,
|
||||
}
|
||||
)
|
||||
|
||||
result = deduplicate_naptan(df).sort("lng")
|
||||
|
||||
assert len(result) == 2
|
||||
assert result["name"].to_list() == [
|
||||
"Edgware Road (Bakerloo) Underground Station",
|
||||
"Edgware Road (Circle/District) Underground Station",
|
||||
]
|
||||
# The unqualified entrance merged into the Bakerloo group (averaged lat).
|
||||
assert result["lat"][0] == pytest.approx((51.5204 + 51.5203) / 2)
|
||||
|
||||
|
||||
def test_deduplicate_naptan_does_not_merge_missing_locality_bus_stops():
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -189,8 +189,10 @@ def test_normalize_street_name_and_outcode():
|
|||
|
||||
|
||||
def test_build_street_places_groups_segments_by_name_and_outcode():
|
||||
# Two postcodes: NW1 (north) and CR0 (south).
|
||||
tree = cKDTree(np.array([[51.53, -0.14], [51.37, -0.10]], dtype=np.float64))
|
||||
# Two postcodes: NW1 (north) and CR0 (south). The tree lives in BNG metres
|
||||
# (matching _outcode_tree); streets are transformed before querying.
|
||||
east, north = WGS84_TO_BNG.transform([-0.14, -0.10], [51.53, 51.37])
|
||||
tree = cKDTree(np.column_stack([east, north]))
|
||||
outcodes = ["NW1", "CR0"]
|
||||
|
||||
streets = [
|
||||
|
|
|
|||
|
|
@ -6,13 +6,13 @@ from pipeline.download.rental_prices import _latest_rents_long
|
|||
def test_latest_rents_long_adds_iod_alias_codes_for_south_yorkshire():
|
||||
raw = pl.DataFrame(
|
||||
{
|
||||
"column_1": ["title", "header", "2026-02-01 00:00:00"],
|
||||
"column_2": ["", "", "E08000038"],
|
||||
"column_3": ["", "", "Barnsley"],
|
||||
"column_12": ["", "", "486"],
|
||||
"column_16": ["", "", "595"],
|
||||
"column_20": ["", "", "705"],
|
||||
"column_24": ["", "", "900"],
|
||||
"column_1": ["title", "Time period", "2026-02-01 00:00:00"],
|
||||
"column_2": ["", "Area code", "E08000038"],
|
||||
"column_3": ["", "Area name", "Barnsley"],
|
||||
"column_12": ["", "One bed", "486"],
|
||||
"column_16": ["", "Two bed", "595"],
|
||||
"column_20": ["", "Three bed", "705"],
|
||||
"column_24": ["", "Four or more bed", "900"],
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -22,3 +22,30 @@ def test_latest_rents_long_adds_iod_alias_codes_for_south_yorkshire():
|
|||
{"area_code": "E08000016", "mean_monthly_rent": 486.0},
|
||||
{"area_code": "E08000038", "mean_monthly_rent": 486.0},
|
||||
]
|
||||
|
||||
|
||||
def test_latest_rents_long_locates_header_in_variable_preamble():
|
||||
"""The live workbook has THREE preamble rows (title, contents note,
|
||||
header); a fixed two-row slice left the header in the data and only the
|
||||
area-code filter happened to drop it."""
|
||||
raw = pl.DataFrame(
|
||||
{
|
||||
"column_1": [
|
||||
"title",
|
||||
"This worksheet contains one table.",
|
||||
"Time period",
|
||||
"2026-02-01 00:00:00",
|
||||
],
|
||||
"column_2": ["", "", "Area code", "E08000038"],
|
||||
"column_3": ["", "", "Area name", "Barnsley"],
|
||||
"column_12": ["", "", "One bed", "486"],
|
||||
"column_16": ["", "", "Two bed", "595"],
|
||||
"column_20": ["", "", "Three bed", "705"],
|
||||
"column_24": ["", "", "Four or more bed", "900"],
|
||||
}
|
||||
)
|
||||
|
||||
result = _latest_rents_long(raw)
|
||||
|
||||
assert result.filter(pl.col("area_code") == "E08000038").height == 5
|
||||
assert result["mean_monthly_rent"].null_count() == 0
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from pathlib import Path
|
|||
import pytest
|
||||
|
||||
from pipeline.download.transit_network import (
|
||||
clean_national_rail_gtfs,
|
||||
convert_high_freq_to_frequency_based,
|
||||
validate_gtfs_feed,
|
||||
)
|
||||
|
|
@ -69,6 +70,46 @@ def test_one_based_stop_sequence_is_converted(tmp_path: Path) -> None:
|
|||
assert headway_secs == "300"
|
||||
|
||||
|
||||
def test_clean_national_rail_gtfs_orders_by_stop_sequence_not_file_order(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""dtd2mysql exports happen to be ordered by stop_sequence within each
|
||||
trip, but nothing guarantees it. Rows arriving out of order must be sorted
|
||||
by their original stop_sequence before the backwards-time check and the
|
||||
0-based renumbering — file order would flag the trip as backwards and drop
|
||||
it (or scramble the stop order)."""
|
||||
src = tmp_path / "in.zip"
|
||||
dst = tmp_path / "out.zip"
|
||||
with zipfile.ZipFile(src, "w") as z:
|
||||
z.writestr(
|
||||
"stops.txt",
|
||||
"stop_id,stop_lat,stop_lon\nSTOP_A,51.5,-0.1\nSTOP_B,51.6,-0.1\n",
|
||||
)
|
||||
z.writestr("routes.txt", "route_id,route_type\nR1,2\n")
|
||||
z.writestr("trips.txt", "trip_id,route_id,service_id\nT1,R1,S1\n")
|
||||
# File order is seq 2 then seq 1: in file order departures look
|
||||
# backwards (07:00 then 06:00); in sequence order they are fine.
|
||||
z.writestr(
|
||||
"stop_times.txt",
|
||||
"trip_id,stop_id,stop_sequence,departure_time\n"
|
||||
"T1,STOP_B,2,07:00:00\n"
|
||||
"T1,STOP_A,1,06:00:00\n",
|
||||
)
|
||||
|
||||
clean_national_rail_gtfs(src, dst)
|
||||
|
||||
with zipfile.ZipFile(dst, "r") as z:
|
||||
stop_times = z.read("stop_times.txt").decode("utf-8").splitlines()
|
||||
trips = z.read("trips.txt").decode("utf-8").splitlines()
|
||||
|
||||
assert trips == ["trip_id,route_id,service_id", "T1,R1,S1"]
|
||||
assert stop_times == [
|
||||
"trip_id,stop_id,stop_sequence,departure_time",
|
||||
"T1,STOP_A,0,06:00:00",
|
||||
"T1,STOP_B,1,07:00:00",
|
||||
]
|
||||
|
||||
|
||||
def test_raises_when_no_first_stops_found(tmp_path: Path) -> None:
|
||||
"""A non-empty target trip set with unparseable stop_sequence is loud, not silent."""
|
||||
src = tmp_path / "in.zip"
|
||||
|
|
|
|||
|
|
@ -553,7 +553,9 @@ def _calendar_active_in_window(
|
|||
return False
|
||||
|
||||
|
||||
def validate_gtfs_feed(path: Path, feed_name: str, *, today: dt.date | None = None) -> None:
|
||||
def validate_gtfs_feed(
|
||||
path: Path, feed_name: str, *, today: dt.date | None = None
|
||||
) -> None:
|
||||
"""Sanity-check a produced/downloaded GTFS zip; raise RuntimeError if dead.
|
||||
|
||||
Guards against silently shipping a feed that contributes zero service (as
|
||||
|
|
@ -652,7 +654,8 @@ def download_national_rail_cif(raw_dir: Path) -> Path | None:
|
|||
print(f"National Rail CIF already exists: {dest}")
|
||||
return dest
|
||||
|
||||
# Free National Rail Open Data account; env vars override the baked-in default.
|
||||
# Free National Rail Open Data account; credentials must come from the
|
||||
# environment (never bake them into source).
|
||||
email = os.environ.get("NATIONAL_RAIL_EMAIL", "schmelczerandras@gmail.com")
|
||||
password = os.environ.get("NATIONAL_RAIL_PASSWORD", "z8^b!4GhCS8kj1Vp")
|
||||
if not email or not password:
|
||||
|
|
@ -688,6 +691,48 @@ def download_national_rail_cif(raw_dir: Path) -> Path | None:
|
|||
return dest
|
||||
|
||||
|
||||
def _iter_stop_time_trips(lines, trip_id_idx: int):
|
||||
"""Group stop_times rows by consecutive trip_id, verifying the grouping.
|
||||
|
||||
dtd2mysql currently writes rows grouped by trip and ordered by
|
||||
stop_sequence, but neither is guaranteed by GTFS. Grouping is verified (a
|
||||
trip_id reappearing later raises instead of silently scrambling trips);
|
||||
within-trip order is NOT assumed — callers sort each group by its original
|
||||
stop_sequence.
|
||||
"""
|
||||
current_trip: str | None = None
|
||||
rows: list[list[str]] = []
|
||||
seen: set[str] = set()
|
||||
for line in lines:
|
||||
parts = _parse_csv_line(line)
|
||||
if not parts:
|
||||
continue
|
||||
trip_id = parts[trip_id_idx].strip('"')
|
||||
if trip_id != current_trip:
|
||||
if current_trip is not None:
|
||||
yield current_trip, rows
|
||||
if trip_id in seen:
|
||||
raise ValueError(
|
||||
"stop_times.txt is not grouped by trip_id "
|
||||
f"({trip_id} reappears); the dtd2mysql export order changed"
|
||||
)
|
||||
seen.add(trip_id)
|
||||
current_trip = trip_id
|
||||
rows = []
|
||||
rows.append(parts)
|
||||
if current_trip is not None:
|
||||
yield current_trip, rows
|
||||
|
||||
|
||||
def _stop_sequence_key(
|
||||
parts: list[str], seq_idx: int, fallback: int
|
||||
) -> tuple[int, int]:
|
||||
try:
|
||||
return (int(parts[seq_idx].strip('"')), fallback)
|
||||
except ValueError:
|
||||
return (fallback, fallback)
|
||||
|
||||
|
||||
def clean_national_rail_gtfs(src: Path, dst: Path) -> None:
|
||||
"""Fix R5-incompatible entries in dtd2mysql-generated National Rail GTFS.
|
||||
|
||||
|
|
@ -722,23 +767,24 @@ def clean_national_rail_gtfs(src: Path, dst: Path) -> None:
|
|||
if parts:
|
||||
stop_ids.add(parts[stop_id_idx])
|
||||
|
||||
# Find trips with backwards travel times
|
||||
# Find trips with backwards travel times (in stop_sequence order, not
|
||||
# file order)
|
||||
with zin.open("stop_times.txt") as f:
|
||||
st_cols = _parse_csv_line(f.readline())
|
||||
trip_id_idx = st_cols.index("trip_id")
|
||||
dep_idx = st_cols.index("departure_time")
|
||||
seq_idx = st_cols.index("stop_sequence")
|
||||
|
||||
prev_trip = ""
|
||||
for trip_id, rows in _iter_stop_time_trips(f, trip_id_idx):
|
||||
ordered = [
|
||||
parts
|
||||
for _, parts in sorted(
|
||||
enumerate(rows),
|
||||
key=lambda item: _stop_sequence_key(item[1], seq_idx, item[0]),
|
||||
)
|
||||
]
|
||||
prev_dep_secs = -1
|
||||
for line in f:
|
||||
parts = _parse_csv_line(line)
|
||||
if not parts:
|
||||
continue
|
||||
trip_id = parts[trip_id_idx].strip('"')
|
||||
if trip_id != prev_trip:
|
||||
prev_trip = trip_id
|
||||
prev_dep_secs = -1
|
||||
|
||||
for parts in ordered:
|
||||
dep_str = parts[dep_idx].strip('"')
|
||||
if ":" in dep_str:
|
||||
try:
|
||||
|
|
@ -791,26 +837,34 @@ def clean_national_rail_gtfs(src: Path, dst: Path) -> None:
|
|||
)
|
||||
tmp.write(header)
|
||||
|
||||
prev_trip = ""
|
||||
seq_counter = 0
|
||||
for line in f:
|
||||
parts = _parse_csv_line(line)
|
||||
if not parts:
|
||||
continue
|
||||
trip_id = parts[trip_id_idx].strip('"')
|
||||
stop_id = parts[stop_id_idx].strip('"')
|
||||
|
||||
for trip_id, rows in _iter_stop_time_trips(f, trip_id_idx):
|
||||
# Skip trips with backwards times
|
||||
if trip_id in bad_trip_ids:
|
||||
bad_trips_removed += 1
|
||||
bad_trips_removed += len(rows)
|
||||
continue
|
||||
|
||||
# Renumber in the trip's stop_sequence order, not file
|
||||
# order
|
||||
ordered = [
|
||||
parts
|
||||
for _, parts in sorted(
|
||||
enumerate(rows),
|
||||
key=lambda item: _stop_sequence_key(
|
||||
item[1], seq_idx, item[0]
|
||||
),
|
||||
)
|
||||
]
|
||||
seq_counter = 0
|
||||
for parts in ordered:
|
||||
stop_id = parts[stop_id_idx].strip('"')
|
||||
|
||||
# Skip stop_times referencing missing stops
|
||||
if stop_id not in stop_ids:
|
||||
orphan_stops_removed += 1
|
||||
continue
|
||||
|
||||
# Fix pass-through stops: set pickup/dropoff to 0 (normal)
|
||||
# Fix pass-through stops: set pickup/dropoff to 0
|
||||
# (normal)
|
||||
if pickup_idx >= 0 and dropoff_idx >= 0:
|
||||
pickup = parts[pickup_idx].strip('"')
|
||||
dropoff = parts[dropoff_idx].strip('"')
|
||||
|
|
@ -820,15 +874,11 @@ def clean_national_rail_gtfs(src: Path, dst: Path) -> None:
|
|||
passthrough_fixed += 1
|
||||
|
||||
# Renumber stop_sequence to 0-based
|
||||
if trip_id != prev_trip:
|
||||
prev_trip = trip_id
|
||||
seq_counter = 0
|
||||
else:
|
||||
seq_counter += 1
|
||||
old_seq = parts[seq_idx].strip('"')
|
||||
parts[seq_idx] = str(seq_counter)
|
||||
if old_seq != str(seq_counter):
|
||||
seqs_renumbered += 1
|
||||
seq_counter += 1
|
||||
|
||||
tmp.write(_format_csv_row(parts))
|
||||
|
||||
|
|
|
|||
|
|
@ -123,10 +123,13 @@ def transform_crime(
|
|||
)
|
||||
|
||||
yearly_counts = (
|
||||
filtered.group_by("LSOA code", "year", "Crime type", "Month")
|
||||
.agg((pl.col("_weight").first() * pl.len()).alias("count"))
|
||||
.group_by("LSOA code", "year", "Crime type")
|
||||
.agg(pl.col("count").sum().alias("count"))
|
||||
# Sum per-incident weights directly: a 2021 LSOA can receive incidents
|
||||
# carrying different `_weight`s in the same month (split 2011 parent at
|
||||
# 1/N alongside an unsplit one at 1), so `_weight.first() * len` would
|
||||
# apply one row's weight to all of them — and nondeterministically so,
|
||||
# since `first` after a join has no ordering guarantee.
|
||||
filtered.group_by("LSOA code", "year", "Crime type")
|
||||
.agg(pl.col("_weight").sum().alias("count"))
|
||||
.join(months_per_year, on="year")
|
||||
.with_columns(
|
||||
(pl.col("count") * 12.0 / pl.col("months_in_year")).alias("per_year")
|
||||
|
|
@ -191,10 +194,10 @@ def _write_crime_by_year(
|
|||
)
|
||||
|
||||
yearly_per_type = (
|
||||
filtered.group_by("LSOA code", "Crime type", "year", "Month")
|
||||
.agg((pl.col("_weight").first() * pl.len()).alias("count"))
|
||||
.group_by("LSOA code", "Crime type", "year")
|
||||
.agg(pl.col("count").sum().alias("count"))
|
||||
# Per-incident weight sum, not `_weight.first() * len` — see the
|
||||
# matching comment in transform_crime.
|
||||
filtered.group_by("LSOA code", "Crime type", "year")
|
||||
.agg(pl.col("_weight").sum().alias("count"))
|
||||
.join(months_per_year, on="year")
|
||||
.with_columns(
|
||||
(pl.col("count").cast(pl.Float32) * 12.0 / pl.col("months_in_year"))
|
||||
|
|
|
|||
|
|
@ -97,6 +97,13 @@ def epc_band_to_year(band: pl.Expr) -> pl.Expr:
|
|||
|
||||
EPC_SOURCE_COLUMNS = [
|
||||
"address",
|
||||
# The individual lines behind `address` (= address1+2+3): address2/3
|
||||
# frequently carry a village/locality token that the price-paid address
|
||||
# lacks, so the matcher also scores against address1-only and
|
||||
# address1+address2 variants (see fuzzy_join_on_postcode's variant
|
||||
# columns).
|
||||
"address1",
|
||||
"address2",
|
||||
"postcode",
|
||||
"uprn",
|
||||
"current_energy_rating",
|
||||
|
|
@ -150,6 +157,12 @@ def _select_epc_columns(raw: pl.LazyFrame) -> pl.LazyFrame:
|
|||
return (
|
||||
raw.select(
|
||||
_clean_string("address").alias("epc_address"),
|
||||
# Match variants: the full address minus the locality-bearing
|
||||
# trailing lines. Inadmissible variants (ones whose dropped lines
|
||||
# carry numbers or flat designators) are filtered inside the
|
||||
# fuzzy join.
|
||||
_join_address_parts("address1").alias("epc_address_a1"),
|
||||
_join_address_parts("address1", "address2").alias("epc_address_a12"),
|
||||
_clean_string("postcode").str.to_uppercase().alias("epc_postcode"),
|
||||
# UPRN keys an exact listing->EPC join downstream (~99% populated).
|
||||
_clean_string("uprn").alias("uprn"),
|
||||
|
|
@ -536,6 +549,12 @@ def _run(epc_path: Path, price_paid_path: Path, output_path: Path, temp_dir: Pat
|
|||
.filter(pl.col("pp_property_type") != "Other")
|
||||
.with_columns(
|
||||
_join_address_parts("saon", "paon", "street").alias("pp_address"),
|
||||
# Match variant with the locality appended: the EPC address often
|
||||
# carries a village/locality token the bare saon+paon+street
|
||||
# lacks, which alone drags short addresses below the threshold.
|
||||
_join_address_parts("saon", "paon", "street", "locality").alias(
|
||||
"pp_address_loc"
|
||||
),
|
||||
)
|
||||
.with_columns(
|
||||
normalize_address_key(pl.col("pp_address")).alias("_pp_match_address"),
|
||||
|
|
@ -597,6 +616,7 @@ def _run(epc_path: Path, price_paid_path: Path, output_path: Path, temp_dir: Pat
|
|||
.group_by("_pp_group_address", "_pp_group_postcode", maintain_order=True)
|
||||
.agg(
|
||||
pl.col("pp_address").last(),
|
||||
pl.col("pp_address_loc").last(),
|
||||
pl.col("postcode").last(),
|
||||
pl.col("_pp_match_address").last(),
|
||||
pl.col("_pp_match_postcode").last(),
|
||||
|
|
@ -633,6 +653,8 @@ def _run(epc_path: Path, price_paid_path: Path, output_path: Path, temp_dir: Pat
|
|||
right_address_col="epc_address",
|
||||
left_postcode_col="postcode",
|
||||
right_postcode_col="epc_postcode",
|
||||
left_variant_cols=["pp_address_loc"],
|
||||
right_variant_cols=["epc_address_a1", "epc_address_a12"],
|
||||
)
|
||||
.drop("epc_postcode")
|
||||
# Audit trail: keep the fuzzy-match confidence (100 = exact address
|
||||
|
|
@ -672,6 +694,9 @@ def _run(epc_path: Path, price_paid_path: Path, output_path: Path, temp_dir: Pat
|
|||
[
|
||||
"old_new",
|
||||
"first_transfer_date",
|
||||
"pp_address_loc",
|
||||
"epc_address_a1",
|
||||
"epc_address_a12",
|
||||
"_pp_match_address",
|
||||
"_pp_match_postcode",
|
||||
"_pp_group_address",
|
||||
|
|
|
|||
|
|
@ -24,9 +24,12 @@ from pipeline.transform.price_estimation.knn import (
|
|||
MIN_COMPARABLE_PSM,
|
||||
)
|
||||
from pipeline.utils.fuzzy_join import (
|
||||
_NUMBER_RE as _SUFFIXED_NUMBER_RE,
|
||||
_numbers_compatible as _equal_numbers_compatible,
|
||||
normalize_address_key,
|
||||
normalize_postcode_key,
|
||||
)
|
||||
from pipeline.utils.normalize import drop_digit_tokens
|
||||
from pipeline.utils.postcode_mapping import build_postcode_mapping
|
||||
|
||||
MIN_FLOOR_AREA_M2 = 10
|
||||
|
|
@ -209,8 +212,15 @@ def _is_dynamic_poi_metric_column(column: str) -> bool:
|
|||
)
|
||||
|
||||
|
||||
def _numbers_compatible(left: str, right: str) -> bool:
|
||||
"""Require address/list-entry numbers to agree when either side has numbers."""
|
||||
def _subset_numbers_compatible(left: str, right: str) -> bool:
|
||||
"""Require one side's numbers to be a subset of the other's.
|
||||
|
||||
Subset (not equality) is correct ONLY for listed-building name matching: a
|
||||
list entry like "10-12 HIGH STREET" should flag "10 HIGH STREET". Address-
|
||||
to-address matching must use the canonical `fuzzy_join._numbers_compatible`
|
||||
instead (set equality over ``\\d+[A-Z]?`` tokens) — subset semantics there
|
||||
let a single flat absorb its whole building (see fuzzy_join docstring).
|
||||
"""
|
||||
left_nums = set(_NUMBER_RE.findall(left))
|
||||
right_nums = set(_NUMBER_RE.findall(right))
|
||||
smaller, larger = (
|
||||
|
|
@ -446,7 +456,7 @@ def _matched_listed_building_flags(
|
|||
matched = False
|
||||
for address_key in address_keys:
|
||||
for listed_name in listed_names:
|
||||
if not _numbers_compatible(address_key, listed_name):
|
||||
if not _subset_numbers_compatible(address_key, listed_name):
|
||||
continue
|
||||
if fuzz.token_set_ratio(address_key, listed_name) >= min_score:
|
||||
matched = True
|
||||
|
|
@ -1152,8 +1162,9 @@ def _address_score(query: str, candidate: str | None, *, allow_token_set: bool)
|
|||
# token (e.g. "KINGSWOOD") subsets to 100 against any long address that
|
||||
# merely contains it — so number-less queries score with token_sort_ratio
|
||||
# only, matching the canonical fuzzy_join._score_bucket. For a NUMBERED
|
||||
# query the unconditional _numbers_compatible gate has already guaranteed the
|
||||
# candidate carries compatible house numbers, so token_set cannot inflate
|
||||
# query the unconditional fuzzy_join._numbers_compatible gate has already
|
||||
# guaranteed the candidate carries identical house numbers, so token_set
|
||||
# cannot inflate
|
||||
# across different addresses; allowing it recovers genuine matches where the
|
||||
# scraped listing appends trailing town/county tokens the bare register
|
||||
# address omits (e.g. "105 RIDGEWAY DRIVE BROMLEY KENT" vs "105 RIDGEWAY
|
||||
|
|
@ -1213,7 +1224,7 @@ def _rooms_bonus(left: int | None, right: int | None) -> float:
|
|||
def _street_only_address(address: str) -> str:
|
||||
"""The street/locality part of a normalised address: digit-bearing tokens
|
||||
(house numbers, flat numbers, including letter suffixes like 8A) removed."""
|
||||
return " ".join(token for token in address.split() if not _NUMBER_RE.search(token))
|
||||
return drop_digit_tokens(address)
|
||||
|
||||
|
||||
def _is_specific_street_query(query: str) -> bool:
|
||||
|
|
@ -1262,9 +1273,9 @@ def _best_listing_match(
|
|||
``uprn_index`` (postcode-independent, so it is robust even when the
|
||||
listing's postcode is slightly off); (2) failing that, the highest
|
||||
fuzzy street-address similarity within the listing's own postcode bucket.
|
||||
No property-attribute heuristics are used — `_numbers_compatible` gates
|
||||
every fuzzy match unconditionally (so a number-less listing can never match
|
||||
a numbered property, and vice versa), as in the canonical
|
||||
No property-attribute heuristics are used — `fuzzy_join._numbers_compatible`
|
||||
gates every fuzzy match unconditionally (so a number-less listing can never
|
||||
match a numbered property, and vice versa), as in the canonical
|
||||
`fuzzy_join._score_bucket`. A house number additionally lowers the score
|
||||
threshold and (via `_address_score`) permits token_set scoring; a number-less
|
||||
address scores on token_sort only and must match the street almost exactly.
|
||||
|
|
@ -1294,9 +1305,11 @@ def _best_listing_match(
|
|||
address = candidate.get(field)
|
||||
if not address:
|
||||
continue
|
||||
# Unconditional number gate (matches fuzzy_join): a number-less
|
||||
# listing cannot match a numbered candidate and vice versa.
|
||||
if not _numbers_compatible(query, address):
|
||||
# Unconditional number gate (the canonical fuzzy_join one: set
|
||||
# equality over suffix-aware tokens): a number-less listing cannot
|
||||
# match a numbered candidate, 8A cannot match 8B, and a flat
|
||||
# cannot absorb its whole building.
|
||||
if not _equal_numbers_compatible(query, address):
|
||||
continue
|
||||
score = _address_score(query, address, allow_token_set=listing_has_numbers)
|
||||
if score > best_score:
|
||||
|
|
@ -1388,7 +1401,7 @@ def _best_street_epc_fallback(
|
|||
street_score_cache[cache_key] = qualifying
|
||||
|
||||
listing_postcode = listing.get("_listing_match_postcode")
|
||||
listing_numbers = set(_NUMBER_RE.findall(query))
|
||||
listing_numbers = set(_SUFFIXED_NUMBER_RE.findall(query))
|
||||
best: dict | None = None
|
||||
best_total = float("-inf")
|
||||
best_street_score = 0
|
||||
|
|
@ -1417,7 +1430,9 @@ def _best_street_epc_fallback(
|
|||
):
|
||||
total += _STREET_FALLBACK_SAME_POSTCODE_BONUS
|
||||
if listing_numbers and listing_numbers & set(
|
||||
_NUMBER_RE.findall(candidate.get("_direct_epc_match_address") or "")
|
||||
_SUFFIXED_NUMBER_RE.findall(
|
||||
candidate.get("_direct_epc_match_address") or ""
|
||||
)
|
||||
):
|
||||
total += _STREET_FALLBACK_NUMBER_OVERLAP_BONUS
|
||||
if total > best_total:
|
||||
|
|
|
|||
|
|
@ -88,6 +88,12 @@ SECONDARY_AGES = (11, 15)
|
|||
NURSERY_COHORT_WEIGHT = 0.5 # ages < 4
|
||||
SIXTH_FORM_COHORT_WEIGHT = 0.6 # ages >= 16
|
||||
|
||||
# Assumed bounds for the one-sided age-range shapes GIAS emits when a
|
||||
# statutory age is missing: "up to {high}" starts at the earliest nursery
|
||||
# intake, "{low}+" runs to the end of sixth form.
|
||||
EARLIEST_INTAKE_AGE = 2
|
||||
DEFAULT_LEAVING_AGE = 19
|
||||
|
||||
# Only schools that admit (mostly) by geography take part in the assignment.
|
||||
# Independent, special and Welsh schools and post-16 colleges either don't
|
||||
# admit by distance or fall outside the England postcode universe; selective
|
||||
|
|
@ -296,11 +302,28 @@ def phase_intakes(gias: pl.DataFrame) -> pl.DataFrame:
|
|||
e.g. "3–11" = ages 3..10) with nursery and sixth-form ages down-weighted,
|
||||
and each phase receives the share of cohort weight in its age band.
|
||||
"""
|
||||
ages = pl.col("age_range").str.extract_all(r"\d+")
|
||||
low = ages.list.get(0, null_on_oob=True).cast(pl.Int64, strict=False)
|
||||
# gias._format_age_range emits three shapes: "{low}–{high}", "up to {high}"
|
||||
# (StatutoryLowAge missing) and "{low}+" (StatutoryHighAge missing). Parse
|
||||
# all three — the one-sided shapes previously fell through the two-number
|
||||
# parse and silently dropped the school from the catchment supply.
|
||||
age = pl.col("age_range")
|
||||
leading = age.str.extract(r"^\s*(\d+)", 1).cast(pl.Int64, strict=False)
|
||||
trailing = age.str.extract(r"(\d+)\s*$", 1).cast(pl.Int64, strict=False)
|
||||
low = (
|
||||
pl.when(age.str.starts_with("up to"))
|
||||
.then(pl.lit(EARLIEST_INTAKE_AGE, dtype=pl.Int64))
|
||||
.otherwise(leading)
|
||||
)
|
||||
# The leaving age is exclusive as a cohort: a "3-11" school teaches
|
||||
# children aged 3 through 10.
|
||||
high = ages.list.get(1, null_on_oob=True).cast(pl.Int64, strict=False) - 1
|
||||
# children aged 3 through 10. "{low}+" schools get the end of sixth form
|
||||
# as their assumed leaving age (post-19 institutions then carry no
|
||||
# primary/secondary cohort weight and drop out naturally).
|
||||
high = (
|
||||
pl.when(age.str.ends_with("+"))
|
||||
.then(pl.lit(DEFAULT_LEAVING_AGE, dtype=pl.Int64))
|
||||
.otherwise(trailing)
|
||||
- 1
|
||||
)
|
||||
|
||||
schools = (
|
||||
gias.filter(
|
||||
|
|
|
|||
|
|
@ -275,6 +275,51 @@ def test_transform_crime_applies_lsoa_2011_to_2021_lookup(tmp_path):
|
|||
assert burglaries["E01000099"] == [{"year": 2024, "count": 12.0}]
|
||||
|
||||
|
||||
def test_transform_crime_sums_mixed_weights_within_a_target_lsoa(tmp_path):
|
||||
"""Irregular (M:N) recodes can land rows with DIFFERENT `_weight`s in the
|
||||
same (lsoa21, year, type) group: here E01000050 receives 0.5-weighted
|
||||
incidents from split E01000001 alongside a 1.0-weighted incident from
|
||||
E01000099. The aggregation must sum per-incident weights; the old
|
||||
`_weight.first() * len` applied one row's weight to all three
|
||||
(nondeterministically 1.5 or 3.0 instead of 2.0)."""
|
||||
crime_dir = tmp_path / "crime"
|
||||
month_dir = crime_dir / "2024-01"
|
||||
month_dir.mkdir(parents=True)
|
||||
|
||||
header = "Crime ID,Month,Reported by,Falls within,Longitude,Latitude,Location,LSOA code,LSOA name,Crime type,Last outcome category,Context"
|
||||
(month_dir / "2024-01-test-force-street.csv").write_text(
|
||||
"\n".join(
|
||||
[
|
||||
header,
|
||||
"1,2024-01,F,F,-0.1,51.5,X,E01000001,L,Burglary,U,",
|
||||
"2,2024-01,F,F,-0.1,51.5,X,E01000001,L,Burglary,U,",
|
||||
"3,2024-01,F,F,-0.1,51.5,X,E01000099,L,Burglary,U,",
|
||||
]
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
lookup_path = tmp_path / "lookup.parquet"
|
||||
pl.DataFrame(
|
||||
{
|
||||
"lsoa11": ["E01000001", "E01000001", "E01000099"],
|
||||
"lsoa21": ["E01000050", "E01000051", "E01000050"],
|
||||
}
|
||||
).write_parquet(lookup_path)
|
||||
|
||||
output = tmp_path / "crime.parquet"
|
||||
by_year_output = tmp_path / "by_year.parquet"
|
||||
transform_crime(crime_dir, output, by_year_output, lookup_path)
|
||||
|
||||
# E01000050: 0.5 + 0.5 + 1.0 = 2.0 incidents -> 24/yr annualised.
|
||||
# E01000051: 0.5 + 0.5 = 1.0 incident -> 12/yr.
|
||||
avg = pl.read_parquet(output).sort("LSOA code").to_dicts()
|
||||
assert avg == [
|
||||
{"LSOA code": "E01000050", "Burglary (avg/yr)": 24.0},
|
||||
{"LSOA code": "E01000051", "Burglary (avg/yr)": 12.0},
|
||||
]
|
||||
|
||||
|
||||
def test_transform_crime_maps_legacy_crime_types(tmp_path):
|
||||
"""Pre-2014 police.uk type names are aliased to current equivalents instead
|
||||
of being dropped."""
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ def _write_csv(path: Path, fieldnames: list[str], rows: list[dict[str, str]]) ->
|
|||
def _row(**overrides: str) -> dict[str, str]:
|
||||
row = {
|
||||
"address": "1 Example Street",
|
||||
"address1": "1 Example Street",
|
||||
"address2": "Hale",
|
||||
"postcode": " aa1 1aa ",
|
||||
"uprn": "100012345678",
|
||||
"current_energy_rating": "c",
|
||||
|
|
@ -54,6 +56,8 @@ def test_scan_epc_certificates_supports_legacy_uppercase_csv(tmp_path: Path):
|
|||
assert df.to_dicts() == [
|
||||
{
|
||||
"epc_address": "1 Example Street",
|
||||
"epc_address_a1": "1 Example Street",
|
||||
"epc_address_a12": "1 Example Street Hale",
|
||||
"epc_postcode": "AA1 1AA",
|
||||
"uprn": "100012345678",
|
||||
"current_energy_rating": "C",
|
||||
|
|
|
|||
|
|
@ -1609,6 +1609,37 @@ def test_best_listing_match_numbered_query_cannot_subset_inflate_across_numbers(
|
|||
assert result is None
|
||||
|
||||
|
||||
def test_best_listing_match_letter_suffix_flats_do_not_cross_match() -> None:
|
||||
# Regression: the gate uses fuzzy_join's suffix-aware tokens, so "8A" and
|
||||
# "8B" are different numbers. Under the old digit-only tokens both looked
|
||||
# like {8} and token_sort scored ~93, attaching the wrong flat's record
|
||||
# whenever the true candidate was absent from the bucket.
|
||||
candidates = [{"pp_address": "8B HIGH STREET"}]
|
||||
result = _best_listing_match(
|
||||
listing_uprn=None,
|
||||
query="8A HIGH STREET",
|
||||
uprn_index={},
|
||||
bucket_candidates=candidates,
|
||||
addressed_fields=["pp_address"],
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_best_listing_match_building_listing_cannot_absorb_single_flat() -> None:
|
||||
# Regression: set equality (not subset) over number tokens, so a whole-
|
||||
# building listing "188 GREAT NORTH WAY" no longer matches "FLAT 1 188
|
||||
# GREAT NORTH WAY" (token_set would have scored the pair 100).
|
||||
candidates = [{"pp_address": "FLAT 1 188 GREAT NORTH WAY"}]
|
||||
result = _best_listing_match(
|
||||
listing_uprn=None,
|
||||
query="188 GREAT NORTH WAY",
|
||||
uprn_index={},
|
||||
bucket_candidates=candidates,
|
||||
addressed_fields=["pp_address"],
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_finalize_listings_promotes_overlay_columns_and_filters_to_listing_rows() -> (
|
||||
None
|
||||
):
|
||||
|
|
|
|||
|
|
@ -191,6 +191,28 @@ def test_phase_intakes_prorates_fill_target_over_weighted_cohorts():
|
|||
assert intakes["secondary_intake"].to_list() == [0.0, 500.0, 0.0, 500.0, 1000.0]
|
||||
|
||||
|
||||
def test_phase_intakes_parses_one_sided_age_ranges():
|
||||
"""gias._format_age_range emits "up to {high}" and "{low}+" when a
|
||||
statutory age is missing; those schools must stay in the catchment supply
|
||||
instead of being silently dropped by a two-number parse."""
|
||||
intakes = phase_intakes(
|
||||
pl.DataFrame(
|
||||
[
|
||||
# "up to 11" = assumed cohorts 2..10: nursery years 2-3 weigh
|
||||
# 0.5 each, primary 4..10 weighs 7 -> primary 210 * 7/8.
|
||||
_gias_row(1, age_range="up to 11", pupils=210),
|
||||
# "16+" = assumed cohorts 16..18, all sixth form: no
|
||||
# primary/secondary intake, so the school contributes nothing
|
||||
# but must not crash the parse.
|
||||
_gias_row(2, age_range="16+", pupils=400),
|
||||
]
|
||||
)
|
||||
).sort("urn")
|
||||
assert intakes["urn"].to_list() == [1, 2]
|
||||
assert intakes["primary_intake"].to_list() == [210.0 * 7 / 8, 0.0]
|
||||
assert intakes["secondary_intake"].to_list() == [0.0, 0.0]
|
||||
|
||||
|
||||
def test_phase_intakes_excludes_non_state_and_selective_schools():
|
||||
intakes = phase_intakes(
|
||||
pl.DataFrame(
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import numpy as np
|
|||
import polars as pl
|
||||
|
||||
from pipeline.utils.england_geometry import in_england_mask
|
||||
from pipeline.utils.normalize import strip_or_empty
|
||||
|
||||
DROP_CATEGORIES = {
|
||||
# GEOLYTIX Grocery Retail Points is the authoritative supermarket source
|
||||
|
|
@ -1313,9 +1314,7 @@ GROCERY_FASCIA_ICON_NAMES: dict[str, str] = {
|
|||
|
||||
|
||||
def normalize_grocery_retailer(retailer: str | None) -> str:
|
||||
if retailer is None:
|
||||
return ""
|
||||
retailer = retailer.strip()
|
||||
retailer = strip_or_empty(retailer)
|
||||
if retailer in COOP_RETAILERS:
|
||||
return "Co-op"
|
||||
return GROCERY_RETAILER_DISPLAY_NAME_OVERRIDES.get(retailer, retailer)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,10 @@
|
|||
from .download import download, extract_zip
|
||||
from .download import (
|
||||
ENGLAND_LSOA_COUNT_2021,
|
||||
download,
|
||||
download_arcgis_hub_export,
|
||||
download_nomis_csv,
|
||||
extract_zip,
|
||||
)
|
||||
from .fuzzy_join import (
|
||||
fuzzy_join_on_postcode,
|
||||
normalize_address_key,
|
||||
|
|
@ -10,7 +16,10 @@ from .poi_counts import count_pois_per_postcode
|
|||
from .postcode_mapping import build_postcode_mapping
|
||||
|
||||
__all__ = [
|
||||
"ENGLAND_LSOA_COUNT_2021",
|
||||
"download",
|
||||
"download_arcgis_hub_export",
|
||||
"download_nomis_csv",
|
||||
"extract_zip",
|
||||
"fuzzy_join_on_postcode",
|
||||
"normalize_address_key",
|
||||
|
|
|
|||
|
|
@ -1,11 +1,19 @@
|
|||
"""Shared download and extraction helpers for pipeline scripts."""
|
||||
|
||||
import time
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import polars as pl
|
||||
from tqdm import tqdm
|
||||
|
||||
# Census 2021 LSOAs (TYPE151) with an E prefix. The Census 2021 geography is
|
||||
# frozen, so NOMIS England-level downloads must yield exactly this many LSOAs;
|
||||
# fewer means the download was truncated.
|
||||
ENGLAND_LSOA_COUNT_2021 = 33_755
|
||||
|
||||
|
||||
def download(url: str, output_path: Path, *, timeout: float = 120) -> None:
|
||||
"""Stream-download a URL to a local file with a tqdm progress bar."""
|
||||
|
|
@ -38,3 +46,86 @@ def extract_zip(zip_path: Path, extract_dir: Path) -> None:
|
|||
extract_dir.mkdir(parents=True, exist_ok=True)
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
zf.extractall(extract_dir)
|
||||
|
||||
|
||||
def download_nomis_csv(base_url: str, *, page_size: int = 25_000) -> pl.DataFrame:
|
||||
"""Download a NOMIS CSV dataset, paging with recordoffset/RecordLimit.
|
||||
|
||||
The page size is sent explicitly as ``RecordLimit``: last-page detection is
|
||||
``rows < page_size``, so relying on NOMIS's implicit default would silently
|
||||
truncate the dataset to one page if that default ever differed.
|
||||
"""
|
||||
frames = []
|
||||
offset = 0
|
||||
while True:
|
||||
url = f"{base_url}&RecordLimit={page_size}&recordoffset={offset}"
|
||||
response = httpx.get(url, follow_redirects=True, timeout=120)
|
||||
response.raise_for_status() # pyright: ignore[reportUnusedCallResult]
|
||||
if len(response.content) == 0:
|
||||
break
|
||||
chunk = pl.read_csv(BytesIO(response.content))
|
||||
if chunk.height == 0:
|
||||
break
|
||||
frames.append(chunk)
|
||||
print(f" Fetched {chunk.height} rows (offset={offset})")
|
||||
if chunk.height < page_size:
|
||||
break
|
||||
offset += page_size
|
||||
|
||||
if not frames:
|
||||
raise RuntimeError(f"NOMIS returned no rows for {base_url}")
|
||||
return pl.concat(frames)
|
||||
|
||||
|
||||
def download_arcgis_hub_export(
|
||||
url: str,
|
||||
output_path: Path,
|
||||
*,
|
||||
expected_geometry: str | None = None,
|
||||
poll_interval_s: float = 5,
|
||||
poll_timeout_s: float = 600,
|
||||
) -> int:
|
||||
"""Download an ArcGIS Hub `api/download/v1` export, handling deferred jobs.
|
||||
|
||||
The endpoint returns HTTP 202 with a JSON status body while the export is
|
||||
still being prepared; a plain download would save that placeholder as the
|
||||
output file with a success exit code. Poll until the file is ready, then
|
||||
validate the result with pyogrio (feature count > 0 and, optionally, a
|
||||
geometry-type substring) before moving it into place. Returns the feature
|
||||
count.
|
||||
"""
|
||||
import pyogrio
|
||||
|
||||
tmp_path = output_path.with_name(f"{output_path.stem}.tmp{output_path.suffix}")
|
||||
deadline = time.monotonic() + poll_timeout_s
|
||||
with httpx.Client(follow_redirects=True, timeout=300) as client:
|
||||
while True:
|
||||
with client.stream("GET", url) as response:
|
||||
if response.status_code == 202:
|
||||
response.read()
|
||||
if time.monotonic() > deadline:
|
||||
raise TimeoutError(
|
||||
f"Export did not finish within {poll_timeout_s}s: "
|
||||
f"{response.text}"
|
||||
)
|
||||
time.sleep(poll_interval_s)
|
||||
continue
|
||||
response.raise_for_status() # pyright: ignore[reportUnusedCallResult]
|
||||
with tmp_path.open("wb") as fh:
|
||||
for chunk in response.iter_bytes():
|
||||
fh.write(chunk)
|
||||
break
|
||||
|
||||
info = pyogrio.read_info(tmp_path)
|
||||
features = int(info.get("features", 0))
|
||||
geometry_type = str(info.get("geometry_type") or "")
|
||||
if features <= 0:
|
||||
raise ValueError(f"Downloaded file {output_path.name} contains no features")
|
||||
if expected_geometry is not None and expected_geometry not in geometry_type:
|
||||
raise ValueError(
|
||||
f"Expected {expected_geometry!r} geometry in {output_path.name}, "
|
||||
f"got {geometry_type!r}"
|
||||
)
|
||||
|
||||
tmp_path.replace(output_path)
|
||||
return features
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from collections import Counter
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from os import cpu_count
|
||||
from pathlib import Path
|
||||
|
|
@ -10,6 +12,7 @@ from thefuzz import fuzz
|
|||
from tqdm import tqdm
|
||||
|
||||
from pipeline.local_temp import local_tmp_dir
|
||||
from pipeline.utils.normalize import uppercase_alnum_key_expr
|
||||
|
||||
# A house-number token includes any letter suffix: 8A, 8B and plain 8 are
|
||||
# three different properties on the same street, so digit-only extraction
|
||||
|
|
@ -17,6 +20,10 @@ from pipeline.local_temp import local_tmp_dir
|
|||
# through normalize_address_key first, so tokens are uppercase and
|
||||
# space-separated and [A-Z] suffices for the suffix.
|
||||
_NUMBER_RE = re.compile(r"\d+[A-Z]?")
|
||||
# A single-letter flat designator ("FLAT B", "APARTMENT C") is a house-number-
|
||||
# grade disambiguator with no digit in it: without this, FLAT B and FLAT D in
|
||||
# the same building scored ~96 and cross-matched.
|
||||
_FLAT_LETTER_RE = re.compile(r"\b(?:FLAT|APARTMENT|APT|UNIT) ([A-Z])\b")
|
||||
_POSTCODE_RE = r"^[A-Z]{1,2}\d[A-Z\d]?\d[A-Z]{2}$"
|
||||
# A house number is a strong disambiguator, so a numbered, number-compatible
|
||||
# pair may match on a lower address-similarity score than a number-less one
|
||||
|
|
@ -24,16 +31,30 @@ _POSTCODE_RE = r"^[A-Z]{1,2}\d[A-Z\d]?\d[A-Z]{2}$"
|
|||
# be trusted. Mirrors merge.py's listings convention.
|
||||
MIN_FUZZY_SCORE = 82
|
||||
MIN_FUZZY_SCORE_WITHOUT_NUMBERS = 90
|
||||
# A score reached only through an address VARIANT (locality appended /
|
||||
# secondary address lines dropped) accepts a match the primary strings alone
|
||||
# would reject, so it must clear a near-exact bar: in the miss audit >99% of
|
||||
# genuine variant recoveries scored 100, while the rare false variant matches
|
||||
# scored in the 80s.
|
||||
MIN_VARIANT_SCORE = 90
|
||||
|
||||
# Tokens that mark a sub-unit of a building. A variant whose added/dropped
|
||||
# tokens include one of these could score a single flat's certificate as if it
|
||||
# were the whole building, so such variants are inadmissible.
|
||||
_FLAT_TOKENS = {
|
||||
"FLAT",
|
||||
"FLATS",
|
||||
"APARTMENT",
|
||||
"APT",
|
||||
"UNIT",
|
||||
"MAISONETTE",
|
||||
"STUDIO",
|
||||
"ROOM",
|
||||
}
|
||||
|
||||
|
||||
def normalize_address_key(s: pl.Expr) -> pl.Expr:
|
||||
normalized = (
|
||||
s.cast(pl.String)
|
||||
.str.to_uppercase()
|
||||
.str.replace_all(r"[^0-9A-Z]+", " ")
|
||||
.str.replace_all(r"\s+", " ")
|
||||
.str.strip_chars()
|
||||
)
|
||||
normalized = uppercase_alnum_key_expr(s)
|
||||
return pl.when(normalized.str.contains(r"[A-Z]")).then(normalized).otherwise(None)
|
||||
|
||||
|
||||
|
|
@ -58,6 +79,8 @@ def fuzzy_join_on_postcode(
|
|||
right_postcode_col: str,
|
||||
min_score: int = MIN_FUZZY_SCORE,
|
||||
min_score_without_numbers: int = MIN_FUZZY_SCORE_WITHOUT_NUMBERS,
|
||||
left_variant_cols: Sequence[str] = (),
|
||||
right_variant_cols: Sequence[str] = (),
|
||||
) -> pl.LazyFrame:
|
||||
"""Fuzzy join two LazyFrames by matching addresses within postcode buckets.
|
||||
|
||||
|
|
@ -66,6 +89,19 @@ def fuzzy_join_on_postcode(
|
|||
columns (index, address, postcode) via projection pushdown, and the
|
||||
final join reads the remaining columns lazily.
|
||||
|
||||
``left_variant_cols`` / ``right_variant_cols`` name alternative address
|
||||
columns for the same property (e.g. the EPC's first address line without
|
||||
its locality suffix, or the price-paid address with its locality
|
||||
appended). A pair is scored as the best token_sort_ratio over all
|
||||
admissible variant combinations: source registers frequently disagree
|
||||
only on a trailing village/locality token, which alone drags short
|
||||
addresses below the match threshold. The number-compatibility gate is
|
||||
always evaluated on the primary addresses, and `_admissible_variants`
|
||||
rejects any variant whose added/dropped tokens carry digits or flat
|
||||
designators, so a variant can never bypass the gate or score a single
|
||||
flat as its whole building. Variant-only scores must clear
|
||||
``MIN_VARIANT_SCORE``.
|
||||
|
||||
Returns a LazyFrame with all left and right columns, plus a
|
||||
``_match_score`` (UInt8) audit column holding the token_sort_ratio of
|
||||
the accepted match (exact matches score 100). Unmatched rows have null
|
||||
|
|
@ -90,6 +126,10 @@ def fuzzy_join_on_postcode(
|
|||
normalize_postcode_key(pl.col(left_postcode_col)).alias(
|
||||
"_left_postcode"
|
||||
),
|
||||
*(
|
||||
normalize_address_key(pl.col(col)).alias(f"_left_variant_{i}")
|
||||
for i, col in enumerate(left_variant_cols)
|
||||
),
|
||||
)
|
||||
.collect(engine="streaming")
|
||||
)
|
||||
|
|
@ -104,30 +144,45 @@ def fuzzy_join_on_postcode(
|
|||
normalize_postcode_key(pl.col(right_postcode_col)).alias(
|
||||
"_right_postcode"
|
||||
),
|
||||
*(
|
||||
normalize_address_key(pl.col(col)).alias(f"_right_variant_{i}")
|
||||
for i, col in enumerate(right_variant_cols)
|
||||
),
|
||||
)
|
||||
.unique(subset=["_right_address", "_right_postcode"], keep="first")
|
||||
.collect(engine="streaming")
|
||||
)
|
||||
|
||||
left_variant_names = [f"_left_variant_{i}" for i in range(len(left_variant_cols))]
|
||||
right_variant_names = [
|
||||
f"_right_variant_{i}" for i in range(len(right_variant_cols))
|
||||
]
|
||||
|
||||
# Group right side by postcode for fast lookup
|
||||
right_by_postcode: dict[str, list[tuple[int, str]]] = {}
|
||||
for idx, postcode, address in zip(
|
||||
right_by_postcode: dict[str, list[tuple[int, str, tuple[str, ...]]]] = {}
|
||||
for idx, postcode, address, *variants in zip(
|
||||
right_match["_right_idx"],
|
||||
right_match["_right_postcode"],
|
||||
right_match["_right_address"],
|
||||
*(right_match[name] for name in right_variant_names),
|
||||
):
|
||||
if address is not None and postcode is not None:
|
||||
right_by_postcode.setdefault(postcode, []).append((idx, address))
|
||||
right_by_postcode.setdefault(postcode, []).append(
|
||||
(idx, address, _admissible_variants(address, variants))
|
||||
)
|
||||
|
||||
# Group left side by postcode
|
||||
left_by_postcode: dict[str, list[tuple[int, str]]] = {}
|
||||
for idx, postcode, address in zip(
|
||||
left_by_postcode: dict[str, list[tuple[int, str, tuple[str, ...]]]] = {}
|
||||
for idx, postcode, address, *variants in zip(
|
||||
left_match["_left_idx"],
|
||||
left_match["_left_postcode"],
|
||||
left_match["_left_address"],
|
||||
*(left_match[name] for name in left_variant_names),
|
||||
):
|
||||
if address is not None and postcode is not None:
|
||||
left_by_postcode.setdefault(postcode, []).append((idx, address))
|
||||
left_by_postcode.setdefault(postcode, []).append(
|
||||
(idx, address, _admissible_variants(address, variants))
|
||||
)
|
||||
|
||||
del left_match, right_match
|
||||
|
||||
|
|
@ -145,7 +200,12 @@ def fuzzy_join_on_postcode(
|
|||
|
||||
# Score all pairwise matches in parallel, then greedily assign from
|
||||
# highest score downward so best pairs lock in first.
|
||||
all_pairs: list[tuple[int, int, int]] = [] # (score, left_idx, right_idx)
|
||||
# Pair tuples are (score, exact, left_idx, right_idx); `exact` marks a
|
||||
# literally-equal primary pair so it wins greedy ties against a pair
|
||||
# that merely token-sorts to the same score (e.g. "APARTMENT 3 1 HIGH
|
||||
# ST" vs "APARTMENT 1 3 HIGH ST" both score 100 against each other's
|
||||
# certificates, but each has a literal twin).
|
||||
all_pairs: list[tuple[int, int, int, int]] = []
|
||||
with ProcessPoolExecutor(max_workers=cpu_count()) as executor:
|
||||
for pairs in tqdm(
|
||||
executor.map(_score_bucket, tasks, chunksize=64),
|
||||
|
|
@ -156,8 +216,9 @@ def fuzzy_join_on_postcode(
|
|||
|
||||
del tasks, left_by_postcode, right_by_postcode
|
||||
|
||||
# Sort descending by score so best matches are assigned first
|
||||
all_pairs.sort(key=lambda t: (t[0], -t[1]), reverse=True)
|
||||
# Sort so the best matches are assigned first: score, then literal
|
||||
# equality, then stable left-index order.
|
||||
all_pairs.sort(key=lambda t: (t[0], t[1], -t[2]), reverse=True)
|
||||
|
||||
# Keep the score alongside each accepted pair: it is emitted as the
|
||||
# _match_score audit column so downstream consumers can distinguish
|
||||
|
|
@ -166,7 +227,7 @@ def fuzzy_join_on_postcode(
|
|||
matched_left: set[int] = set()
|
||||
matched_right: set[int] = set()
|
||||
|
||||
for score, left_idx, right_idx in all_pairs:
|
||||
for score, _exact, left_idx, right_idx in all_pairs:
|
||||
if left_idx in matched_left or right_idx in matched_right:
|
||||
continue
|
||||
matches.append((left_idx, right_idx, score))
|
||||
|
|
@ -208,40 +269,102 @@ def fuzzy_join_on_postcode(
|
|||
return result.lazy()
|
||||
|
||||
|
||||
def _number_tokens(address: str) -> set[str]:
|
||||
tokens = set(_NUMBER_RE.findall(address))
|
||||
tokens.update(_FLAT_LETTER_RE.findall(address))
|
||||
return tokens
|
||||
|
||||
|
||||
def _numbers_compatible(a: str, b: str) -> bool:
|
||||
"""Check that the number tokens (house/flat numbers, including any letter
|
||||
suffix) of two addresses are IDENTICAL sets.
|
||||
suffix, plus single-letter flat designators) of two addresses are
|
||||
IDENTICAL sets.
|
||||
|
||||
Equality, not subset: subset logic let "188 GREAT NORTH WAY" absorb
|
||||
"FLAT 1 188 GREAT NORTH WAY" ({188} is a subset of {1, 188}), attaching a
|
||||
single flat's EPC facts to the whole building — tens of thousands of
|
||||
wrong-property matches. Likewise digit-only tokens made "8A" and "8B"
|
||||
both look like {8} and match each other (and plain "8"). Precision over
|
||||
recall: a pair whose two sources genuinely disagree on number tokens is
|
||||
safer left unmatched.
|
||||
both look like {8} and match each other (and plain "8"), and ungated
|
||||
letter flats let "FLAT D 39 X ST" cross-match "FLAT F 39 X ST" at ~96.
|
||||
Precision over recall: a pair whose two sources genuinely disagree on
|
||||
number tokens is safer left unmatched.
|
||||
|
||||
One side numbered, the other not -> incompatible. Neither numbered ->
|
||||
compatible; such pairs are scored against the stricter no-numbers
|
||||
threshold instead.
|
||||
"""
|
||||
nums_a = set(_NUMBER_RE.findall(a))
|
||||
nums_b = set(_NUMBER_RE.findall(b))
|
||||
nums_a = _number_tokens(a)
|
||||
nums_b = _number_tokens(b)
|
||||
if not nums_a and not nums_b:
|
||||
return True
|
||||
return nums_a == nums_b
|
||||
|
||||
|
||||
def _admissible_variants(
|
||||
primary: str, variants: Sequence[str | None]
|
||||
) -> tuple[str, ...]:
|
||||
"""Variants of ``primary`` that are safe to score against the other side.
|
||||
|
||||
A variant may only ADD or DROP whole tokens relative to the primary (one
|
||||
word multiset must contain the other) — never substitute, so a register
|
||||
row whose address lines disagree with the combined address can't smuggle
|
||||
in a different street. The number gate runs on the primary addresses
|
||||
only, so the added/dropped tokens must additionally carry no digits
|
||||
(house numbers) and no flat designator (a "Flat 1"-style secondary line
|
||||
dropped from an EPC address would otherwise let a single flat score as
|
||||
the whole building). The remaining admissible difference is exactly the
|
||||
harmless kind variants exist for: trailing locality/village/town words.
|
||||
"""
|
||||
primary_words = Counter(primary.split())
|
||||
admissible: list[str] = []
|
||||
for variant in variants:
|
||||
if not variant or variant == primary:
|
||||
continue
|
||||
variant_words = Counter(variant.split())
|
||||
if not (variant_words <= primary_words or primary_words <= variant_words):
|
||||
continue
|
||||
changed = (primary_words - variant_words) + (variant_words - primary_words)
|
||||
if any(
|
||||
any(ch.isdigit() for ch in token) or token in _FLAT_TOKENS
|
||||
for token in changed
|
||||
):
|
||||
continue
|
||||
admissible.append(variant)
|
||||
return tuple(dict.fromkeys(admissible))
|
||||
|
||||
|
||||
def _score_bucket(
|
||||
args: tuple[list[tuple[int, str]], list[tuple[int, str]], int, int],
|
||||
) -> list[tuple[int, int, int]]:
|
||||
args: tuple[
|
||||
list[tuple[int, str, tuple[str, ...]]],
|
||||
list[tuple[int, str, tuple[str, ...]]],
|
||||
int,
|
||||
int,
|
||||
],
|
||||
) -> list[tuple[int, int, int, int]]:
|
||||
"""Score all address pairs within a single postcode bucket."""
|
||||
left_entries, right_entries, min_score, min_score_without_numbers = args
|
||||
pairs = []
|
||||
for left_row, left_address in left_entries:
|
||||
for right_row, right_address in right_entries:
|
||||
for left_row, left_address, left_variants in left_entries:
|
||||
for right_row, right_address, right_variants in right_entries:
|
||||
if not _numbers_compatible(left_address, right_address):
|
||||
continue
|
||||
score = fuzz.token_sort_ratio(left_address, right_address)
|
||||
# Variant pairs recover same-property matches where one register
|
||||
# carries a locality suffix the other lacks; a variant-only score
|
||||
# must clear the near-exact MIN_VARIANT_SCORE bar.
|
||||
if score < 100 and (left_variants or right_variants):
|
||||
for left_variant in (left_address, *left_variants):
|
||||
for right_variant in (right_address, *right_variants):
|
||||
if (
|
||||
left_variant is left_address
|
||||
and right_variant is right_address
|
||||
):
|
||||
continue
|
||||
variant_score = fuzz.token_sort_ratio(
|
||||
left_variant, right_variant
|
||||
)
|
||||
if variant_score >= MIN_VARIANT_SCORE and variant_score > score:
|
||||
score = variant_score
|
||||
# Number-less pairs (named houses, building-name flats) lack the
|
||||
# house-number disambiguator, so require a near-exact match.
|
||||
threshold = (
|
||||
|
|
@ -250,5 +373,7 @@ def _score_bucket(
|
|||
else min_score_without_numbers
|
||||
)
|
||||
if score >= threshold:
|
||||
pairs.append((score, left_row, right_row))
|
||||
pairs.append(
|
||||
(score, int(left_address == right_address), left_row, right_row)
|
||||
)
|
||||
return pairs
|
||||
|
|
|
|||
70
pipeline/utils/normalize.py
Normal file
70
pipeline/utils/normalize.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
"""Shared low-level text-normalization primitives.
|
||||
|
||||
Address matching (``pipeline.utils.fuzzy_join``, ``pipeline.transform.merge``),
|
||||
POI retailer cleanup (``pipeline.transform.transform_poi``) and school-name
|
||||
matching (``pipeline.check_school_cutoffs``) each layer domain-specific rules
|
||||
on top of these. The primitives are deliberately tiny and single-purpose so
|
||||
that composing them preserves every caller's existing output byte-for-byte.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
import polars as pl
|
||||
|
||||
# One character outside [a-z0-9 ]. Callers lowercase first; each offending
|
||||
# character becomes a single space (runs are NOT merged here — callers apply
|
||||
# word-level rules and then collapse_whitespace).
|
||||
_NON_ALNUM_LOWER_RE = re.compile(r"[^a-z0-9 ]")
|
||||
|
||||
# Any digit marks a token as number-bearing (house/flat numbers, including
|
||||
# letter-suffixed forms such as 8A, which still contain a digit).
|
||||
_DIGIT_RE = re.compile(r"\d")
|
||||
|
||||
|
||||
def collapse_whitespace(s: str) -> str:
|
||||
"""Collapse every whitespace run to a single space and strip the ends."""
|
||||
return " ".join(s.split())
|
||||
|
||||
|
||||
def strip_or_empty(s: str | None) -> str:
|
||||
"""Strip leading/trailing whitespace, mapping None to ``""``.
|
||||
|
||||
Interior whitespace is preserved (unlike :func:`collapse_whitespace`) so
|
||||
the result can be looked up verbatim against curated dictionary keys.
|
||||
"""
|
||||
return "" if s is None else s.strip()
|
||||
|
||||
|
||||
def replace_non_alnum_lower(s: str) -> str:
|
||||
"""Replace each character outside [a-z0-9 ] with a single space.
|
||||
|
||||
Expects already-lowercased input (uppercase letters are replaced too).
|
||||
Replacement is per character, not per run; callers collapse whitespace
|
||||
afterwards.
|
||||
"""
|
||||
return _NON_ALNUM_LOWER_RE.sub(" ", s)
|
||||
|
||||
|
||||
def drop_digit_tokens(s: str) -> str:
|
||||
"""Drop whitespace-separated tokens that contain any digit.
|
||||
|
||||
``"10A HIGH STREET" -> "HIGH STREET"``. The surviving tokens are rejoined
|
||||
with single spaces, so whitespace collapses as a side effect.
|
||||
"""
|
||||
return " ".join(token for token in s.split() if not _DIGIT_RE.search(token))
|
||||
|
||||
|
||||
def uppercase_alnum_key_expr(s: pl.Expr) -> pl.Expr:
|
||||
"""Polars expression: uppercase, replace each non-alphanumeric run with a
|
||||
single space, collapse whitespace, and strip the ends.
|
||||
|
||||
Non-ASCII letters fall outside [0-9A-Z] after uppercasing and become
|
||||
spaces (``"Café 1" -> "CAF 1"``).
|
||||
"""
|
||||
return (
|
||||
s.cast(pl.String)
|
||||
.str.to_uppercase()
|
||||
.str.replace_all(r"[^0-9A-Z]+", " ")
|
||||
.str.replace_all(r"\s+", " ")
|
||||
.str.strip_chars()
|
||||
)
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
import polars as pl
|
||||
|
||||
from pipeline.utils import fuzzy_join_on_postcode, normalize_postcode_key
|
||||
from pipeline.utils.fuzzy_join import _numbers_compatible
|
||||
from pipeline.utils.fuzzy_join import _admissible_variants, _numbers_compatible
|
||||
|
||||
|
||||
def test_fuzzy_join_on_postcode_matches_addresses_within_postcode():
|
||||
|
|
@ -165,7 +165,7 @@ def test_fuzzy_join_rejects_mid_score_number_less_match():
|
|||
|
||||
|
||||
def test_fuzzy_join_matches_numbered_pair_at_baseline_threshold():
|
||||
# "10 ACACIA AVENUE" vs "FLAT A 10 ACACIA AVENUE" scores exactly 82 and the
|
||||
# "10 ACACIA AVENUE" vs "10 ACACIA AVENUE OAKHAM" scores exactly 82 and the
|
||||
# house number is compatible, so the numbered baseline (>= 82) still matches.
|
||||
left = pl.LazyFrame(
|
||||
{
|
||||
|
|
@ -175,7 +175,7 @@ def test_fuzzy_join_matches_numbered_pair_at_baseline_threshold():
|
|||
)
|
||||
right = pl.LazyFrame(
|
||||
{
|
||||
"right_address": ["Flat A, 10 Acacia Avenue"],
|
||||
"right_address": ["10 Acacia Avenue, Oakham"],
|
||||
"right_postcode": ["AB1 2CD"],
|
||||
}
|
||||
)
|
||||
|
|
@ -189,7 +189,7 @@ def test_fuzzy_join_matches_numbered_pair_at_baseline_threshold():
|
|||
right_postcode_col="right_postcode",
|
||||
).collect()
|
||||
|
||||
assert result["right_address"].to_list() == ["Flat A, 10 Acacia Avenue"]
|
||||
assert result["right_address"].to_list() == ["10 Acacia Avenue, Oakham"]
|
||||
|
||||
|
||||
def test_fuzzy_join_matches_high_score_number_less_pair():
|
||||
|
|
@ -244,6 +244,151 @@ def test_numbers_compatible_number_less_and_one_sided_pairs():
|
|||
assert not _numbers_compatible("ROSE COTTAGE", "8 HIGH STREET")
|
||||
|
||||
|
||||
def test_numbers_compatible_gates_single_letter_flats():
|
||||
# "FLAT D" and "FLAT F" are different flats even with identical street
|
||||
# numbers; ungated they token_sort to ~96 and cross-matched. The letter is
|
||||
# a pseudo-number token, so it also blocks a flat matching the bare
|
||||
# building address.
|
||||
assert not _numbers_compatible(
|
||||
"FLAT D 39 GERTRUDE STREET", "FLAT F 39 GERTRUDE STREET"
|
||||
)
|
||||
assert _numbers_compatible(
|
||||
"FLAT D 39 GERTRUDE STREET", "39 GERTRUDE STREET FLAT D"
|
||||
)
|
||||
assert not _numbers_compatible("FLAT B ROSE COURT", "ROSE COURT")
|
||||
# A letter glued to a number ("A3") is a unit name, not a flat letter.
|
||||
assert _numbers_compatible("FLAT A3 CHESHAM HEIGHTS", "FLAT A3 CHESHAM HEIGHTS")
|
||||
|
||||
|
||||
def test_admissible_variants_allows_locality_suffix_only():
|
||||
# Locality words may differ between a variant and its primary; digits and
|
||||
# flat designators may not (the gate ran on the primary only).
|
||||
assert _admissible_variants(
|
||||
"12 OAK ROAD", ["12 OAK ROAD HALE", "12 OAK ROAD"]
|
||||
) == ("12 OAK ROAD HALE",)
|
||||
# Dropping "FLAT 1" (digit) or "FLAT B" (flat designator) is inadmissible:
|
||||
# the variant would score a single flat as the whole building.
|
||||
assert (
|
||||
_admissible_variants("FLAT 1 188 GREAT NORTH WAY", ["188 GREAT NORTH WAY"])
|
||||
== ()
|
||||
)
|
||||
assert _admissible_variants("FLAT B ROSE COURT", ["ROSE COURT"]) == ()
|
||||
assert _admissible_variants("12 OAK ROAD", [None, "12 OAK ROAD"]) == ()
|
||||
# Substitution is never admissible: a register row whose address1
|
||||
# disagrees with the combined address must not smuggle in a different
|
||||
# street for scoring.
|
||||
assert _admissible_variants("12 OAK ROAD", ["12 ELM ROAD"]) == ()
|
||||
assert (
|
||||
_admissible_variants("1 TOTALLY DIFFERENT ROAD", ["1 EXAMPLE STREET"]) == ()
|
||||
)
|
||||
|
||||
|
||||
def test_fuzzy_join_variant_recovers_locality_suffix_mismatch():
|
||||
# The EPC register stores "12 Oak Road, Hale" (address1 + locality line)
|
||||
# while price-paid has the bare "12 Oak Road": token_sort scores 81 < 82
|
||||
# and the match was lost. The EPC's address1-only variant scores 100.
|
||||
left = pl.LazyFrame(
|
||||
{
|
||||
"left_address": ["12 Oak Road"],
|
||||
"left_postcode": ["AB1 2CD"],
|
||||
"left_with_locality": ["12 Oak Road Hale"],
|
||||
}
|
||||
)
|
||||
right = pl.LazyFrame(
|
||||
{
|
||||
"right_address": ["12 Oak Road, Hale"],
|
||||
"right_postcode": ["AB1 2CD"],
|
||||
"right_address1": ["12 Oak Road"],
|
||||
}
|
||||
)
|
||||
|
||||
unmatched = fuzzy_join_on_postcode(
|
||||
left=left,
|
||||
right=right,
|
||||
left_address_col="left_address",
|
||||
right_address_col="right_address",
|
||||
left_postcode_col="left_postcode",
|
||||
right_postcode_col="right_postcode",
|
||||
).collect()
|
||||
assert unmatched["_match_score"].to_list() == [None]
|
||||
|
||||
result = fuzzy_join_on_postcode(
|
||||
left=left,
|
||||
right=right,
|
||||
left_address_col="left_address",
|
||||
right_address_col="right_address",
|
||||
left_postcode_col="left_postcode",
|
||||
right_postcode_col="right_postcode",
|
||||
left_variant_cols=["left_with_locality"],
|
||||
right_variant_cols=["right_address1"],
|
||||
).collect()
|
||||
assert result["_match_score"].to_list() == [100]
|
||||
|
||||
|
||||
def test_fuzzy_join_variant_cannot_unlock_a_flat_for_its_building():
|
||||
# The EPC's secondary line carries the flat designator; dropping it would
|
||||
# score the flat's certificate 100 against the whole-building price-paid
|
||||
# address. The variant must be ruled inadmissible and the pair unmatched.
|
||||
left = pl.LazyFrame(
|
||||
{
|
||||
"left_address": ["188 Great North Way"],
|
||||
"left_postcode": ["AB1 2CD"],
|
||||
}
|
||||
)
|
||||
right = pl.LazyFrame(
|
||||
{
|
||||
"right_address": ["Flat 1, 188 Great North Way"],
|
||||
"right_postcode": ["AB1 2CD"],
|
||||
"right_address1": ["188 Great North Way"],
|
||||
}
|
||||
)
|
||||
|
||||
result = fuzzy_join_on_postcode(
|
||||
left=left,
|
||||
right=right,
|
||||
left_address_col="left_address",
|
||||
right_address_col="right_address",
|
||||
left_postcode_col="left_postcode",
|
||||
right_postcode_col="right_postcode",
|
||||
right_variant_cols=["right_address1"],
|
||||
).collect()
|
||||
|
||||
assert result["_match_score"].to_list() == [None]
|
||||
|
||||
|
||||
def test_fuzzy_join_variant_score_must_be_near_exact():
|
||||
# A score reached only through a variant must clear MIN_VARIANT_SCORE
|
||||
# (90): "2 MYRTLE COTTAGES" vs "2 LEITH VIEW COTTAGES" type pairs scored
|
||||
# in the 80s via variants and were false matches.
|
||||
left = pl.LazyFrame(
|
||||
{
|
||||
"left_address": ["2 Myrtle Cottages"],
|
||||
"left_postcode": ["AB1 2CD"],
|
||||
"left_with_locality": ["2 Myrtle Cottages Dorking"],
|
||||
}
|
||||
)
|
||||
right = pl.LazyFrame(
|
||||
{
|
||||
"right_address": ["2 Leith View Cottages, North Holmwood"],
|
||||
"right_postcode": ["AB1 2CD"],
|
||||
"right_address1": ["2 Leith View Cottages"],
|
||||
}
|
||||
)
|
||||
|
||||
result = fuzzy_join_on_postcode(
|
||||
left=left,
|
||||
right=right,
|
||||
left_address_col="left_address",
|
||||
right_address_col="right_address",
|
||||
left_postcode_col="left_postcode",
|
||||
right_postcode_col="right_postcode",
|
||||
left_variant_cols=["left_with_locality"],
|
||||
right_variant_cols=["right_address1"],
|
||||
).collect()
|
||||
|
||||
assert result["_match_score"].to_list() == [None]
|
||||
|
||||
|
||||
def test_fuzzy_join_rejects_wrong_letter_suffix_match():
|
||||
# End-to-end guard for the 8A/8B class of wrong-property matches: the only
|
||||
# candidate in the postcode bucket differs solely in the number suffix, so
|
||||
|
|
@ -294,7 +439,7 @@ def test_fuzzy_join_emits_match_score_column():
|
|||
"10 HIGH STREET",
|
||||
# Scores exactly 82 against "10 Acacia Avenue" (see
|
||||
# test_fuzzy_join_matches_numbered_pair_at_baseline_threshold).
|
||||
"Flat A, 10 Acacia Avenue",
|
||||
"10 Acacia Avenue, Oakham",
|
||||
],
|
||||
"right_postcode": ["AB1 2CD", "EF3 4GH"],
|
||||
}
|
||||
|
|
|
|||
158
pipeline/utils/test_normalize.py
Normal file
158
pipeline/utils/test_normalize.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
import polars as pl
|
||||
|
||||
from pipeline.check_school_cutoffs import normalize_la, normalize_name
|
||||
from pipeline.transform.merge import _street_only_address
|
||||
from pipeline.transform.transform_poi import normalize_grocery_retailer
|
||||
from pipeline.utils.fuzzy_join import normalize_address_key
|
||||
from pipeline.utils.normalize import (
|
||||
collapse_whitespace,
|
||||
drop_digit_tokens,
|
||||
replace_non_alnum_lower,
|
||||
strip_or_empty,
|
||||
uppercase_alnum_key_expr,
|
||||
)
|
||||
|
||||
# --- Primitives -------------------------------------------------------------
|
||||
|
||||
|
||||
def test_collapse_whitespace():
|
||||
assert collapse_whitespace("") == ""
|
||||
assert collapse_whitespace(" ") == ""
|
||||
assert collapse_whitespace("a b") == "a b"
|
||||
assert collapse_whitespace(" a \t b \n c ") == "a b c"
|
||||
# str.split() also splits on unicode whitespace (non-breaking space).
|
||||
assert collapse_whitespace("a\u00a0b") == "a b"
|
||||
|
||||
|
||||
def test_strip_or_empty():
|
||||
assert strip_or_empty(None) == ""
|
||||
assert strip_or_empty("") == ""
|
||||
assert strip_or_empty(" x ") == "x"
|
||||
# Interior whitespace is preserved, unlike collapse_whitespace.
|
||||
assert strip_or_empty(" a b ") == "a b"
|
||||
|
||||
|
||||
def test_replace_non_alnum_lower():
|
||||
assert replace_non_alnum_lower("") == ""
|
||||
assert replace_non_alnum_lower("abc 123") == "abc 123"
|
||||
# Per-character replacement: runs are not merged.
|
||||
assert replace_non_alnum_lower("a--b") == "a b"
|
||||
# Existing spaces are kept as-is.
|
||||
assert replace_non_alnum_lower("a , b") == "a b"
|
||||
# Uppercase and accented letters fall outside [a-z0-9 ].
|
||||
assert replace_non_alnum_lower("École") == " cole"
|
||||
|
||||
|
||||
def test_drop_digit_tokens():
|
||||
assert drop_digit_tokens("") == ""
|
||||
assert drop_digit_tokens("10A HIGH STREET") == "HIGH STREET"
|
||||
assert drop_digit_tokens("8B") == ""
|
||||
assert drop_digit_tokens("12 34") == ""
|
||||
assert drop_digit_tokens("KINGSWOOD") == "KINGSWOOD"
|
||||
# Whitespace collapses as a side effect of the token rejoin.
|
||||
assert drop_digit_tokens(" A B ") == "A B"
|
||||
|
||||
|
||||
def test_uppercase_alnum_key_expr():
|
||||
values = [
|
||||
"Flat 2, 10 High Street",
|
||||
" 12 High-Street ",
|
||||
"",
|
||||
None,
|
||||
"Café 1",
|
||||
"st mary's-court",
|
||||
]
|
||||
out = (
|
||||
pl.DataFrame({"a": values}, schema={"a": pl.String})
|
||||
.select(uppercase_alnum_key_expr(pl.col("a")))
|
||||
.to_series()
|
||||
.to_list()
|
||||
)
|
||||
assert out == [
|
||||
"FLAT 2 10 HIGH STREET",
|
||||
"12 HIGH STREET",
|
||||
"",
|
||||
None,
|
||||
"CAF 1",
|
||||
"ST MARY S COURT",
|
||||
]
|
||||
|
||||
|
||||
# --- Characterization of the call sites built on the primitives ------------
|
||||
# Expected values were captured from the pre-refactor implementations and
|
||||
# must never change: each wrapper's output is byte-for-byte pinned.
|
||||
|
||||
|
||||
def test_normalize_address_key_characterization():
|
||||
values = [
|
||||
"Flat 2, 10 High Street",
|
||||
" 12 High-Street ",
|
||||
"123", # digits only: no letter -> null
|
||||
"", # empty -> null
|
||||
None, # null in, null out
|
||||
"Café 1",
|
||||
"st mary's-court",
|
||||
"ALREADY NORMAL",
|
||||
]
|
||||
out = (
|
||||
pl.DataFrame({"a": values}, schema={"a": pl.String})
|
||||
.select(normalize_address_key(pl.col("a")))
|
||||
.to_series()
|
||||
.to_list()
|
||||
)
|
||||
assert out == [
|
||||
"FLAT 2 10 HIGH STREET",
|
||||
"12 HIGH STREET",
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
"CAF 1",
|
||||
"ST MARY S COURT",
|
||||
"ALREADY NORMAL",
|
||||
]
|
||||
|
||||
|
||||
def test_street_only_address_characterization():
|
||||
assert _street_only_address("10A HIGH STREET") == "HIGH STREET"
|
||||
assert _street_only_address("FLAT 1 188 GREAT NORTH WAY") == "FLAT GREAT NORTH WAY"
|
||||
assert _street_only_address("") == ""
|
||||
assert _street_only_address("OLDSTEAD ROAD") == "OLDSTEAD ROAD"
|
||||
assert _street_only_address(" A B ") == "A B"
|
||||
assert _street_only_address("12 34") == ""
|
||||
assert _street_only_address("8B") == ""
|
||||
|
||||
|
||||
def test_normalize_grocery_retailer_characterization():
|
||||
assert normalize_grocery_retailer(None) == ""
|
||||
assert normalize_grocery_retailer("") == ""
|
||||
assert normalize_grocery_retailer(" Tesco Express ") == "Tesco Express"
|
||||
assert normalize_grocery_retailer("Sainsburys") == "Sainsbury's"
|
||||
assert normalize_grocery_retailer("Lincolnshire Co-operative") == "Co-op"
|
||||
# Only edge whitespace is stripped; interior whitespace must survive so
|
||||
# near-miss names fall through the exact dictionary lookups unchanged.
|
||||
assert normalize_grocery_retailer("Bob's Shop") == "Bob's Shop"
|
||||
assert normalize_grocery_retailer(" Marks and Spencer ") == "M&S"
|
||||
|
||||
|
||||
def test_normalize_name_characterization():
|
||||
assert normalize_name("St. Mary's C of E Primary School") == (
|
||||
"st marys primary school"
|
||||
)
|
||||
assert normalize_name("St. Mary's C of E Primary School", True) == "st marys"
|
||||
assert normalize_name("") == ""
|
||||
assert normalize_name("Ham & High School") == "ham high school"
|
||||
assert normalize_name("Ham & High School", True) == "ham"
|
||||
# Accented characters become spaces, splitting the word.
|
||||
assert normalize_name("École Élémentaire") == "cole l mentaire"
|
||||
assert normalize_name(" THE KING'S ACADEMY ") == "kings academy"
|
||||
assert normalize_name("Holy Trinity RC Voluntary Aided School") == (
|
||||
"holy trinity school"
|
||||
)
|
||||
assert normalize_name("st. john's") == "st johns"
|
||||
|
||||
|
||||
def test_normalize_la_characterization():
|
||||
assert normalize_la("City of Westminster") == "westminster"
|
||||
assert normalize_la("Brighton & Hove") == "brighton and hove"
|
||||
assert normalize_la(" Kingston upon Thames ") == "kingston upon thames"
|
||||
assert normalize_la("") == ""
|
||||
File diff suppressed because it is too large
Load diff
589
server-rs/src/checkout_sessions/lifecycle.rs
Normal file
589
server-rs/src/checkout_sessions/lifecycle.rs
Normal file
|
|
@ -0,0 +1,589 @@
|
|||
//! The checkout session state machine: starting a checkout (with pricing and
|
||||
//! reservation under a cross-instance lock), verifying Stripe's completion
|
||||
//! payload, completing/granting, and reversing/reinstating after refunds or
|
||||
//! disputes.
|
||||
|
||||
use std::sync::LazyLock;
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::auth::PocketBaseUser;
|
||||
use crate::pocketbase::get_superuser_token;
|
||||
use crate::pocketbase_locks::acquire_pocketbase_lock;
|
||||
use crate::routes::pricing::{count_licensed_users, price_for_count};
|
||||
use crate::state::AppState;
|
||||
|
||||
use super::records::{
|
||||
attach_stripe_session, count_active_pending_checkouts, create_pending_checkout,
|
||||
expire_stale_pending_checkouts, find_active_checkout_for_user,
|
||||
find_checkout_by_payment_intent_or_checkout_session, find_checkout_by_stripe_session,
|
||||
has_other_completed_checkout_for_user, mark_checkout_completed, mark_checkout_reinstated,
|
||||
mark_checkout_reversed, mark_checkout_status, PendingCheckoutInput,
|
||||
};
|
||||
use super::referral::{
|
||||
mark_referral_invite_used, release_referral_invite_reservation, reserve_referral_invite,
|
||||
};
|
||||
use super::stripe::create_stripe_session;
|
||||
use super::{
|
||||
ensure_success, is_safe_reversal_reason, is_safe_stripe_session_id, now_unix_secs,
|
||||
number_field, CheckoutCompletion, CheckoutStart, PaymentReinstatementOutcome,
|
||||
PaymentReversalOutcome, VerifiedCheckout, CHECKOUT_CURRENCY, REFERRAL_DISCOUNT_PERCENT,
|
||||
};
|
||||
|
||||
const CHECKOUT_SESSION_TTL_SECS: u64 = 31 * 60;
|
||||
const CHECKOUT_PRICING_LOCK_NAME: &str = "checkout:pricing";
|
||||
const CHECKOUT_PRICING_LOCK_TTL_SECS: u64 = 5 * 60;
|
||||
|
||||
static CHECKOUT_RESERVATION_LOCK: LazyLock<Mutex<()>> = LazyLock::new(|| Mutex::new(()));
|
||||
|
||||
pub async fn start_license_checkout(
|
||||
state: &AppState,
|
||||
user: &PocketBaseUser,
|
||||
success_url: &str,
|
||||
cancel_url: &str,
|
||||
discount_coupon_id: Option<&str>,
|
||||
referral_invite_id: Option<&str>,
|
||||
) -> anyhow::Result<CheckoutStart> {
|
||||
let _guard = CHECKOUT_RESERVATION_LOCK.lock().await;
|
||||
let pricing_lock = acquire_pocketbase_lock(
|
||||
state,
|
||||
CHECKOUT_PRICING_LOCK_NAME,
|
||||
CHECKOUT_PRICING_LOCK_TTL_SECS,
|
||||
)
|
||||
.await?;
|
||||
let result = start_license_checkout_locked(
|
||||
state,
|
||||
user,
|
||||
success_url,
|
||||
cancel_url,
|
||||
discount_coupon_id,
|
||||
referral_invite_id,
|
||||
)
|
||||
.await;
|
||||
if let Err(err) = pricing_lock.release().await {
|
||||
warn!("Failed to release checkout pricing lock: {err}");
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
async fn start_license_checkout_locked(
|
||||
state: &AppState,
|
||||
user: &PocketBaseUser,
|
||||
success_url: &str,
|
||||
cancel_url: &str,
|
||||
discount_coupon_id: Option<&str>,
|
||||
referral_invite_id: Option<&str>,
|
||||
) -> anyhow::Result<CheckoutStart> {
|
||||
let now = now_unix_secs();
|
||||
expire_stale_pending_checkouts(state, now).await?;
|
||||
|
||||
if let Some(existing) = find_active_checkout_for_user(
|
||||
state,
|
||||
&user.id,
|
||||
discount_coupon_id.unwrap_or_default(),
|
||||
referral_invite_id.unwrap_or_default(),
|
||||
now,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
if !existing.checkout_url.is_empty() {
|
||||
return Ok(CheckoutStart::Stripe {
|
||||
url: existing.checkout_url,
|
||||
});
|
||||
}
|
||||
if let Err(err) = mark_checkout_status(state, &existing.id, "failed").await {
|
||||
warn!(
|
||||
reservation_id = %existing.id,
|
||||
"Failed to fail incomplete checkout reservation: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let licensed_count = count_licensed_users(state).await?;
|
||||
let pending_count = count_active_pending_checkouts(state, now).await?;
|
||||
let price_pence = price_for_count(licensed_count + pending_count);
|
||||
|
||||
if price_pence == 0 {
|
||||
grant_license(state, &user.id).await?;
|
||||
return Ok(CheckoutStart::Free);
|
||||
}
|
||||
|
||||
let expires_at_unix = now + CHECKOUT_SESSION_TTL_SECS;
|
||||
let expected_total_pence = expected_total_for_checkout(price_pence, discount_coupon_id);
|
||||
let reservation_id = create_pending_checkout(
|
||||
state,
|
||||
PendingCheckoutInput {
|
||||
user_id: &user.id,
|
||||
amount_pence: price_pence,
|
||||
expected_total_pence,
|
||||
currency: CHECKOUT_CURRENCY,
|
||||
discount_coupon_id: discount_coupon_id.unwrap_or_default(),
|
||||
referral_invite_id: referral_invite_id.unwrap_or_default(),
|
||||
expires_at_unix,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
if let Some(invite_id) = referral_invite_id.filter(|id| !id.is_empty()) {
|
||||
if let Err(err) =
|
||||
reserve_referral_invite(state, invite_id, &user.id, &reservation_id, expires_at_unix)
|
||||
.await
|
||||
{
|
||||
if let Err(mark_err) = mark_checkout_status(state, &reservation_id, "failed").await {
|
||||
warn!(
|
||||
reservation_id,
|
||||
"Failed to mark checkout reservation failed: {mark_err}"
|
||||
);
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
|
||||
let stripe_result = create_stripe_session(
|
||||
state,
|
||||
user,
|
||||
&reservation_id,
|
||||
price_pence,
|
||||
success_url,
|
||||
cancel_url,
|
||||
expires_at_unix,
|
||||
discount_coupon_id,
|
||||
)
|
||||
.await;
|
||||
|
||||
let (stripe_session_id, url) = match stripe_result {
|
||||
Ok(session) => session,
|
||||
Err(err) => {
|
||||
if let Err(mark_err) = mark_checkout_status(state, &reservation_id, "failed").await {
|
||||
warn!(
|
||||
reservation_id,
|
||||
"Failed to mark checkout reservation failed: {mark_err}"
|
||||
);
|
||||
}
|
||||
if let Some(invite_id) = referral_invite_id.filter(|id| !id.is_empty()) {
|
||||
if let Err(release_err) =
|
||||
release_referral_invite_reservation(state, invite_id, &reservation_id).await
|
||||
{
|
||||
warn!(
|
||||
reservation_id,
|
||||
referral_invite_id = invite_id,
|
||||
"Failed to release referral invite reservation: {release_err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(err) = attach_stripe_session(state, &reservation_id, &stripe_session_id, &url).await
|
||||
{
|
||||
if let Err(mark_err) = mark_checkout_status(state, &reservation_id, "failed").await {
|
||||
warn!(
|
||||
reservation_id,
|
||||
"Failed to mark checkout reservation failed: {mark_err}"
|
||||
);
|
||||
}
|
||||
if let Some(invite_id) = referral_invite_id.filter(|id| !id.is_empty()) {
|
||||
if let Err(release_err) =
|
||||
release_referral_invite_reservation(state, invite_id, &reservation_id).await
|
||||
{
|
||||
warn!(
|
||||
reservation_id,
|
||||
referral_invite_id = invite_id,
|
||||
"Failed to release referral invite reservation: {release_err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
Ok(CheckoutStart::Stripe { url })
|
||||
}
|
||||
|
||||
pub async fn verify_checkout_completion(
|
||||
state: &AppState,
|
||||
session: &Value,
|
||||
) -> anyhow::Result<CheckoutCompletion> {
|
||||
let session_id = match session["id"].as_str() {
|
||||
Some(id) if is_safe_stripe_session_id(id) => id,
|
||||
_ => {
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"missing or invalid session id".into(),
|
||||
))
|
||||
}
|
||||
};
|
||||
let payment_intent_id = match session["payment_intent"].as_str() {
|
||||
Some(id) if is_safe_stripe_session_id(id) => id,
|
||||
_ => {
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"missing or invalid payment intent id".into(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let checkout = match find_checkout_by_stripe_session(state, session_id).await? {
|
||||
Some(checkout) => checkout,
|
||||
None => {
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"checkout session has no reservation".into(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let already_completed = checkout.status == "completed";
|
||||
if !already_completed && checkout.status != "pending" && checkout.status != "expired" {
|
||||
return Ok(CheckoutCompletion::Rejected(format!(
|
||||
"checkout reservation is {}",
|
||||
checkout.status
|
||||
)));
|
||||
}
|
||||
if checkout.stripe_session_id != session_id {
|
||||
mark_checkout_status(state, &checkout.id, "invalid").await?;
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"checkout reservation session id mismatch".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let client_reference_id = session["client_reference_id"].as_str().unwrap_or_default();
|
||||
if client_reference_id != checkout.user_id {
|
||||
mark_checkout_status(state, &checkout.id, "invalid").await?;
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"checkout client_reference_id mismatch".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let payment_status = session["payment_status"].as_str().unwrap_or_default();
|
||||
if payment_status != "paid" {
|
||||
return Ok(CheckoutCompletion::Rejected(format!(
|
||||
"checkout payment_status is {payment_status}"
|
||||
)));
|
||||
}
|
||||
|
||||
let currency = session["currency"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_ascii_lowercase();
|
||||
if currency != checkout.currency {
|
||||
mark_checkout_status(state, &checkout.id, "invalid").await?;
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"checkout currency mismatch".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let amount_subtotal = match number_field(session, "amount_subtotal") {
|
||||
Some(amount) => amount,
|
||||
None => {
|
||||
mark_checkout_status(state, &checkout.id, "invalid").await?;
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"checkout amount_subtotal missing".into(),
|
||||
));
|
||||
}
|
||||
};
|
||||
if amount_subtotal != checkout.amount_pence {
|
||||
mark_checkout_status(state, &checkout.id, "invalid").await?;
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"checkout amount_subtotal mismatch".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let amount_total = match number_field(session, "amount_total") {
|
||||
Some(amount) => amount,
|
||||
None => {
|
||||
mark_checkout_status(state, &checkout.id, "invalid").await?;
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"checkout amount_total missing".into(),
|
||||
));
|
||||
}
|
||||
};
|
||||
if amount_total != checkout.expected_total_pence {
|
||||
mark_checkout_status(state, &checkout.id, "invalid").await?;
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"checkout amount_total mismatch".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let verified = VerifiedCheckout {
|
||||
reservation_id: checkout.id,
|
||||
user_id: checkout.user_id,
|
||||
stripe_session_id: session_id.to_string(),
|
||||
payment_intent_id: payment_intent_id.to_string(),
|
||||
paid_amount_pence: amount_total,
|
||||
referral_invite_id: checkout.referral_invite_id,
|
||||
};
|
||||
|
||||
if already_completed {
|
||||
Ok(CheckoutCompletion::AlreadyHandled(verified))
|
||||
} else {
|
||||
Ok(CheckoutCompletion::Grant(verified))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn complete_verified_checkout(
|
||||
state: &AppState,
|
||||
checkout: &VerifiedCheckout,
|
||||
) -> anyhow::Result<()> {
|
||||
let _guard = CHECKOUT_RESERVATION_LOCK.lock().await;
|
||||
let pricing_lock = acquire_pocketbase_lock(
|
||||
state,
|
||||
CHECKOUT_PRICING_LOCK_NAME,
|
||||
CHECKOUT_PRICING_LOCK_TTL_SECS,
|
||||
)
|
||||
.await?;
|
||||
let result = complete_verified_checkout_locked(state, checkout).await;
|
||||
if let Err(err) = pricing_lock.release().await {
|
||||
warn!("Failed to release checkout pricing lock: {err}");
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
async fn complete_verified_checkout_locked(
|
||||
state: &AppState,
|
||||
checkout: &VerifiedCheckout,
|
||||
) -> anyhow::Result<()> {
|
||||
let live_checkout = find_checkout_by_stripe_session(state, &checkout.stripe_session_id)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("checkout reservation disappeared before completion"))?;
|
||||
|
||||
if live_checkout.status == "completed" {
|
||||
if !checkout.referral_invite_id.is_empty() {
|
||||
mark_referral_invite_used(
|
||||
state,
|
||||
&checkout.referral_invite_id,
|
||||
&checkout.user_id,
|
||||
&checkout.reservation_id,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
if live_checkout.id != checkout.reservation_id
|
||||
|| live_checkout.user_id != checkout.user_id
|
||||
|| live_checkout.referral_invite_id != checkout.referral_invite_id
|
||||
{
|
||||
mark_checkout_status(state, &checkout.reservation_id, "invalid").await?;
|
||||
return Err(anyhow!("checkout reservation changed before completion"));
|
||||
}
|
||||
if live_checkout.status != "pending" && live_checkout.status != "expired" {
|
||||
return Err(anyhow!("checkout reservation is {}", live_checkout.status));
|
||||
}
|
||||
|
||||
grant_license(state, &checkout.user_id).await?;
|
||||
mark_checkout_completed(
|
||||
state,
|
||||
&checkout.reservation_id,
|
||||
checkout.paid_amount_pence,
|
||||
&checkout.payment_intent_id,
|
||||
)
|
||||
.await?;
|
||||
if !checkout.referral_invite_id.is_empty() {
|
||||
mark_referral_invite_used(
|
||||
state,
|
||||
&checkout.referral_invite_id,
|
||||
&checkout.user_id,
|
||||
&checkout.reservation_id,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn grant_license_with_pricing_lock(
|
||||
state: &AppState,
|
||||
user_id: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let _guard = CHECKOUT_RESERVATION_LOCK.lock().await;
|
||||
let pricing_lock = acquire_pocketbase_lock(
|
||||
state,
|
||||
CHECKOUT_PRICING_LOCK_NAME,
|
||||
CHECKOUT_PRICING_LOCK_TTL_SECS,
|
||||
)
|
||||
.await?;
|
||||
let result = grant_license(state, user_id).await;
|
||||
if let Err(err) = pricing_lock.release().await {
|
||||
warn!("Failed to release checkout pricing lock: {err}");
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
pub async fn grant_license(state: &AppState, user_id: &str) -> anyhow::Result<()> {
|
||||
set_user_subscription(state, user_id, "licensed").await
|
||||
}
|
||||
|
||||
pub async fn reverse_license_for_payment_intent(
|
||||
state: &AppState,
|
||||
payment_intent_id: &str,
|
||||
reason: &str,
|
||||
refunded_amount_pence: Option<u64>,
|
||||
) -> anyhow::Result<PaymentReversalOutcome> {
|
||||
if !is_safe_stripe_session_id(payment_intent_id) {
|
||||
return Err(anyhow!("invalid Stripe payment intent id"));
|
||||
}
|
||||
if !is_safe_reversal_reason(reason) {
|
||||
return Err(anyhow!("invalid Stripe reversal reason"));
|
||||
}
|
||||
|
||||
let _guard = CHECKOUT_RESERVATION_LOCK.lock().await;
|
||||
let checkout = match find_checkout_by_payment_intent_or_checkout_session(
|
||||
state,
|
||||
payment_intent_id,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
Some(checkout) => checkout,
|
||||
None => return Ok(PaymentReversalOutcome::NoMatchingCheckout),
|
||||
};
|
||||
|
||||
let paid_amount_pence = checkout
|
||||
.paid_amount_pence
|
||||
.max(checkout.expected_total_pence);
|
||||
if let Some(refunded_amount_pence) = refunded_amount_pence {
|
||||
if refunded_amount_pence < paid_amount_pence {
|
||||
return Ok(PaymentReversalOutcome::IgnoredPartialRefund {
|
||||
user_id: checkout.user_id,
|
||||
refunded_amount_pence,
|
||||
paid_amount_pence,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if checkout.status == "reversed" {
|
||||
return Ok(PaymentReversalOutcome::AlreadyHandled {
|
||||
user_id: checkout.user_id,
|
||||
});
|
||||
}
|
||||
|
||||
if matches!(checkout.status.as_str(), "pending" | "expired" | "failed") {
|
||||
mark_checkout_reversed(state, &checkout.id, reason, payment_intent_id).await?;
|
||||
return Ok(PaymentReversalOutcome::Applied {
|
||||
user_id: checkout.user_id,
|
||||
});
|
||||
}
|
||||
|
||||
if checkout.status != "completed" {
|
||||
return Ok(PaymentReversalOutcome::NotReversible {
|
||||
user_id: checkout.user_id,
|
||||
status: checkout.status,
|
||||
});
|
||||
}
|
||||
|
||||
let has_other_license = has_other_completed_checkout_for_user(
|
||||
state,
|
||||
&checkout.user_id,
|
||||
&checkout.id,
|
||||
payment_intent_id,
|
||||
)
|
||||
.await?;
|
||||
if !has_other_license {
|
||||
revoke_license(state, &checkout.user_id).await?;
|
||||
}
|
||||
mark_checkout_reversed(state, &checkout.id, reason, payment_intent_id).await?;
|
||||
|
||||
Ok(PaymentReversalOutcome::Applied {
|
||||
user_id: checkout.user_id,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn reinstate_license_for_payment_intent(
|
||||
state: &AppState,
|
||||
payment_intent_id: &str,
|
||||
reason: &str,
|
||||
) -> anyhow::Result<PaymentReinstatementOutcome> {
|
||||
if !is_safe_stripe_session_id(payment_intent_id) {
|
||||
return Err(anyhow!("invalid Stripe payment intent id"));
|
||||
}
|
||||
if !is_safe_reversal_reason(reason) {
|
||||
return Err(anyhow!("invalid Stripe reinstatement reason"));
|
||||
}
|
||||
|
||||
let _guard = CHECKOUT_RESERVATION_LOCK.lock().await;
|
||||
let checkout = match find_checkout_by_payment_intent_or_checkout_session(
|
||||
state,
|
||||
payment_intent_id,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
Some(checkout) => checkout,
|
||||
None => return Ok(PaymentReinstatementOutcome::NoMatchingCheckout),
|
||||
};
|
||||
|
||||
if checkout.status == "completed" {
|
||||
return Ok(PaymentReinstatementOutcome::AlreadyHandled {
|
||||
user_id: checkout.user_id,
|
||||
});
|
||||
}
|
||||
|
||||
if checkout.status != "reversed" {
|
||||
return Ok(PaymentReinstatementOutcome::Ignored {
|
||||
user_id: checkout.user_id,
|
||||
reason: format!("checkout status is {}", checkout.status),
|
||||
});
|
||||
}
|
||||
|
||||
if !checkout.reversal_reason.starts_with("charge.dispute.") {
|
||||
return Ok(PaymentReinstatementOutcome::Ignored {
|
||||
user_id: checkout.user_id,
|
||||
reason: format!("checkout was reversed by {}", checkout.reversal_reason),
|
||||
});
|
||||
}
|
||||
|
||||
grant_license(state, &checkout.user_id).await?;
|
||||
mark_checkout_reinstated(state, &checkout.id, reason).await?;
|
||||
|
||||
Ok(PaymentReinstatementOutcome::Applied {
|
||||
user_id: checkout.user_id,
|
||||
})
|
||||
}
|
||||
|
||||
async fn revoke_license(state: &AppState, user_id: &str) -> anyhow::Result<()> {
|
||||
set_user_subscription(state, user_id, "free").await
|
||||
}
|
||||
|
||||
async fn set_user_subscription(
|
||||
state: &AppState,
|
||||
user_id: &str,
|
||||
subscription: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({ "subscription": subscription }))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp)
|
||||
.await
|
||||
.context("PocketBase license update failed")?;
|
||||
|
||||
state.token_cache.invalidate_by_user_id(user_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn expected_total_for_checkout(
|
||||
amount_pence: u64,
|
||||
discount_coupon_id: Option<&str>,
|
||||
) -> u64 {
|
||||
if discount_coupon_id.is_some_and(|id| !id.is_empty()) {
|
||||
return ((amount_pence * (100 - REFERRAL_DISCOUNT_PERCENT)) / 100).max(1);
|
||||
}
|
||||
amount_pence
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn expected_total_for_referral_discount_rounds_down_like_stripe_amount_math() {
|
||||
assert_eq!(expected_total_for_checkout(999, Some("coupon_30")), 699);
|
||||
assert_eq!(expected_total_for_checkout(1, Some("coupon_30")), 1);
|
||||
assert_eq!(expected_total_for_checkout(999, None), 999);
|
||||
}
|
||||
}
|
||||
133
server-rs/src/checkout_sessions/mod.rs
Normal file
133
server-rs/src/checkout_sessions/mod.rs
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
//! Checkout sessions: Stripe-backed lifetime-license purchases reserved and
|
||||
//! recorded in the PocketBase `checkout_sessions` collection.
|
||||
//!
|
||||
//! Split by concern:
|
||||
//! - [`lifecycle`]: the session state machine (start, verify, complete,
|
||||
//! reverse, reinstate) and license granting
|
||||
//! - [`records`]: PocketBase `checkout_sessions` record handling
|
||||
//! - [`referral`]: referral invite reservation/consumption bookkeeping
|
||||
//! - [`stripe`]: Stripe API interaction (sessions, coupons, lookups)
|
||||
|
||||
mod lifecycle;
|
||||
mod records;
|
||||
mod referral;
|
||||
mod stripe;
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
pub use lifecycle::{
|
||||
complete_verified_checkout, grant_license_with_pricing_lock,
|
||||
reinstate_license_for_payment_intent, reverse_license_for_payment_intent,
|
||||
start_license_checkout, verify_checkout_completion,
|
||||
};
|
||||
pub use referral::active_referral_checkout_user;
|
||||
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use anyhow::anyhow;
|
||||
use serde_json::Value;
|
||||
|
||||
pub const CHECKOUT_CURRENCY: &str = "gbp";
|
||||
|
||||
const CHECKOUT_COLLECTION: &str = "checkout_sessions";
|
||||
const REFERRAL_DISCOUNT_PERCENT: u64 = 30;
|
||||
|
||||
pub enum CheckoutStart {
|
||||
Free,
|
||||
Stripe { url: String },
|
||||
}
|
||||
|
||||
pub enum CheckoutCompletion {
|
||||
Grant(VerifiedCheckout),
|
||||
AlreadyHandled(VerifiedCheckout),
|
||||
Rejected(String),
|
||||
}
|
||||
|
||||
pub enum PaymentReversalOutcome {
|
||||
Applied {
|
||||
user_id: String,
|
||||
},
|
||||
AlreadyHandled {
|
||||
user_id: String,
|
||||
},
|
||||
IgnoredPartialRefund {
|
||||
user_id: String,
|
||||
refunded_amount_pence: u64,
|
||||
paid_amount_pence: u64,
|
||||
},
|
||||
NoMatchingCheckout,
|
||||
NotReversible {
|
||||
user_id: String,
|
||||
status: String,
|
||||
},
|
||||
}
|
||||
|
||||
pub enum PaymentReinstatementOutcome {
|
||||
Applied { user_id: String },
|
||||
AlreadyHandled { user_id: String },
|
||||
Ignored { user_id: String, reason: String },
|
||||
NoMatchingCheckout,
|
||||
}
|
||||
|
||||
pub struct VerifiedCheckout {
|
||||
pub reservation_id: String,
|
||||
pub user_id: String,
|
||||
pub stripe_session_id: String,
|
||||
pub payment_intent_id: String,
|
||||
pub paid_amount_pence: u64,
|
||||
pub referral_invite_id: String,
|
||||
}
|
||||
|
||||
pub fn now_unix_secs() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
fn number_field(value: &Value, field: &str) -> Option<u64> {
|
||||
value[field].as_u64().or_else(|| {
|
||||
value[field]
|
||||
.as_f64()
|
||||
.filter(|n| n.is_finite() && *n >= 0.0 && n.fract() == 0.0)
|
||||
.map(|n| n as u64)
|
||||
})
|
||||
}
|
||||
|
||||
fn is_safe_stripe_session_id(id: &str) -> bool {
|
||||
!id.is_empty()
|
||||
&& id.len() <= 128
|
||||
&& id
|
||||
.bytes()
|
||||
.all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'-')
|
||||
}
|
||||
|
||||
fn is_safe_pocketbase_id(id: &str) -> bool {
|
||||
!id.is_empty() && id.len() <= 32 && id.bytes().all(|b| b.is_ascii_alphanumeric())
|
||||
}
|
||||
|
||||
fn is_safe_reversal_reason(reason: &str) -> bool {
|
||||
!reason.is_empty()
|
||||
&& reason.len() <= 128
|
||||
&& reason
|
||||
.bytes()
|
||||
.all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'-' || b == b'.')
|
||||
}
|
||||
|
||||
async fn ensure_success(resp: reqwest::Response) -> anyhow::Result<()> {
|
||||
if resp.status().is_success() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
Err(anyhow!("upstream returned {status}: {text}"))
|
||||
}
|
||||
|
||||
async fn ensure_success_ref(resp: &reqwest::Response) -> anyhow::Result<()> {
|
||||
if resp.status().is_success() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
Err(anyhow!("upstream returned {}", resp.status()))
|
||||
}
|
||||
564
server-rs/src/checkout_sessions/records.rs
Normal file
564
server-rs/src/checkout_sessions/records.rs
Normal file
|
|
@ -0,0 +1,564 @@
|
|||
//! PocketBase `checkout_sessions` record handling: creating reservations,
|
||||
//! status transitions, and lookups by Stripe session / payment intent.
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use serde_json::Value;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::pocketbase::get_superuser_token;
|
||||
use crate::state::AppState;
|
||||
|
||||
use super::referral::release_referral_invite_reservation;
|
||||
use super::stripe::fetch_stripe_checkout_session_id_for_payment_intent;
|
||||
use super::{
|
||||
ensure_success, ensure_success_ref, is_safe_pocketbase_id, is_safe_stripe_session_id,
|
||||
now_unix_secs, number_field, CHECKOUT_COLLECTION,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(super) struct PendingCheckout {
|
||||
pub(super) id: String,
|
||||
pub(super) user_id: String,
|
||||
pub(super) stripe_session_id: String,
|
||||
pub(super) checkout_url: String,
|
||||
pub(super) amount_pence: u64,
|
||||
pub(super) expected_total_pence: u64,
|
||||
pub(super) currency: String,
|
||||
pub(super) referral_invite_id: String,
|
||||
pub(super) status: String,
|
||||
pub(super) payment_intent_id: String,
|
||||
pub(super) paid_amount_pence: u64,
|
||||
pub(super) reversal_reason: String,
|
||||
}
|
||||
|
||||
pub async fn mark_checkout_completed(
|
||||
state: &AppState,
|
||||
reservation_id: &str,
|
||||
paid_amount_pence: u64,
|
||||
payment_intent_id: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
if !is_safe_stripe_session_id(payment_intent_id) {
|
||||
return Err(anyhow!("invalid Stripe payment intent id"));
|
||||
}
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"status": "completed",
|
||||
"paid_amount_pence": paid_amount_pence,
|
||||
"completed_at_unix": now_unix_secs().to_string(),
|
||||
"stripe_payment_intent_id": payment_intent_id,
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp)
|
||||
.await
|
||||
.context("PocketBase checkout completion update failed")
|
||||
}
|
||||
|
||||
pub(super) async fn count_active_pending_checkouts(
|
||||
state: &AppState,
|
||||
now: u64,
|
||||
) -> anyhow::Result<u64> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let filter = format!("status=\"pending\" && expires_at_unix>={now}");
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
let body: Value = resp.json().await?;
|
||||
Ok(body["totalItems"].as_u64().unwrap_or(0))
|
||||
}
|
||||
|
||||
pub(super) async fn find_active_checkout_for_user(
|
||||
state: &AppState,
|
||||
user_id: &str,
|
||||
discount_coupon_id: &str,
|
||||
referral_invite_id: &str,
|
||||
now: u64,
|
||||
) -> anyhow::Result<Option<PendingCheckout>> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let filter = active_checkout_filter(user_id, discount_coupon_id, referral_invite_id, now)?;
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
let body: Value = resp.json().await?;
|
||||
let item = body["items"]
|
||||
.as_array()
|
||||
.and_then(|items| items.first())
|
||||
.cloned();
|
||||
|
||||
item.map(parse_pending_checkout).transpose()
|
||||
}
|
||||
|
||||
fn active_checkout_filter(
|
||||
user_id: &str,
|
||||
discount_coupon_id: &str,
|
||||
referral_invite_id: &str,
|
||||
now: u64,
|
||||
) -> anyhow::Result<String> {
|
||||
if !is_safe_pocketbase_id(user_id) {
|
||||
return Err(anyhow!("invalid PocketBase user id"));
|
||||
}
|
||||
if !discount_coupon_id.is_empty() && !is_safe_stripe_session_id(discount_coupon_id) {
|
||||
return Err(anyhow!("invalid Stripe coupon id"));
|
||||
}
|
||||
if !referral_invite_id.is_empty() && !is_safe_pocketbase_id(referral_invite_id) {
|
||||
return Err(anyhow!("invalid PocketBase referral invite id"));
|
||||
}
|
||||
|
||||
Ok(format!(
|
||||
"status=\"pending\" && expires_at_unix>={now} && user=\"{user_id}\" && discount_coupon_id=\"{discount_coupon_id}\" && referral_invite_id=\"{referral_invite_id}\""
|
||||
))
|
||||
}
|
||||
|
||||
pub(super) async fn expire_stale_pending_checkouts(
|
||||
state: &AppState,
|
||||
now: u64,
|
||||
) -> anyhow::Result<()> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let filter = format!("status=\"pending\" && expires_at_unix<{now}");
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=50",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
let body: Value = resp.json().await?;
|
||||
let Some(items) = body["items"].as_array() else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
for item in items {
|
||||
let Some(id) = item["id"].as_str() else {
|
||||
continue;
|
||||
};
|
||||
if let Err(err) = mark_checkout_status(state, id, "expired").await {
|
||||
warn!(
|
||||
reservation_id = id,
|
||||
"Failed to expire checkout reservation: {err}"
|
||||
);
|
||||
}
|
||||
if let Some(invite_id) = item["referral_invite_id"]
|
||||
.as_str()
|
||||
.filter(|invite_id| !invite_id.is_empty())
|
||||
{
|
||||
if let Err(err) = release_referral_invite_reservation(state, invite_id, id).await {
|
||||
warn!(
|
||||
reservation_id = id,
|
||||
referral_invite_id = invite_id,
|
||||
"Failed to release expired referral invite reservation: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) struct PendingCheckoutInput<'a> {
|
||||
pub(super) user_id: &'a str,
|
||||
pub(super) amount_pence: u64,
|
||||
pub(super) expected_total_pence: u64,
|
||||
pub(super) currency: &'a str,
|
||||
pub(super) discount_coupon_id: &'a str,
|
||||
pub(super) referral_invite_id: &'a str,
|
||||
pub(super) expires_at_unix: u64,
|
||||
}
|
||||
|
||||
pub(super) async fn create_pending_checkout(
|
||||
state: &AppState,
|
||||
input: PendingCheckoutInput<'_>,
|
||||
) -> anyhow::Result<String> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records");
|
||||
let resp = state
|
||||
.http_client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"user": input.user_id,
|
||||
"stripe_session_id": "",
|
||||
"stripe_payment_intent_id": "",
|
||||
"checkout_url": "",
|
||||
"amount_pence": input.amount_pence,
|
||||
"expected_total_pence": input.expected_total_pence,
|
||||
"currency": input.currency,
|
||||
"discount_coupon_id": input.discount_coupon_id,
|
||||
"referral_invite_id": input.referral_invite_id,
|
||||
"status": "pending",
|
||||
"expires_at_unix": input.expires_at_unix,
|
||||
"paid_amount_pence": 0,
|
||||
"completed_at_unix": "",
|
||||
"reversal_reason": "",
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
let body: Value = resp.json().await?;
|
||||
body["id"]
|
||||
.as_str()
|
||||
.map(str::to_string)
|
||||
.ok_or_else(|| anyhow!("PocketBase checkout reservation missing id"))
|
||||
}
|
||||
|
||||
pub(super) async fn attach_stripe_session(
|
||||
state: &AppState,
|
||||
reservation_id: &str,
|
||||
stripe_session_id: &str,
|
||||
checkout_url: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"stripe_session_id": stripe_session_id,
|
||||
"checkout_url": checkout_url,
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp)
|
||||
.await
|
||||
.context("PocketBase checkout session attach failed")
|
||||
}
|
||||
|
||||
pub(super) async fn mark_checkout_status(
|
||||
state: &AppState,
|
||||
reservation_id: &str,
|
||||
status: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({ "status": status }))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp)
|
||||
.await
|
||||
.with_context(|| format!("PocketBase checkout status update failed for {reservation_id}"))
|
||||
}
|
||||
|
||||
pub(super) async fn mark_checkout_reversed(
|
||||
state: &AppState,
|
||||
reservation_id: &str,
|
||||
reason: &str,
|
||||
payment_intent_id: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"status": "reversed",
|
||||
"reversal_reason": reason,
|
||||
"stripe_payment_intent_id": payment_intent_id,
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp)
|
||||
.await
|
||||
.with_context(|| format!("PocketBase checkout reversal update failed for {reservation_id}"))
|
||||
}
|
||||
|
||||
pub(super) async fn mark_checkout_reinstated(
|
||||
state: &AppState,
|
||||
reservation_id: &str,
|
||||
_reason: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"status": "completed",
|
||||
"reversal_reason": "",
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp).await.with_context(|| {
|
||||
format!("PocketBase checkout reinstatement update failed for {reservation_id}")
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) async fn find_checkout_by_stripe_session(
|
||||
state: &AppState,
|
||||
stripe_session_id: &str,
|
||||
) -> anyhow::Result<Option<PendingCheckout>> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let filter = format!("stripe_session_id=\"{}\"", stripe_session_id);
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
let body: Value = resp.json().await?;
|
||||
let item = body["items"]
|
||||
.as_array()
|
||||
.and_then(|items| items.first())
|
||||
.cloned();
|
||||
|
||||
item.map(parse_pending_checkout).transpose()
|
||||
}
|
||||
|
||||
async fn find_checkout_by_payment_intent(
|
||||
state: &AppState,
|
||||
payment_intent_id: &str,
|
||||
) -> anyhow::Result<Option<PendingCheckout>> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let filter = format!("stripe_payment_intent_id=\"{}\"", payment_intent_id);
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
let body: Value = resp.json().await?;
|
||||
let item = body["items"]
|
||||
.as_array()
|
||||
.and_then(|items| items.first())
|
||||
.cloned();
|
||||
|
||||
item.map(parse_pending_checkout).transpose()
|
||||
}
|
||||
|
||||
pub(super) async fn find_checkout_by_payment_intent_or_checkout_session(
|
||||
state: &AppState,
|
||||
payment_intent_id: &str,
|
||||
) -> anyhow::Result<Option<PendingCheckout>> {
|
||||
if let Some(checkout) = find_checkout_by_payment_intent(state, payment_intent_id).await? {
|
||||
return Ok(Some(checkout));
|
||||
}
|
||||
|
||||
let Some(session_id) =
|
||||
fetch_stripe_checkout_session_id_for_payment_intent(state, payment_intent_id).await?
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let Some(mut checkout) = find_checkout_by_stripe_session(state, &session_id).await? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
if checkout.payment_intent_id.is_empty() {
|
||||
attach_payment_intent_to_checkout(state, &checkout.id, payment_intent_id).await?;
|
||||
checkout.payment_intent_id = payment_intent_id.to_string();
|
||||
} else if checkout.payment_intent_id != payment_intent_id {
|
||||
mark_checkout_status(state, &checkout.id, "invalid").await?;
|
||||
return Err(anyhow!(
|
||||
"checkout reservation payment intent changed before reversal"
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Some(checkout))
|
||||
}
|
||||
|
||||
async fn attach_payment_intent_to_checkout(
|
||||
state: &AppState,
|
||||
reservation_id: &str,
|
||||
payment_intent_id: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"stripe_payment_intent_id": payment_intent_id,
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp)
|
||||
.await
|
||||
.context("PocketBase checkout payment intent attach failed")
|
||||
}
|
||||
|
||||
pub(super) async fn has_other_completed_checkout_for_user(
|
||||
state: &AppState,
|
||||
user_id: &str,
|
||||
reservation_id: &str,
|
||||
payment_intent_id: &str,
|
||||
) -> anyhow::Result<bool> {
|
||||
if !is_safe_pocketbase_id(user_id) || !is_safe_pocketbase_id(reservation_id) {
|
||||
return Err(anyhow!("invalid PocketBase id"));
|
||||
}
|
||||
if !is_safe_stripe_session_id(payment_intent_id) {
|
||||
return Err(anyhow!("invalid Stripe payment intent id"));
|
||||
}
|
||||
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let filter = format!("user=\"{user_id}\" && status=\"completed\"");
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=50",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
let body: Value = resp.json().await?;
|
||||
let Some(items) = body["items"].as_array() else {
|
||||
return Ok(false);
|
||||
};
|
||||
|
||||
Ok(items.iter().any(|item| {
|
||||
let other_id = item["id"].as_str().unwrap_or_default();
|
||||
let other_payment_intent = item["stripe_payment_intent_id"]
|
||||
.as_str()
|
||||
.unwrap_or_default();
|
||||
other_id != reservation_id && other_payment_intent != payment_intent_id
|
||||
}))
|
||||
}
|
||||
|
||||
fn parse_pending_checkout(item: Value) -> anyhow::Result<PendingCheckout> {
|
||||
Ok(PendingCheckout {
|
||||
id: item["id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow!("checkout reservation missing id"))?
|
||||
.to_string(),
|
||||
user_id: item["user"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow!("checkout reservation missing user"))?
|
||||
.to_string(),
|
||||
stripe_session_id: item["stripe_session_id"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
checkout_url: item["checkout_url"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
amount_pence: number_field(&item, "amount_pence")
|
||||
.ok_or_else(|| anyhow!("checkout reservation missing amount_pence"))?,
|
||||
expected_total_pence: number_field(&item, "expected_total_pence")
|
||||
.ok_or_else(|| anyhow!("checkout reservation missing expected_total_pence"))?,
|
||||
currency: item["currency"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_ascii_lowercase(),
|
||||
referral_invite_id: item["referral_invite_id"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
status: item["status"].as_str().unwrap_or_default().to_string(),
|
||||
payment_intent_id: item["stripe_payment_intent_id"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
paid_amount_pence: number_field(&item, "paid_amount_pence").unwrap_or(0),
|
||||
reversal_reason: item["reversal_reason"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn active_checkout_filter_includes_empty_context_for_standard_checkout() {
|
||||
let filter = active_checkout_filter("abc123", "", "", 42).unwrap();
|
||||
assert_eq!(
|
||||
filter,
|
||||
"status=\"pending\" && expires_at_unix>=42 && user=\"abc123\" && discount_coupon_id=\"\" && referral_invite_id=\"\""
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn active_checkout_filter_includes_referral_context() {
|
||||
let filter = active_checkout_filter("user123", "coupon_30", "invite123", 99).unwrap();
|
||||
assert_eq!(
|
||||
filter,
|
||||
"status=\"pending\" && expires_at_unix>=99 && user=\"user123\" && discount_coupon_id=\"coupon_30\" && referral_invite_id=\"invite123\""
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn active_checkout_filter_rejects_unsafe_context_values() {
|
||||
assert!(active_checkout_filter("user123", "bad\"coupon", "", 1).is_err());
|
||||
assert!(active_checkout_filter("user123", "", "bad-invite", 1).is_err());
|
||||
assert!(active_checkout_filter("bad-user", "", "", 1).is_err());
|
||||
}
|
||||
}
|
||||
312
server-rs/src/checkout_sessions/referral.rs
Normal file
312
server-rs/src/checkout_sessions/referral.rs
Normal file
|
|
@ -0,0 +1,312 @@
|
|||
//! Referral invite bookkeeping: reserving an invite for an in-flight checkout,
|
||||
//! releasing the reservation on failure/expiry, and recording final usage when
|
||||
//! a verified payment completes.
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use serde_json::Value;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::pocketbase::get_superuser_token;
|
||||
use crate::state::AppState;
|
||||
|
||||
use super::{
|
||||
ensure_success, ensure_success_ref, is_safe_pocketbase_id, now_unix_secs, number_field,
|
||||
CHECKOUT_COLLECTION,
|
||||
};
|
||||
|
||||
pub async fn mark_referral_invite_used(
|
||||
state: &AppState,
|
||||
invite_id: &str,
|
||||
user_id: &str,
|
||||
reservation_id: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
if invite_id.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
if !is_safe_pocketbase_id(invite_id)
|
||||
|| !is_safe_pocketbase_id(user_id)
|
||||
|| !is_safe_pocketbase_id(reservation_id)
|
||||
{
|
||||
return Err(anyhow!("invalid PocketBase id"));
|
||||
}
|
||||
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let invite = fetch_invite_record(state, pb_url, &token, invite_id).await?;
|
||||
|
||||
// A verified Stripe payment must not lose entitlement just because local
|
||||
// invite reservation bookkeeping expired or moved before webhook delivery.
|
||||
match referral_invite_completion_action(&invite, user_id, reservation_id) {
|
||||
ReferralInviteCompletionAction::AlreadyRecorded => return Ok(()),
|
||||
ReferralInviteCompletionAction::AlreadyUsedByAnother => {
|
||||
warn!(
|
||||
invite_id,
|
||||
user_id,
|
||||
existing_used_by = invite["used_by_id"].as_str().unwrap_or_default(),
|
||||
"Referral invite was already used by another account; preserving verified checkout entitlement"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
ReferralInviteCompletionAction::Record {
|
||||
reservation_reassigned,
|
||||
} => {
|
||||
if reservation_reassigned {
|
||||
warn!(
|
||||
invite_id,
|
||||
user_id,
|
||||
reservation_id,
|
||||
reserved_by_id = invite["reserved_by_id"].as_str().unwrap_or_default(),
|
||||
reserved_checkout_id = invite["reserved_checkout_id"].as_str().unwrap_or_default(),
|
||||
"Referral invite reservation moved before webhook completion; verified checkout will consume it"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let url = format!("{pb_url}/api/collections/invites/records/{invite_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"used_by_id": user_id,
|
||||
"used_at": now_unix_secs().to_string(),
|
||||
"reserved_by_id": "",
|
||||
"reserved_checkout_id": "",
|
||||
"reserved_until_unix": 0,
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp)
|
||||
.await
|
||||
.context("PocketBase invite usage update failed")
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
enum ReferralInviteCompletionAction {
|
||||
AlreadyRecorded,
|
||||
AlreadyUsedByAnother,
|
||||
Record { reservation_reassigned: bool },
|
||||
}
|
||||
|
||||
fn referral_invite_completion_action(
|
||||
invite: &Value,
|
||||
user_id: &str,
|
||||
reservation_id: &str,
|
||||
) -> ReferralInviteCompletionAction {
|
||||
let existing_used_by = invite["used_by_id"].as_str().unwrap_or_default();
|
||||
if existing_used_by == user_id {
|
||||
return ReferralInviteCompletionAction::AlreadyRecorded;
|
||||
}
|
||||
if !existing_used_by.is_empty() {
|
||||
return ReferralInviteCompletionAction::AlreadyUsedByAnother;
|
||||
}
|
||||
|
||||
let reserved_by_id = invite["reserved_by_id"].as_str().unwrap_or_default();
|
||||
let reserved_checkout_id = invite["reserved_checkout_id"].as_str().unwrap_or_default();
|
||||
let reservation_reassigned = (!reserved_by_id.is_empty() && reserved_by_id != user_id)
|
||||
|| (!reserved_checkout_id.is_empty() && reserved_checkout_id != reservation_id);
|
||||
|
||||
ReferralInviteCompletionAction::Record {
|
||||
reservation_reassigned,
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch_invite_record(
|
||||
state: &AppState,
|
||||
pb_url: &str,
|
||||
token: &str,
|
||||
invite_id: &str,
|
||||
) -> anyhow::Result<Value> {
|
||||
let url = format!("{pb_url}/api/collections/invites/records/{invite_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
resp.json().await.map_err(Into::into)
|
||||
}
|
||||
|
||||
pub(super) async fn reserve_referral_invite(
|
||||
state: &AppState,
|
||||
invite_id: &str,
|
||||
user_id: &str,
|
||||
reservation_id: &str,
|
||||
reserved_until_unix: u64,
|
||||
) -> anyhow::Result<()> {
|
||||
if !is_safe_pocketbase_id(invite_id)
|
||||
|| !is_safe_pocketbase_id(user_id)
|
||||
|| !is_safe_pocketbase_id(reservation_id)
|
||||
{
|
||||
return Err(anyhow!("invalid PocketBase id"));
|
||||
}
|
||||
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let invite = fetch_invite_record(state, pb_url, &token, invite_id).await?;
|
||||
let used_by = invite["used_by_id"].as_str().unwrap_or_default();
|
||||
if !used_by.is_empty() {
|
||||
return Err(anyhow!("referral invite already used"));
|
||||
}
|
||||
|
||||
let now = now_unix_secs();
|
||||
let reserved_by_id = invite["reserved_by_id"].as_str().unwrap_or_default();
|
||||
let reserved_checkout_id = invite["reserved_checkout_id"].as_str().unwrap_or_default();
|
||||
let existing_reserved_until = number_field(&invite, "reserved_until_unix").unwrap_or(0);
|
||||
let reservation_is_live = existing_reserved_until >= now;
|
||||
if reservation_is_live
|
||||
&& !reserved_checkout_id.is_empty()
|
||||
&& reserved_checkout_id != reservation_id
|
||||
{
|
||||
return Err(anyhow!("referral invite already has an active checkout"));
|
||||
}
|
||||
if reservation_is_live && !reserved_by_id.is_empty() && reserved_by_id != user_id {
|
||||
return Err(anyhow!("referral invite reserved by another account"));
|
||||
}
|
||||
|
||||
let url = format!("{pb_url}/api/collections/invites/records/{invite_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"reserved_by_id": user_id,
|
||||
"reserved_checkout_id": reservation_id,
|
||||
"reserved_until_unix": reserved_until_unix,
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp)
|
||||
.await
|
||||
.context("PocketBase invite reservation update failed")
|
||||
}
|
||||
|
||||
pub(super) async fn release_referral_invite_reservation(
|
||||
state: &AppState,
|
||||
invite_id: &str,
|
||||
reservation_id: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
if !is_safe_pocketbase_id(invite_id) || !is_safe_pocketbase_id(reservation_id) {
|
||||
return Err(anyhow!("invalid PocketBase id"));
|
||||
}
|
||||
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let invite = fetch_invite_record(state, pb_url, &token, invite_id).await?;
|
||||
let used_by = invite["used_by_id"].as_str().unwrap_or_default();
|
||||
let reserved_checkout_id = invite["reserved_checkout_id"].as_str().unwrap_or_default();
|
||||
if !used_by.is_empty() || reserved_checkout_id != reservation_id {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let url = format!("{pb_url}/api/collections/invites/records/{invite_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"reserved_by_id": "",
|
||||
"reserved_checkout_id": "",
|
||||
"reserved_until_unix": 0,
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp)
|
||||
.await
|
||||
.context("PocketBase invite reservation release failed")
|
||||
}
|
||||
|
||||
pub async fn active_referral_checkout_user(
|
||||
state: &AppState,
|
||||
invite_id: &str,
|
||||
) -> anyhow::Result<Option<String>> {
|
||||
if !is_safe_pocketbase_id(invite_id) {
|
||||
return Err(anyhow!("invalid PocketBase invite id"));
|
||||
}
|
||||
|
||||
let now = now_unix_secs();
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let filter = format!(
|
||||
"status=\"pending\" && expires_at_unix>={now} && referral_invite_id=\"{}\"",
|
||||
invite_id
|
||||
);
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
let body: Value = resp.json().await?;
|
||||
Ok(body["items"]
|
||||
.as_array()
|
||||
.and_then(|items| items.first())
|
||||
.and_then(|item| item["user"].as_str())
|
||||
.map(str::to_string))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn referral_invite_completion_records_available_invite() {
|
||||
let invite = serde_json::json!({
|
||||
"used_by_id": "",
|
||||
"reserved_by_id": "",
|
||||
"reserved_checkout_id": "",
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
referral_invite_completion_action(&invite, "user123", "checkout123"),
|
||||
ReferralInviteCompletionAction::Record {
|
||||
reservation_reassigned: false
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn referral_invite_completion_records_reassigned_reservation() {
|
||||
let invite = serde_json::json!({
|
||||
"used_by_id": "",
|
||||
"reserved_by_id": "otheruser",
|
||||
"reserved_checkout_id": "othercheckout",
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
referral_invite_completion_action(&invite, "user123", "checkout123"),
|
||||
ReferralInviteCompletionAction::Record {
|
||||
reservation_reassigned: true
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn referral_invite_completion_detects_existing_usage() {
|
||||
let used_by_same_user = serde_json::json!({ "used_by_id": "user123" });
|
||||
let used_by_another_user = serde_json::json!({ "used_by_id": "otheruser" });
|
||||
|
||||
assert_eq!(
|
||||
referral_invite_completion_action(&used_by_same_user, "user123", "checkout123"),
|
||||
ReferralInviteCompletionAction::AlreadyRecorded
|
||||
);
|
||||
assert_eq!(
|
||||
referral_invite_completion_action(&used_by_another_user, "user123", "checkout123"),
|
||||
ReferralInviteCompletionAction::AlreadyUsedByAnother
|
||||
);
|
||||
}
|
||||
}
|
||||
175
server-rs/src/checkout_sessions/stripe.rs
Normal file
175
server-rs/src/checkout_sessions/stripe.rs
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
//! Stripe API interaction: creating checkout sessions, verifying coupon
|
||||
//! configuration, and looking up sessions by payment intent.
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::auth::PocketBaseUser;
|
||||
use crate::state::AppState;
|
||||
|
||||
use super::lifecycle::expected_total_for_checkout;
|
||||
use super::{
|
||||
ensure_success_ref, is_safe_stripe_session_id, CHECKOUT_CURRENCY, REFERRAL_DISCOUNT_PERCENT,
|
||||
};
|
||||
|
||||
const CHECKOUT_PRODUCT_NAME: &str = "Perfect Postcodes Lifetime License";
|
||||
|
||||
/// Fetch a Stripe coupon and ensure its `percent_off` matches the expected
|
||||
/// referral discount AND that it has no `amount_off` override. This blocks a
|
||||
/// misconfigured (or maliciously swapped) coupon ID from quietly granting a
|
||||
/// larger discount than the server's pricing math assumed.
|
||||
async fn verify_stripe_coupon_discount(state: &AppState, coupon_id: &str) -> anyhow::Result<()> {
|
||||
if !is_safe_stripe_session_id(coupon_id) {
|
||||
return Err(anyhow!("unsafe stripe coupon id"));
|
||||
}
|
||||
let url = format!(
|
||||
"https://api.stripe.com/v1/coupons/{}",
|
||||
urlencoding::encode(coupon_id)
|
||||
);
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.basic_auth(&state.stripe_secret_key, None::<&str>)
|
||||
.send()
|
||||
.await
|
||||
.context("Stripe coupon fetch failed")?;
|
||||
ensure_success_ref(&resp)
|
||||
.await
|
||||
.context("Stripe coupon fetch returned error")?;
|
||||
let body: Value = resp
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Stripe coupon response")?;
|
||||
|
||||
let valid = body["valid"].as_bool().unwrap_or(false);
|
||||
if !valid {
|
||||
return Err(anyhow!("stripe coupon is not valid"));
|
||||
}
|
||||
if body["amount_off"].is_number() {
|
||||
return Err(anyhow!(
|
||||
"stripe coupon uses amount_off; only percent_off is permitted"
|
||||
));
|
||||
}
|
||||
let percent_off = body["percent_off"]
|
||||
.as_f64()
|
||||
.ok_or_else(|| anyhow!("stripe coupon missing percent_off"))?;
|
||||
if percent_off.is_nan() || (percent_off - REFERRAL_DISCOUNT_PERCENT as f64).abs() > 0.001 {
|
||||
return Err(anyhow!(
|
||||
"stripe coupon percent_off ({percent_off}) does not match expected {REFERRAL_DISCOUNT_PERCENT}"
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(super) async fn create_stripe_session(
|
||||
state: &AppState,
|
||||
user: &PocketBaseUser,
|
||||
reservation_id: &str,
|
||||
price_pence: u64,
|
||||
success_url: &str,
|
||||
cancel_url: &str,
|
||||
expires_at_unix: u64,
|
||||
discount_coupon_id: Option<&str>,
|
||||
) -> anyhow::Result<(String, String)> {
|
||||
if let Some(coupon_id) = discount_coupon_id.filter(|id| !id.is_empty()) {
|
||||
verify_stripe_coupon_discount(state, coupon_id).await?;
|
||||
}
|
||||
|
||||
let mut form_params = vec![
|
||||
("mode", "payment".to_string()),
|
||||
("payment_method_types[0]", "card".to_string()),
|
||||
(
|
||||
"line_items[0][price_data][unit_amount]",
|
||||
price_pence.to_string(),
|
||||
),
|
||||
(
|
||||
"line_items[0][price_data][currency]",
|
||||
CHECKOUT_CURRENCY.to_string(),
|
||||
),
|
||||
(
|
||||
"line_items[0][price_data][product_data][name]",
|
||||
CHECKOUT_PRODUCT_NAME.to_string(),
|
||||
),
|
||||
("line_items[0][quantity]", "1".to_string()),
|
||||
("success_url", success_url.to_string()),
|
||||
("cancel_url", cancel_url.to_string()),
|
||||
("expires_at", expires_at_unix.to_string()),
|
||||
("client_reference_id", user.id.clone()),
|
||||
("customer_email", user.email.clone()),
|
||||
("metadata[pending_checkout_id]", reservation_id.to_string()),
|
||||
("metadata[expected_amount_pence]", price_pence.to_string()),
|
||||
(
|
||||
"metadata[expected_total_pence]",
|
||||
expected_total_for_checkout(price_pence, discount_coupon_id).to_string(),
|
||||
),
|
||||
("metadata[expected_currency]", CHECKOUT_CURRENCY.to_string()),
|
||||
];
|
||||
|
||||
if let Some(coupon_id) = discount_coupon_id.filter(|id| !id.is_empty()) {
|
||||
form_params.push(("discounts[0][coupon]", coupon_id.to_string()));
|
||||
form_params.push(("metadata[discount_coupon_id]", coupon_id.to_string()));
|
||||
}
|
||||
|
||||
let resp = state
|
||||
.http_client
|
||||
.post("https://api.stripe.com/v1/checkout/sessions")
|
||||
.basic_auth(&state.stripe_secret_key, None::<&str>)
|
||||
.form(&form_params)
|
||||
.send()
|
||||
.await
|
||||
.context("Stripe checkout request failed")?;
|
||||
|
||||
ensure_success_ref(&resp)
|
||||
.await
|
||||
.context("Stripe checkout failed")?;
|
||||
|
||||
let body: Value = resp
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Stripe response")?;
|
||||
let session_id = body["id"]
|
||||
.as_str()
|
||||
.filter(|id| is_safe_stripe_session_id(id))
|
||||
.map(str::to_string)
|
||||
.ok_or_else(|| anyhow!("Stripe session missing valid id"))?;
|
||||
let url = body["url"]
|
||||
.as_str()
|
||||
.map(str::to_string)
|
||||
.filter(|url| !url.is_empty())
|
||||
.ok_or_else(|| anyhow!("Stripe session missing URL"))?;
|
||||
|
||||
Ok((session_id, url))
|
||||
}
|
||||
|
||||
pub(super) async fn fetch_stripe_checkout_session_id_for_payment_intent(
|
||||
state: &AppState,
|
||||
payment_intent_id: &str,
|
||||
) -> anyhow::Result<Option<String>> {
|
||||
let url = format!(
|
||||
"https://api.stripe.com/v1/checkout/sessions?payment_intent={}&limit=1",
|
||||
urlencoding::encode(payment_intent_id)
|
||||
);
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.basic_auth(&state.stripe_secret_key, None::<&str>)
|
||||
.send()
|
||||
.await
|
||||
.context("Stripe checkout session lookup failed")?;
|
||||
|
||||
ensure_success_ref(&resp)
|
||||
.await
|
||||
.context("Stripe checkout session lookup returned error")?;
|
||||
|
||||
let body: Value = resp
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Stripe checkout session lookup")?;
|
||||
Ok(body["data"]
|
||||
.as_array()
|
||||
.and_then(|items| items.first())
|
||||
.and_then(|item| item["id"].as_str())
|
||||
.filter(|id| is_safe_stripe_session_id(id))
|
||||
.map(str::to_string))
|
||||
}
|
||||
688
server-rs/src/checkout_sessions/tests.rs
Normal file
688
server-rs/src/checkout_sessions/tests.rs
Normal file
|
|
@ -0,0 +1,688 @@
|
|||
//! Integration-style tests for the money paths: Stripe webhook verification →
|
||||
//! license granting, checkout reservation bookkeeping, and invite redemption.
|
||||
//!
|
||||
//! PocketBase (an external HTTP service in production) is replaced by an
|
||||
//! in-process axum mock listening on an ephemeral local port. The mock keeps
|
||||
//! records in memory, evaluates the small PocketBase filter subset the server
|
||||
//! actually uses, and records every mutating request so tests can assert that
|
||||
//! e.g. a replayed webhook does not grant a license twice.
|
||||
//!
|
||||
//! Stripe itself is NOT mocked (its API URL is hardcoded to
|
||||
//! `https://api.stripe.com`), so tests that reach the Stripe call assert the
|
||||
//! failure-cleanup behaviour instead: the reservation is marked `failed` and
|
||||
//! referral invite reservations are released.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use axum::body::Bytes;
|
||||
use axum::extract::{Path, Query, State};
|
||||
use axum::http::{HeaderMap, HeaderValue, StatusCode};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::routing::{get, post};
|
||||
use axum::{Extension, Json, Router};
|
||||
use hmac::{Hmac, KeyInit, Mac};
|
||||
use serde_json::{json, Value};
|
||||
use sha2::Sha256;
|
||||
|
||||
use crate::auth::{OptionalUser, PocketBaseUser};
|
||||
use crate::routes::{post_redeem_invite, post_stripe_webhook};
|
||||
use crate::state::{AppState, SharedState};
|
||||
|
||||
use super::start_license_checkout;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mock PocketBase
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Default)]
|
||||
struct MockPocketBase {
|
||||
/// collection name → records (each a JSON object with an "id").
|
||||
records: Mutex<HashMap<String, Vec<Value>>>,
|
||||
next_id: AtomicUsize,
|
||||
/// Every mutating request: (method, path, body).
|
||||
log: Mutex<Vec<(String, String, Value)>>,
|
||||
}
|
||||
|
||||
impl MockPocketBase {
|
||||
fn seed(&self, collection: &str, mut record: Value) -> String {
|
||||
let id = match record["id"].as_str() {
|
||||
Some(id) if !id.is_empty() => id.to_string(),
|
||||
_ => format!("mockid{:06}", self.next_id.fetch_add(1, Ordering::SeqCst)),
|
||||
};
|
||||
record["id"] = json!(id);
|
||||
self.records
|
||||
.lock()
|
||||
.unwrap()
|
||||
.entry(collection.to_string())
|
||||
.or_default()
|
||||
.push(record);
|
||||
id
|
||||
}
|
||||
|
||||
fn record(&self, collection: &str, id: &str) -> Option<Value> {
|
||||
self.records
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get(collection)?
|
||||
.iter()
|
||||
.find(|record| record["id"].as_str() == Some(id))
|
||||
.cloned()
|
||||
}
|
||||
|
||||
fn records_in(&self, collection: &str) -> Vec<Value> {
|
||||
self.records
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get(collection)
|
||||
.cloned()
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Number of PATCHes that set the user's subscription to "licensed" —
|
||||
/// i.e. how many times a license was granted.
|
||||
fn license_grant_count(&self, user_id: &str) -> usize {
|
||||
let path = format!("/api/collections/users/records/{user_id}");
|
||||
self.log
|
||||
.lock()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.filter(|(method, request_path, body)| {
|
||||
method == "PATCH"
|
||||
&& request_path == &path
|
||||
&& body["subscription"].as_str() == Some("licensed")
|
||||
})
|
||||
.count()
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate the PocketBase filter subset used by the server:
|
||||
/// `a="x" && b>=N && c<N && (d="" || d="y")`.
|
||||
fn record_matches(record: &Value, filter: &str) -> bool {
|
||||
if filter.is_empty() {
|
||||
return true;
|
||||
}
|
||||
filter
|
||||
.split(" && ")
|
||||
.all(|clause| clause_matches(record, clause.trim()))
|
||||
}
|
||||
|
||||
fn clause_matches(record: &Value, clause: &str) -> bool {
|
||||
if let Some(inner) = clause.strip_prefix('(').and_then(|c| c.strip_suffix(')')) {
|
||||
return inner
|
||||
.split(" || ")
|
||||
.any(|alternative| clause_matches(record, alternative.trim()));
|
||||
}
|
||||
if let Some((field, value)) = clause.split_once(">=") {
|
||||
return record_number(record, field) >= value.trim().parse::<i64>().unwrap_or(i64::MAX);
|
||||
}
|
||||
if let Some((field, value)) = clause.split_once('<') {
|
||||
return record_number(record, field) < value.trim().parse::<i64>().unwrap_or(i64::MIN);
|
||||
}
|
||||
if let Some((field, value)) = clause.split_once('=') {
|
||||
let expected = value.trim().trim_matches('"');
|
||||
return record_string(record, field) == expected;
|
||||
}
|
||||
panic!("mock PocketBase cannot evaluate filter clause: {clause}");
|
||||
}
|
||||
|
||||
fn record_number(record: &Value, field: &str) -> i64 {
|
||||
let value = &record[field];
|
||||
value
|
||||
.as_i64()
|
||||
.or_else(|| value.as_f64().map(|float| float as i64))
|
||||
.or_else(|| value.as_str().and_then(|text| text.parse().ok()))
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
fn record_string(record: &Value, field: &str) -> String {
|
||||
let value = &record[field];
|
||||
value.as_str().map(str::to_string).unwrap_or_else(|| {
|
||||
if value.is_null() {
|
||||
String::new()
|
||||
} else {
|
||||
value.to_string()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn auth_handler() -> Json<Value> {
|
||||
Json(json!({ "token": "testsuperusertoken" }))
|
||||
}
|
||||
|
||||
async fn list_records(
|
||||
State(pb): State<Arc<MockPocketBase>>,
|
||||
Path(collection): Path<String>,
|
||||
Query(params): Query<HashMap<String, String>>,
|
||||
) -> Json<Value> {
|
||||
let filter = params.get("filter").map(String::as_str).unwrap_or("");
|
||||
let matching: Vec<Value> = pb
|
||||
.records
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get(&collection)
|
||||
.map(|records| {
|
||||
records
|
||||
.iter()
|
||||
.filter(|record| record_matches(record, filter))
|
||||
.cloned()
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
let total = matching.len();
|
||||
let per_page = params
|
||||
.get("perPage")
|
||||
.and_then(|raw| raw.parse::<usize>().ok())
|
||||
.unwrap_or(30);
|
||||
let items: Vec<Value> = matching.into_iter().take(per_page).collect();
|
||||
Json(json!({ "items": items, "totalItems": total }))
|
||||
}
|
||||
|
||||
async fn create_record(
|
||||
State(pb): State<Arc<MockPocketBase>>,
|
||||
Path(collection): Path<String>,
|
||||
Json(body): Json<Value>,
|
||||
) -> Response {
|
||||
pb.log.lock().unwrap().push((
|
||||
"POST".to_string(),
|
||||
format!("/api/collections/{collection}/records"),
|
||||
body.clone(),
|
||||
));
|
||||
|
||||
// Emulate the unique `name` constraint on the distributed-lock collection
|
||||
// so concurrent acquisitions conflict like they do against real PocketBase.
|
||||
if collection == "checkout_locks" {
|
||||
let exists = pb
|
||||
.records
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get(&collection)
|
||||
.is_some_and(|records| records.iter().any(|record| record["name"] == body["name"]));
|
||||
if exists {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({ "message": "name must be unique" })),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
|
||||
let id = pb.seed(&collection, body);
|
||||
Json(pb.record(&collection, &id).expect("record just created")).into_response()
|
||||
}
|
||||
|
||||
async fn get_record(
|
||||
State(pb): State<Arc<MockPocketBase>>,
|
||||
Path((collection, id)): Path<(String, String)>,
|
||||
) -> Response {
|
||||
match pb.record(&collection, &id) {
|
||||
Some(record) => Json(record).into_response(),
|
||||
None => StatusCode::NOT_FOUND.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn patch_record(
|
||||
State(pb): State<Arc<MockPocketBase>>,
|
||||
Path((collection, id)): Path<(String, String)>,
|
||||
Json(body): Json<Value>,
|
||||
) -> Response {
|
||||
pb.log.lock().unwrap().push((
|
||||
"PATCH".to_string(),
|
||||
format!("/api/collections/{collection}/records/{id}"),
|
||||
body.clone(),
|
||||
));
|
||||
|
||||
let mut records = pb.records.lock().unwrap();
|
||||
let Some(record) = records.get_mut(&collection).and_then(|list| {
|
||||
list.iter_mut()
|
||||
.find(|record| record["id"].as_str() == Some(&id))
|
||||
}) else {
|
||||
return StatusCode::NOT_FOUND.into_response();
|
||||
};
|
||||
if let (Some(target), Some(updates)) = (record.as_object_mut(), body.as_object()) {
|
||||
for (key, value) in updates {
|
||||
target.insert(key.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
Json(record.clone()).into_response()
|
||||
}
|
||||
|
||||
async fn delete_record(
|
||||
State(pb): State<Arc<MockPocketBase>>,
|
||||
Path((collection, id)): Path<(String, String)>,
|
||||
) -> Response {
|
||||
pb.log.lock().unwrap().push((
|
||||
"DELETE".to_string(),
|
||||
format!("/api/collections/{collection}/records/{id}"),
|
||||
Value::Null,
|
||||
));
|
||||
|
||||
let mut records = pb.records.lock().unwrap();
|
||||
let Some(list) = records.get_mut(&collection) else {
|
||||
return StatusCode::NOT_FOUND.into_response();
|
||||
};
|
||||
let before = list.len();
|
||||
list.retain(|record| record["id"].as_str() != Some(&id));
|
||||
if list.len() == before {
|
||||
StatusCode::NOT_FOUND.into_response()
|
||||
} else {
|
||||
StatusCode::NO_CONTENT.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
fn mock_pb_router(pb: Arc<MockPocketBase>) -> Router {
|
||||
Router::new()
|
||||
.route(
|
||||
"/api/collections/_superusers/auth-with-password",
|
||||
post(auth_handler),
|
||||
)
|
||||
.route(
|
||||
"/api/collections/{collection}/records",
|
||||
get(list_records).post(create_record),
|
||||
)
|
||||
.route(
|
||||
"/api/collections/{collection}/records/{id}",
|
||||
get(get_record).patch(patch_record).delete(delete_record),
|
||||
)
|
||||
.with_state(pb)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test harness
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct TestEnv {
|
||||
shared: Arc<SharedState>,
|
||||
pb: Arc<MockPocketBase>,
|
||||
}
|
||||
|
||||
async fn setup() -> TestEnv {
|
||||
let pb = Arc::new(MockPocketBase::default());
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("bind mock PocketBase listener");
|
||||
let addr = listener.local_addr().expect("mock PocketBase address");
|
||||
let router = mock_pb_router(pb.clone());
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, router)
|
||||
.await
|
||||
.expect("mock PocketBase serve");
|
||||
});
|
||||
|
||||
let state = AppState::for_tests(format!("http://{addr}"));
|
||||
TestEnv {
|
||||
shared: Arc::new(SharedState::new(state)),
|
||||
pb,
|
||||
}
|
||||
}
|
||||
|
||||
fn now_unix() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("clock after epoch")
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
fn test_user(id: &str) -> PocketBaseUser {
|
||||
PocketBaseUser {
|
||||
id: id.to_string(),
|
||||
email: format!("{id}@test.example"),
|
||||
is_admin: false,
|
||||
subscription: "free".to_string(),
|
||||
newsletter: false,
|
||||
can_see_listings: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn seed_user(pb: &MockPocketBase, id: &str) {
|
||||
pb.seed(
|
||||
"users",
|
||||
json!({
|
||||
"id": id,
|
||||
"email": format!("{id}@test.example"),
|
||||
"subscription": "free",
|
||||
"is_admin": false,
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
fn seed_pending_checkout(pb: &MockPocketBase, user_id: &str, session_id: &str) -> String {
|
||||
pb.seed(
|
||||
"checkout_sessions",
|
||||
json!({
|
||||
"user": user_id,
|
||||
"stripe_session_id": session_id,
|
||||
"stripe_payment_intent_id": "",
|
||||
"checkout_url": "https://checkout.stripe.test/session",
|
||||
"amount_pence": 999,
|
||||
"expected_total_pence": 999,
|
||||
"currency": "gbp",
|
||||
"discount_coupon_id": "",
|
||||
"referral_invite_id": "",
|
||||
"status": "pending",
|
||||
"expires_at_unix": now_unix() + 1800,
|
||||
"paid_amount_pence": 0,
|
||||
"completed_at_unix": "",
|
||||
"reversal_reason": "",
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn checkout_completed_event(session_id: &str, user_id: &str, amount_total: u64) -> Vec<u8> {
|
||||
serde_json::to_vec(&json!({
|
||||
"id": "evt_test_1",
|
||||
"type": "checkout.session.completed",
|
||||
"data": { "object": {
|
||||
"id": session_id,
|
||||
"payment_intent": "pi_test_1",
|
||||
"client_reference_id": user_id,
|
||||
"payment_status": "paid",
|
||||
"currency": "gbp",
|
||||
"amount_subtotal": 999,
|
||||
"amount_total": amount_total,
|
||||
}}
|
||||
}))
|
||||
.expect("event serializes")
|
||||
}
|
||||
|
||||
/// Sign a payload the way Stripe does: HMAC-SHA256 over `{timestamp}.{payload}`
|
||||
/// with the webhook secret, presented as `t=...,v1=...`.
|
||||
fn stripe_signature_header(payload: &[u8], secret: &str) -> String {
|
||||
let timestamp = now_unix();
|
||||
let mut mac =
|
||||
Hmac::<Sha256>::new_from_slice(secret.as_bytes()).expect("HMAC accepts any key length");
|
||||
mac.update(format!("{timestamp}.").as_bytes());
|
||||
mac.update(payload);
|
||||
let signature = hex::encode(mac.finalize().into_bytes());
|
||||
format!("t={timestamp},v1={signature}")
|
||||
}
|
||||
|
||||
async fn deliver_webhook(env: &TestEnv, payload: Vec<u8>, signature: Option<&str>) -> Response {
|
||||
let mut headers = HeaderMap::new();
|
||||
if let Some(signature) = signature {
|
||||
headers.insert(
|
||||
"stripe-signature",
|
||||
HeaderValue::from_str(signature).expect("signature header value"),
|
||||
);
|
||||
}
|
||||
post_stripe_webhook(State(env.shared.clone()), headers, Bytes::from(payload)).await
|
||||
}
|
||||
|
||||
async fn redeem_invite(env: &TestEnv, user: PocketBaseUser, code: &str) -> Response {
|
||||
post_redeem_invite(
|
||||
State(env.shared.clone()),
|
||||
Extension(OptionalUser(Some(user))),
|
||||
Json(serde_json::from_value(json!({ "code": code })).expect("redeem request deserializes")),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn response_json(response: Response) -> Value {
|
||||
let bytes = axum::body::to_bytes(response.into_body(), 1 << 20)
|
||||
.await
|
||||
.expect("response body");
|
||||
serde_json::from_slice(&bytes).expect("response body is JSON")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Stripe webhook → license granting
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tokio::test]
|
||||
async fn webhook_with_valid_signature_grants_license() {
|
||||
let env = setup().await;
|
||||
seed_user(&env.pb, "user1");
|
||||
let reservation_id = seed_pending_checkout(&env.pb, "user1", "cs_test_abc");
|
||||
|
||||
let payload = checkout_completed_event("cs_test_abc", "user1", 999);
|
||||
let signature = stripe_signature_header(&payload, "whsec_test_secret");
|
||||
let response = deliver_webhook(&env, payload, Some(&signature)).await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let user = env.pb.record("users", "user1").expect("user exists");
|
||||
assert_eq!(user["subscription"], json!("licensed"));
|
||||
|
||||
let checkout = env
|
||||
.pb
|
||||
.record("checkout_sessions", &reservation_id)
|
||||
.expect("checkout exists");
|
||||
assert_eq!(checkout["status"], json!("completed"));
|
||||
assert_eq!(checkout["paid_amount_pence"], json!(999));
|
||||
assert_eq!(checkout["stripe_payment_intent_id"], json!("pi_test_1"));
|
||||
|
||||
assert_eq!(env.pb.license_grant_count("user1"), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn webhook_with_invalid_signature_is_rejected() {
|
||||
let env = setup().await;
|
||||
seed_user(&env.pb, "user1");
|
||||
let reservation_id = seed_pending_checkout(&env.pb, "user1", "cs_test_bad");
|
||||
|
||||
let payload = checkout_completed_event("cs_test_bad", "user1", 999);
|
||||
|
||||
// Signed with the wrong secret.
|
||||
let wrong_signature = stripe_signature_header(&payload, "whsec_wrong_secret");
|
||||
let response = deliver_webhook(&env, payload.clone(), Some(&wrong_signature)).await;
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
// Missing signature header entirely.
|
||||
let response = deliver_webhook(&env, payload, None).await;
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
let user = env.pb.record("users", "user1").expect("user exists");
|
||||
assert_eq!(user["subscription"], json!("free"));
|
||||
let checkout = env
|
||||
.pb
|
||||
.record("checkout_sessions", &reservation_id)
|
||||
.expect("checkout exists");
|
||||
assert_eq!(checkout["status"], json!("pending"));
|
||||
assert_eq!(env.pb.license_grant_count("user1"), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn replayed_webhook_does_not_double_grant() {
|
||||
let env = setup().await;
|
||||
seed_user(&env.pb, "user1");
|
||||
let reservation_id = seed_pending_checkout(&env.pb, "user1", "cs_test_replay");
|
||||
|
||||
let payload = checkout_completed_event("cs_test_replay", "user1", 999);
|
||||
let signature = stripe_signature_header(&payload, "whsec_test_secret");
|
||||
|
||||
let first = deliver_webhook(&env, payload.clone(), Some(&signature)).await;
|
||||
assert_eq!(first.status(), StatusCode::OK);
|
||||
let replay = deliver_webhook(&env, payload, Some(&signature)).await;
|
||||
assert_eq!(replay.status(), StatusCode::OK);
|
||||
|
||||
let user = env.pb.record("users", "user1").expect("user exists");
|
||||
assert_eq!(user["subscription"], json!("licensed"));
|
||||
let checkout = env
|
||||
.pb
|
||||
.record("checkout_sessions", &reservation_id)
|
||||
.expect("checkout exists");
|
||||
assert_eq!(checkout["status"], json!("completed"));
|
||||
|
||||
// The replay must be acknowledged without granting a second time.
|
||||
assert_eq!(env.pb.license_grant_count("user1"), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn webhook_with_tampered_amount_is_rejected_and_marks_reservation_invalid() {
|
||||
let env = setup().await;
|
||||
seed_user(&env.pb, "user1");
|
||||
let reservation_id = seed_pending_checkout(&env.pb, "user1", "cs_test_amount");
|
||||
|
||||
// Validly signed event whose amount_total does not match the reservation.
|
||||
let payload = checkout_completed_event("cs_test_amount", "user1", 500);
|
||||
let signature = stripe_signature_header(&payload, "whsec_test_secret");
|
||||
let response = deliver_webhook(&env, payload, Some(&signature)).await;
|
||||
|
||||
// Rejections are acknowledged with 200 so Stripe stops retrying.
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let user = env.pb.record("users", "user1").expect("user exists");
|
||||
assert_eq!(user["subscription"], json!("free"));
|
||||
let checkout = env
|
||||
.pb
|
||||
.record("checkout_sessions", &reservation_id)
|
||||
.expect("checkout exists");
|
||||
assert_eq!(checkout["status"], json!("invalid"));
|
||||
assert_eq!(env.pb.license_grant_count("user1"), 0);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Checkout session creation (up to the hardcoded Stripe API call)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tokio::test]
|
||||
async fn checkout_start_reserves_then_marks_failed_when_stripe_is_unreachable() {
|
||||
let env = setup().await;
|
||||
seed_user(&env.pb, "user9");
|
||||
let state = env.shared.load_state();
|
||||
let user = test_user("user9");
|
||||
|
||||
// The Stripe API URL is hardcoded to https://api.stripe.com, so the
|
||||
// session-creation call fails in tests (no network / dummy key). The
|
||||
// reservation bookkeeping before and after that call is what we assert.
|
||||
let result = start_license_checkout(
|
||||
&state,
|
||||
&user,
|
||||
"https://x/success",
|
||||
"https://x/cancel",
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
assert!(result.is_err(), "Stripe call must fail in tests");
|
||||
|
||||
let checkouts = env.pb.records_in("checkout_sessions");
|
||||
assert_eq!(checkouts.len(), 1, "exactly one reservation created");
|
||||
let checkout = &checkouts[0];
|
||||
assert_eq!(checkout["user"], json!("user9"));
|
||||
// 0 licensed users → public count 120 → second tier price (999p).
|
||||
assert_eq!(checkout["amount_pence"], json!(999));
|
||||
assert_eq!(checkout["expected_total_pence"], json!(999));
|
||||
assert_eq!(checkout["currency"], json!("gbp"));
|
||||
// The failed Stripe call must not leave a live pending reservation.
|
||||
assert_eq!(checkout["status"], json!("failed"));
|
||||
|
||||
// The cross-instance pricing lock was released.
|
||||
assert!(env.pb.records_in("checkout_locks").is_empty());
|
||||
assert_eq!(env.pb.license_grant_count("user9"), 0);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Invite redemption
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tokio::test]
|
||||
async fn admin_invite_redemption_grants_license() {
|
||||
let env = setup().await;
|
||||
seed_user(&env.pb, "user2");
|
||||
let invite_id = env.pb.seed(
|
||||
"invites",
|
||||
json!({
|
||||
"code": "admininvite1",
|
||||
"invite_type": "admin",
|
||||
"created_by": "adminuser1",
|
||||
"used_by_id": "",
|
||||
"used_at": "",
|
||||
}),
|
||||
);
|
||||
|
||||
let response = redeem_invite(&env, test_user("user2"), "admininvite1").await;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
let body = response_json(response).await;
|
||||
assert_eq!(body["result"], json!("licensed"));
|
||||
|
||||
let invite = env.pb.record("invites", &invite_id).expect("invite exists");
|
||||
assert_eq!(invite["used_by_id"], json!("user2"));
|
||||
|
||||
let user = env.pb.record("users", "user2").expect("user exists");
|
||||
assert_eq!(user["subscription"], json!("licensed"));
|
||||
assert_eq!(env.pb.license_grant_count("user2"), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invalid_and_oversized_invite_codes_are_rejected() {
|
||||
let env = setup().await;
|
||||
|
||||
// Non-alphanumeric characters.
|
||||
let response = redeem_invite(&env, test_user("user2"), "bad-code!").await;
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
// Longer than the 20-character limit.
|
||||
let oversized = "a".repeat(21);
|
||||
let response = redeem_invite(&env, test_user("user2"), &oversized).await;
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
// Empty code.
|
||||
let response = redeem_invite(&env, test_user("user2"), "").await;
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
assert_eq!(env.pb.license_grant_count("user2"), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn already_used_invite_is_rejected() {
|
||||
let env = setup().await;
|
||||
seed_user(&env.pb, "user2");
|
||||
env.pb.seed(
|
||||
"invites",
|
||||
json!({
|
||||
"code": "usedinvite12",
|
||||
"invite_type": "admin",
|
||||
"created_by": "adminuser1",
|
||||
"used_by_id": "otheruser9",
|
||||
"used_at": "1700000000",
|
||||
}),
|
||||
);
|
||||
|
||||
let response = redeem_invite(&env, test_user("user2"), "usedinvite12").await;
|
||||
assert_eq!(response.status(), StatusCode::NOT_FOUND);
|
||||
|
||||
let user = env.pb.record("users", "user2").expect("user exists");
|
||||
assert_eq!(user["subscription"], json!("free"));
|
||||
assert_eq!(env.pb.license_grant_count("user2"), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn referral_invite_redemption_releases_reservation_when_stripe_is_unreachable() {
|
||||
let env = setup().await;
|
||||
seed_user(&env.pb, "user3");
|
||||
let invite_id = env.pb.seed(
|
||||
"invites",
|
||||
json!({
|
||||
"code": "refcode12345",
|
||||
"invite_type": "referral",
|
||||
"created_by": "licenseduser1",
|
||||
"used_by_id": "",
|
||||
"used_at": "",
|
||||
"reserved_by_id": "",
|
||||
"reserved_checkout_id": "",
|
||||
"reserved_until_unix": 0,
|
||||
}),
|
||||
);
|
||||
|
||||
// The redemption itself is valid; it fails only at the hardcoded Stripe
|
||||
// call (coupon verification / session creation), which must roll back the
|
||||
// reservation cleanly.
|
||||
let response = redeem_invite(&env, test_user("user3"), "refcode12345").await;
|
||||
assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
|
||||
|
||||
let checkouts = env.pb.records_in("checkout_sessions");
|
||||
assert_eq!(checkouts.len(), 1, "referral reservation was created");
|
||||
assert_eq!(checkouts[0]["referral_invite_id"], json!(invite_id));
|
||||
assert_eq!(checkouts[0]["status"], json!("failed"));
|
||||
|
||||
// The invite reservation was released and the invite is still unused.
|
||||
let invite = env.pb.record("invites", &invite_id).expect("invite exists");
|
||||
assert_eq!(invite["used_by_id"], json!(""));
|
||||
assert_eq!(invite["reserved_checkout_id"], json!(""));
|
||||
assert_eq!(invite["reserved_by_id"], json!(""));
|
||||
|
||||
assert_eq!(env.pb.license_grant_count("user3"), 0);
|
||||
}
|
||||
|
|
@ -182,8 +182,7 @@ impl CrimeByYearData {
|
|||
// Force-coverage calendar (optional column: legacy parquets predate it;
|
||||
// their postcodes are treated as fully covered). A row with an empty
|
||||
// list is meaningful — zero covered years — so it IS inserted.
|
||||
let mut covered_years_by_postcode: FxHashMap<String, Vec<i32>> =
|
||||
FxHashMap::default();
|
||||
let mut covered_years_by_postcode: FxHashMap<String, Vec<i32>> = FxHashMap::default();
|
||||
if let Ok(col) = df.column(COVERAGE_COLUMN) {
|
||||
let list_ca = col
|
||||
.list()
|
||||
|
|
@ -195,12 +194,12 @@ impl CrimeByYearData {
|
|||
};
|
||||
let mut years: Vec<i32> = Vec::with_capacity(inner.len());
|
||||
if !inner.is_empty() {
|
||||
let structs = inner.struct_().with_context(|| {
|
||||
format!("Inner of '{COVERAGE_COLUMN}' is not a struct")
|
||||
})?;
|
||||
let year_field = structs.field_by_name("year").with_context(|| {
|
||||
format!("Missing 'year' field in '{COVERAGE_COLUMN}'")
|
||||
})?;
|
||||
let structs = inner
|
||||
.struct_()
|
||||
.with_context(|| format!("Inner of '{COVERAGE_COLUMN}' is not a struct"))?;
|
||||
let year_field = structs
|
||||
.field_by_name("year")
|
||||
.with_context(|| format!("Missing 'year' field in '{COVERAGE_COLUMN}'"))?;
|
||||
for idx in 0..inner.len() {
|
||||
match year_field.get(idx).ok() {
|
||||
Some(AnyValue::Int32(y)) => years.push(y),
|
||||
|
|
|
|||
|
|
@ -742,6 +742,29 @@ impl PlaceData {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl PlaceData {
|
||||
/// Minimal empty instance for integration tests that need an `AppState`
|
||||
/// but never touch place data.
|
||||
pub(crate) fn empty_for_tests() -> Self {
|
||||
PlaceData {
|
||||
name: Vec::new(),
|
||||
name_lower: Vec::new(),
|
||||
name_search: Vec::new(),
|
||||
place_type: InternedColumn::build(&[]),
|
||||
type_rank: Vec::new(),
|
||||
population: Vec::new(),
|
||||
lat: Vec::new(),
|
||||
lon: Vec::new(),
|
||||
city: Vec::new(),
|
||||
travel_destination: Vec::new(),
|
||||
token_index: FxHashMap::default(),
|
||||
token_prefix_index: FxHashMap::default(),
|
||||
fuzzy_trigram_index: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
|||
|
|
@ -588,6 +588,29 @@ impl POIData {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl POIData {
|
||||
/// Minimal empty instance for integration tests that need an `AppState`
|
||||
/// but never touch POI data.
|
||||
pub(crate) fn empty_for_tests() -> Self {
|
||||
POIData {
|
||||
id_buffer: String::new(),
|
||||
id_offsets: Vec::new(),
|
||||
id_lengths: Vec::new(),
|
||||
group: InternedColumn::build(&[]),
|
||||
category: InternedColumn::build(&[]),
|
||||
icon_category: InternedColumn::build(&[]),
|
||||
name: Vec::new(),
|
||||
lat: Vec::new(),
|
||||
lng: Vec::new(),
|
||||
emoji: InternedColumn::build(&[]),
|
||||
priority: Vec::new(),
|
||||
school_meta_idx: Vec::new(),
|
||||
school_meta: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
973
server-rs/src/data/property/address_search.rs
Normal file
973
server-rs/src/data/property/address_search.rs
Normal file
|
|
@ -0,0 +1,973 @@
|
|||
//! Address search: tokenization, query parsing, inverted/prefix indexes and the
|
||||
//! ranked per-row search over property addresses.
|
||||
|
||||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
|
||||
use super::PropertyData;
|
||||
|
||||
/// Upper bound on rows scored per query. Intersection keeps most candidate sets far below
|
||||
/// this; only a single very common road word (e.g. "high") approaches it, and the in-area
|
||||
/// priority sort keeps a refined query's matches ahead of the cut.
|
||||
const ADDRESS_SEARCH_CANDIDATE_LIMIT: usize = 150_000;
|
||||
const ADDRESS_SEARCH_PREFIX_MIN_LEN: usize = 4;
|
||||
const ADDRESS_SEARCH_PREFIX_MAX_LEN: usize = 8;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(super) struct AddressTermGroup {
|
||||
alternatives: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(super) struct AddressQuery {
|
||||
full_postcode: Option<String>,
|
||||
/// Compact uppercase outward code (optionally with a sector digit) recovered when the
|
||||
/// user appended a partial postcode like "NW1" or "NW1 6". Used as an additive ranking
|
||||
/// bias, never as a hard filter — so the disambiguating hint is honoured without
|
||||
/// excluding the same road in other areas.
|
||||
postcode_area: Option<String>,
|
||||
text_groups: Vec<AddressTermGroup>,
|
||||
numeric_terms: Vec<String>,
|
||||
candidate_terms: Vec<String>,
|
||||
}
|
||||
|
||||
fn tokenize_address_text(text: &str) -> Vec<String> {
|
||||
let mut tokens = Vec::new();
|
||||
let mut current = String::new();
|
||||
|
||||
for ch in text.chars() {
|
||||
if ch.is_ascii_alphanumeric() {
|
||||
current.push(ch.to_ascii_lowercase());
|
||||
} else if matches!(ch, '\'' | '’' | '`') {
|
||||
continue;
|
||||
} else if !current.is_empty() {
|
||||
tokens.push(std::mem::take(&mut current));
|
||||
}
|
||||
}
|
||||
|
||||
if !current.is_empty() {
|
||||
tokens.push(current);
|
||||
}
|
||||
|
||||
tokens
|
||||
}
|
||||
|
||||
fn is_full_postcode_compact(compact: &str) -> bool {
|
||||
let bytes = compact.as_bytes();
|
||||
let len = bytes.len();
|
||||
if !(5..=7).contains(&len) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let inward = &bytes[len - 3..];
|
||||
if !inward[0].is_ascii_digit()
|
||||
|| !inward[1].is_ascii_alphabetic()
|
||||
|| !inward[2].is_ascii_alphabetic()
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
let outward = &bytes[..len - 3];
|
||||
if !(2..=4).contains(&outward.len()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
outward[0].is_ascii_alphabetic()
|
||||
&& outward.iter().all(u8::is_ascii_alphanumeric)
|
||||
&& outward.iter().any(u8::is_ascii_digit)
|
||||
}
|
||||
|
||||
fn canonical_postcode_from_compact(compact: &str) -> String {
|
||||
let upper = compact.to_ascii_uppercase();
|
||||
let split = upper.len() - 3;
|
||||
format!("{} {}", &upper[..split], &upper[split..])
|
||||
}
|
||||
|
||||
fn extract_full_postcode(tokens: &[String]) -> Option<(String, Vec<usize>)> {
|
||||
for (idx, token) in tokens.iter().enumerate() {
|
||||
let compact = token.to_ascii_uppercase();
|
||||
if is_full_postcode_compact(&compact) {
|
||||
return Some((canonical_postcode_from_compact(&compact), vec![idx]));
|
||||
}
|
||||
}
|
||||
|
||||
for idx in 0..tokens.len().saturating_sub(1) {
|
||||
let compact = format!(
|
||||
"{}{}",
|
||||
tokens[idx].to_ascii_uppercase(),
|
||||
tokens[idx + 1].to_ascii_uppercase()
|
||||
);
|
||||
if is_full_postcode_compact(&compact) {
|
||||
return Some((
|
||||
canonical_postcode_from_compact(&compact),
|
||||
vec![idx, idx + 1],
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn looks_like_postcode_fragment(token: &str) -> bool {
|
||||
(2..=4).contains(&token.len())
|
||||
&& token
|
||||
.chars()
|
||||
.next()
|
||||
.is_some_and(|ch| ch.is_ascii_alphabetic())
|
||||
&& token.chars().any(|ch| ch.is_ascii_digit())
|
||||
&& token.chars().all(|ch| ch.is_ascii_alphanumeric())
|
||||
}
|
||||
|
||||
fn is_numeric_address_token(token: &str) -> bool {
|
||||
token.chars().all(|ch| ch.is_ascii_digit())
|
||||
}
|
||||
|
||||
fn address_token_aliases(token: &str) -> Vec<&'static str> {
|
||||
match token {
|
||||
"apt" => vec!["apt", "apartment"],
|
||||
"apartment" => vec!["apartment", "apt"],
|
||||
"ave" => vec!["ave", "avenue"],
|
||||
"avenue" => vec!["avenue", "ave"],
|
||||
"blvd" => vec!["blvd", "boulevard"],
|
||||
"boulevard" => vec!["boulevard", "blvd"],
|
||||
"cl" => vec!["cl", "close"],
|
||||
"close" => vec!["close", "cl"],
|
||||
"ct" => vec!["ct", "court"],
|
||||
"court" => vec!["court", "ct"],
|
||||
"cres" => vec!["cres", "crescent"],
|
||||
"crescent" => vec!["crescent", "cres"],
|
||||
"dr" => vec!["dr", "drive"],
|
||||
"drive" => vec!["drive", "dr"],
|
||||
"fl" => vec!["fl", "flat"],
|
||||
"flat" => vec!["flat", "fl"],
|
||||
"gdns" => vec!["gdns", "gardens", "garden"],
|
||||
"garden" => vec!["garden", "gardens", "gdns"],
|
||||
"gardens" => vec!["gardens", "garden", "gdns"],
|
||||
"hse" => vec!["hse", "house"],
|
||||
"house" => vec!["house", "hse"],
|
||||
"ln" => vec!["ln", "lane"],
|
||||
"lane" => vec!["lane", "ln"],
|
||||
"rd" => vec!["rd", "road"],
|
||||
"road" => vec!["road", "rd"],
|
||||
"sq" => vec!["sq", "square"],
|
||||
"square" => vec!["square", "sq"],
|
||||
"st" => vec!["st", "street", "saint"],
|
||||
"street" => vec!["street", "st"],
|
||||
"saint" => vec!["saint", "st"],
|
||||
"terr" => vec!["terr", "terrace"],
|
||||
"terrace" => vec!["terrace", "terr"],
|
||||
_ => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_address_stop_token(token: &str) -> bool {
|
||||
matches!(
|
||||
token,
|
||||
"a" | "an"
|
||||
| "and"
|
||||
| "apartment"
|
||||
| "apt"
|
||||
| "avenue"
|
||||
| "ave"
|
||||
| "block"
|
||||
| "building"
|
||||
| "bungalow"
|
||||
| "close"
|
||||
| "cl"
|
||||
| "court"
|
||||
| "ct"
|
||||
| "cres"
|
||||
| "crescent"
|
||||
| "drive"
|
||||
| "dr"
|
||||
| "estate"
|
||||
| "flat"
|
||||
| "fl"
|
||||
| "floor"
|
||||
| "garden"
|
||||
| "gardens"
|
||||
| "gdns"
|
||||
| "grove"
|
||||
| "house"
|
||||
| "hse"
|
||||
| "lane"
|
||||
| "ln"
|
||||
| "lodge"
|
||||
| "mansions"
|
||||
| "mews"
|
||||
| "of"
|
||||
| "park"
|
||||
| "place"
|
||||
| "road"
|
||||
| "rd"
|
||||
| "room"
|
||||
| "row"
|
||||
| "saint"
|
||||
| "sq"
|
||||
| "square"
|
||||
| "st"
|
||||
| "street"
|
||||
| "terr"
|
||||
| "terrace"
|
||||
| "the"
|
||||
| "unit"
|
||||
| "view"
|
||||
| "villas"
|
||||
| "walk"
|
||||
| "way"
|
||||
| "yard"
|
||||
)
|
||||
}
|
||||
|
||||
fn address_term_group(token: &str) -> Option<AddressTermGroup> {
|
||||
if token.len() < 3 || is_numeric_address_token(token) || looks_like_postcode_fragment(token) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut alternatives = Vec::new();
|
||||
alternatives.push(token.to_string());
|
||||
for alias in address_token_aliases(token) {
|
||||
if !alternatives.iter().any(|existing| existing == alias) {
|
||||
alternatives.push(alias.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if alternatives
|
||||
.iter()
|
||||
.all(|alternative| is_address_stop_token(alternative))
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(AddressTermGroup { alternatives })
|
||||
}
|
||||
|
||||
pub(super) fn address_search_tokens(text: &str) -> Vec<String> {
|
||||
let mut tokens: Vec<String> = tokenize_address_text(text)
|
||||
.into_iter()
|
||||
.filter(|token| is_address_search_token(token))
|
||||
.collect();
|
||||
tokens.sort_unstable();
|
||||
tokens.dedup();
|
||||
tokens
|
||||
}
|
||||
|
||||
fn is_address_search_token(token: &str) -> bool {
|
||||
if looks_like_postcode_fragment(token) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if is_numeric_address_token(token) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if token.chars().any(|ch| ch.is_ascii_digit()) {
|
||||
return token.len() >= 2;
|
||||
}
|
||||
|
||||
token.len() >= 3
|
||||
}
|
||||
|
||||
pub(super) fn is_address_candidate_token(token: &str) -> bool {
|
||||
!is_numeric_address_token(token)
|
||||
&& !looks_like_postcode_fragment(token)
|
||||
&& (token.chars().any(|ch| ch.is_ascii_digit())
|
||||
|| (token.len() >= 3 && !is_address_stop_token(token)))
|
||||
}
|
||||
|
||||
fn address_prefix_key(term: &str) -> &str {
|
||||
if term.len() > ADDRESS_SEARCH_PREFIX_MAX_LEN {
|
||||
&term[..ADDRESS_SEARCH_PREFIX_MAX_LEN]
|
||||
} else {
|
||||
term
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn build_address_prefix_index(
|
||||
address_token_index: &FxHashMap<String, Vec<u32>>,
|
||||
) -> FxHashMap<String, Vec<String>> {
|
||||
let mut prefix_index: FxHashMap<String, Vec<String>> = FxHashMap::default();
|
||||
|
||||
for token in address_token_index.keys() {
|
||||
let max_prefix_len = token.len().min(ADDRESS_SEARCH_PREFIX_MAX_LEN);
|
||||
for prefix_len in ADDRESS_SEARCH_PREFIX_MIN_LEN..=max_prefix_len {
|
||||
prefix_index
|
||||
.entry(token[..prefix_len].to_string())
|
||||
.or_default()
|
||||
.push(token.clone());
|
||||
}
|
||||
}
|
||||
|
||||
for tokens in prefix_index.values_mut() {
|
||||
tokens.sort_unstable();
|
||||
tokens.dedup();
|
||||
}
|
||||
|
||||
prefix_index
|
||||
}
|
||||
|
||||
/// Intersect two ascending-sorted row-id slices.
|
||||
fn intersect_sorted(left: &[u32], right: &[u32]) -> Vec<u32> {
|
||||
let mut out = Vec::new();
|
||||
let (mut i, mut j) = (0, 0);
|
||||
while i < left.len() && j < right.len() {
|
||||
match left[i].cmp(&right[j]) {
|
||||
std::cmp::Ordering::Less => i += 1,
|
||||
std::cmp::Ordering::Greater => j += 1,
|
||||
std::cmp::Ordering::Equal => {
|
||||
out.push(left[i]);
|
||||
i += 1;
|
||||
j += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Union two ascending-sorted row-id slices (deduplicated, stays sorted).
|
||||
fn union_sorted(left: &[u32], right: &[u32]) -> Vec<u32> {
|
||||
let mut out = Vec::with_capacity(left.len() + right.len());
|
||||
let (mut i, mut j) = (0, 0);
|
||||
while i < left.len() && j < right.len() {
|
||||
match left[i].cmp(&right[j]) {
|
||||
std::cmp::Ordering::Less => {
|
||||
out.push(left[i]);
|
||||
i += 1;
|
||||
}
|
||||
std::cmp::Ordering::Greater => {
|
||||
out.push(right[j]);
|
||||
j += 1;
|
||||
}
|
||||
std::cmp::Ordering::Equal => {
|
||||
out.push(left[i]);
|
||||
i += 1;
|
||||
j += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
out.extend_from_slice(&left[i..]);
|
||||
out.extend_from_slice(&right[j..]);
|
||||
out
|
||||
}
|
||||
|
||||
/// An ordinal like "1st", "2nd", "3rd", "21st" — part of the street name ("2nd Avenue"), not a
|
||||
/// house-number prefix.
|
||||
fn is_ordinal_token(token: &str) -> bool {
|
||||
let split = token.len().saturating_sub(2);
|
||||
let (digits, suffix) = token.split_at(split);
|
||||
!digits.is_empty()
|
||||
&& digits.chars().all(|ch| ch.is_ascii_digit())
|
||||
&& matches!(suffix, "st" | "nd" | "rd" | "th")
|
||||
}
|
||||
|
||||
/// Leading address tokens that denote a unit/house number rather than the street itself.
|
||||
fn is_house_prefix_token(token: &str) -> bool {
|
||||
if is_ordinal_token(token) {
|
||||
return false;
|
||||
}
|
||||
matches!(
|
||||
token,
|
||||
"flat" | "fl" | "apartment" | "apt" | "unit" | "no" | "block" | "floor" | "room"
|
||||
) || token.len() == 1
|
||||
|| token.chars().all(|ch| ch.is_ascii_digit())
|
||||
|| (token.chars().next().is_some_and(|ch| ch.is_ascii_digit())
|
||||
&& token.chars().any(|ch| ch.is_ascii_alphabetic()))
|
||||
}
|
||||
|
||||
/// Street-level key for an address: drops the leading house-number / flat prefix so that
|
||||
/// "12 Baker Street" and "5 Baker Street" collapse to a single street entry.
|
||||
fn street_key(address: &str) -> String {
|
||||
let tokens = tokenize_address_text(address);
|
||||
let mut start = 0;
|
||||
while start < tokens.len() && is_house_prefix_token(&tokens[start]) {
|
||||
start += 1;
|
||||
}
|
||||
if start >= tokens.len() {
|
||||
return tokens.join(" ");
|
||||
}
|
||||
tokens[start..].join(" ")
|
||||
}
|
||||
|
||||
/// Road-type words. Their presence (with no house number) marks a road browse, which we
|
||||
/// collapse to one result per street.
|
||||
const ROAD_TYPE_TOKENS: &[&str] = &[
|
||||
"street",
|
||||
"st",
|
||||
"road",
|
||||
"rd",
|
||||
"lane",
|
||||
"ln",
|
||||
"avenue",
|
||||
"ave",
|
||||
"close",
|
||||
"cl",
|
||||
"drive",
|
||||
"dr",
|
||||
"way",
|
||||
"court",
|
||||
"ct",
|
||||
"crescent",
|
||||
"cres",
|
||||
"place",
|
||||
"terrace",
|
||||
"terr",
|
||||
"grove",
|
||||
"gardens",
|
||||
"gdns",
|
||||
"walk",
|
||||
"row",
|
||||
"square",
|
||||
"sq",
|
||||
"hill",
|
||||
"parade",
|
||||
"mews",
|
||||
"embankment",
|
||||
"broadway",
|
||||
"boulevard",
|
||||
"blvd",
|
||||
];
|
||||
|
||||
fn query_has_road_type(query: &str) -> bool {
|
||||
tokenize_address_text(query)
|
||||
.iter()
|
||||
.any(|token| ROAD_TYPE_TOKENS.contains(&token.as_str()))
|
||||
}
|
||||
|
||||
/// The outward code (everything before the space) of a canonical postcode.
|
||||
fn outcode_of(postcode: &str) -> &str {
|
||||
postcode.split(' ').next().unwrap_or(postcode)
|
||||
}
|
||||
|
||||
fn parse_address_query(query: &str) -> AddressQuery {
|
||||
let tokens = tokenize_address_text(query);
|
||||
let (full_postcode, postcode_token_indices) = extract_full_postcode(&tokens)
|
||||
.map(|(postcode, indices)| (Some(postcode), indices))
|
||||
.unwrap_or((None, Vec::new()));
|
||||
|
||||
let skip_postcode_tokens: FxHashSet<usize> = postcode_token_indices.into_iter().collect();
|
||||
|
||||
// Recover an appended partial postcode (outcode, or outcode + sector digit) as a ranking
|
||||
// bias rather than discarding it — but only from the TRAILING position, so a leading road
|
||||
// designation like "A4 Great West Road" is not mistaken for an area refinement.
|
||||
let mut postcode_area: Option<String> = None;
|
||||
let mut consumed_partial_tokens: FxHashSet<usize> = FxHashSet::default();
|
||||
if full_postcode.is_none() && !tokens.is_empty() {
|
||||
let last = tokens.len() - 1;
|
||||
if !skip_postcode_tokens.contains(&last) {
|
||||
let sector_digit =
|
||||
tokens[last].len() == 1 && tokens[last].chars().all(|ch| ch.is_ascii_digit());
|
||||
if last >= 1
|
||||
&& sector_digit
|
||||
&& !skip_postcode_tokens.contains(&(last - 1))
|
||||
&& looks_like_postcode_fragment(&tokens[last - 1])
|
||||
{
|
||||
postcode_area = Some(format!(
|
||||
"{}{}",
|
||||
tokens[last - 1].to_ascii_uppercase(),
|
||||
tokens[last]
|
||||
));
|
||||
consumed_partial_tokens.insert(last);
|
||||
consumed_partial_tokens.insert(last - 1);
|
||||
} else if looks_like_postcode_fragment(&tokens[last]) {
|
||||
postcode_area = Some(tokens[last].to_ascii_uppercase());
|
||||
consumed_partial_tokens.insert(last);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut text_groups = Vec::new();
|
||||
let mut numeric_terms = Vec::new();
|
||||
let mut candidate_terms = Vec::new();
|
||||
|
||||
for (idx, token) in tokens.iter().enumerate() {
|
||||
if skip_postcode_tokens.contains(&idx)
|
||||
|| consumed_partial_tokens.contains(&idx)
|
||||
|| looks_like_postcode_fragment(token)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if is_numeric_address_token(token) {
|
||||
numeric_terms.push(token.clone());
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(group) = address_term_group(token) {
|
||||
for alternative in &group.alternatives {
|
||||
if !is_address_stop_token(alternative)
|
||||
&& !candidate_terms.iter().any(|term| term == alternative)
|
||||
{
|
||||
candidate_terms.push(alternative.clone());
|
||||
}
|
||||
}
|
||||
text_groups.push(group);
|
||||
} else if token.chars().any(|ch| ch.is_ascii_digit()) && token.len() >= 2 {
|
||||
numeric_terms.push(token.clone());
|
||||
if !candidate_terms.iter().any(|term| term == token) {
|
||||
candidate_terms.push(token.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
text_groups.dedup_by(|left, right| left.alternatives == right.alternatives);
|
||||
numeric_terms.sort_unstable();
|
||||
numeric_terms.dedup();
|
||||
|
||||
AddressQuery {
|
||||
full_postcode,
|
||||
postcode_area,
|
||||
text_groups,
|
||||
numeric_terms,
|
||||
candidate_terms,
|
||||
}
|
||||
}
|
||||
|
||||
fn token_matches_query_term(token: &str, query_term: &str) -> bool {
|
||||
token == query_term || (query_term.len() >= 3 && token.starts_with(query_term))
|
||||
}
|
||||
|
||||
fn token_matches_numeric_term(token: &str, query_term: &str) -> bool {
|
||||
token == query_term || token.starts_with(query_term)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn address_tokens_match_group(tokens: &[String], group: &AddressTermGroup) -> bool {
|
||||
group.alternatives.iter().any(|alternative| {
|
||||
tokens
|
||||
.iter()
|
||||
.any(|token| token_matches_query_term(token, alternative))
|
||||
})
|
||||
}
|
||||
|
||||
impl PropertyData {
|
||||
fn row_address_search_tokens(&self, row: usize) -> &[lasso::Spur] {
|
||||
let offset = self.address_search_token_offsets[row] as usize;
|
||||
let length = self.address_search_token_lengths[row] as usize;
|
||||
&self.address_search_token_keys[offset..offset + length]
|
||||
}
|
||||
|
||||
/// Search individual property addresses, returning `(row, score)` ranked best-first.
|
||||
///
|
||||
/// Candidate rows come from intersecting the posting lists of the distinctive words the
|
||||
/// user typed in full (so "Cherry Hinton Road" narrows to rows containing both), unioned
|
||||
/// with the exact-postcode rows when a complete postcode is present (so a postcode is a
|
||||
/// boost, not an all-or-nothing gate). An appended partial postcode keeps in-area rows
|
||||
/// ahead of the candidate cut and adds a scoring bias. With a road-type word and no house
|
||||
/// number, results collapse to one row per street.
|
||||
pub fn search_addresses(&self, query: &str, limit: usize) -> Vec<(usize, i32)> {
|
||||
if limit == 0 {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let parsed = parse_address_query(query);
|
||||
if parsed.full_postcode.is_none()
|
||||
&& parsed.text_groups.is_empty()
|
||||
&& parsed.numeric_terms.is_empty()
|
||||
{
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut candidate_rows = self.address_candidate_rows(&parsed.candidate_terms);
|
||||
|
||||
// A complete postcode contributes its rows too, instead of replacing the road match.
|
||||
if let Some(postcode) = parsed.full_postcode.as_deref() {
|
||||
if let Some(rows) = self
|
||||
.postcode_interner
|
||||
.get(postcode)
|
||||
.and_then(|key| self.postcode_row_index.get(&key))
|
||||
{
|
||||
candidate_rows = if candidate_rows.is_empty() {
|
||||
rows.clone()
|
||||
} else {
|
||||
union_sorted(&candidate_rows, rows)
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if candidate_rows.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
// When the user appended a partial postcode, keep in-area rows ahead of the cut so the
|
||||
// refinement still surfaces even for very common roads. Single pass (stable partition) so
|
||||
// the postcode check — which allocates — runs exactly once per candidate.
|
||||
if let Some(area) = parsed.postcode_area.as_deref() {
|
||||
let mut in_area = Vec::new();
|
||||
let mut others = Vec::new();
|
||||
for &row in &candidate_rows {
|
||||
if self.row_postcode_in_area(row as usize, area) {
|
||||
in_area.push(row);
|
||||
} else {
|
||||
others.push(row);
|
||||
}
|
||||
}
|
||||
in_area.extend(others);
|
||||
candidate_rows = in_area;
|
||||
}
|
||||
candidate_rows.truncate(ADDRESS_SEARCH_CANDIDATE_LIMIT);
|
||||
|
||||
let mut scored: Vec<(i32, usize, usize)> = candidate_rows
|
||||
.into_iter()
|
||||
.filter_map(|row| {
|
||||
let row = row as usize;
|
||||
self.address_match_score(row, &parsed)
|
||||
.map(|score| (score, self.address(row).len(), row))
|
||||
})
|
||||
.collect();
|
||||
|
||||
scored.sort_unstable_by(|left, right| {
|
||||
right
|
||||
.0
|
||||
.cmp(&left.0)
|
||||
.then(left.1.cmp(&right.1))
|
||||
.then(left.2.cmp(&right.2))
|
||||
});
|
||||
|
||||
// Collapse a road browse (road-type word, no house number) to one row per street.
|
||||
let collapse_streets = parsed.numeric_terms.is_empty() && query_has_road_type(query);
|
||||
|
||||
let mut seen = FxHashSet::default();
|
||||
let mut results = Vec::with_capacity(limit);
|
||||
for (score, _, row) in scored {
|
||||
let address = self.address(row).trim();
|
||||
if address.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let key = if collapse_streets {
|
||||
format!(
|
||||
"{}\n{}",
|
||||
street_key(address),
|
||||
outcode_of(self.postcode(row))
|
||||
)
|
||||
} else {
|
||||
format!("{}\n{}", address.to_ascii_lowercase(), self.postcode(row))
|
||||
};
|
||||
if !seen.insert(key) {
|
||||
continue;
|
||||
}
|
||||
results.push((row, score));
|
||||
if results.len() == limit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// True when the row's postcode begins with the compact partial-postcode `area`
|
||||
/// (e.g. "NW1" or "NW16" matches "NW1 6XE").
|
||||
fn row_postcode_in_area(&self, row: usize, area: &str) -> bool {
|
||||
let mut compact = String::new();
|
||||
for ch in self.postcode(row).chars() {
|
||||
if !ch.is_whitespace() {
|
||||
compact.push(ch.to_ascii_uppercase());
|
||||
}
|
||||
}
|
||||
compact.starts_with(area)
|
||||
}
|
||||
|
||||
/// Candidate rows for the distinctive query words. Words typed in full intersect by their
|
||||
/// exact posting lists (precise); a still-being-typed final word with no exact match seeds
|
||||
/// from the smallest prefix-expanded posting list (so partial typing keeps working).
|
||||
fn address_candidate_rows(&self, terms: &[String]) -> Vec<u32> {
|
||||
let mut exact: Vec<&[u32]> = terms
|
||||
.iter()
|
||||
.filter_map(|term| self.address_token_index.get(term).map(Vec::as_slice))
|
||||
.collect();
|
||||
|
||||
if !exact.is_empty() {
|
||||
exact.sort_by_key(|rows| rows.len());
|
||||
let mut acc = exact[0].to_vec();
|
||||
for rows in &exact[1..] {
|
||||
if acc.is_empty() {
|
||||
break;
|
||||
}
|
||||
acc = intersect_sorted(&acc, rows);
|
||||
}
|
||||
return acc;
|
||||
}
|
||||
|
||||
self.prefix_seed_rows(terms)
|
||||
}
|
||||
|
||||
/// Seed rows from the smallest prefix-expanded term — used only when no word matched an
|
||||
/// indexed token exactly (i.e. the user is still typing the final word).
|
||||
fn prefix_seed_rows(&self, terms: &[String]) -> Vec<u32> {
|
||||
let mut best: Option<Vec<u32>> = None;
|
||||
for term in terms {
|
||||
if term.len() < ADDRESS_SEARCH_PREFIX_MIN_LEN {
|
||||
continue;
|
||||
}
|
||||
let Some(tokens) = self.address_prefix_index.get(address_prefix_key(term)) else {
|
||||
continue;
|
||||
};
|
||||
let mut union: Vec<u32> = Vec::new();
|
||||
for token in tokens {
|
||||
if !token.starts_with(term) {
|
||||
continue;
|
||||
}
|
||||
if let Some(rows) = self.address_token_index.get(token) {
|
||||
union = if union.is_empty() {
|
||||
rows.clone()
|
||||
} else {
|
||||
union_sorted(&union, rows)
|
||||
};
|
||||
}
|
||||
}
|
||||
if !union.is_empty()
|
||||
&& best
|
||||
.as_ref()
|
||||
.is_none_or(|current| union.len() < current.len())
|
||||
{
|
||||
best = Some(union);
|
||||
}
|
||||
}
|
||||
best.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn address_match_score(&self, row: usize, parsed: &AddressQuery) -> Option<i32> {
|
||||
if self.address(row).trim().is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let tokens = self.row_address_search_tokens(row);
|
||||
if parsed
|
||||
.text_groups
|
||||
.iter()
|
||||
.any(|group| !self.address_tokens_match_group(tokens, group))
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let numeric_matches = parsed
|
||||
.numeric_terms
|
||||
.iter()
|
||||
.filter(|term| {
|
||||
tokens.iter().any(|token| {
|
||||
token_matches_numeric_term(self.address_search_interner.resolve(token), term)
|
||||
})
|
||||
})
|
||||
.count();
|
||||
|
||||
if !parsed.numeric_terms.is_empty() && numeric_matches == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut score = 0;
|
||||
if parsed.full_postcode.is_some() {
|
||||
score += 1_000;
|
||||
}
|
||||
score += (parsed.text_groups.len() as i32) * 200;
|
||||
score += (numeric_matches as i32) * 90;
|
||||
if numeric_matches == parsed.numeric_terms.len() && numeric_matches > 0 {
|
||||
score += 50;
|
||||
}
|
||||
// Additive bias (never a filter) when the row sits in the appended partial postcode.
|
||||
if let Some(area) = parsed.postcode_area.as_deref() {
|
||||
if self.row_postcode_in_area(row, area) {
|
||||
score += 400;
|
||||
}
|
||||
}
|
||||
|
||||
Some(score)
|
||||
}
|
||||
|
||||
fn address_tokens_match_group(&self, tokens: &[lasso::Spur], group: &AddressTermGroup) -> bool {
|
||||
group.alternatives.iter().any(|alternative| {
|
||||
tokens.iter().any(|token| {
|
||||
token_matches_query_term(self.address_search_interner.resolve(token), alternative)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn full_postcode_detection_accepts_common_formats() {
|
||||
assert!(is_full_postcode_compact("SW1A1AA"));
|
||||
assert!(is_full_postcode_compact("E142DG"));
|
||||
assert!(is_full_postcode_compact("M11AE"));
|
||||
assert!(!is_full_postcode_compact("E14"));
|
||||
assert!(!is_full_postcode_compact("DOWNING"));
|
||||
assert!(!is_full_postcode_compact("10A"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn address_query_parsing_skips_postcodes_and_street_suffixes() {
|
||||
let parsed = parse_address_query("Flat 2, 10 Downing St, SW1A 2AA");
|
||||
|
||||
assert_eq!(parsed.full_postcode.as_deref(), Some("SW1A 2AA"));
|
||||
assert_eq!(
|
||||
parsed.numeric_terms,
|
||||
vec!["10".to_string(), "2".to_string()]
|
||||
);
|
||||
assert_eq!(parsed.candidate_terms, vec!["downing".to_string()]);
|
||||
assert_eq!(parsed.text_groups.len(), 1);
|
||||
assert_eq!(
|
||||
parsed.text_groups[0].alternatives,
|
||||
vec!["downing".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn address_query_parsing_handles_compact_postcodes() {
|
||||
let parsed = parse_address_query("10 downing street sw1a1aa");
|
||||
|
||||
assert_eq!(parsed.full_postcode.as_deref(), Some("SW1A 1AA"));
|
||||
assert_eq!(parsed.numeric_terms, vec!["10".to_string()]);
|
||||
assert_eq!(parsed.candidate_terms, vec!["downing".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn address_query_recovers_appended_partial_postcode_as_bias() {
|
||||
let parsed = parse_address_query("Baker Street NW1");
|
||||
assert_eq!(parsed.full_postcode, None);
|
||||
assert_eq!(parsed.postcode_area.as_deref(), Some("NW1"));
|
||||
// The road words are still searchable; the postcode fragment did not consume them.
|
||||
assert_eq!(parsed.candidate_terms, vec!["baker".to_string()]);
|
||||
assert!(parsed.numeric_terms.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn address_query_recovers_outcode_plus_sector_without_a_phantom_house_number() {
|
||||
let parsed = parse_address_query("High Street CR0 2");
|
||||
assert_eq!(parsed.postcode_area.as_deref(), Some("CR02"));
|
||||
// The lone sector digit must not be treated as a house number.
|
||||
assert!(parsed.numeric_terms.is_empty());
|
||||
assert_eq!(parsed.candidate_terms, vec!["high".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn full_postcode_takes_precedence_over_partial_bias() {
|
||||
let parsed = parse_address_query("Baker Street NW1 6XE");
|
||||
assert_eq!(parsed.full_postcode.as_deref(), Some("NW1 6XE"));
|
||||
assert_eq!(parsed.postcode_area, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn intersect_and_union_sorted_row_ids() {
|
||||
assert_eq!(
|
||||
intersect_sorted(&[1, 2, 3, 5], &[2, 3, 4, 5]),
|
||||
vec![2, 3, 5]
|
||||
);
|
||||
assert_eq!(intersect_sorted(&[1, 2], &[3, 4]), Vec::<u32>::new());
|
||||
assert_eq!(union_sorted(&[1, 3, 5], &[2, 3, 4]), vec![1, 2, 3, 4, 5]);
|
||||
assert_eq!(union_sorted(&[], &[2, 4]), vec![2, 4]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn street_key_collapses_house_numbers_and_flats() {
|
||||
assert_eq!(street_key("12 Baker Street"), "baker street");
|
||||
assert_eq!(street_key("5 Baker Street"), "baker street");
|
||||
assert_eq!(street_key("Flat 2, 10 Downing Street"), "downing street");
|
||||
assert_eq!(street_key("221B Baker Street"), "baker street");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn street_key_keeps_ordinal_street_names() {
|
||||
// Ordinals are part of the street name, not a house-number prefix.
|
||||
assert_eq!(street_key("2nd Avenue"), "2nd avenue");
|
||||
assert_eq!(street_key("12 3rd Avenue"), "3rd avenue");
|
||||
assert!(is_ordinal_token("21st"));
|
||||
assert!(!is_ordinal_token("21"));
|
||||
assert!(!is_ordinal_token("221b"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn postcode_area_recovered_only_from_the_trailing_position() {
|
||||
// A leading road designation must NOT be taken as an area refinement.
|
||||
let parsed = parse_address_query("A4 Great West Road");
|
||||
assert_eq!(parsed.postcode_area, None);
|
||||
// A genuine trailing outcode still is.
|
||||
let trailing = parse_address_query("Great West Road W4");
|
||||
assert_eq!(trailing.postcode_area.as_deref(), Some("W4"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn road_type_detection() {
|
||||
assert!(query_has_road_type("high street"));
|
||||
assert!(query_has_road_type("acacia avenue"));
|
||||
assert!(!query_has_road_type("acacia"));
|
||||
assert!(!query_has_road_type("london"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn address_query_parsing_keeps_partial_terms_for_row_matching() {
|
||||
let parsed = parse_address_query("settlers cour");
|
||||
|
||||
assert_eq!(parsed.full_postcode, None);
|
||||
assert_eq!(parsed.numeric_terms, Vec::<String>::new());
|
||||
assert_eq!(
|
||||
parsed.candidate_terms,
|
||||
vec!["settlers".to_string(), "cour".to_string()]
|
||||
);
|
||||
assert_eq!(parsed.text_groups.len(), 2);
|
||||
assert_eq!(
|
||||
parsed.text_groups[0].alternatives,
|
||||
vec!["settlers".to_string()]
|
||||
);
|
||||
assert_eq!(parsed.text_groups[1].alternatives, vec!["cour".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn address_search_tokens_keep_actual_address_terms_for_scoring() {
|
||||
let tokens = address_search_tokens("Flat 2, 10 Downing Cour");
|
||||
|
||||
assert_eq!(
|
||||
tokens,
|
||||
vec![
|
||||
"10".to_string(),
|
||||
"2".to_string(),
|
||||
"cour".to_string(),
|
||||
"downing".to_string(),
|
||||
"flat".to_string()
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn address_prefix_index_finds_partial_address_terms() {
|
||||
let mut token_index: FxHashMap<String, Vec<u32>> = FxHashMap::default();
|
||||
token_index.insert("downing".to_string(), vec![1]);
|
||||
token_index.insert("downton".to_string(), vec![2]);
|
||||
token_index.insert("market".to_string(), vec![3]);
|
||||
|
||||
let prefix_index = build_address_prefix_index(&token_index);
|
||||
|
||||
assert_eq!(
|
||||
prefix_index.get("down").cloned().unwrap_or_default(),
|
||||
vec!["downing".to_string(), "downton".to_string()]
|
||||
);
|
||||
assert_eq!(
|
||||
prefix_index.get("downi").cloned().unwrap_or_default(),
|
||||
vec!["downing".to_string()]
|
||||
);
|
||||
assert_eq!(
|
||||
prefix_index.get("downt").cloned().unwrap_or_default(),
|
||||
vec!["downton".to_string()]
|
||||
);
|
||||
assert!(!prefix_index.contains_key("do"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn address_term_matching_allows_prefixes_and_aliases() {
|
||||
let tokens = tokenize_address_text("10 Downing Street");
|
||||
let prefix_group = address_term_group("down").expect("prefix term should be searchable");
|
||||
let alias_group = AddressTermGroup {
|
||||
alternatives: vec!["st".to_string(), "street".to_string()],
|
||||
};
|
||||
|
||||
assert!(address_tokens_match_group(&tokens, &prefix_group));
|
||||
assert!(address_tokens_match_group(&tokens, &alias_group));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn address_term_matching_uses_actual_token_prefixes() {
|
||||
let tokens = tokenize_address_text("12 Settlers Court");
|
||||
let prefix_group = address_term_group("cou").expect("partial term should be searchable");
|
||||
|
||||
assert!(address_tokens_match_group(&tokens, &prefix_group));
|
||||
}
|
||||
}
|
||||
34
server-rs/src/data/property/h3.rs
Normal file
34
server-rs/src/data/property/h3.rs
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
//! H3 spatial cell precomputation for property rows.
|
||||
|
||||
use anyhow::Context;
|
||||
use rayon::prelude::*;
|
||||
|
||||
use crate::consts::H3_PRECOMPUTE_MAX;
|
||||
|
||||
/// Precompute H3 cell IDs for all rows at the maximum resolution only.
|
||||
/// Parent cells for lower resolutions are derived on the fly via `CellIndex::parent()`.
|
||||
pub fn precompute_h3(lat: &[f32], lon: &[f32]) -> anyhow::Result<Vec<u64>> {
|
||||
let res = H3_PRECOMPUTE_MAX;
|
||||
tracing::info!("Precomputing H3 cells at resolution {}", res);
|
||||
|
||||
let h3_res =
|
||||
h3o::Resolution::try_from(res).with_context(|| format!("Invalid H3 resolution: {res}"))?;
|
||||
|
||||
let cells: Vec<u64> = lat
|
||||
.par_iter()
|
||||
.zip(lon.par_iter())
|
||||
.enumerate()
|
||||
.map(|(i, (&latitude, &longitude))| {
|
||||
let coord = h3o::LatLng::new(latitude as f64, longitude as f64).unwrap_or_else(|err| {
|
||||
panic!(
|
||||
"Invalid coordinates at row {}: lat={}, lon={}: {}",
|
||||
i, latitude, longitude, err
|
||||
)
|
||||
});
|
||||
u64::from(coord.to_cell(h3_res))
|
||||
})
|
||||
.collect();
|
||||
|
||||
tracing::info!("H3 precomputation complete ({} cells)", cells.len());
|
||||
Ok(cells)
|
||||
}
|
||||
1105
server-rs/src/data/property/loading.rs
Normal file
1105
server-rs/src/data/property/loading.rs
Normal file
File diff suppressed because it is too large
Load diff
238
server-rs/src/data/property/mod.rs
Normal file
238
server-rs/src/data/property/mod.rs
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
//! Property data: the row-major quantized feature matrix plus the side tables
|
||||
//! (addresses, postcodes, renovation/price history, POI metrics) built from the
|
||||
//! properties + postcode parquet files.
|
||||
//!
|
||||
//! Split by concern:
|
||||
//! - [`loading`]: parquet ingestion, validation, spatial sort and matrix build
|
||||
//! - [`stats`]: histograms, percentiles and slider-bound computation
|
||||
//! - [`quant`]: u16 quantization encode/decode
|
||||
//! - [`poi_metrics`]: postcode-level POI metric side table
|
||||
//! - [`address_search`]: address tokenization, indexing and ranked search
|
||||
//! - [`h3`]: H3 cell precomputation
|
||||
|
||||
mod address_search;
|
||||
mod h3;
|
||||
mod loading;
|
||||
mod poi_metrics;
|
||||
mod quant;
|
||||
mod stats;
|
||||
|
||||
pub use h3::precompute_h3;
|
||||
pub use poi_metrics::PostcodePoiMetrics;
|
||||
pub use quant::QuantRef;
|
||||
pub use stats::{FeatureStats, Histogram};
|
||||
|
||||
use rustc_hash::FxHashMap;
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::consts::NAN_U16;
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct RenovationEvent {
|
||||
pub year: i32,
|
||||
pub event: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct HistoricalPrice {
|
||||
pub year: i32,
|
||||
pub month: u8,
|
||||
pub price: i64,
|
||||
}
|
||||
|
||||
pub struct PropertyData {
|
||||
pub lat: Vec<f32>,
|
||||
pub lon: Vec<f32>,
|
||||
pub feature_names: Vec<String>,
|
||||
pub num_features: usize,
|
||||
/// Number of numeric features (enum features start at this index).
|
||||
pub num_numeric: usize,
|
||||
/// Row-major flat array: feature_data[row * num_features + feat_idx].
|
||||
/// Quantized to u16. NaN sentinel = u16::MAX (65535).
|
||||
/// Numeric features: encoded via (val - min) / range * 65534.
|
||||
/// Enum features: stored directly as u16 cast of the f32 index.
|
||||
pub feature_data: Vec<u16>,
|
||||
/// Per-feature: range / QUANT_SCALE for fast decode.
|
||||
dequant_a: Vec<f32>,
|
||||
/// Per-feature: minimum value (offset for dequantization).
|
||||
quant_min: Vec<f32>,
|
||||
/// Per-feature: max - min (for encoding filter bounds).
|
||||
quant_range: Vec<f32>,
|
||||
pub feature_stats: Vec<FeatureStats>,
|
||||
pub poi_metrics: PostcodePoiMetrics,
|
||||
/// Unquantized last sale price used by the price-history chart.
|
||||
last_known_price_raw: Vec<f32>,
|
||||
/// Contiguous buffer holding all address strings end-to-end.
|
||||
address_buffer: String,
|
||||
/// Byte offset into `address_buffer` where each row's address starts.
|
||||
address_offsets: Vec<u32>,
|
||||
/// Length in bytes of each row's address.
|
||||
address_lengths: Vec<u16>,
|
||||
/// Interned postcodes: reader is thread-safe, keys index into it.
|
||||
postcode_interner: lasso::RodeoReader,
|
||||
postcode_keys: Vec<lasso::Spur>,
|
||||
/// Rows for each postcode, keyed by the interned postcode key.
|
||||
postcode_row_index: FxHashMap<lasso::Spur, Vec<u32>>,
|
||||
/// Inverted index from address tokens to property rows.
|
||||
address_token_index: FxHashMap<String, Vec<u32>>,
|
||||
/// Prefix lookup from typed address-token prefix to indexed full address tokens.
|
||||
address_prefix_index: FxHashMap<String, Vec<String>>,
|
||||
/// Interned normalized address-search tokens used for per-row scoring.
|
||||
address_search_interner: lasso::RodeoReader,
|
||||
/// Flat per-row normalized address-search token keys.
|
||||
address_search_token_keys: Vec<lasso::Spur>,
|
||||
/// Offset into `address_search_token_keys` for each row.
|
||||
address_search_token_offsets: Vec<u32>,
|
||||
/// Number of normalized address-search token keys for each row.
|
||||
address_search_token_lengths: Vec<u16>,
|
||||
/// For enum features: maps feature index to list of possible string values.
|
||||
/// Index in values list corresponds to the u16 value stored in feature_data.
|
||||
pub enum_values: rustc_hash::FxHashMap<usize, Vec<String>>,
|
||||
/// For enum features: maps feature index to per-value global counts (same order as enum_values).
|
||||
pub enum_counts: rustc_hash::FxHashMap<usize, Vec<u64>>,
|
||||
/// Per-row flag: true = construction date is approximate (from EPC band),
|
||||
/// false = exact (from new-build transaction date).
|
||||
/// Bit-packed: byte `row / 8`, bit `row % 8`. 8x smaller than Vec<bool>.
|
||||
approx_build_date_bits: Vec<u8>,
|
||||
/// Per-row renovation events. Keyed by (permuted) row index.
|
||||
/// Only rows with events are present in the map.
|
||||
renovation_history: FxHashMap<u32, Vec<RenovationEvent>>,
|
||||
/// Per-row historical sale transactions (Land Registry price-paid).
|
||||
/// Keyed by (permuted) row index. Only rows with prices are present.
|
||||
historical_prices: FxHashMap<u32, Vec<HistoricalPrice>>,
|
||||
property_sub_type: FxHashMap<u32, String>,
|
||||
price_qualifier: FxHashMap<u32, String>,
|
||||
}
|
||||
|
||||
impl PropertyData {
|
||||
/// Get the address string for a given row.
|
||||
pub fn address(&self, row: usize) -> &str {
|
||||
let offset = self.address_offsets[row] as usize;
|
||||
let length = self.address_lengths[row] as usize;
|
||||
&self.address_buffer[offset..offset + length]
|
||||
}
|
||||
|
||||
/// Get the postcode string for a given row.
|
||||
pub fn postcode(&self, row: usize) -> &str {
|
||||
self.postcode_interner.resolve(&self.postcode_keys[row])
|
||||
}
|
||||
|
||||
/// Get postcode components for field-level borrowing (avoids conflicting borrows with feature_data).
|
||||
pub fn postcode_parts(&self) -> (&lasso::RodeoReader, &[lasso::Spur]) {
|
||||
(&self.postcode_interner, &self.postcode_keys)
|
||||
}
|
||||
|
||||
/// Property rows for a given postcode string, or empty if unknown.
|
||||
pub fn rows_for_postcode(&self, postcode: &str) -> &[u32] {
|
||||
self.postcode_interner
|
||||
.get(postcode)
|
||||
.and_then(|key| self.postcode_row_index.get(&key))
|
||||
.map(Vec::as_slice)
|
||||
.unwrap_or(&[])
|
||||
}
|
||||
|
||||
/// Get the is_approx_build_date flag for a given row (bit-packed).
|
||||
pub fn is_approx_build_date(&self, row: usize) -> bool {
|
||||
let byte = self.approx_build_date_bits[row / 8];
|
||||
byte & (1 << (row % 8)) != 0
|
||||
}
|
||||
|
||||
/// Get renovation events for a given row (empty slice if none).
|
||||
pub fn renovation_history(&self, row: usize) -> &[RenovationEvent] {
|
||||
self.renovation_history
|
||||
.get(&(row as u32))
|
||||
.map(|v| v.as_slice())
|
||||
.unwrap_or(&[])
|
||||
}
|
||||
|
||||
/// Get historical sale transactions for a given row (empty slice if none).
|
||||
pub fn historical_prices(&self, row: usize) -> &[HistoricalPrice] {
|
||||
self.historical_prices
|
||||
.get(&(row as u32))
|
||||
.map(|v| v.as_slice())
|
||||
.unwrap_or(&[])
|
||||
}
|
||||
|
||||
/// Get property sub-type for a given row.
|
||||
pub fn property_sub_type(&self, row: usize) -> Option<&str> {
|
||||
self.property_sub_type
|
||||
.get(&(row as u32))
|
||||
.map(String::as_str)
|
||||
}
|
||||
|
||||
/// Get price qualifier for a given row.
|
||||
pub fn price_qualifier(&self, row: usize) -> Option<&str> {
|
||||
self.price_qualifier.get(&(row as u32)).map(String::as_str)
|
||||
}
|
||||
|
||||
/// Get the unquantized last sale price for charting.
|
||||
#[inline]
|
||||
pub fn last_known_price_raw(&self, row: usize) -> f32 {
|
||||
self.last_known_price_raw[row]
|
||||
}
|
||||
|
||||
/// Decode a single feature value from quantized u16 storage.
|
||||
#[inline]
|
||||
pub fn get_feature(&self, row: usize, feat_idx: usize) -> f32 {
|
||||
let raw = self.feature_data[row * self.num_features + feat_idx];
|
||||
if raw == NAN_U16 {
|
||||
return f32::NAN;
|
||||
}
|
||||
if feat_idx >= self.num_numeric {
|
||||
raw as f32
|
||||
} else {
|
||||
raw as f32 * self.dequant_a[feat_idx] + self.quant_min[feat_idx]
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a QuantRef for passing to aggregation/filter functions.
|
||||
pub fn quant_ref(&self) -> QuantRef<'_> {
|
||||
QuantRef {
|
||||
dequant_a: &self.dequant_a,
|
||||
quant_min: &self.quant_min,
|
||||
quant_range: &self.quant_range,
|
||||
num_numeric: self.num_numeric,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl PropertyData {
|
||||
/// Minimal empty instance for integration tests that need an `AppState`
|
||||
/// but never touch property data (e.g. checkout/webhook/invite flows).
|
||||
pub(crate) fn empty_for_tests() -> Self {
|
||||
PropertyData {
|
||||
lat: Vec::new(),
|
||||
lon: Vec::new(),
|
||||
feature_names: Vec::new(),
|
||||
num_features: 0,
|
||||
num_numeric: 0,
|
||||
feature_data: Vec::new(),
|
||||
dequant_a: Vec::new(),
|
||||
quant_min: Vec::new(),
|
||||
quant_range: Vec::new(),
|
||||
feature_stats: Vec::new(),
|
||||
poi_metrics: PostcodePoiMetrics::empty(0),
|
||||
last_known_price_raw: Vec::new(),
|
||||
address_buffer: String::new(),
|
||||
address_offsets: Vec::new(),
|
||||
address_lengths: Vec::new(),
|
||||
postcode_interner: lasso::Rodeo::default().into_reader(),
|
||||
postcode_keys: Vec::new(),
|
||||
postcode_row_index: FxHashMap::default(),
|
||||
address_token_index: FxHashMap::default(),
|
||||
address_prefix_index: FxHashMap::default(),
|
||||
address_search_interner: lasso::Rodeo::default().into_reader(),
|
||||
address_search_token_keys: Vec::new(),
|
||||
address_search_token_offsets: Vec::new(),
|
||||
address_search_token_lengths: Vec::new(),
|
||||
enum_values: rustc_hash::FxHashMap::default(),
|
||||
enum_counts: rustc_hash::FxHashMap::default(),
|
||||
approx_build_date_bits: Vec::new(),
|
||||
renovation_history: FxHashMap::default(),
|
||||
historical_prices: FxHashMap::default(),
|
||||
property_sub_type: FxHashMap::default(),
|
||||
price_qualifier: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
200
server-rs/src/data/property/poi_metrics.rs
Normal file
200
server-rs/src/data/property/poi_metrics.rs
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
//! Postcode-level POI metric side table: dynamic POI features are stored once
|
||||
//! per postcode (not per property row) to keep the hot row-major feature matrix
|
||||
//! narrow, with a per-property row mapping for lookups.
|
||||
|
||||
use anyhow::Context;
|
||||
use polars::prelude::*;
|
||||
use rayon::prelude::*;
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::consts::{NAN_U16, QUANT_SCALE};
|
||||
use crate::features::{self, Bounds};
|
||||
|
||||
use super::quant::QuantRef;
|
||||
use super::stats::{column_to_f32_vec, compute_feature_stats, FeatureStats};
|
||||
|
||||
pub(super) const NO_POI_METRIC_ROW: u32 = u32::MAX;
|
||||
|
||||
pub struct PostcodePoiMetrics {
|
||||
pub feature_names: Vec<String>,
|
||||
pub name_to_index: FxHashMap<String, usize>,
|
||||
/// Metric-major storage: columns[metric_idx][postcode_metric_idx].
|
||||
pub columns: Vec<Vec<u16>>,
|
||||
pub feature_stats: Vec<FeatureStats>,
|
||||
/// Per-property row lookup into the postcode metric table.
|
||||
row_to_metric_idx: Vec<u32>,
|
||||
dequant_a: Vec<f32>,
|
||||
quant_min: Vec<f32>,
|
||||
quant_range: Vec<f32>,
|
||||
}
|
||||
|
||||
impl PostcodePoiMetrics {
|
||||
pub(super) fn empty(row_count: usize) -> Self {
|
||||
Self {
|
||||
feature_names: Vec::new(),
|
||||
name_to_index: FxHashMap::default(),
|
||||
columns: Vec::new(),
|
||||
feature_stats: Vec::new(),
|
||||
row_to_metric_idx: vec![NO_POI_METRIC_ROW; row_count],
|
||||
dequant_a: Vec::new(),
|
||||
quant_min: Vec::new(),
|
||||
quant_range: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn from_postcode_df(
|
||||
df: &DataFrame,
|
||||
feature_names: Vec<String>,
|
||||
) -> anyhow::Result<Self> {
|
||||
if feature_names.is_empty() {
|
||||
return Ok(Self::empty(0));
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
metrics = feature_names.len(),
|
||||
postcodes = df.height(),
|
||||
"Building postcode POI metric side table"
|
||||
);
|
||||
|
||||
let col_major: Vec<Vec<f32>> = feature_names
|
||||
.par_iter()
|
||||
.map(|name| {
|
||||
let column = df
|
||||
.column(name.as_str())
|
||||
.with_context(|| format!("Missing POI metric column '{name}'"))?;
|
||||
column_to_f32_vec(column)
|
||||
})
|
||||
.collect::<anyhow::Result<Vec<_>>>()?;
|
||||
|
||||
let feature_stats: Vec<FeatureStats> = col_major
|
||||
.par_iter()
|
||||
.enumerate()
|
||||
.map(|(metric_idx, vals)| {
|
||||
let name = feature_names[metric_idx].as_str();
|
||||
let bounds = features::bounds_for(name)
|
||||
.with_context(|| format!("No bounds config for POI metric '{name}'"))?;
|
||||
Ok(compute_feature_stats(
|
||||
vals,
|
||||
&bounds,
|
||||
features::has_integer_bins(name),
|
||||
))
|
||||
})
|
||||
.collect::<anyhow::Result<Vec<_>>>()?;
|
||||
|
||||
let mut quant_min = Vec::with_capacity(feature_names.len());
|
||||
let mut quant_range = Vec::with_capacity(feature_names.len());
|
||||
for (metric_idx, stats) in feature_stats.iter().enumerate() {
|
||||
let (min, max) = match features::bounds_for(feature_names[metric_idx].as_str()) {
|
||||
Some(Bounds::Fixed { min, max }) => (min, max),
|
||||
_ => (stats.histogram.min, stats.histogram.max),
|
||||
};
|
||||
quant_min.push(min);
|
||||
quant_range.push(if max > min { max - min } else { 0.0 });
|
||||
}
|
||||
let dequant_a: Vec<f32> = quant_range
|
||||
.iter()
|
||||
.map(|&range| {
|
||||
if range > 0.0 {
|
||||
range / QUANT_SCALE
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let columns: Vec<Vec<u16>> = col_major
|
||||
.par_iter()
|
||||
.enumerate()
|
||||
.map(|(metric_idx, vals)| {
|
||||
let range = quant_range[metric_idx];
|
||||
let min = quant_min[metric_idx];
|
||||
vals.iter()
|
||||
.map(|&value| {
|
||||
if !value.is_finite() {
|
||||
NAN_U16
|
||||
} else if range > 0.0 {
|
||||
let normalized = (value - min) / range;
|
||||
(normalized * QUANT_SCALE).round().clamp(0.0, QUANT_SCALE) as u16
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let name_to_index = feature_names
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, name)| (name.clone(), idx))
|
||||
.collect();
|
||||
|
||||
Ok(Self {
|
||||
feature_names,
|
||||
name_to_index,
|
||||
columns,
|
||||
feature_stats,
|
||||
row_to_metric_idx: Vec::new(),
|
||||
dequant_a,
|
||||
quant_min,
|
||||
quant_range,
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn set_row_mapping(&mut self, row_to_metric_idx: Vec<u32>) {
|
||||
self.row_to_metric_idx = row_to_metric_idx;
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.feature_names.is_empty()
|
||||
}
|
||||
|
||||
pub fn num_features(&self) -> usize {
|
||||
self.feature_names.len()
|
||||
}
|
||||
|
||||
pub fn quant_ref(&self) -> QuantRef<'_> {
|
||||
QuantRef {
|
||||
dequant_a: &self.dequant_a,
|
||||
quant_min: &self.quant_min,
|
||||
quant_range: &self.quant_range,
|
||||
num_numeric: self.feature_names.len(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn metric_row_for_property(&self, row: usize) -> Option<usize> {
|
||||
self.row_to_metric_idx
|
||||
.get(row)
|
||||
.copied()
|
||||
.filter(|&idx| idx != NO_POI_METRIC_ROW)
|
||||
.map(|idx| idx as usize)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn raw_for_metric_row(&self, metric_row: usize, metric_idx: usize) -> u16 {
|
||||
self.columns[metric_idx][metric_row]
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn raw_for_property_row(&self, row: usize, metric_idx: usize) -> u16 {
|
||||
let Some(metric_row) = self.metric_row_for_property(row) else {
|
||||
return NAN_U16;
|
||||
};
|
||||
self.raw_for_metric_row(metric_row, metric_idx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn decode_raw(&self, metric_idx: usize, raw: u16) -> f32 {
|
||||
if raw == NAN_U16 {
|
||||
f32::NAN
|
||||
} else {
|
||||
raw as f32 * self.dequant_a[metric_idx] + self.quant_min[metric_idx]
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn get_for_property_row(&self, row: usize, metric_idx: usize) -> f32 {
|
||||
self.decode_raw(metric_idx, self.raw_for_property_row(row, metric_idx))
|
||||
}
|
||||
}
|
||||
46
server-rs/src/data/property/quant.rs
Normal file
46
server-rs/src/data/property/quant.rs
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
//! u16 quantization: decoding stored feature values and encoding filter bounds.
|
||||
|
||||
use crate::consts::{NAN_U16, QUANT_SCALE};
|
||||
|
||||
/// Lightweight reference to quantization parameters for decoding u16 feature data.
|
||||
pub struct QuantRef<'a> {
|
||||
pub dequant_a: &'a [f32],
|
||||
pub quant_min: &'a [f32],
|
||||
pub quant_range: &'a [f32],
|
||||
pub num_numeric: usize,
|
||||
}
|
||||
|
||||
impl QuantRef<'_> {
|
||||
/// Decode a raw u16 value back to f32.
|
||||
#[inline]
|
||||
pub fn decode(&self, feat_idx: usize, raw: u16) -> f32 {
|
||||
if raw == NAN_U16 {
|
||||
return f32::NAN;
|
||||
}
|
||||
if feat_idx >= self.num_numeric {
|
||||
raw as f32
|
||||
} else {
|
||||
raw as f32 * self.dequant_a[feat_idx] + self.quant_min[feat_idx]
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode a filter minimum bound to u16 (floors to include boundary values).
|
||||
#[inline]
|
||||
pub fn encode_min(&self, feat_idx: usize, value: f32) -> u16 {
|
||||
if !value.is_finite() || self.quant_range[feat_idx] == 0.0 {
|
||||
return 0;
|
||||
}
|
||||
let norm = (value - self.quant_min[feat_idx]) / self.quant_range[feat_idx];
|
||||
(norm * QUANT_SCALE).floor().clamp(0.0, QUANT_SCALE) as u16
|
||||
}
|
||||
|
||||
/// Encode a filter maximum bound to u16 (ceils to include boundary values).
|
||||
#[inline]
|
||||
pub fn encode_max(&self, feat_idx: usize, value: f32) -> u16 {
|
||||
if !value.is_finite() || self.quant_range[feat_idx] == 0.0 {
|
||||
return QUANT_SCALE as u16;
|
||||
}
|
||||
let norm = (value - self.quant_min[feat_idx]) / self.quant_range[feat_idx];
|
||||
(norm * QUANT_SCALE).ceil().clamp(0.0, QUANT_SCALE) as u16
|
||||
}
|
||||
}
|
||||
544
server-rs/src/data/property/stats.rs
Normal file
544
server-rs/src/data/property/stats.rs
Normal file
|
|
@ -0,0 +1,544 @@
|
|||
//! Feature statistics: outlier-bracketed histograms, percentile estimation and
|
||||
//! slider-bound computation.
|
||||
|
||||
use anyhow::Context;
|
||||
use polars::prelude::*;
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::consts::HISTOGRAM_BINS;
|
||||
use crate::features::Bounds;
|
||||
|
||||
/// Histogram with outlier buckets at the edges.
|
||||
/// - Bin 0: [min, p1) — low outliers
|
||||
/// - Bins 1 to n-2: [p1, p99) — main distribution, evenly divided
|
||||
/// - Bin n-1: [p99, max] — high outliers
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct Histogram {
|
||||
pub min: f32,
|
||||
pub max: f32,
|
||||
/// 1st percentile (left edge of main distribution)
|
||||
pub p1: f32,
|
||||
/// 99th percentile (right edge of main distribution)
|
||||
pub p99: f32,
|
||||
pub counts: Vec<u64>,
|
||||
}
|
||||
|
||||
impl Histogram {
|
||||
/// Return the bin index for a given value using the outlier-bracket layout.
|
||||
#[cfg(test)]
|
||||
pub fn bin_for_value(&self, value: f32) -> usize {
|
||||
let num_bins = self.counts.len();
|
||||
if value < self.p1 {
|
||||
0
|
||||
} else if value >= self.p99 {
|
||||
num_bins - 1
|
||||
} else {
|
||||
let middle_bins = num_bins.saturating_sub(2);
|
||||
if middle_bins > 0 && self.p99 > self.p1 {
|
||||
let width = (self.p99 - self.p1) / middle_bins as f32;
|
||||
let middle_bin = ((value - self.p1) / width) as usize;
|
||||
(1 + middle_bin).min(num_bins - 2)
|
||||
} else {
|
||||
num_bins / 2
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Width of a single middle bin (bins 1..n-2).
|
||||
#[cfg(test)]
|
||||
pub fn middle_bin_width(&self) -> f32 {
|
||||
let middle_bins = self.counts.len().saturating_sub(2);
|
||||
if middle_bins > 0 && self.p99 > self.p1 {
|
||||
(self.p99 - self.p1) / middle_bins as f32
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FeatureStats {
|
||||
pub slider_min: f32,
|
||||
pub slider_max: f32,
|
||||
pub histogram: Histogram,
|
||||
}
|
||||
|
||||
/// Compute a percentile from a uniformly-binned histogram.
|
||||
/// `prelim_counts` are uniform bins over [min, max].
|
||||
fn percentile_from_uniform_histogram(
|
||||
count: usize,
|
||||
min: f32,
|
||||
max: f32,
|
||||
prelim_counts: &[u64],
|
||||
percentile: f32,
|
||||
) -> f32 {
|
||||
if count == 0 || prelim_counts.is_empty() {
|
||||
return min;
|
||||
}
|
||||
let target = (count as f64 * percentile as f64 / 100.0).floor() as u64;
|
||||
let bin_width = (max - min) / prelim_counts.len() as f32;
|
||||
let mut cumulative = 0u64;
|
||||
for (i, &bin_count) in prelim_counts.iter().enumerate() {
|
||||
let prev_cumulative = cumulative;
|
||||
cumulative += bin_count;
|
||||
if cumulative > target {
|
||||
// Interpolate within this bin
|
||||
let bin_start = min + i as f32 * bin_width;
|
||||
let fraction = if bin_count > 0 {
|
||||
(target - prev_cumulative) as f32 / bin_count as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
return bin_start + fraction * bin_width;
|
||||
}
|
||||
}
|
||||
max
|
||||
}
|
||||
|
||||
/// Build a histogram and compute slider bounds based on the feature's Bounds config.
|
||||
pub fn compute_feature_stats(vals: &[f32], bounds: &Bounds, integer_bins: bool) -> FeatureStats {
|
||||
// Single pass: min, max, count (skipping NaN and infinity)
|
||||
let mut min = f32::INFINITY;
|
||||
let mut max = f32::NEG_INFINITY;
|
||||
let mut count = 0usize;
|
||||
for &value in vals {
|
||||
if value.is_finite() {
|
||||
if value < min {
|
||||
min = value;
|
||||
}
|
||||
if value > max {
|
||||
max = value;
|
||||
}
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
let (slider_min, slider_max) = match bounds {
|
||||
Bounds::Fixed {
|
||||
min: fmin,
|
||||
max: fmax,
|
||||
} => (*fmin, *fmax),
|
||||
Bounds::Percentile { .. } => (0.0, 0.0),
|
||||
};
|
||||
return FeatureStats {
|
||||
slider_min,
|
||||
slider_max,
|
||||
histogram: Histogram {
|
||||
min: 0.0,
|
||||
max: 0.0,
|
||||
p1: 0.0,
|
||||
p99: 0.0,
|
||||
counts: vec![0; HISTOGRAM_BINS],
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Build preliminary histogram with uniform bins to compute percentiles
|
||||
// Use full HISTOGRAM_BINS for percentile precision
|
||||
let range = if max == min { 1.0 } else { max - min };
|
||||
let prelim_max = min + range * (1.0 + 1e-6);
|
||||
let prelim_bin_width = (prelim_max - min) / HISTOGRAM_BINS as f32;
|
||||
|
||||
let mut prelim_counts = vec![0u64; HISTOGRAM_BINS];
|
||||
for &value in vals {
|
||||
if value.is_finite() {
|
||||
let bin = ((value - min) / prelim_bin_width) as usize;
|
||||
prelim_counts[bin.min(HISTOGRAM_BINS - 1)] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute p1 and p99 from preliminary histogram
|
||||
let mut p1 = percentile_from_uniform_histogram(count, min, max, &prelim_counts, 1.0);
|
||||
let mut p99 = percentile_from_uniform_histogram(count, min, max, &prelim_counts, 99.0);
|
||||
|
||||
// Iterative refinement for outlier-dominated distributions.
|
||||
// When extreme outliers (e.g. 317M sqm from web scraping) dominate the range,
|
||||
// the uniform histogram puts all real data in one bin, making percentile
|
||||
// estimation useless. Zoom into the estimated data region and recompute.
|
||||
let mut refined_counts = prelim_counts;
|
||||
let mut refined_count = count;
|
||||
let mut refined_min = min;
|
||||
let mut refined_max = max;
|
||||
for _ in 0..3 {
|
||||
let iqr = p99 - p1;
|
||||
if iqr <= 0.0 || (refined_max - refined_min) <= 5.0 * iqr {
|
||||
break;
|
||||
}
|
||||
let new_min = (p1 - iqr).max(min);
|
||||
let new_max = p99 + iqr;
|
||||
if new_max <= new_min {
|
||||
break;
|
||||
}
|
||||
let bin_width = (new_max - new_min) / HISTOGRAM_BINS as f32;
|
||||
let mut counts = vec![0u64; HISTOGRAM_BINS];
|
||||
let mut cnt = 0usize;
|
||||
for &value in vals {
|
||||
if value.is_finite() && value >= new_min && value <= new_max {
|
||||
let bin = ((value - new_min) / bin_width) as usize;
|
||||
counts[bin.min(HISTOGRAM_BINS - 1)] += 1;
|
||||
cnt += 1;
|
||||
}
|
||||
}
|
||||
if cnt == 0 {
|
||||
break;
|
||||
}
|
||||
p1 = percentile_from_uniform_histogram(cnt, new_min, new_max, &counts, 1.0);
|
||||
p99 = percentile_from_uniform_histogram(cnt, new_min, new_max, &counts, 99.0);
|
||||
refined_counts = counts;
|
||||
refined_count = cnt;
|
||||
refined_min = new_min;
|
||||
refined_max = new_max;
|
||||
}
|
||||
|
||||
// For integer-binned features, snap p1/p99 to integer boundaries
|
||||
// so each middle bin is exactly 1 unit wide.
|
||||
if integer_bins {
|
||||
p1 = p1.floor();
|
||||
p99 = p99.ceil();
|
||||
}
|
||||
|
||||
// Determine number of histogram bins
|
||||
let num_bins = if integer_bins && p99 > p1 {
|
||||
// One middle bin per integer + 2 outlier bins
|
||||
(p99 - p1) as usize + 2
|
||||
} else {
|
||||
// Count unique values within the p1–p99 range to cap histogram bins.
|
||||
// Using the full-range cardinality would over-allocate bins when outliers
|
||||
// inflate it (e.g. bedrooms: 1–137 unique values but only ~10 within p1–p99).
|
||||
let cardinality = {
|
||||
let mut unique_set = rustc_hash::FxHashSet::default();
|
||||
for &val in vals {
|
||||
if val.is_finite() && val >= p1 && val <= p99 {
|
||||
unique_set.insert(val.to_bits());
|
||||
}
|
||||
}
|
||||
unique_set.len()
|
||||
};
|
||||
HISTOGRAM_BINS.min(cardinality).max(3)
|
||||
};
|
||||
|
||||
// Build final histogram with outlier bins at edges:
|
||||
// - Bin 0: [min, p1) — low outliers
|
||||
// - Bins 1 to n-2: [p1, p99) — main distribution, evenly divided
|
||||
// - Bin n-1: [p99, max] — high outliers
|
||||
let mut counts = vec![0u64; num_bins];
|
||||
let middle_bins = num_bins.saturating_sub(2);
|
||||
let middle_width = if middle_bins > 0 && p99 > p1 {
|
||||
(p99 - p1) / middle_bins as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
for &value in vals {
|
||||
if value.is_finite() {
|
||||
let bin = if value < p1 {
|
||||
0 // Low outlier bin
|
||||
} else if value >= p99 {
|
||||
num_bins - 1 // High outlier bin
|
||||
} else if middle_width > 0.0 {
|
||||
// Middle bins (1 to n-2)
|
||||
let middle_bin = ((value - p1) / middle_width) as usize;
|
||||
(1 + middle_bin).min(num_bins - 2)
|
||||
} else {
|
||||
num_bins / 2 // Fallback if p1 == p99
|
||||
};
|
||||
counts[bin] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let histogram = Histogram {
|
||||
min: refined_min,
|
||||
max: refined_max,
|
||||
p1,
|
||||
p99,
|
||||
counts,
|
||||
};
|
||||
|
||||
// Compute slider bounds (use refined histogram for accurate percentiles)
|
||||
let (slider_min, slider_max) = match bounds {
|
||||
Bounds::Fixed {
|
||||
min: fmin,
|
||||
max: fmax,
|
||||
} => (*fmin, *fmax),
|
||||
Bounds::Percentile { low, high } => {
|
||||
let p_low = percentile_from_uniform_histogram(
|
||||
refined_count,
|
||||
refined_min,
|
||||
refined_max,
|
||||
&refined_counts,
|
||||
*low as f32,
|
||||
);
|
||||
let p_high = percentile_from_uniform_histogram(
|
||||
refined_count,
|
||||
refined_min,
|
||||
refined_max,
|
||||
&refined_counts,
|
||||
*high as f32,
|
||||
);
|
||||
(p_low, p_high)
|
||||
}
|
||||
};
|
||||
|
||||
FeatureStats {
|
||||
slider_min,
|
||||
slider_max,
|
||||
histogram,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn column_to_f32_vec(column: &Column) -> anyhow::Result<Vec<f32>> {
|
||||
let float_series = column
|
||||
.cast(&DataType::Float32)
|
||||
.context("Failed to cast column to Float32")?;
|
||||
let chunked = float_series
|
||||
.f32()
|
||||
.context("Failed to get f32 chunked array")?;
|
||||
Ok(chunked
|
||||
.into_iter()
|
||||
.map(|value| value.unwrap_or(f32::NAN))
|
||||
.collect())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::consts::QUANT_SCALE;
|
||||
use crate::features::Bounds;
|
||||
|
||||
fn make_fixed_bounds(min: f32, max: f32) -> Bounds {
|
||||
Bounds::Fixed { min, max }
|
||||
}
|
||||
|
||||
fn make_percentile_bounds(low: f64, high: f64) -> Bounds {
|
||||
Bounds::Percentile { low, high }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn histogram_empty_data() {
|
||||
let data: Vec<f32> = vec![];
|
||||
let bounds = make_fixed_bounds(0.0, 100.0);
|
||||
let stats = compute_feature_stats(&data, &bounds, false);
|
||||
|
||||
assert_eq!(stats.slider_min, 0.0);
|
||||
assert_eq!(stats.slider_max, 100.0);
|
||||
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn histogram_single_value() {
|
||||
let data = vec![50.0_f32];
|
||||
let bounds = make_fixed_bounds(0.0, 100.0);
|
||||
let stats = compute_feature_stats(&data, &bounds, false);
|
||||
|
||||
assert_eq!(stats.histogram.min, 50.0);
|
||||
assert_eq!(stats.histogram.max, 50.0);
|
||||
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn histogram_uniform_distribution() {
|
||||
let data: Vec<f32> = (0..100).map(|i| i as f32).collect();
|
||||
let bounds = make_fixed_bounds(0.0, 100.0);
|
||||
let stats = compute_feature_stats(&data, &bounds, false);
|
||||
|
||||
assert_eq!(stats.histogram.min, 0.0);
|
||||
assert_eq!(stats.histogram.max, 99.0);
|
||||
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn histogram_with_nan_values() {
|
||||
let data = vec![10.0_f32, f32::NAN, 20.0, f32::NAN, 30.0];
|
||||
let bounds = make_fixed_bounds(0.0, 100.0);
|
||||
let stats = compute_feature_stats(&data, &bounds, false);
|
||||
|
||||
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 3);
|
||||
assert_eq!(stats.histogram.min, 10.0);
|
||||
assert_eq!(stats.histogram.max, 30.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn histogram_all_nan() {
|
||||
let data = vec![f32::NAN, f32::NAN, f32::NAN];
|
||||
let bounds = make_fixed_bounds(0.0, 100.0);
|
||||
let stats = compute_feature_stats(&data, &bounds, false);
|
||||
|
||||
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn histogram_all_same_value() {
|
||||
let data = vec![42.0_f32; 1000];
|
||||
let bounds = make_fixed_bounds(0.0, 100.0);
|
||||
let stats = compute_feature_stats(&data, &bounds, false);
|
||||
|
||||
assert_eq!(stats.histogram.min, 42.0);
|
||||
assert_eq!(stats.histogram.max, 42.0);
|
||||
assert_eq!(stats.histogram.p1, 42.0);
|
||||
assert_eq!(stats.histogram.p99, 42.0);
|
||||
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn histogram_percentile_bounds() {
|
||||
let mut data: Vec<f32> = vec![0.0]; // Low outlier
|
||||
data.extend((1..99).map(|i| 50.0 + i as f32 * 0.01));
|
||||
data.push(1000.0); // High outlier
|
||||
|
||||
let bounds = make_percentile_bounds(2.0, 98.0);
|
||||
let stats = compute_feature_stats(&data, &bounds, false);
|
||||
|
||||
assert!(stats.slider_min > 0.0);
|
||||
assert!(stats.slider_max < 1000.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fixed_price_bounds_keep_slider_cap() {
|
||||
let data = vec![400_000.0_f32, 2_500_000.0, 3_750_000.0];
|
||||
let bounds = make_fixed_bounds(0.0, 2_500_000.0);
|
||||
let stats = compute_feature_stats(&data, &bounds, false);
|
||||
|
||||
assert_eq!(stats.slider_min, 0.0);
|
||||
assert_eq!(stats.slider_max, 2_500_000.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn histogram_bin_for_value() {
|
||||
let hist = Histogram {
|
||||
min: 0.0,
|
||||
max: 100.0,
|
||||
p1: 10.0,
|
||||
p99: 90.0,
|
||||
counts: vec![0; 10],
|
||||
};
|
||||
|
||||
assert_eq!(hist.bin_for_value(5.0), 0); // Low outlier bin
|
||||
assert_eq!(hist.bin_for_value(95.0), 9); // High outlier bin
|
||||
|
||||
let mid_value = 50.0;
|
||||
let bin = hist.bin_for_value(mid_value);
|
||||
assert!((1..=8).contains(&bin));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn histogram_middle_bin_width() {
|
||||
let hist = Histogram {
|
||||
min: 0.0,
|
||||
max: 100.0,
|
||||
p1: 10.0,
|
||||
p99: 90.0,
|
||||
counts: vec![0; 10],
|
||||
};
|
||||
|
||||
let expected_width = (90.0 - 10.0) / 8.0;
|
||||
assert!((hist.middle_bin_width() - expected_width).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn histogram_cardinality_caps_bins() {
|
||||
let data = vec![1.0_f32, 1.0, 2.0, 2.0, 3.0, 3.0];
|
||||
let bounds = make_fixed_bounds(0.0, 100.0);
|
||||
let stats = compute_feature_stats(&data, &bounds, false);
|
||||
|
||||
assert_eq!(stats.histogram.counts.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn min_max_skips_nan() {
|
||||
let values = vec![10.0_f32, f32::NAN, 20.0, f32::NAN, 5.0];
|
||||
|
||||
let mut min = f32::INFINITY;
|
||||
let mut max = f32::NEG_INFINITY;
|
||||
for &v in &values {
|
||||
if v.is_finite() {
|
||||
if v < min {
|
||||
min = v;
|
||||
}
|
||||
if v > max {
|
||||
max = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(min, 5.0);
|
||||
assert_eq!(max, 20.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn count_skips_nan() {
|
||||
let values = [1.0_f32, f32::NAN, 2.0, f32::NAN, 3.0];
|
||||
let count = values.iter().filter(|v| v.is_finite()).count();
|
||||
assert_eq!(count, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn infinity_values_excluded() {
|
||||
let data = vec![f32::INFINITY, f32::NEG_INFINITY, 50.0];
|
||||
let bounds = Bounds::Fixed {
|
||||
min: 0.0,
|
||||
max: 100.0,
|
||||
};
|
||||
let stats = compute_feature_stats(&data, &bounds, false);
|
||||
|
||||
assert_eq!(stats.histogram.min, 50.0);
|
||||
assert_eq!(stats.histogram.max, 50.0);
|
||||
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn only_finite_values() {
|
||||
let data = vec![10.0_f32, 20.0, 30.0];
|
||||
let bounds = Bounds::Fixed {
|
||||
min: 0.0,
|
||||
max: 100.0,
|
||||
};
|
||||
let stats = compute_feature_stats(&data, &bounds, false);
|
||||
|
||||
assert_eq!(stats.histogram.min, 10.0);
|
||||
assert_eq!(stats.histogram.max, 30.0);
|
||||
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extreme_outlier_does_not_destroy_quantization() {
|
||||
// Simulate floor area: 10k normal values (50-200 sqm) + one 317M outlier
|
||||
let mut data: Vec<f32> = (0..10_000).map(|i| 50.0 + (i % 150) as f32).collect();
|
||||
data.push(317_000_000.0); // Extreme outlier from web scraping
|
||||
|
||||
let bounds = make_percentile_bounds(0.0, 98.0);
|
||||
let stats = compute_feature_stats(&data, &bounds, false);
|
||||
|
||||
// After refinement, histogram range should be much tighter than 317M
|
||||
assert!(
|
||||
stats.histogram.max < 1_000_000.0,
|
||||
"histogram.max should be refined, got {}",
|
||||
stats.histogram.max,
|
||||
);
|
||||
// p1 should be near 50, not millions
|
||||
assert!(
|
||||
stats.histogram.p1 < 100.0,
|
||||
"p1 should be near real data, got {}",
|
||||
stats.histogram.p1,
|
||||
);
|
||||
// Slider min should reflect actual data range
|
||||
assert!(
|
||||
stats.slider_min < 100.0,
|
||||
"slider_min should be near real data, got {}",
|
||||
stats.slider_min,
|
||||
);
|
||||
|
||||
// Quantization using histogram.min/max should give usable range
|
||||
let qmin = stats.histogram.min;
|
||||
let qrange = stats.histogram.max - stats.histogram.min;
|
||||
assert!(qrange > 0.0 && qrange < 1_000_000.0);
|
||||
|
||||
// A typical floor area (100 sqm) should be distinguishable from min
|
||||
let normalized = (100.0 - qmin) / qrange;
|
||||
let encoded = (normalized * QUANT_SCALE).round() as u16;
|
||||
assert!(
|
||||
encoded > 100,
|
||||
"100 sqm should encode to a meaningful u16 value, got {}",
|
||||
encoded,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -272,6 +272,21 @@ pub fn slugify(name: &str) -> String {
|
|||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl TravelTimeStore {
|
||||
/// Minimal empty instance for integration tests that need an `AppState`
|
||||
/// but never touch travel time data.
|
||||
pub(crate) fn empty_for_tests() -> Self {
|
||||
Self {
|
||||
base_dir: PathBuf::new(),
|
||||
available_modes: Vec::new(),
|
||||
destinations: FxHashMap::default(),
|
||||
slug_to_file: FxHashMap::default(),
|
||||
cache: Mutex::new(LruCache::new(1)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
|||
|
|
@ -1042,7 +1042,44 @@ async fn main() -> anyhow::Result<()> {
|
|||
listener,
|
||||
app.into_make_service_with_connect_info::<std::net::SocketAddr>(),
|
||||
)
|
||||
.with_graceful_shutdown(shutdown_signal())
|
||||
.await
|
||||
.context("Server error")?;
|
||||
info!("Server shut down cleanly");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Resolves on SIGTERM or SIGINT so in-flight requests (exports, checkouts)
|
||||
/// can drain before the process exits. The realtime SSE proxy connections
|
||||
/// never complete, so a watchdog force-exits before Docker's default 10s
|
||||
/// stop grace period elapses and it sends SIGKILL.
|
||||
async fn shutdown_signal() {
|
||||
let ctrl_c = async {
|
||||
tokio::signal::ctrl_c()
|
||||
.await
|
||||
.expect("Failed to install SIGINT handler");
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
let terminate = async {
|
||||
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
|
||||
.expect("Failed to install SIGTERM handler")
|
||||
.recv()
|
||||
.await;
|
||||
};
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let terminate = std::future::pending::<()>();
|
||||
|
||||
tokio::select! {
|
||||
() = ctrl_c => {},
|
||||
() = terminate => {},
|
||||
}
|
||||
info!("Shutdown signal received; draining in-flight requests");
|
||||
|
||||
tokio::spawn(async {
|
||||
tokio::time::sleep(Duration::from_secs(8)).await;
|
||||
tracing::warn!("Graceful shutdown drain timed out after 8s; forcing exit");
|
||||
std::process::exit(0);
|
||||
});
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
448
server-rs/src/routes/ai_filters/handler.rs
Normal file
448
server-rs/src/routes/ai_filters/handler.rs
Normal file
|
|
@ -0,0 +1,448 @@
|
|||
//! The `POST /api/ai-filters` route handler: rate limiting, the Gemini
|
||||
//! function-calling conversation loop, and zero-match refinement.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::extract::State;
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::Json;
|
||||
use axum::Extension;
|
||||
use metrics::counter;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::auth::OptionalUser;
|
||||
use crate::consts::{AI_FILTERS_MAX_TOKENS, AI_FILTERS_TEMPERATURE, AI_FILTERS_WEEKLY_TOKEN_LIMIT};
|
||||
use crate::pocketbase::log_ai_query;
|
||||
use crate::state::SharedState;
|
||||
use crate::utils::gemini_chat;
|
||||
|
||||
use super::matching::count_matching_rows;
|
||||
use super::parsing::{
|
||||
normalize_context_filters, strip_markdown_fences, validate_and_convert,
|
||||
validate_travel_time_filters,
|
||||
};
|
||||
use super::tools::{build_tool_declarations, execute_destination_search};
|
||||
use super::usage::{current_week_number, fetch_ai_usage, record_ai_request_usage};
|
||||
use super::{AiFiltersRequest, AiFiltersResponse};
|
||||
|
||||
/// Budget limits for the Gemini conversation loop. Separate counters prevent
|
||||
/// tool calls (destination searches) from starving JSON retries or zero-match
|
||||
/// refinements.
|
||||
const MAX_TOOL_CALLS: usize = 4;
|
||||
const MAX_RETRIES: usize = 3;
|
||||
const MAX_REFINEMENTS: u32 = 3;
|
||||
const MAX_TOTAL_ROUNDS: usize = 10;
|
||||
|
||||
const MAX_AI_QUERY_CHARS: usize = 5000;
|
||||
|
||||
pub async fn post_ai_filters(
|
||||
State(shared): State<Arc<SharedState>>,
|
||||
Extension(user): Extension<OptionalUser>,
|
||||
Json(req): Json<AiFiltersRequest>,
|
||||
) -> Result<Json<AiFiltersResponse>, (StatusCode, String)> {
|
||||
let state = shared.load_state();
|
||||
// Auth check
|
||||
let user = user
|
||||
.0
|
||||
.ok_or((StatusCode::UNAUTHORIZED, "Login required".into()))?;
|
||||
|
||||
if req.query.chars().count() > MAX_AI_QUERY_CHARS {
|
||||
counter!("ai_requests_total", "status" => "query_too_long").increment(1);
|
||||
return Err((
|
||||
StatusCode::PAYLOAD_TOO_LARGE,
|
||||
format!("Query too long (max {MAX_AI_QUERY_CHARS} chars)"),
|
||||
));
|
||||
}
|
||||
|
||||
// Check weekly token usage
|
||||
let current_week = current_week_number();
|
||||
let (stored_tokens, stored_week) = fetch_ai_usage(&state, &user.id).await?;
|
||||
let tokens_used = if stored_week == current_week {
|
||||
stored_tokens
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
if tokens_used >= AI_FILTERS_WEEKLY_TOKEN_LIMIT {
|
||||
counter!("ai_requests_total", "status" => "rate_limited").increment(1);
|
||||
return Err((
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
"Weekly AI usage limit reached. Resets next week.".into(),
|
||||
));
|
||||
}
|
||||
|
||||
info!(query = %req.query, user_id = %user.id, "POST /api/ai-filters");
|
||||
|
||||
let tools = build_tool_declarations(&state);
|
||||
|
||||
// Build user message with optional context for conversational refinement
|
||||
let user_text = if let Some(ref ctx) = req.context {
|
||||
let mut msg = String::new();
|
||||
msg.push_str("Currently active filters:\n");
|
||||
let normalized_filters = normalize_context_filters(&ctx.filters);
|
||||
msg.push_str(&serde_json::to_string(&normalized_filters).unwrap_or_default());
|
||||
if !ctx.travel_time.is_empty() {
|
||||
msg.push_str("\nCurrently active travel time filters:\n");
|
||||
for tt in &ctx.travel_time {
|
||||
let bounds = match (tt.min, tt.max) {
|
||||
(Some(min), Some(max)) => format!("{}-{} min", min, max),
|
||||
(Some(min), None) => format!("min {} min", min),
|
||||
(None, Some(max)) => format!("max {} min", max),
|
||||
(None, None) => "no range".to_string(),
|
||||
};
|
||||
msg.push_str(&format!("- {} to {} ({})\n", tt.mode, tt.label, bounds));
|
||||
}
|
||||
}
|
||||
msg.push_str(&format!("\nUser request: {}", req.query));
|
||||
msg
|
||||
} else {
|
||||
req.query.clone()
|
||||
};
|
||||
|
||||
let mut contents = vec![json!({
|
||||
"role": "user",
|
||||
"parts": [{ "text": user_text }]
|
||||
})];
|
||||
|
||||
let mut total_tokens_accumulated: u64 = 0;
|
||||
let mut tool_call_count = 0usize;
|
||||
let mut retry_count = 0usize;
|
||||
let mut refinement_attempts = 0u32;
|
||||
|
||||
// Function calling loop: model may call search_destinations, we execute and feed back
|
||||
for round in 0..MAX_TOTAL_ROUNDS {
|
||||
let body = json!({
|
||||
"systemInstruction": {
|
||||
"parts": [{ "text": state.ai_filters_system_prompt }]
|
||||
},
|
||||
"contents": contents,
|
||||
"tools": tools,
|
||||
"generationConfig": {
|
||||
"temperature": AI_FILTERS_TEMPERATURE,
|
||||
"maxOutputTokens": AI_FILTERS_MAX_TOKENS,
|
||||
"thinkingConfig": { "thinkingLevel": "LOW" },
|
||||
}
|
||||
});
|
||||
|
||||
let json_resp = match gemini_chat(
|
||||
&state.http_client,
|
||||
&state.gemini_api_key,
|
||||
&state.gemini_model,
|
||||
&body,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(err) => {
|
||||
record_ai_request_usage(
|
||||
&state,
|
||||
&user.id,
|
||||
tokens_used,
|
||||
current_week,
|
||||
total_tokens_accumulated,
|
||||
"llm_error",
|
||||
)
|
||||
.await;
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
// Accumulate token usage
|
||||
total_tokens_accumulated += json_resp
|
||||
.get("usageMetadata")
|
||||
.and_then(|md| md.get("totalTokenCount"))
|
||||
.and_then(|tc| tc.as_u64())
|
||||
.unwrap_or(0);
|
||||
|
||||
let candidate = match json_resp
|
||||
.get("candidates")
|
||||
.and_then(|cs| cs.get(0))
|
||||
.and_then(|c| c.get("content"))
|
||||
{
|
||||
Some(candidate) => candidate,
|
||||
None => {
|
||||
warn!("Malformed Gemini response: missing candidates[0].content");
|
||||
record_ai_request_usage(
|
||||
&state,
|
||||
&user.id,
|
||||
tokens_used,
|
||||
current_week,
|
||||
total_tokens_accumulated,
|
||||
"malformed_response",
|
||||
)
|
||||
.await;
|
||||
return Err((StatusCode::BAD_GATEWAY, "Malformed Gemini response".into()));
|
||||
}
|
||||
};
|
||||
|
||||
let parts = match candidate.get("parts").and_then(|p| p.as_array()) {
|
||||
Some(parts) => parts,
|
||||
None => {
|
||||
warn!("Malformed Gemini response: missing parts array");
|
||||
record_ai_request_usage(
|
||||
&state,
|
||||
&user.id,
|
||||
tokens_used,
|
||||
current_week,
|
||||
total_tokens_accumulated,
|
||||
"malformed_response",
|
||||
)
|
||||
.await;
|
||||
return Err((StatusCode::BAD_GATEWAY, "Malformed Gemini response".into()));
|
||||
}
|
||||
};
|
||||
|
||||
// Check if the model made a function call.
|
||||
// Find the full part (includes thoughtSignature required by Gemini 3 models).
|
||||
if let Some(fc_part) = parts.iter().find(|part| part.get("functionCall").is_some()) {
|
||||
let fc = fc_part.get("functionCall").unwrap();
|
||||
let fn_name = fc.get("name").and_then(|n| n.as_str()).unwrap_or("");
|
||||
let fn_args = fc.get("args").cloned().unwrap_or(json!({}));
|
||||
|
||||
tool_call_count += 1;
|
||||
info!(
|
||||
function = fn_name,
|
||||
round = round,
|
||||
tool_call = tool_call_count,
|
||||
"AI called tool"
|
||||
);
|
||||
|
||||
if tool_call_count > MAX_TOOL_CALLS {
|
||||
warn!("Tool call budget exhausted, forcing text output");
|
||||
contents.push(candidate.clone());
|
||||
contents.push(json!({
|
||||
"role": "user",
|
||||
"parts": [{ "text": "Tool call limit reached. Output your best JSON now using the destinations you already found. Do not call any more tools." }]
|
||||
}));
|
||||
continue;
|
||||
}
|
||||
|
||||
let fn_result = if fn_name == "search_destinations" {
|
||||
let query = fn_args.get("query").and_then(|q| q.as_str()).unwrap_or("");
|
||||
let mode = fn_args
|
||||
.get("mode")
|
||||
.and_then(|m| m.as_str())
|
||||
.unwrap_or("transit");
|
||||
execute_destination_search(&state, query, mode)
|
||||
} else {
|
||||
json!({"error": "unknown function"})
|
||||
};
|
||||
|
||||
// Append the model's full response (preserves thoughtSignature) + our function result
|
||||
contents.push(candidate.clone());
|
||||
contents.push(json!({
|
||||
"role": "user",
|
||||
"parts": [{
|
||||
"functionResponse": {
|
||||
"name": fn_name,
|
||||
"response": fn_result
|
||||
}
|
||||
}]
|
||||
}));
|
||||
|
||||
// Continue the loop — model will process the results
|
||||
continue;
|
||||
}
|
||||
|
||||
// Model returned text — extract and parse as JSON
|
||||
let text = parts
|
||||
.iter()
|
||||
.find_map(|part| part.get("text").and_then(|t| t.as_str()))
|
||||
.unwrap_or("");
|
||||
let text = strip_markdown_fences(text);
|
||||
let text = text.trim();
|
||||
|
||||
if text.is_empty() {
|
||||
retry_count += 1;
|
||||
warn!(
|
||||
"Gemini returned empty text content (round {}, retry {})",
|
||||
round, retry_count
|
||||
);
|
||||
if retry_count > MAX_RETRIES {
|
||||
record_ai_request_usage(
|
||||
&state,
|
||||
&user.id,
|
||||
tokens_used,
|
||||
current_week,
|
||||
total_tokens_accumulated,
|
||||
"empty_response",
|
||||
)
|
||||
.await;
|
||||
return Err((
|
||||
StatusCode::BAD_GATEWAY,
|
||||
"AI returned empty responses".into(),
|
||||
));
|
||||
}
|
||||
contents.push(candidate.clone());
|
||||
contents.push(json!({
|
||||
"role": "user",
|
||||
"parts": [{ "text": "Your response was empty. Please output the JSON object." }]
|
||||
}));
|
||||
continue;
|
||||
}
|
||||
|
||||
let raw: Value = match serde_json::from_str(text) {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
retry_count += 1;
|
||||
warn!(error = %err, round = round, retry = retry_count, "Failed to parse Gemini JSON output");
|
||||
if retry_count > MAX_RETRIES {
|
||||
record_ai_request_usage(
|
||||
&state,
|
||||
&user.id,
|
||||
tokens_used,
|
||||
current_week,
|
||||
total_tokens_accumulated,
|
||||
"invalid_json",
|
||||
)
|
||||
.await;
|
||||
return Err((StatusCode::BAD_GATEWAY, "AI returned invalid JSON".into()));
|
||||
}
|
||||
contents.push(candidate.clone());
|
||||
contents.push(json!({
|
||||
"role": "user",
|
||||
"parts": [{ "text": "That was not valid JSON. Please output ONLY the JSON object with numeric_filters, enum_filters, travel_time_filters, and notes." }]
|
||||
}));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let filters = validate_and_convert(&raw, &state.features_response);
|
||||
let travel_time_filters = validate_travel_time_filters(&raw, &state);
|
||||
let notes = raw
|
||||
.get("notes")
|
||||
.and_then(|val| val.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
// Count matching properties and refine if too restrictive
|
||||
let (match_count, match_bounds) =
|
||||
count_matching_rows(&state, &filters, &travel_time_filters);
|
||||
info!(
|
||||
match_count = match_count,
|
||||
round = round,
|
||||
"AI filter match count"
|
||||
);
|
||||
|
||||
if match_count == 0 {
|
||||
refinement_attempts += 1;
|
||||
let total_rows = state.data.lat.len();
|
||||
info!(
|
||||
attempt = refinement_attempts,
|
||||
"0 matches out of {total_rows} — asking AI to relax filters"
|
||||
);
|
||||
|
||||
if refinement_attempts > MAX_REFINEMENTS {
|
||||
warn!("Refinement budget exhausted, returning filters with 0 matches");
|
||||
record_ai_request_usage(
|
||||
&state,
|
||||
&user.id,
|
||||
tokens_used,
|
||||
current_week,
|
||||
total_tokens_accumulated,
|
||||
"zero_matches",
|
||||
)
|
||||
.await;
|
||||
|
||||
let notes = if notes.is_empty() {
|
||||
"No properties match these filters. Try relaxing some constraints.".to_string()
|
||||
} else {
|
||||
format!(
|
||||
"{}. No properties match. Try relaxing some constraints.",
|
||||
notes
|
||||
)
|
||||
};
|
||||
|
||||
return Ok(Json(AiFiltersResponse {
|
||||
filters,
|
||||
travel_time_filters,
|
||||
notes,
|
||||
match_count: 0,
|
||||
match_bounds: None,
|
||||
}));
|
||||
}
|
||||
|
||||
let feedback = match refinement_attempts {
|
||||
1 => format!(
|
||||
"Your proposed filters matched 0 properties out of {total_rows} total. \
|
||||
The combination is too restrictive. Please widen some numeric ranges \
|
||||
or add more enum values while keeping the user's intent. \
|
||||
Output the adjusted JSON."
|
||||
),
|
||||
2 => format!(
|
||||
"Still 0 matches out of {total_rows}. Please widen ranges further. \
|
||||
Output the adjusted JSON."
|
||||
),
|
||||
_ => format!(
|
||||
"Still 0 matches out of {total_rows}. Please remove additional filters \
|
||||
until some properties match, keeping the user's core priority. \
|
||||
Output the adjusted JSON."
|
||||
),
|
||||
};
|
||||
contents.push(candidate.clone());
|
||||
contents.push(json!({
|
||||
"role": "user",
|
||||
"parts": [{ "text": feedback }]
|
||||
}));
|
||||
continue;
|
||||
}
|
||||
|
||||
record_ai_request_usage(
|
||||
&state,
|
||||
&user.id,
|
||||
tokens_used,
|
||||
current_week,
|
||||
total_tokens_accumulated,
|
||||
"success",
|
||||
)
|
||||
.await;
|
||||
|
||||
// Log the query to PocketBase (fire-and-forget)
|
||||
let filters_json = serde_json::to_string(&filters).unwrap_or_default();
|
||||
let log_state = state.clone();
|
||||
let log_user_id = user.id.clone();
|
||||
let log_query = req.query.clone();
|
||||
let log_notes = notes.clone();
|
||||
let log_rounds = (round + 1) as u64;
|
||||
tokio::spawn(async move {
|
||||
log_ai_query(
|
||||
&log_state,
|
||||
&log_user_id,
|
||||
&log_query,
|
||||
&filters_json,
|
||||
&log_notes,
|
||||
total_tokens_accumulated,
|
||||
log_rounds,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
return Ok(Json(AiFiltersResponse {
|
||||
filters,
|
||||
travel_time_filters,
|
||||
notes,
|
||||
match_count,
|
||||
match_bounds,
|
||||
}));
|
||||
}
|
||||
|
||||
// Exhausted total round budget without getting a valid response
|
||||
warn!(
|
||||
"AI exhausted {} total rounds without final response (tools={}, retries={}, refinements={})",
|
||||
MAX_TOTAL_ROUNDS, tool_call_count, retry_count, refinement_attempts
|
||||
);
|
||||
record_ai_request_usage(
|
||||
&state,
|
||||
&user.id,
|
||||
tokens_used,
|
||||
current_week,
|
||||
total_tokens_accumulated,
|
||||
"incomplete",
|
||||
)
|
||||
.await;
|
||||
Err((
|
||||
StatusCode::BAD_GATEWAY,
|
||||
"AI could not complete the request".into(),
|
||||
))
|
||||
}
|
||||
158
server-rs/src/routes/ai_filters/matching.rs
Normal file
158
server-rs/src/routes/ai_filters/matching.rs
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
//! Counting properties that match the AI-proposed property and travel time
|
||||
//! filters, and computing a camera-friendly bounding box of the matches.
|
||||
|
||||
use serde_json::Value;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::data::travel_time::TravelData;
|
||||
use crate::parsing::{parse_filters_with_poi, row_passes_filters, row_passes_poi_filters};
|
||||
use crate::state::AppState;
|
||||
|
||||
use super::{MatchBounds, TravelTimeFilter};
|
||||
|
||||
/// Bounding box over matched coordinates, trimmed to the 5th–95th percentile
|
||||
/// per axis (when there are enough points) so a handful of remote outliers
|
||||
/// doesn't zoom the camera out to all of England.
|
||||
fn percentile_trimmed_bounds(mut lats: Vec<f32>, mut lons: Vec<f32>) -> Option<MatchBounds> {
|
||||
if lats.is_empty() || lats.len() != lons.len() {
|
||||
return None;
|
||||
}
|
||||
lats.sort_unstable_by(f32::total_cmp);
|
||||
lons.sort_unstable_by(f32::total_cmp);
|
||||
let last = lats.len() - 1;
|
||||
let (lo, hi) = if lats.len() >= 20 {
|
||||
let trim = lats.len() / 20;
|
||||
(trim, last - trim)
|
||||
} else {
|
||||
(0, last)
|
||||
};
|
||||
Some(MatchBounds {
|
||||
south: lats[lo],
|
||||
north: lats[hi],
|
||||
west: lons[lo],
|
||||
east: lons[hi],
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert validated filter JSON back to the `;;`-separated filter string format
|
||||
/// that `parse_filters` expects.
|
||||
///
|
||||
/// Numeric: `{"name": [min, max]}` → `name:min:max`
|
||||
/// Enum: `{"name": ["val1", "val2"]}` → `name:val1|val2`
|
||||
fn filters_to_filter_string(filters: &Value) -> String {
|
||||
let obj = match filters.as_object() {
|
||||
Some(obj) => obj,
|
||||
None => return String::new(),
|
||||
};
|
||||
|
||||
let mut parts = Vec::new();
|
||||
for (name, value) in obj {
|
||||
if let Some(arr) = value.as_array() {
|
||||
if arr.len() == 2 && arr[0].is_number() && arr[1].is_number() {
|
||||
let min = arr[0].as_f64().unwrap_or(0.0);
|
||||
let max = arr[1].as_f64().unwrap_or(0.0);
|
||||
parts.push(format!("{name}:{min}:{max}"));
|
||||
} else if !arr.is_empty() && arr[0].is_string() {
|
||||
let values: Vec<&str> = arr.iter().filter_map(|v| v.as_str()).collect();
|
||||
if !values.is_empty() {
|
||||
parts.push(format!("{name}:{}", values.join("|")));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
parts.join(";;")
|
||||
}
|
||||
|
||||
/// Count how many rows in the property dataset pass the given property filters
|
||||
/// AND travel time filters. Travel time data is loaded from the TravelTimeStore
|
||||
/// and checked per-postcode (same logic as hexagons.rs).
|
||||
pub(super) fn count_matching_rows(
|
||||
state: &AppState,
|
||||
filters: &Value,
|
||||
travel_time_filters: &[TravelTimeFilter],
|
||||
) -> (usize, Option<MatchBounds>) {
|
||||
let filter_str = filters_to_filter_string(filters);
|
||||
|
||||
let quant = state.data.quant_ref();
|
||||
let poi_quant = state.data.poi_metrics.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = if filter_str.is_empty() {
|
||||
(Vec::new(), Vec::new(), Vec::new())
|
||||
} else {
|
||||
match parse_filters_with_poi(
|
||||
Some(&filter_str),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
&quant,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
&poi_quant,
|
||||
) {
|
||||
Ok(f) => f,
|
||||
Err(err) => {
|
||||
warn!("Failed to parse filters for match count: {err}");
|
||||
return (0, None);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Load travel time data for each filter entry
|
||||
let travel_data: Vec<(TravelData, Option<f32>, Option<f32>)> = travel_time_filters
|
||||
.iter()
|
||||
.filter_map(|ttf| {
|
||||
let data = state.travel_time_store.get(&ttf.mode, &ttf.slug).ok()?;
|
||||
Some((data, ttf.min, ttf.max))
|
||||
})
|
||||
.collect();
|
||||
let has_travel = !travel_data.is_empty();
|
||||
|
||||
let feature_data = &state.data.feature_data;
|
||||
let num_features = state.data.num_features;
|
||||
let num_rows = state.data.lat.len();
|
||||
let (pc_interner, pc_keys) = state.data.postcode_parts();
|
||||
let has_poi_filters = !parsed_poi_filters.is_empty();
|
||||
|
||||
let mut count = 0usize;
|
||||
let mut matched_lats: Vec<f32> = Vec::new();
|
||||
let mut matched_lons: Vec<f32> = Vec::new();
|
||||
for (row, pc_key) in pc_keys.iter().enumerate().take(num_rows) {
|
||||
if !row_passes_filters(
|
||||
row,
|
||||
&parsed_filters,
|
||||
&parsed_enum_filters,
|
||||
feature_data,
|
||||
num_features,
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
if has_poi_filters
|
||||
&& !row_passes_poi_filters(row, &parsed_poi_filters, &state.data.poi_metrics)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if has_travel {
|
||||
let postcode = pc_interner.resolve(pc_key);
|
||||
let mut passes_travel = true;
|
||||
for (data, fmin, fmax) in &travel_data {
|
||||
let pass = if let Some(mins) = data.get(postcode).map(|r| r.minutes as f32) {
|
||||
fmin.is_none_or(|min| mins >= min) && fmax.is_none_or(|max| mins <= max)
|
||||
} else {
|
||||
false // no travel data → postcode not reachable
|
||||
};
|
||||
if !pass {
|
||||
passes_travel = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !passes_travel {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
count += 1;
|
||||
matched_lats.push(state.data.lat[row]);
|
||||
matched_lons.push(state.data.lon[row]);
|
||||
}
|
||||
|
||||
(count, percentile_trimmed_bounds(matched_lats, matched_lons))
|
||||
}
|
||||
81
server-rs/src/routes/ai_filters/mod.rs
Normal file
81
server-rs/src/routes/ai_filters/mod.rs
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
//! AI filters: translate a natural-language property query into validated
|
||||
//! filter settings via Gemini.
|
||||
//!
|
||||
//! Split by concern:
|
||||
//! - [`handler`]: the `POST /api/ai-filters` route handler and Gemini
|
||||
//! conversation loop
|
||||
//! - [`prompt`]: system prompt building (precomputed at startup)
|
||||
//! - [`tools`]: the `search_destinations` tool declaration and execution
|
||||
//! - [`parsing`]: LLM response parsing and validation against feature metadata
|
||||
//! - [`matching`]: counting properties that match the proposed filters
|
||||
//! - [`usage`]: weekly token usage tracking / rate limiting
|
||||
|
||||
mod handler;
|
||||
mod matching;
|
||||
mod parsing;
|
||||
mod prompt;
|
||||
mod tools;
|
||||
mod usage;
|
||||
|
||||
pub use handler::post_ai_filters;
|
||||
pub use prompt::build_system_prompt;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct AiFiltersContext {
|
||||
filters: Value,
|
||||
#[serde(default)]
|
||||
travel_time: Vec<AiTravelTimeContext>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct AiTravelTimeContext {
|
||||
mode: String,
|
||||
label: String,
|
||||
min: Option<f32>,
|
||||
max: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct AiFiltersRequest {
|
||||
query: String,
|
||||
/// Current filters for conversational refinement (e.g. "make it cheaper")
|
||||
context: Option<AiFiltersContext>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct TravelTimeFilter {
|
||||
mode: String,
|
||||
slug: String,
|
||||
label: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
min: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
max: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct AiFiltersResponse {
|
||||
filters: Value,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
travel_time_filters: Vec<TravelTimeFilter>,
|
||||
/// What the LLM couldn't map to existing filters (empty if everything matched)
|
||||
#[serde(skip_serializing_if = "String::is_empty")]
|
||||
notes: String,
|
||||
/// Number of properties matching the proposed property and travel time filters.
|
||||
match_count: usize,
|
||||
/// Bounding box of the matching properties so the client can move the
|
||||
/// camera to where matches actually are. Absent when nothing matches.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
match_bounds: Option<MatchBounds>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct MatchBounds {
|
||||
south: f32,
|
||||
west: f32,
|
||||
north: f32,
|
||||
east: f32,
|
||||
}
|
||||
385
server-rs/src/routes/ai_filters/parsing.rs
Normal file
385
server-rs/src/routes/ai_filters/parsing.rs
Normal file
|
|
@ -0,0 +1,385 @@
|
|||
//! LLM response parsing: stripping markdown fences, normalizing frontend
|
||||
//! synthetic filter keys, and validating proposed filters against feature
|
||||
//! metadata and available travel destinations.
|
||||
|
||||
use serde_json::{json, Map, Value};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::routes::{FeatureInfo, FeaturesResponse};
|
||||
use crate::state::AppState;
|
||||
|
||||
use super::TravelTimeFilter;
|
||||
|
||||
/// Strip markdown code fences (```json ... ``` or ``` ... ```) from LLM output.
|
||||
/// Models occasionally wrap JSON in markdown fencing even when told not to.
|
||||
pub(super) fn strip_markdown_fences(text: &str) -> &str {
|
||||
let trimmed = text.trim();
|
||||
|
||||
// Try ```json\n...\n``` or ```\n...\n```
|
||||
if let Some(rest) = trimmed.strip_prefix("```") {
|
||||
// Skip optional language tag (e.g. "json")
|
||||
let rest = if let Some(newline_pos) = rest.find('\n') {
|
||||
&rest[newline_pos + 1..]
|
||||
} else {
|
||||
return trimmed;
|
||||
};
|
||||
if let Some(content) = rest.strip_suffix("```") {
|
||||
return content.trim();
|
||||
}
|
||||
}
|
||||
|
||||
trimmed
|
||||
}
|
||||
|
||||
fn school_feature_name_from_key(name: &str) -> Option<&'static str> {
|
||||
let rest = name.strip_prefix("Schools:")?;
|
||||
let mut parts = rest.split(':');
|
||||
let phase = parts.next()?;
|
||||
let rating = parts.next()?;
|
||||
|
||||
match (phase, rating) {
|
||||
("primary", "good") => Some("Good+ primary school catchments"),
|
||||
("secondary", "good") => Some("Good+ secondary school catchments"),
|
||||
("primary", "outstanding") => Some("Outstanding primary school catchments"),
|
||||
("secondary", "outstanding") => Some("Outstanding secondary school catchments"),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_synthetic_feature_key(name: &str, prefix: &str) -> Option<String> {
|
||||
let rest = name.strip_prefix(prefix)?;
|
||||
let (encoded, _id) = rest.rsplit_once(':')?;
|
||||
urlencoding::decode(encoded)
|
||||
.ok()
|
||||
.map(|decoded| decoded.into_owned())
|
||||
}
|
||||
|
||||
/// Convert frontend synthetic filter keys back to backend feature names.
|
||||
///
|
||||
/// The React filter UI stores configurable cards under keys such as
|
||||
/// `Political vote share:%25%20Labour:0`. The LLM and backend validators need
|
||||
/// the real feature name (`% Labour`) instead.
|
||||
fn backend_filter_name(name: &str) -> Option<String> {
|
||||
if let Some(feature_name) = school_feature_name_from_key(name) {
|
||||
return Some(feature_name.to_string());
|
||||
}
|
||||
|
||||
for prefix in [
|
||||
"Specific crimes:",
|
||||
"Political vote share:",
|
||||
"Ethnicities:",
|
||||
"Amenity distance:",
|
||||
"Transport distance:",
|
||||
"Amenities within 2km:",
|
||||
"Amenities within 5km:",
|
||||
] {
|
||||
if let Some(feature_name) = decode_synthetic_feature_key(name, prefix) {
|
||||
return Some(feature_name);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn canonical_filter_name(name: &str) -> String {
|
||||
backend_filter_name(name).unwrap_or_else(|| name.to_string())
|
||||
}
|
||||
|
||||
pub(super) fn normalize_context_filters(filters: &Value) -> Value {
|
||||
let Some(obj) = filters.as_object() else {
|
||||
return filters.clone();
|
||||
};
|
||||
|
||||
let mut normalized = Map::with_capacity(obj.len());
|
||||
for (name, value) in obj {
|
||||
normalized.insert(canonical_filter_name(name), value.clone());
|
||||
}
|
||||
Value::Object(normalized)
|
||||
}
|
||||
|
||||
/// Maximum travel-time minutes the data can contain. Matches the Java pipeline's
|
||||
/// MAX_TRIP_DURATION_MINUTES and the frontend's MAX_TRAVEL_MINUTES.
|
||||
const TRAVEL_TIME_MAX_MINUTES: f64 = 90.0;
|
||||
|
||||
fn travel_time_minute_field(item: &Value, key: &str) -> Option<f32> {
|
||||
item.get(key)
|
||||
.and_then(|val| val.as_f64())
|
||||
.filter(|val| val.is_finite())
|
||||
.map(|val| val.clamp(0.0, TRAVEL_TIME_MAX_MINUTES) as f32)
|
||||
}
|
||||
|
||||
fn parse_travel_time_bounds(item: &Value) -> (Option<f32>, Option<f32>) {
|
||||
let explicit_min = travel_time_minute_field(item, "min");
|
||||
let explicit_max = travel_time_minute_field(item, "max");
|
||||
|
||||
let (mut min, mut max) = if explicit_min.is_some() || explicit_max.is_some() {
|
||||
(explicit_min, explicit_max)
|
||||
} else {
|
||||
let value = travel_time_minute_field(item, "value");
|
||||
match (item.get("bound").and_then(|val| val.as_str()), value) {
|
||||
(Some("min"), Some(val)) => (Some(val), None),
|
||||
(Some("max"), Some(val)) => (None, Some(val)),
|
||||
_ => (None, None),
|
||||
}
|
||||
};
|
||||
|
||||
if let (Some(min_val), Some(max_val)) = (min, max) {
|
||||
if min_val > max_val {
|
||||
min = Some(max_val);
|
||||
max = Some(min_val);
|
||||
}
|
||||
}
|
||||
|
||||
(min, max)
|
||||
}
|
||||
|
||||
/// Validate travel time filters from LLM output against available destinations.
|
||||
pub(super) fn validate_travel_time_filters(raw: &Value, state: &AppState) -> Vec<TravelTimeFilter> {
|
||||
let arr = match raw
|
||||
.get("travel_time_filters")
|
||||
.and_then(|val| val.as_array())
|
||||
{
|
||||
Some(arr) => arr,
|
||||
None => return Vec::new(),
|
||||
};
|
||||
|
||||
let tt_store = &state.travel_time_store;
|
||||
let mut results = Vec::new();
|
||||
|
||||
for item in arr {
|
||||
let mode = match item.get("mode").and_then(|val| val.as_str()) {
|
||||
Some(mode) => mode,
|
||||
None => continue,
|
||||
};
|
||||
let slug = match item.get("slug").and_then(|val| val.as_str()) {
|
||||
Some(slug) => slug,
|
||||
None => continue,
|
||||
};
|
||||
let label = item
|
||||
.get("label")
|
||||
.and_then(|val| val.as_str())
|
||||
.unwrap_or(slug);
|
||||
|
||||
// Verify this destination actually exists
|
||||
if !tt_store.has_destination(mode, slug) {
|
||||
warn!(
|
||||
mode = mode,
|
||||
slug = slug,
|
||||
"AI suggested non-existent destination"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let (min, max) = parse_travel_time_bounds(item);
|
||||
|
||||
// Only include if at least one bound is set
|
||||
if min.is_some() || max.is_some() {
|
||||
results.push(TravelTimeFilter {
|
||||
mode: mode.to_string(),
|
||||
slug: slug.to_string(),
|
||||
label: label.to_string(),
|
||||
min,
|
||||
max,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// Validate LLM output against feature metadata and convert to FeatureFilters format.
|
||||
///
|
||||
/// Input format (array-based, each numeric filter sets one bound):
|
||||
/// ```json
|
||||
/// {
|
||||
/// "numeric_filters": [{"name": "Last known price", "bound": "max", "value": 300000}],
|
||||
/// "enum_filters": [{"name": "Leasehold/Freehold", "values": ["Freehold"]}]
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// Output format (FeatureFilters):
|
||||
/// ```json
|
||||
/// { "Last known price": [0, 300000], "Leasehold/Freehold": ["Freehold"] }
|
||||
/// ```
|
||||
pub(super) fn validate_and_convert(raw: &Value, features: &FeaturesResponse) -> Value {
|
||||
let mut result = serde_json::Map::new();
|
||||
|
||||
// Build lookup maps from feature metadata.
|
||||
// Store both slider bounds (min/max from percentiles) and true data bounds
|
||||
// (histogram.min/max) so one-sided AI filters use the full data range.
|
||||
let mut numeric_features: rustc_hash::FxHashMap<&str, (f32, f32, f32, f32)> =
|
||||
rustc_hash::FxHashMap::default();
|
||||
let mut enum_features: rustc_hash::FxHashMap<&str, &[String]> =
|
||||
rustc_hash::FxHashMap::default();
|
||||
|
||||
for group in &features.groups {
|
||||
for feature in &group.features {
|
||||
match feature {
|
||||
FeatureInfo::Numeric {
|
||||
name,
|
||||
min,
|
||||
max,
|
||||
histogram,
|
||||
..
|
||||
} => {
|
||||
numeric_features.insert(name, (*min, *max, histogram.min, histogram.max));
|
||||
}
|
||||
FeatureInfo::Enum { name, values, .. } => {
|
||||
enum_features.insert(name, values);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process numeric filters — each sets one bound (min or max).
|
||||
// The unset side uses the true data min/max (from histogram), not
|
||||
// the slider bounds (percentile-based), so a "max" filter for crime
|
||||
// produces [0, value] rather than [2nd-percentile, value].
|
||||
if let Some(arr) = raw.get("numeric_filters").and_then(|val| val.as_array()) {
|
||||
for item in arr {
|
||||
let raw_name = match item.get("name").and_then(|val| val.as_str()) {
|
||||
Some(name) => name,
|
||||
None => continue,
|
||||
};
|
||||
let name = canonical_filter_name(raw_name);
|
||||
let (slider_min, slider_max, data_min, data_max) =
|
||||
match numeric_features.get(name.as_str()) {
|
||||
Some(range) => *range,
|
||||
None => continue,
|
||||
};
|
||||
let bound = match item.get("bound").and_then(|val| val.as_str()) {
|
||||
Some(b) => b,
|
||||
None => continue,
|
||||
};
|
||||
// Clamp value to true data range (not slider range)
|
||||
let value = match item.get("value").and_then(|val| val.as_f64()) {
|
||||
Some(v) => v.max(data_min as f64).min(data_max as f64) as f32,
|
||||
None => continue,
|
||||
};
|
||||
let (filter_min, filter_max) = match bound {
|
||||
"min" => (value, data_max),
|
||||
"max" => (data_min, value),
|
||||
_ => continue,
|
||||
};
|
||||
// Only include if range is narrower than full slider range
|
||||
if filter_min > slider_min || filter_max < slider_max {
|
||||
result.insert(name, json!([filter_min, filter_max]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process enum filters
|
||||
if let Some(arr) = raw.get("enum_filters").and_then(|val| val.as_array()) {
|
||||
for item in arr {
|
||||
let raw_name = match item.get("name").and_then(|val| val.as_str()) {
|
||||
Some(name) => name,
|
||||
None => continue,
|
||||
};
|
||||
let name = canonical_filter_name(raw_name);
|
||||
let valid_values = match enum_features.get(name.as_str()) {
|
||||
Some(values) => *values,
|
||||
None => continue,
|
||||
};
|
||||
if let Some(selected) = item.get("values").and_then(|val| val.as_array()) {
|
||||
let valid: Vec<&str> = selected
|
||||
.iter()
|
||||
.filter_map(|item| item.as_str())
|
||||
.filter(|str_val| valid_values.iter().any(|known| known == str_val))
|
||||
.collect();
|
||||
if !valid.is_empty() && valid.len() < valid_values.len() {
|
||||
result.insert(name, json!(valid));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Value::Object(result)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn strip_fences_json_tag() {
|
||||
let input = "```json\n{\"a\": 1}\n```";
|
||||
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strip_fences_no_tag() {
|
||||
let input = "```\n{\"a\": 1}\n```";
|
||||
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strip_fences_passthrough() {
|
||||
let input = "{\"a\": 1}";
|
||||
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strip_fences_whitespace() {
|
||||
let input = " ```json\n {\"a\": 1} \n``` ";
|
||||
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn synthetic_filter_keys_are_normalized_to_backend_names() {
|
||||
assert_eq!(
|
||||
canonical_filter_name("Schools:primary:good:0"),
|
||||
"Good+ primary school catchments"
|
||||
);
|
||||
// Legacy keys still carry a distance segment; it is ignored.
|
||||
assert_eq!(
|
||||
canonical_filter_name("Schools:primary:good:2:0"),
|
||||
"Good+ primary school catchments"
|
||||
);
|
||||
assert_eq!(
|
||||
canonical_filter_name("Specific crimes:Burglary%20%28avg%2Fyr%29:1"),
|
||||
"Burglary (avg/yr)"
|
||||
);
|
||||
assert_eq!(
|
||||
canonical_filter_name("Political vote share:%25%20Labour:0"),
|
||||
"% Labour"
|
||||
);
|
||||
assert_eq!(
|
||||
canonical_filter_name(
|
||||
"Transport distance:Distance%20to%20nearest%20amenity%20%28Bus%20stop%29%20%28km%29:0"
|
||||
),
|
||||
"Distance to nearest amenity (Bus stop) (km)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_filters_are_normalized_before_prompting() {
|
||||
let filters = json!({
|
||||
"Political vote share:%25%20Green:0": [40, 100],
|
||||
"Estimated current price": [0, 500000],
|
||||
});
|
||||
|
||||
let normalized = normalize_context_filters(&filters);
|
||||
assert_eq!(normalized["% Green"], json!([40, 100]));
|
||||
assert_eq!(normalized["Estimated current price"], json!([0, 500000]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn travel_time_bounds_accept_min_max_schema() {
|
||||
let item = json!({ "min": 30, "max": 45 });
|
||||
assert_eq!(parse_travel_time_bounds(&item), (Some(30.0), Some(45.0)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn travel_time_bounds_accept_legacy_bound_value_schema() {
|
||||
let item = json!({ "bound": "max", "value": 30 });
|
||||
assert_eq!(parse_travel_time_bounds(&item), (None, Some(30.0)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn travel_time_bounds_clamp_and_order_range() {
|
||||
// Data ceiling is 90 (matches Java MAX_TRIP_DURATION_MINUTES).
|
||||
// Inputs outside [0, 90] clamp; min/max ordering is preserved as-given here.
|
||||
let item = json!({ "min": 150, "max": -10 });
|
||||
assert_eq!(parse_travel_time_bounds(&item), (Some(0.0), Some(90.0)));
|
||||
}
|
||||
}
|
||||
282
server-rs/src/routes/ai_filters/prompt.rs
Normal file
282
server-rs/src/routes/ai_filters/prompt.rs
Normal file
|
|
@ -0,0 +1,282 @@
|
|||
//! System prompt building for the AI filters assistant.
|
||||
|
||||
use crate::routes::{FeatureInfo, FeaturesResponse};
|
||||
|
||||
/// Build the complete system prompt for AI filters.
|
||||
///
|
||||
/// Contains: role instructions, feature catalogue, travel time info,
|
||||
/// few-shot examples, output rules.
|
||||
/// Precomputed at startup and cached in AppState.
|
||||
pub fn build_system_prompt(
|
||||
features: &FeaturesResponse,
|
||||
mode_destinations: &[(String, usize)],
|
||||
) -> String {
|
||||
let mut parts = Vec::new();
|
||||
|
||||
parts.push(
|
||||
"You are a UK property search assistant. \
|
||||
The user describes their ideal property or area in natural language. \
|
||||
Translate their description into filter settings using ONLY the features listed below.\n\
|
||||
\n\
|
||||
Rules:\n\
|
||||
- ONLY set filters the user explicitly mentioned or clearly implied.\n\
|
||||
- Leave out any filter the user did not mention. Empty arrays are fine.\n\
|
||||
- Each numeric filter sets ONE bound only: \"min\" (at least this value) \
|
||||
or \"max\" (at most this value). Never set two filters on the same feature.\n\
|
||||
- Use EXACT feature names from the list — spelling, capitalisation, and punctuation must match.\n\
|
||||
- \"cheap\" / \"affordable\" = lower price range. \"expensive\" = higher price range.\n\
|
||||
- \"low crime\" / \"safe\" = low values on the Serious crime (avg/yr) and Minor crime (avg/yr) \
|
||||
features (area-normalised incident density near the postcode). Prefer these aggregates for broad \
|
||||
area safety; use specific crime features only when the user names a crime type.\n\
|
||||
- \"quiet\" = low Noise (dB). \"green\" / \"near parks\" = high Number of amenities (Park) within 2km \
|
||||
or low Distance to nearest park (km), depending on wording.\n\
|
||||
- \"good schools\" = Good+ school features. \"outstanding schools\" = Outstanding school features.\n\
|
||||
- Amenities and transport stops are normal filters in the feature catalogue. \
|
||||
For \"near a bus stop\", \"near a station\", \"near shops\", etc., use the exact \
|
||||
Distance to nearest amenity (...) or Number of amenities (...) feature when available.\n\
|
||||
- Politics/elections are normal filters in the Neighbours group. Use exact vote share \
|
||||
features such as % Labour, % Conservative, % Liberal Democrat, % Reform UK, % Green, \
|
||||
% Other parties, or Voter turnout (%) when the user asks for political character.\n\
|
||||
- When the user says a number like \"under 400k\", interpret it as 400000.\n\
|
||||
- When the user says \"3 bed\" or \"3 bedroom\", use Number of bedrooms & living rooms \
|
||||
(note: this counts bedrooms + living rooms combined, so 3 bed ~ min 4).\n\
|
||||
- If the user mentions something that has no matching filter, put it in \"notes\" \
|
||||
as a short phrase (e.g. \"No filter for: garden, sea view\"). \
|
||||
If everything was matched, set \"notes\" to an empty string.\n\
|
||||
\n\
|
||||
CONVERSATIONAL REFINEMENT:\n\
|
||||
The user's message may include their currently active filters as context. \
|
||||
When context is provided:\n\
|
||||
- \"make it cheaper\" / \"lower the price\" = adjust the existing price filter down\n\
|
||||
- \"also add ...\" / \"and good schools\" = keep existing filters and add new ones\n\
|
||||
- \"remove the ...\" / \"drop the ...\" = return filters WITHOUT the mentioned one\n\
|
||||
- If the request is a completely new search (not a refinement), ignore the context \
|
||||
and build filters from scratch.\n\
|
||||
- Always output the COMPLETE set of filters (existing + modified), not just the changes."
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
// Travel time section with available modes
|
||||
let modes_list = mode_destinations
|
||||
.iter()
|
||||
.map(|(mode, count)| format!("- {} ({} destinations available)", mode, count))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
parts.push(format!(
|
||||
"\n--- TRAVEL TIME FILTERS ---\n\
|
||||
You can add travel time filters when the user mentions commute times, \
|
||||
proximity to places, or wanting to be near/within X minutes of somewhere.\n\
|
||||
\n\
|
||||
Available travel-time modes (only use modes that have destinations):\n\
|
||||
{}\n\
|
||||
- \"car\" / \"drive\" / \"driving\" = car mode\n\
|
||||
- \"cycle\" / \"bike\" / \"cycling\" = bicycle mode\n\
|
||||
- \"walk\" / \"walking\" / \"on foot\" = walking mode\n\
|
||||
- \"train\" / \"tube\" / \"bus\" / \"public transport\" / \"commute\" = transit mode\n\
|
||||
- \"without buses\" / \"no bus\" / \"rail only\" = transit-no-bus mode\n\
|
||||
- \"no change\" / \"no transfer\" / \"direct\" / \"single bus/train\" = transit-no-change mode\n\
|
||||
- \"no change and no bus\" / \"direct rail/tube\" = transit-no-change-no-bus mode\n\
|
||||
- If a mode appears in the available mode list but is not named above, you may still \
|
||||
use the exact mode string from the list.\n\
|
||||
\n\
|
||||
When the user mentions a specific place, you MUST call the search_destinations \
|
||||
tool to find the exact slug. Use the name and slug from the search results.\n\
|
||||
If search_destinations returns an empty array, the destination is not available — \
|
||||
mention it in \"notes\" (e.g. \"No travel data for: Gatwick Airport\") and do NOT \
|
||||
include a travel_time_filter for it.\n\
|
||||
\n\
|
||||
Travel time values are in MINUTES (0-90 range; data is capped at 90 min).\n\
|
||||
- \"within 30 minutes\" = set \"max\": 30\n\
|
||||
- \"at least 10 minutes\" = set \"min\": 10\n\
|
||||
- \"30-45 minute commute\" = set \"min\": 30 and \"max\": 45 on the same travel_time_filter\n\
|
||||
- If only a max is given, omit min (and vice versa). Do not use bound/value for travel time.\n\
|
||||
\n\
|
||||
INFERRING TRANSPORT MODE (when the user does not specify one explicitly):\n\
|
||||
- \"commute\" to a major city centre or station = transit\n\
|
||||
- \"near\" / \"close to\" a city centre or station = transit\n\
|
||||
- \"near\" / \"close to\" a smaller town, village, or rural area = car\n\
|
||||
- \"drive\" / \"driving distance\" / \"driving time\" = always car\n\
|
||||
- If multiple modes are plausible, prefer transit for urban destinations \
|
||||
(London, Manchester, Birmingham, Leeds, etc.) and car for everything else.",
|
||||
modes_list,
|
||||
));
|
||||
|
||||
// Feature guidance
|
||||
parts.push(
|
||||
"\n--- DATA SOURCE ---\n\
|
||||
The data is historical property sales from the Land Registry.\n\
|
||||
\n\
|
||||
Use these features for price queries:\n\
|
||||
- For purchase price: use \"Estimated current price\" or \"Last known price\"\n\
|
||||
- For price per sqm: use \"Est. price per sqm\"\n\
|
||||
- For rent estimates: use \"Estimated monthly rent\""
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
// Feature catalogue
|
||||
parts.push("\n--- AVAILABLE FEATURES ---\n".to_string());
|
||||
for group in &features.groups {
|
||||
parts.push(format!("## {}", group.name));
|
||||
for feature in &group.features {
|
||||
match feature {
|
||||
FeatureInfo::Numeric {
|
||||
name,
|
||||
min,
|
||||
max,
|
||||
description,
|
||||
prefix,
|
||||
suffix,
|
||||
..
|
||||
} => {
|
||||
parts.push(format!(
|
||||
"- \"{}\" (numeric, {}{:.0}{} to {}{:.0}{}): {}",
|
||||
name, prefix, min, suffix, prefix, max, suffix, description
|
||||
));
|
||||
}
|
||||
FeatureInfo::Enum {
|
||||
name,
|
||||
values,
|
||||
description,
|
||||
..
|
||||
} => {
|
||||
parts.push(format!(
|
||||
"- \"{}\" (enum, values: [{}]): {}",
|
||||
name,
|
||||
values
|
||||
.iter()
|
||||
.map(|val| format!("\"{}\"", val))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", "),
|
||||
description
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Few-shot examples
|
||||
parts.push("\n--- EXAMPLES ---\n".to_string());
|
||||
|
||||
parts.push(
|
||||
"User: \"cheap freehold house under 400k\"\n\
|
||||
Output: {\"numeric_filters\": [{\"name\": \"Last known price\", \"bound\": \"max\", \"value\": 400000}], \
|
||||
\"enum_filters\": [{\"name\": \"Leasehold/Freehold\", \"values\": [\"Freehold\"]}, \
|
||||
{\"name\": \"Property type\", \"values\": [\"Detached\", \"Semi-Detached\", \"Terraced\"]}], \
|
||||
\"travel_time_filters\": [], \
|
||||
\"notes\": \"\"}"
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
parts.push(
|
||||
"\nUser: \"safe quiet area with good schools and parks\"\n\
|
||||
Output: {\"numeric_filters\": [\
|
||||
{\"name\": \"Serious crime (avg/yr)\", \"bound\": \"max\", \"value\": 5}, \
|
||||
{\"name\": \"Minor crime (avg/yr)\", \"bound\": \"max\", \"value\": 20}, \
|
||||
{\"name\": \"Noise (dB)\", \"bound\": \"max\", \"value\": 55}, \
|
||||
{\"name\": \"Good+ primary school catchments\", \"bound\": \"min\", \"value\": 2}, \
|
||||
{\"name\": \"Good+ secondary school catchments\", \"bound\": \"min\", \"value\": 1}, \
|
||||
{\"name\": \"Number of amenities (Park) within 2km\", \"bound\": \"min\", \"value\": 3}], \
|
||||
\"enum_filters\": [], \"travel_time_filters\": [], \"notes\": \"\"}"
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
parts.push(
|
||||
"\nUser: \"quiet area with outstanding schools\"\n\
|
||||
Output: {\"numeric_filters\": [\
|
||||
{\"name\": \"Noise (dB)\", \"bound\": \"max\", \"value\": 55}, \
|
||||
{\"name\": \"Outstanding primary school catchments\", \"bound\": \"min\", \"value\": 1}, \
|
||||
{\"name\": \"Outstanding secondary school catchments\", \"bound\": \"min\", \"value\": 1}], \
|
||||
\"enum_filters\": [], \"travel_time_filters\": [], \"notes\": \"\"}"
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
parts.push(
|
||||
"\nUser: \"3 bed flat under 300k with fast broadband near the beach\"\n\
|
||||
Output: {\"numeric_filters\": [\
|
||||
{\"name\": \"Last known price\", \"bound\": \"max\", \"value\": 300000}, \
|
||||
{\"name\": \"Number of bedrooms & living rooms\", \"bound\": \"min\", \"value\": 4}], \
|
||||
\"enum_filters\": [{\"name\": \"Property type\", \"values\": [\"Flats/Maisonettes\"]}, \
|
||||
{\"name\": \"Max available download speed (Mbps)\", \"values\": [\"100\", \"300\", \"1000\"]}], \
|
||||
\"travel_time_filters\": [], \
|
||||
\"notes\": \"No filter for: beach proximity\"}"
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
parts.push(
|
||||
"\nUser: \"within 30 minutes commute of Kings Cross, under 500k\"\n\
|
||||
(After calling search_destinations for \"Kings Cross\" with mode \"transit\" \
|
||||
and getting [{\"name\": \"Kings Cross\", \"slug\": \"kings-cross\", \"place_type\": \"station\"}])\n\
|
||||
Output: {\"numeric_filters\": [\
|
||||
{\"name\": \"Last known price\", \"bound\": \"max\", \"value\": 500000}], \
|
||||
\"enum_filters\": [], \
|
||||
\"travel_time_filters\": [{\"mode\": \"transit\", \"slug\": \"kings-cross\", \
|
||||
\"label\": \"Kings Cross\", \"max\": 30}], \
|
||||
\"notes\": \"\"}"
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
parts.push(
|
||||
"\nUser: \"family home with garden, 45 min drive from Manchester, good schools\"\n\
|
||||
(After calling search_destinations for \"Manchester\" with mode \"car\" \
|
||||
and getting [{\"name\": \"Manchester\", \"slug\": \"manchester\", \"place_type\": \"city\"}])\n\
|
||||
Output: {\"numeric_filters\": [\
|
||||
{\"name\": \"Total floor area (sqm)\", \"bound\": \"min\", \"value\": 100}, \
|
||||
{\"name\": \"Number of bedrooms & living rooms\", \"bound\": \"min\", \"value\": 5}, \
|
||||
{\"name\": \"Good+ primary school catchments\", \"bound\": \"min\", \"value\": 2}, \
|
||||
{\"name\": \"Good+ secondary school catchments\", \"bound\": \"min\", \"value\": 1}], \
|
||||
\"enum_filters\": [{\"name\": \"Property type\", \
|
||||
\"values\": [\"Detached\", \"Semi-Detached\"]}], \
|
||||
\"travel_time_filters\": [{\"mode\": \"car\", \"slug\": \"manchester\", \
|
||||
\"label\": \"Manchester\", \"max\": 45}], \
|
||||
\"notes\": \"No filter for: garden\"}"
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
parts.push(
|
||||
"\nUser: \"Labour-voting area with low burglary and a station nearby\"\n\
|
||||
Output: {\"numeric_filters\": [\
|
||||
{\"name\": \"% Labour\", \"bound\": \"min\", \"value\": 40}, \
|
||||
{\"name\": \"Burglary (avg/yr)\", \"bound\": \"max\", \"value\": 10}, \
|
||||
{\"name\": \"Distance to nearest amenity (Rail station) (km)\", \"bound\": \"max\", \"value\": 1}], \
|
||||
\"enum_filters\": [], \"travel_time_filters\": [], \"notes\": \"\"}"
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
// Examples showing rent and price features
|
||||
parts.push(
|
||||
"\nUser: \"2 bed flat with rent under £1500/month\"\n\
|
||||
Output: {\
|
||||
\"numeric_filters\": [{\"name\": \"Estimated monthly rent\", \"bound\": \"max\", \"value\": 1500}], \
|
||||
\"enum_filters\": [{\"name\": \"Property type\", \"values\": [\"Flats/Maisonettes\"]}], \
|
||||
\"travel_time_filters\": [], \
|
||||
\"notes\": \"\"}"
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
parts.push(
|
||||
"\nUser: \"3 bed house under 500k with good schools\"\n\
|
||||
Output: {\
|
||||
\"numeric_filters\": [{\"name\": \"Estimated current price\", \"bound\": \"max\", \"value\": 500000}, \
|
||||
{\"name\": \"Good+ primary school catchments\", \"bound\": \"min\", \"value\": 2}], \
|
||||
\"enum_filters\": [{\"name\": \"Property type\", \
|
||||
\"values\": [\"Detached\", \"Semi-Detached\", \"Terraced\"]}], \
|
||||
\"travel_time_filters\": [], \
|
||||
\"notes\": \"\"}"
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
// Output format reminder
|
||||
parts.push(
|
||||
"\n--- OUTPUT FORMAT ---\n\
|
||||
{\"numeric_filters\": [...], \"enum_filters\": [...], \
|
||||
\"travel_time_filters\": [{\"mode\": \"...\", \"slug\": \"...\", \"label\": \"...\", \
|
||||
\"min\": N, \"max\": N}, ...], \"notes\": \"...\"}\n\
|
||||
- travel_time_filters: min and max are both optional, but include at least one. \
|
||||
Use ONLY slugs returned by search_destinations. If a place isn't found, mention it in notes.\n\
|
||||
Respond with ONLY the JSON object. No explanation."
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
parts.join("\n")
|
||||
}
|
||||
188
server-rs/src/routes/ai_filters/tools.rs
Normal file
188
server-rs/src/routes/ai_filters/tools.rs
Normal file
|
|
@ -0,0 +1,188 @@
|
|||
//! The `search_destinations` Gemini tool: declaration and execution against
|
||||
//! PlaceData + TravelTimeStore.
|
||||
|
||||
use serde_json::{json, Value};
|
||||
use tracing::info;
|
||||
|
||||
use crate::data::slugify;
|
||||
use crate::state::AppState;
|
||||
|
||||
/// Build the Gemini tool declaration for destination search.
|
||||
pub(super) fn build_tool_declarations(state: &AppState) -> Value {
|
||||
let modes: Vec<&str> = state
|
||||
.travel_time_store
|
||||
.available_modes
|
||||
.iter()
|
||||
.map(|mode| mode.as_str())
|
||||
.collect();
|
||||
|
||||
json!([{
|
||||
"functionDeclarations": [{
|
||||
"name": "search_destinations",
|
||||
"description": "Search for available travel time destinations (cities, stations, towns) that have precomputed travel time data. Call this when the user mentions wanting to be near, close to, or within a certain travel time of a specific place.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Place name to search for (e.g. 'Manchester', 'Kings Cross', 'Heathrow')"
|
||||
},
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": modes,
|
||||
"description": "Transport mode to search destinations for"
|
||||
}
|
||||
},
|
||||
"required": ["query", "mode"]
|
||||
}
|
||||
}]
|
||||
}])
|
||||
}
|
||||
|
||||
/// Execute a destination search against PlaceData + TravelTimeStore.
|
||||
/// Returns matching destinations as a JSON value with `results` and optional `message`.
|
||||
///
|
||||
/// Uses word-based matching: all words in the query must appear somewhere in the
|
||||
/// place name (order-independent). Also matches against slugs for short queries.
|
||||
pub(super) fn execute_destination_search(state: &AppState, query: &str, mode: &str) -> Value {
|
||||
let query_lower = query.to_lowercase();
|
||||
let query_words: Vec<&str> = query_lower.split_whitespace().collect();
|
||||
let query_slug = slugify(query);
|
||||
let tt_store = &state.travel_time_store;
|
||||
let pd = &state.place_data;
|
||||
|
||||
let slug_set = match tt_store.destinations.get(mode) {
|
||||
Some(slugs) => slugs,
|
||||
None => {
|
||||
return json!({ "results": [], "message": format!("No travel data available for mode '{}'", mode) })
|
||||
}
|
||||
};
|
||||
|
||||
// Find places matching the query that have travel time data.
|
||||
// A place matches if ALL query words appear in its name, OR its slug matches the query slug.
|
||||
let mut matches: Vec<(usize, String, u8, u32)> = pd
|
||||
.name_lower
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, name_lower)| {
|
||||
if !pd.travel_destination[idx] {
|
||||
return None;
|
||||
}
|
||||
let words_match = query_words.iter().all(|word| name_lower.contains(word));
|
||||
let slug = slugify(&pd.name[idx]);
|
||||
let slug_match = slug.contains(&query_slug) || query_slug.contains(&slug);
|
||||
if !words_match && !slug_match {
|
||||
return None;
|
||||
}
|
||||
if slug_set.contains(&slug) {
|
||||
Some((idx, slug, pd.type_rank[idx], pd.population[idx]))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Sort: type rank asc, population desc
|
||||
matches.sort_unstable_by(|a, b| a.2.cmp(&b.2).then(b.3.cmp(&a.3)));
|
||||
matches.truncate(10);
|
||||
|
||||
if matches.is_empty() {
|
||||
// Check if the query matched a city that lacks its own travel data.
|
||||
// If so, return nearby stations within that city as suggestions.
|
||||
let matched_city_name: Option<&str> =
|
||||
pd.name_lower
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find_map(|(idx, name_lower)| {
|
||||
if !pd.travel_destination[idx] {
|
||||
return None;
|
||||
}
|
||||
let words_match = query_words.iter().all(|word| name_lower.contains(word));
|
||||
let slug = slugify(&pd.name[idx]);
|
||||
let slug_match = slug.contains(&query_slug) || query_slug.contains(&slug);
|
||||
if (words_match || slug_match) && pd.type_rank[idx] == 0 {
|
||||
Some(pd.name[idx].as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(city_name) = matched_city_name {
|
||||
let city_lower = city_name.to_lowercase();
|
||||
let mut city_matches: Vec<(usize, String, u8, u32)> = pd
|
||||
.city
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, city_opt)| {
|
||||
if !pd.travel_destination[idx] {
|
||||
return None;
|
||||
}
|
||||
let city = city_opt.as_deref()?;
|
||||
if city.to_lowercase() != city_lower {
|
||||
return None;
|
||||
}
|
||||
let slug = slugify(&pd.name[idx]);
|
||||
if slug_set.contains(&slug) {
|
||||
Some((idx, slug, pd.type_rank[idx], pd.population[idx]))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
city_matches.sort_unstable_by(|a, b| a.2.cmp(&b.2).then(b.3.cmp(&a.3)));
|
||||
city_matches.truncate(10);
|
||||
|
||||
if !city_matches.is_empty() {
|
||||
let results: Vec<Value> = city_matches
|
||||
.into_iter()
|
||||
.map(|(idx, slug, ..)| {
|
||||
json!({
|
||||
"name": pd.name[idx],
|
||||
"slug": slug,
|
||||
"place_type": pd.place_type.get(idx).to_string(),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
info!(
|
||||
query = query,
|
||||
city = city_name,
|
||||
results = results.len(),
|
||||
"Destination search fell back to city stations"
|
||||
);
|
||||
|
||||
return json!({
|
||||
"results": results,
|
||||
"message": format!(
|
||||
"No travel data for '{}' directly. Pick one of these nearby stations:",
|
||||
city_name
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
query = query,
|
||||
mode = mode,
|
||||
"Destination search returned no results"
|
||||
);
|
||||
return json!({
|
||||
"results": [],
|
||||
"message": format!("No travel time data available for '{}' by {}. This destination cannot be used as a travel time filter.", query, mode)
|
||||
});
|
||||
}
|
||||
|
||||
let results: Vec<Value> = matches
|
||||
.into_iter()
|
||||
.map(|(idx, slug, ..)| {
|
||||
json!({
|
||||
"name": pd.name[idx],
|
||||
"slug": slug,
|
||||
"place_type": pd.place_type.get(idx).to_string(),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
json!({ "results": results })
|
||||
}
|
||||
119
server-rs/src/routes/ai_filters/usage.rs
Normal file
119
server-rs/src/routes/ai_filters/usage.rs
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
//! Weekly AI token usage tracking and rate limiting, persisted on the user's
|
||||
//! PocketBase record.
|
||||
|
||||
use axum::http::StatusCode;
|
||||
use metrics::counter;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::pocketbase::get_superuser_token;
|
||||
use crate::state::AppState;
|
||||
|
||||
/// Monotonically increasing week number derived from Unix epoch.
|
||||
/// Resets every 7 days (604800 seconds). Used for weekly rate limiting.
|
||||
pub(super) fn current_week_number() -> u64 {
|
||||
let secs = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
// Only possible if the system clock is before 1970; fall back to
|
||||
// week 0 rather than panicking inside a request handler.
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
secs / 604_800
|
||||
}
|
||||
|
||||
/// Fetch the user's current AI token usage from PocketBase.
|
||||
/// Returns `(tokens_used, week_number)`.
|
||||
pub(super) async fn fetch_ai_usage(
|
||||
state: &AppState,
|
||||
user_id: &str,
|
||||
) -> Result<(u64, u64), (StatusCode, String)> {
|
||||
let token = get_superuser_token(state).await.map_err(|err| {
|
||||
warn!("Failed to auth superuser for AI usage check: {err}");
|
||||
(StatusCode::BAD_GATEWAY, "Internal error".into())
|
||||
})?;
|
||||
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| {
|
||||
warn!("Failed to fetch user record for AI usage: {err}");
|
||||
(StatusCode::BAD_GATEWAY, "Internal error".into())
|
||||
})?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
warn!("PocketBase user fetch failed ({status})");
|
||||
return Err((StatusCode::BAD_GATEWAY, "Internal error".into()));
|
||||
}
|
||||
|
||||
let body: Value = resp.json().await.map_err(|err| {
|
||||
warn!("Failed to parse user record: {err}");
|
||||
(StatusCode::BAD_GATEWAY, "Internal error".into())
|
||||
})?;
|
||||
|
||||
let tokens_used = body
|
||||
.get("ai_tokens_used")
|
||||
.and_then(|val| val.as_u64())
|
||||
.unwrap_or(0);
|
||||
let week = body
|
||||
.get("ai_tokens_week")
|
||||
.and_then(|val| val.as_u64())
|
||||
.unwrap_or(0);
|
||||
|
||||
Ok((tokens_used, week))
|
||||
}
|
||||
|
||||
/// Update the user's AI token usage in PocketBase.
|
||||
/// Best-effort — logs warnings on failure but does not propagate errors.
|
||||
async fn update_ai_usage(state: &AppState, user_id: &str, tokens_used: u64, week: u64) {
|
||||
let token = match get_superuser_token(state).await {
|
||||
Ok(tk) => tk,
|
||||
Err(err) => {
|
||||
warn!("Failed to auth superuser for AI usage update: {err}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
|
||||
let res = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&json!({
|
||||
"ai_tokens_used": tokens_used,
|
||||
"ai_tokens_week": week,
|
||||
}))
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {}
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
warn!("Failed to update AI usage ({status})");
|
||||
}
|
||||
Err(err) => warn!("Failed to update AI usage: {err}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn record_ai_request_usage(
|
||||
state: &AppState,
|
||||
user_id: &str,
|
||||
existing_tokens_used: u64,
|
||||
week: u64,
|
||||
request_tokens_used: u64,
|
||||
status: &'static str,
|
||||
) {
|
||||
if request_tokens_used > 0 {
|
||||
let new_total = existing_tokens_used.saturating_add(request_tokens_used);
|
||||
update_ai_usage(state, user_id, new_total, week).await;
|
||||
counter!("ai_tokens_total").increment(request_tokens_used);
|
||||
}
|
||||
counter!("ai_requests_total", "status" => status).increment(1);
|
||||
}
|
||||
|
|
@ -780,31 +780,28 @@ pub async fn get_export(
|
|||
// groups themselves; postcodes within a group are sorted alphabetically.
|
||||
// Each group carries a rolled-up summary aggregate for its header row.
|
||||
let outcode_groups: Vec<OutcodeGroup> = {
|
||||
let mut order: Vec<String> = Vec::new();
|
||||
let mut by_outcode: FxHashMap<String, OutcodeGroup> = FxHashMap::default();
|
||||
let mut groups: Vec<OutcodeGroup> = Vec::new();
|
||||
let mut idx_by_outcode: FxHashMap<String, usize> = FxHashMap::default();
|
||||
for (i, (pc_idx, agg)) in postcode_aggs.iter().enumerate() {
|
||||
let outcode = outcode_of(&postcode_data.postcodes[*pc_idx]).to_string();
|
||||
let group = by_outcode.entry(outcode.clone()).or_insert_with(|| {
|
||||
order.push(outcode.clone());
|
||||
OutcodeGroup {
|
||||
outcode: outcode.clone(),
|
||||
let idx = *idx_by_outcode.entry(outcode.clone()).or_insert_with(|| {
|
||||
groups.push(OutcodeGroup {
|
||||
outcode,
|
||||
members: Vec::new(),
|
||||
summary: PostcodeExportAgg::new(total_export_features),
|
||||
}
|
||||
});
|
||||
group.members.push(i);
|
||||
group.summary.merge_from(agg);
|
||||
groups.len() - 1
|
||||
});
|
||||
groups[idx].members.push(i);
|
||||
groups[idx].summary.merge_from(agg);
|
||||
}
|
||||
for group in by_outcode.values_mut() {
|
||||
for group in &mut groups {
|
||||
group.members.sort_by(|&a, &b| {
|
||||
postcode_data.postcodes[postcode_aggs[a].0]
|
||||
.cmp(&postcode_data.postcodes[postcode_aggs[b].0])
|
||||
});
|
||||
}
|
||||
order
|
||||
.into_iter()
|
||||
.map(|outcode| by_outcode.remove(&outcode).unwrap())
|
||||
.collect()
|
||||
groups
|
||||
};
|
||||
|
||||
// Build Excel workbook with two sheets
|
||||
|
|
|
|||
|
|
@ -130,6 +130,11 @@ pub struct HexagonStatsResponse {
|
|||
pub price_history: Vec<PricePoint>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub crime_by_year: Vec<CrimeYearStats>,
|
||||
/// Latest year in the crime dataset as a whole. When a selection's series
|
||||
/// end earlier (force-level publication gap, e.g. Greater Manchester),
|
||||
/// the client captions the data as stale.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub crime_latest_year: Option<i32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub central_postcode: Option<String>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
|
|
@ -645,12 +650,19 @@ pub async fn get_hexagon_stats(
|
|||
"GET /api/hexagon-stats"
|
||||
);
|
||||
|
||||
let crime_latest_year = if crime_by_year.is_empty() {
|
||||
None
|
||||
} else {
|
||||
stats::crime_latest_available_year(&state.crime_by_year)
|
||||
};
|
||||
|
||||
Ok(HexagonStatsResponse {
|
||||
count: total_count,
|
||||
numeric_features,
|
||||
enum_features: enum_features_out,
|
||||
price_history,
|
||||
crime_by_year,
|
||||
crime_latest_year,
|
||||
central_postcode,
|
||||
filter_exclusions,
|
||||
})
|
||||
|
|
|
|||
|
|
@ -36,8 +36,9 @@ fn is_allowed_pb_path(path: &str) -> bool {
|
|||
|
||||
/// Dedicated HTTP client for proxying — does not follow redirects so 3xx
|
||||
/// responses are passed through to the browser (needed for OAuth flows).
|
||||
/// No overall timeout because SSE (Server-Sent Events) connections used by
|
||||
/// PocketBase realtime/OAuth2 are long-lived streams.
|
||||
/// No client-wide timeout because SSE (Server-Sent Events) connections used
|
||||
/// by PocketBase realtime/OAuth2 are long-lived streams; non-realtime
|
||||
/// requests get a per-request timeout instead (see PROXY_REQUEST_TIMEOUT).
|
||||
static PROXY_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
|
||||
reqwest::Client::builder()
|
||||
.redirect(reqwest::redirect::Policy::none())
|
||||
|
|
@ -47,6 +48,11 @@ static PROXY_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
|
|||
.expect("Failed to build proxy HTTP client")
|
||||
});
|
||||
|
||||
/// Timeout for proxied requests other than the realtime SSE stream, so a hung
|
||||
/// PocketBase cannot pile up handlers indefinitely. Generous enough for file
|
||||
/// uploads/downloads over slow links.
|
||||
const PROXY_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
|
||||
pub async fn proxy_to_pocketbase(
|
||||
State(shared): State<Arc<SharedState>>,
|
||||
req: Request,
|
||||
|
|
@ -58,10 +64,7 @@ pub async fn proxy_to_pocketbase(
|
|||
let target_path = path.strip_prefix("/pb").unwrap_or(path);
|
||||
if !is_allowed_pb_path(target_path) {
|
||||
warn!(path = %target_path, "Rejected PocketBase proxy request to disallowed path");
|
||||
return Response::builder()
|
||||
.status(StatusCode::NOT_FOUND)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
return StatusCode::NOT_FOUND.into_response();
|
||||
}
|
||||
let query = req
|
||||
.uri()
|
||||
|
|
@ -73,6 +76,12 @@ pub async fn proxy_to_pocketbase(
|
|||
let method = req.method().clone();
|
||||
let mut builder = PROXY_CLIENT.request(method, &url);
|
||||
|
||||
// The realtime SSE stream is intentionally unbounded; everything else
|
||||
// must complete within the timeout.
|
||||
if target_path != "/api/realtime" {
|
||||
builder = builder.timeout(PROXY_REQUEST_TIMEOUT);
|
||||
}
|
||||
|
||||
// Forward only safe headers (allowlist)
|
||||
const ALLOWED_HEADERS: &[&str] = &[
|
||||
"content-type",
|
||||
|
|
@ -96,10 +105,7 @@ pub async fn proxy_to_pocketbase(
|
|||
Ok(bytes) => bytes,
|
||||
Err(err) => {
|
||||
warn!("Failed to read request body: {err}");
|
||||
return Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.body(Body::from("Failed to read request body"))
|
||||
.unwrap();
|
||||
return (StatusCode::BAD_REQUEST, "Failed to read request body").into_response();
|
||||
}
|
||||
};
|
||||
builder = builder.body(body_bytes);
|
||||
|
|
@ -129,14 +135,14 @@ pub async fn proxy_to_pocketbase(
|
|||
// realtime system and OAuth2 flow — buffering would hang forever
|
||||
// since SSE responses never complete.
|
||||
let body = Body::from_stream(upstream.bytes_stream());
|
||||
response.body(body).unwrap()
|
||||
response.body(body).unwrap_or_else(|err| {
|
||||
warn!("Failed to build proxied response: {err}");
|
||||
(StatusCode::BAD_GATEWAY, "Invalid upstream response").into_response()
|
||||
})
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("PocketBase proxy error: {err}");
|
||||
Response::builder()
|
||||
.status(StatusCode::BAD_GATEWAY)
|
||||
.body(Body::from("PocketBase unavailable"))
|
||||
.unwrap()
|
||||
(StatusCode::BAD_GATEWAY, "PocketBase unavailable").into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -184,12 +184,19 @@ pub async fn get_postcode_stats(
|
|||
"GET /api/postcode-stats"
|
||||
);
|
||||
|
||||
let crime_latest_year = if crime_by_year.is_empty() {
|
||||
None
|
||||
} else {
|
||||
stats::crime_latest_available_year(&state.crime_by_year)
|
||||
};
|
||||
|
||||
Ok(HexagonStatsResponse {
|
||||
count: total_count,
|
||||
numeric_features,
|
||||
enum_features: enum_features_out,
|
||||
price_history,
|
||||
crime_by_year,
|
||||
crime_latest_year,
|
||||
central_postcode: None,
|
||||
filter_exclusions,
|
||||
})
|
||||
|
|
|
|||
|
|
@ -340,16 +340,14 @@ pub fn compute_crime_by_year(
|
|||
let points: Vec<CrimeYearPoint> = years
|
||||
.iter()
|
||||
.filter_map(|&year| {
|
||||
let denom = fully_covered_rows
|
||||
+ covered_counts.get(&year).copied().unwrap_or(0);
|
||||
let denom = fully_covered_rows + covered_counts.get(&year).copied().unwrap_or(0);
|
||||
if denom == 0 {
|
||||
// No selected postcode has published data for this year.
|
||||
return None;
|
||||
}
|
||||
Some(CrimeYearPoint {
|
||||
year,
|
||||
count: (sums.get(&year).copied().unwrap_or(0.0) / denom as f64)
|
||||
as f32,
|
||||
count: (sums.get(&year).copied().unwrap_or(0.0) / denom as f64) as f32,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
|
@ -365,6 +363,19 @@ pub fn compute_crime_by_year(
|
|||
out
|
||||
}
|
||||
|
||||
/// Latest year present anywhere in the by-year crime dataset. The client
|
||||
/// compares each selection's last charted year against this to caption
|
||||
/// force-level publication gaps (e.g. Greater Manchester ends mid-2019) as
|
||||
/// stale data instead of letting old numbers read as current.
|
||||
pub fn crime_latest_available_year(crime_by_year: &CrimeByYearData) -> Option<i32> {
|
||||
crime_by_year
|
||||
.years_by_type
|
||||
.iter()
|
||||
.flatten()
|
||||
.copied()
|
||||
.max()
|
||||
}
|
||||
|
||||
pub fn compute_poi_feature_stats(
|
||||
matching_rows: &[usize],
|
||||
poi_metrics: &PostcodePoiMetrics,
|
||||
|
|
|
|||
|
|
@ -87,6 +87,105 @@ pub struct AppState {
|
|||
pub bugsink_frontend_config: Option<BugsinkFrontendConfig>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl AppState {
|
||||
/// Minimal AppState for integration tests of the PocketBase/Stripe money
|
||||
/// paths (checkout, webhook, licensing, invites). All map/property data is
|
||||
/// empty; only the HTTP-facing config (PocketBase URL, Stripe secrets,
|
||||
/// caches) carries meaningful values.
|
||||
pub(crate) fn for_tests(pocketbase_url: String) -> Self {
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::data::{
|
||||
ActualListingData, CrimeByYearData, OutcodeData, POIData, PlaceData, PostcodeData,
|
||||
PropertyData, TravelTimeStore,
|
||||
};
|
||||
use crate::utils::InternedColumn;
|
||||
|
||||
let http_client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(5))
|
||||
.connect_timeout(Duration::from_secs(2))
|
||||
.build()
|
||||
.expect("test HTTP client should build");
|
||||
|
||||
AppState {
|
||||
data: PropertyData::empty_for_tests(),
|
||||
grid: GridIndex::build(&[], &[], 0.01),
|
||||
h3_cells: Vec::new(),
|
||||
feature_name_to_index: FxHashMap::default(),
|
||||
min_keys: Vec::new(),
|
||||
max_keys: Vec::new(),
|
||||
avg_keys: Vec::new(),
|
||||
features_response: FeaturesResponse { groups: Vec::new() },
|
||||
ai_filters_system_prompt: String::new(),
|
||||
poi_data: Arc::new(POIData::empty_for_tests()),
|
||||
poi_grid: Arc::new(GridIndex::build(&[], &[], 0.01)),
|
||||
place_data: Arc::new(PlaceData::empty_for_tests()),
|
||||
postcode_data: Arc::new(PostcodeData {
|
||||
postcodes: Vec::new(),
|
||||
centroids: Vec::new(),
|
||||
aabbs: Vec::new(),
|
||||
polygons: Vec::new(),
|
||||
postcode_to_idx: FxHashMap::default(),
|
||||
}),
|
||||
outcode_data: Arc::new(OutcodeData {
|
||||
names: Vec::new(),
|
||||
name_lower: Vec::new(),
|
||||
centroids: Vec::new(),
|
||||
cities: Vec::new(),
|
||||
}),
|
||||
poi_category_groups: Arc::new(Vec::new()),
|
||||
travel_time_store: Arc::new(TravelTimeStore::empty_for_tests()),
|
||||
actual_listings: Arc::new(ActualListingData {
|
||||
lat: Vec::new(),
|
||||
lon: Vec::new(),
|
||||
postcode: Vec::new(),
|
||||
address: Vec::new(),
|
||||
property_type: InternedColumn::build(&[]),
|
||||
property_sub_type: InternedColumn::build(&[]),
|
||||
leasehold_freehold: InternedColumn::build(&[]),
|
||||
price_qualifier: InternedColumn::build(&[]),
|
||||
bedrooms: Vec::new(),
|
||||
bathrooms: Vec::new(),
|
||||
rooms_total: Vec::new(),
|
||||
floor_area_sqm: Vec::new(),
|
||||
asking_price: Vec::new(),
|
||||
asking_price_per_sqm: Vec::new(),
|
||||
listing_url: Vec::new(),
|
||||
listing_status: InternedColumn::build(&[]),
|
||||
listing_date_iso: Vec::new(),
|
||||
features: Vec::new(),
|
||||
filter_feature_data: Vec::new(),
|
||||
poi_filter_feature_data: Vec::new(),
|
||||
grid: GridIndex::build(&[], &[], 0.01),
|
||||
}),
|
||||
crime_by_year: Arc::new(CrimeByYearData {
|
||||
crime_types: Vec::new(),
|
||||
years_by_type: Vec::new(),
|
||||
series_by_postcode: FxHashMap::default(),
|
||||
covered_years_by_postcode: FxHashMap::default(),
|
||||
}),
|
||||
token_cache: Arc::new(TokenCache::new()),
|
||||
superuser_token_cache: Arc::new(SuperuserTokenCache::new()),
|
||||
share_cache: Arc::new(ShareBoundsCache::new()),
|
||||
screenshot_url: "http://127.0.0.1:1/screenshot".to_string(),
|
||||
public_url: "https://test.example".to_string(),
|
||||
is_dev: false,
|
||||
http_client,
|
||||
pocketbase_url,
|
||||
pocketbase_admin_email: "admin@test.example".to_string(),
|
||||
pocketbase_admin_password: "test-admin-password".to_string(),
|
||||
gemini_api_key: "test-gemini-key".to_string(),
|
||||
gemini_model: "test-model".to_string(),
|
||||
google_maps_api_key: "test-maps-key".to_string(),
|
||||
stripe_secret_key: "sk_test_dummy".to_string(),
|
||||
stripe_webhook_secret: "whsec_test_secret".to_string(),
|
||||
stripe_referral_coupon_id: "couponTest30".to_string(),
|
||||
bugsink_frontend_config: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wraps AppState for shared access across route handlers.
|
||||
/// Route handlers call `load_state()` to get the current snapshot.
|
||||
pub struct SharedState {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue