Compare commits
4 commits
fe46cb3379
...
be02fc16bb
| Author | SHA1 | Date | |
|---|---|---|---|
| be02fc16bb | |||
| 4c95815dc8 | |||
| 584b053a23 | |||
| dd9f00b105 |
57 changed files with 4604 additions and 904 deletions
|
|
@ -52,7 +52,7 @@ services:
|
|||
|
||||
screenshot:
|
||||
init: true
|
||||
build: /volumes/syncthing/Projects/property-map/screenshot
|
||||
build: ./screenshot
|
||||
environment:
|
||||
PORT: "8002"
|
||||
APP_URL: http://frontend:3001
|
||||
|
|
|
|||
|
|
@ -197,7 +197,7 @@ export default function App() {
|
|||
setShowLicenseSuccess(true);
|
||||
}
|
||||
// Always refresh auth on startup to pick up server-side subscription changes
|
||||
refreshAuth().catch(() => { });
|
||||
refreshAuth().catch(() => {});
|
||||
}, []); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
const savedSearches = useSavedSearches(user?.id ?? null);
|
||||
|
|
@ -271,8 +271,8 @@ export default function App() {
|
|||
{ page: activePage },
|
||||
'',
|
||||
pageToPath(activePage, inviteCode ?? undefined) +
|
||||
window.location.search +
|
||||
window.location.hash
|
||||
window.location.search +
|
||||
window.location.hash
|
||||
);
|
||||
}
|
||||
const handlePopState = (e: PopStateEvent) => {
|
||||
|
|
@ -355,8 +355,8 @@ export default function App() {
|
|||
initialLoading={initialLoading}
|
||||
theme={theme}
|
||||
pendingInfoFeature={null}
|
||||
onClearPendingInfoFeature={() => { }}
|
||||
onNavigateTo={() => { }}
|
||||
onClearPendingInfoFeature={() => {}}
|
||||
onNavigateTo={() => {}}
|
||||
screenshotMode
|
||||
ogMode={isOgMode}
|
||||
initialTravelTime={urlState.travelTime}
|
||||
|
|
|
|||
|
|
@ -119,10 +119,7 @@ export function DualHistogram({
|
|||
})}
|
||||
</div>
|
||||
{showMeanMarker && (
|
||||
<div
|
||||
className="pointer-events-none absolute inset-y-0"
|
||||
style={{ left: `${meanPct}%` }}
|
||||
>
|
||||
<div className="pointer-events-none absolute inset-y-0" style={{ left: `${meanPct}%` }}>
|
||||
<div
|
||||
className="absolute top-0 max-w-[7rem] truncate rounded-sm border border-warm-300 bg-white px-1 py-0.5 text-[9px] font-medium leading-none text-warm-600 shadow-sm dark:border-warm-600 dark:bg-navy-900 dark:text-warm-300"
|
||||
style={meanLabelStyle}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import { useTranslation } from 'react-i18next';
|
||||
import type { FeatureMeta } from '../../types';
|
||||
import { EyeIcon, InfoIcon, PlusIcon, CloseIcon } from './icons';
|
||||
import { IconButton } from './IconButton';
|
||||
|
|
@ -11,6 +12,7 @@ interface FeatureActionsProps {
|
|||
onShowInfo?: (feature: FeatureMeta) => void;
|
||||
onRemove?: (name: string) => void;
|
||||
onAdd?: (name: string) => void;
|
||||
showText?: boolean;
|
||||
}
|
||||
|
||||
export function FeatureActions({
|
||||
|
|
@ -22,36 +24,59 @@ export function FeatureActions({
|
|||
onShowInfo,
|
||||
onRemove,
|
||||
onAdd,
|
||||
showText = false,
|
||||
}: FeatureActionsProps) {
|
||||
const { t } = useTranslation();
|
||||
const isEyeActive = isPinned || isPreviewing;
|
||||
const callbackName = actionName ?? feature.name;
|
||||
const mapLabel = isPinned ? t('filters.clearColourMap') : t('filters.colourMap');
|
||||
|
||||
return (
|
||||
<div className="flex items-center gap-0.5 shrink-0">
|
||||
{feature.detail && onShowInfo && (
|
||||
<IconButton onClick={() => onShowInfo(feature)} title="Feature info" size="md">
|
||||
<InfoIcon className="w-5 h-5 md:w-3.5 md:h-3.5" />
|
||||
{feature.detail &&
|
||||
onShowInfo &&
|
||||
(showText ? (
|
||||
<IconButton onClick={() => onShowInfo(feature)} title={t('filters.aboutData')} size="md">
|
||||
<InfoIcon className="w-4 h-4" />
|
||||
</IconButton>
|
||||
) : (
|
||||
<IconButton onClick={() => onShowInfo(feature)} title={t('filters.aboutData')} size="md">
|
||||
<InfoIcon className="w-5 h-5 md:w-3.5 md:h-3.5" />
|
||||
</IconButton>
|
||||
))}
|
||||
{showText ? (
|
||||
<IconButton
|
||||
onClick={() => onTogglePin(callbackName)}
|
||||
title={mapLabel}
|
||||
active={isEyeActive}
|
||||
size="md"
|
||||
>
|
||||
<EyeIcon filled={isEyeActive} className="w-4 h-4" />
|
||||
</IconButton>
|
||||
) : (
|
||||
<IconButton
|
||||
onClick={() => onTogglePin(callbackName)}
|
||||
title={mapLabel}
|
||||
active={isEyeActive}
|
||||
size="md"
|
||||
>
|
||||
<EyeIcon filled={isEyeActive} className="w-5 h-5 md:w-3.5 md:h-3.5" />
|
||||
</IconButton>
|
||||
)}
|
||||
<IconButton
|
||||
onClick={() => onTogglePin(callbackName)}
|
||||
title={isPinned ? 'Unpin colour view' : 'Colour map by this feature'}
|
||||
active={isEyeActive}
|
||||
size="md"
|
||||
>
|
||||
<EyeIcon filled={isEyeActive} className="w-5 h-5 md:w-3.5 md:h-3.5" />
|
||||
</IconButton>
|
||||
{onAdd && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => onAdd(callbackName)}
|
||||
title="Add filter"
|
||||
className="p-1 rounded-md text-teal-600 dark:text-teal-400 bg-teal-50 dark:bg-teal-900/30 hover:bg-teal-100 dark:hover:bg-teal-800/40"
|
||||
title={t('filters.addFilterLabel')}
|
||||
aria-label={t('filters.addFilterLabel')}
|
||||
className="inline-flex items-center gap-1 rounded-md bg-teal-50 px-2 py-1 text-xs font-semibold text-teal-700 hover:bg-teal-100 dark:bg-teal-900/30 dark:text-teal-300 dark:hover:bg-teal-800/40"
|
||||
>
|
||||
<PlusIcon className="w-5 h-5 md:w-5 md:h-5" strokeWidth={2.5} />
|
||||
<PlusIcon className="w-4 h-4" strokeWidth={2.5} />
|
||||
<span>{t('filters.addFilterAction')}</span>
|
||||
</button>
|
||||
)}
|
||||
{onRemove && (
|
||||
<IconButton onClick={() => onRemove(callbackName)} title="Remove filter">
|
||||
<IconButton onClick={() => onRemove(callbackName)} title={t('filters.removeFilter')}>
|
||||
<CloseIcon className="w-3.5 h-3.5" />
|
||||
</IconButton>
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -45,7 +45,8 @@ export function FeatureLabel({
|
|||
<button
|
||||
onClick={() => onShowInfo(feature)}
|
||||
className="p-1 -m-0.5 rounded text-warm-400 hover:text-warm-700 dark:hover:text-warm-300 hover:bg-warm-100 dark:hover:bg-warm-700 shrink-0"
|
||||
title={t('filters.featureInfo')}
|
||||
title={t('filters.aboutData')}
|
||||
aria-label={t('filters.aboutData')}
|
||||
>
|
||||
<InfoIcon className="w-3.5 h-3.5" />
|
||||
</button>
|
||||
|
|
|
|||
|
|
@ -97,8 +97,8 @@ export default function Header({
|
|||
const [copied, setCopied] = useState(false);
|
||||
const [sharing, setSharing] = useState(false);
|
||||
const [menuOpen, setMenuOpen] = useState(false);
|
||||
const [isDashboardTabletSidebarWidth, setIsDashboardTabletSidebarWidth] = useState(() =>
|
||||
window.matchMedia(DASHBOARD_TABLET_SIDEBAR_QUERY).matches
|
||||
const [isDashboardTabletSidebarWidth, setIsDashboardTabletSidebarWidth] = useState(
|
||||
() => window.matchMedia(DASHBOARD_TABLET_SIDEBAR_QUERY).matches
|
||||
);
|
||||
const useSidebarNav = isMobile || (activePage === 'dashboard' && isDashboardTabletSidebarWidth);
|
||||
|
||||
|
|
|
|||
|
|
@ -268,27 +268,27 @@ export function useDeckLayers({
|
|||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const pieProps: any = isEnum
|
||||
? {
|
||||
extensions: [new PieHexExtension(requireEnumPalette(enumPaletteRef.current))],
|
||||
getCenter: (d: HexagonData) => [d.lon, d.lat],
|
||||
getRatios0: (d: HexagonData) => {
|
||||
const r = distToRatios(d[distKey]);
|
||||
return [r[0], r[1], r[2], r[3]];
|
||||
},
|
||||
getRatios1: (d: HexagonData) => {
|
||||
const r = distToRatios(d[distKey]);
|
||||
return [r[4], r[5], r[6], r[7]];
|
||||
},
|
||||
getRatios2: (d: HexagonData) => {
|
||||
const r = distToRatios(d[distKey]);
|
||||
return [r[8], r[9]];
|
||||
},
|
||||
updateTriggers: {
|
||||
getCenter: [colorTrigger, data],
|
||||
getRatios0: [colorTrigger, data],
|
||||
getRatios1: [colorTrigger, data],
|
||||
getRatios2: [colorTrigger, data],
|
||||
},
|
||||
}
|
||||
extensions: [new PieHexExtension(requireEnumPalette(enumPaletteRef.current))],
|
||||
getCenter: (d: HexagonData) => [d.lon, d.lat],
|
||||
getRatios0: (d: HexagonData) => {
|
||||
const r = distToRatios(d[distKey]);
|
||||
return [r[0], r[1], r[2], r[3]];
|
||||
},
|
||||
getRatios1: (d: HexagonData) => {
|
||||
const r = distToRatios(d[distKey]);
|
||||
return [r[4], r[5], r[6], r[7]];
|
||||
},
|
||||
getRatios2: (d: HexagonData) => {
|
||||
const r = distToRatios(d[distKey]);
|
||||
return [r[8], r[9]];
|
||||
},
|
||||
updateTriggers: {
|
||||
getCenter: [colorTrigger, data],
|
||||
getRatios0: [colorTrigger, data],
|
||||
getRatios1: [colorTrigger, data],
|
||||
getRatios2: [colorTrigger, data],
|
||||
},
|
||||
}
|
||||
: {};
|
||||
|
||||
return new H3HexagonLayer<HexagonData>({
|
||||
|
|
|
|||
|
|
@ -6,19 +6,16 @@ export function useLicense() {
|
|||
const [checkingOut, setCheckingOut] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const startCheckout = useCallback(async (referralCode?: string) => {
|
||||
trackEvent('Checkout Start', { has_referral: String(!!referralCode) });
|
||||
const startCheckout = useCallback(async () => {
|
||||
trackEvent('Checkout Start', { has_referral: 'false' });
|
||||
setCheckingOut(true);
|
||||
setError(null);
|
||||
try {
|
||||
const body: Record<string, string> = {};
|
||||
if (referralCode) body.referral_code = referralCode;
|
||||
|
||||
const res = await fetch(apiUrl('checkout'), {
|
||||
method: 'POST',
|
||||
...authHeaders({
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(body),
|
||||
body: JSON.stringify({}),
|
||||
}),
|
||||
});
|
||||
assertOk(res, 'Checkout');
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ export function useLocationSearch(mode?: string) {
|
|||
const [activeIndex, setActiveIndex] = useState(-1);
|
||||
const [open, setOpen] = useState(false);
|
||||
const abortRef = useRef<AbortController | null>(null);
|
||||
const debounceRef = useRef<ReturnType<typeof setTimeout>>();
|
||||
const debounceRef = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||
const latestQueryRef = useRef('');
|
||||
const lastResultsRef = useRef<SearchResult[]>([]);
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ import Supercluster from 'supercluster';
|
|||
import type { POI } from '../types';
|
||||
import {
|
||||
POI_GROUP_COLORS,
|
||||
POI_DEFAULT_COLOR,
|
||||
MINOR_POI_CATEGORIES,
|
||||
MINOR_POI_ZOOM_THRESHOLD,
|
||||
POI_CLUSTER_RADIUS,
|
||||
|
|
@ -40,6 +39,30 @@ interface UsePoiLayersProps {
|
|||
isDark: boolean;
|
||||
}
|
||||
|
||||
function getPoiIconUrlForPoi(poi: POI): string {
|
||||
return getPoiIconUrl(poi.category, poi.emoji, poi.icon_category, poi.name);
|
||||
}
|
||||
|
||||
function isBundledPoiIcon(url: string): boolean {
|
||||
return url.startsWith('/assets/poi-icons/');
|
||||
}
|
||||
|
||||
function hasBundledPoiLogo(poi: POI): boolean {
|
||||
return isBundledPoiIcon(getPoiIconUrlForPoi(poi));
|
||||
}
|
||||
|
||||
function getPoiGroupColor(group: string): [number, number, number] {
|
||||
const color = POI_GROUP_COLORS[group];
|
||||
if (!color) {
|
||||
throw new Error(`Missing POI group color for '${group}'`);
|
||||
}
|
||||
return color;
|
||||
}
|
||||
|
||||
function getPoiIconSize(poi: POI): number {
|
||||
return hasBundledPoiLogo(poi) ? 24 : 18;
|
||||
}
|
||||
|
||||
export function usePoiLayers({ pois, zoom, isDark }: UsePoiLayersProps) {
|
||||
const [popupInfo, setPopupInfo] = useState<PopupInfo | null>(null);
|
||||
|
||||
|
|
@ -139,7 +162,7 @@ export function usePoiLayers({ pois, zoom, isDark }: UsePoiLayersProps) {
|
|||
id: 'poi-shadow',
|
||||
data: visiblePois,
|
||||
getPosition: (d) => [d.lng, d.lat],
|
||||
getRadius: 16,
|
||||
getRadius: (d) => (hasBundledPoiLogo(d) ? 0 : 16),
|
||||
radiusUnits: 'pixels',
|
||||
getFillColor: isDark ? [0, 0, 0, 50] : [0, 0, 0, 25],
|
||||
pickable: false,
|
||||
|
|
@ -154,11 +177,17 @@ export function usePoiLayers({ pois, zoom, isDark }: UsePoiLayersProps) {
|
|||
id: 'poi-background',
|
||||
data: visiblePois,
|
||||
getPosition: (d) => [d.lng, d.lat],
|
||||
getRadius: 14,
|
||||
getRadius: (d) => (hasBundledPoiLogo(d) ? 24 : 14),
|
||||
radiusUnits: 'pixels',
|
||||
getFillColor: isDark ? [41, 37, 36, 255] : [255, 255, 255, 255],
|
||||
getFillColor: (d) =>
|
||||
hasBundledPoiLogo(d)
|
||||
? ([0, 0, 0, 0] as [number, number, number, number])
|
||||
: isDark
|
||||
? ([41, 37, 36, 255] as [number, number, number, number])
|
||||
: ([255, 255, 255, 255] as [number, number, number, number]),
|
||||
getLineColor: (d) => {
|
||||
const c = POI_GROUP_COLORS[d.group] || POI_DEFAULT_COLOR;
|
||||
if (hasBundledPoiLogo(d)) return [0, 0, 0, 0] as [number, number, number, number];
|
||||
const c = getPoiGroupColor(d.group);
|
||||
return [c[0], c[1], c[2], 255] as [number, number, number, number];
|
||||
},
|
||||
getLineWidth: 2.5,
|
||||
|
|
@ -177,12 +206,16 @@ export function usePoiLayers({ pois, zoom, isDark }: UsePoiLayersProps) {
|
|||
id: 'poi-icons',
|
||||
data: visiblePois,
|
||||
getPosition: (d) => [d.lng, d.lat],
|
||||
getIcon: (d) => ({
|
||||
url: getPoiIconUrl(d.category, d.emoji, d.icon_category, d.name),
|
||||
width: 72,
|
||||
height: 72,
|
||||
}),
|
||||
getSize: 18,
|
||||
getIcon: (d) => {
|
||||
const url = getPoiIconUrlForPoi(d);
|
||||
const isLogo = isBundledPoiIcon(url);
|
||||
return {
|
||||
url,
|
||||
width: isLogo ? 96 : 72,
|
||||
height: isLogo ? 48 : 72,
|
||||
};
|
||||
},
|
||||
getSize: getPoiIconSize,
|
||||
sizeUnits: 'pixels',
|
||||
pickable: false,
|
||||
transitions: { getSize: { duration: 300, enter: () => [0] } },
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ const descriptions: Record<string, Record<string, string>> = {
|
|||
'Property type': 'Type de bien : individuel, jumelé, mitoyen, appartement ou autre',
|
||||
'Leasehold/Freehold': 'Indique si le bien est en bail ou en pleine propriété',
|
||||
'Last known price': 'Dernier prix de vente enregistré au Land Registry',
|
||||
'Estimated current price': 'Estimation du prix actuel ajusté à l’inflation',
|
||||
'Estimated current price': 'Estimation modélisée du prix actuel',
|
||||
'Price per sqm': 'Prix de vente divisé par la surface totale',
|
||||
'Est. price per sqm': 'Prix actuel estimé divisé par la surface totale',
|
||||
'Estimated monthly rent': 'Loyer mensuel privé moyen pour le secteur',
|
||||
|
|
@ -48,14 +48,15 @@ const descriptions: Record<string, Record<string, string>> = {
|
|||
'Outstanding secondary schools within 5km':
|
||||
'Collèges/lycées notés Excellent par Ofsted dans un rayon de 5 km',
|
||||
'Education, Skills and Training Score':
|
||||
'Score de qualité éducative du secteur (plus élevé = meilleur)',
|
||||
'Income Score': 'Taux de précarité de revenu, inversé (plus élevé = moins précaire)',
|
||||
'Employment Score': 'Taux de précarité d’emploi, inversé (plus élevé = moins précaire)',
|
||||
'Centile de défaveur éducative (plus élevé = moins défavorisé)',
|
||||
'Income Score': 'Centile de défaveur de revenu (plus élevé = moins défavorisé)',
|
||||
'Employment Score': 'Centile de défaveur d’emploi (plus élevé = moins défavorisé)',
|
||||
'Health Deprivation and Disability Score':
|
||||
'Score de santé et handicap (plus élevé = meilleurs résultats)',
|
||||
'Housing Conditions Score': 'Qualité et état du logement (plus élevé = meilleur)',
|
||||
'Centile de défaveur santé et handicap (plus élevé = meilleurs résultats)',
|
||||
'Housing Conditions Score':
|
||||
'Centile des conditions de logement (plus élevé = meilleures conditions)',
|
||||
'Air Quality and Road Safety Score':
|
||||
'Qualité de l’air et sécurité routière (plus élevé = meilleur)',
|
||||
'Centile air et sécurité routière (plus élevé = meilleures conditions)',
|
||||
'Serious crime per 1k residents (avg/yr)': 'Taux de crimes graves pour 1 000 habitants par an',
|
||||
'Minor crime per 1k residents (avg/yr)': 'Taux de délits mineurs pour 1 000 habitants par an',
|
||||
'Serious crime (avg/yr)': 'Agrégat des catégories de crimes graves par an',
|
||||
|
|
@ -107,7 +108,7 @@ const descriptions: Record<string, Record<string, string>> = {
|
|||
'Immobilientyp: freistehend, Doppelhaushälfte, Reihenhaus, Wohnung oder sonstige',
|
||||
'Leasehold/Freehold': 'Ob die Immobilie Erbbaurecht oder Volleigentum ist',
|
||||
'Last known price': 'Letzter Verkaufspreis laut Land Registry',
|
||||
'Estimated current price': 'Inflationsbereinigter Schätzwert der Immobilie',
|
||||
'Estimated current price': 'Modellierter aktueller Schätzwert der Immobilie',
|
||||
'Price per sqm': 'Verkaufspreis geteilt durch die Gesamtfläche',
|
||||
'Est. price per sqm': 'Geschätzter aktueller Preis geteilt durch die Gesamtfläche',
|
||||
'Estimated monthly rent': 'Durchschnittliche monatliche Privatmiete in der Gegend',
|
||||
|
|
@ -137,14 +138,15 @@ const descriptions: Record<string, Record<string, string>> = {
|
|||
'Von Ofsted mit Hervorragend bewertete Grundschulen im Umkreis von 5 km',
|
||||
'Outstanding secondary schools within 5km':
|
||||
'Von Ofsted mit Hervorragend bewertete weiterführende Schulen im Umkreis von 5 km',
|
||||
'Education, Skills and Training Score': 'Bildungsqualitätsscore der Gegend (höher = besser)',
|
||||
'Income Score': 'Einkommensbenachteiligungsrate, invertiert (höher = weniger benachteiligt)',
|
||||
'Employment Score':
|
||||
'Beschäftigungsbenachteiligungsrate, invertiert (höher = weniger benachteiligt)',
|
||||
'Education, Skills and Training Score':
|
||||
'Bildungs- und Ausbildungsbenachteiligungs-Perzentil (höher = weniger benachteiligt)',
|
||||
'Income Score': 'Einkommensbenachteiligungs-Perzentil (höher = weniger benachteiligt)',
|
||||
'Employment Score': 'Beschäftigungsbenachteiligungs-Perzentil (höher = weniger benachteiligt)',
|
||||
'Health Deprivation and Disability Score':
|
||||
'Gesundheits- und Behinderungsscore (höher = bessere Ergebnisse)',
|
||||
'Housing Conditions Score': 'Wohnqualität und -zustand (höher = besser)',
|
||||
'Air Quality and Road Safety Score': 'Luftqualität und Verkehrssicherheit (höher = besser)',
|
||||
'Gesundheits- und Behinderungsbenachteiligungs-Perzentil (höher = bessere Ergebnisse)',
|
||||
'Housing Conditions Score': 'Perzentil der Wohnbedingungen (höher = bessere Bedingungen)',
|
||||
'Air Quality and Road Safety Score':
|
||||
'Perzentil für Luftqualität und Verkehrssicherheit (höher = bessere Bedingungen)',
|
||||
'Serious crime per 1k residents (avg/yr)':
|
||||
'Rate schwerer Straftaten pro 1.000 Einwohner pro Jahr',
|
||||
'Minor crime per 1k residents (avg/yr)':
|
||||
|
|
@ -199,7 +201,7 @@ const descriptions: Record<string, Record<string, string>> = {
|
|||
'Property type': '房产类型:独立式、半独立式、联排、公寓或其他',
|
||||
'Leasehold/Freehold': '该房产是租赁产权还是永久产权',
|
||||
'Last known price': 'Land Registry记录的最近一次售价',
|
||||
'Estimated current price': '经通胀调整后的当前估计价值',
|
||||
'Estimated current price': '模型估算的当前价格',
|
||||
'Price per sqm': '售价除以总建筑面积',
|
||||
'Est. price per sqm': '估计当前价格除以总建筑面积',
|
||||
'Estimated monthly rent': '当地私人租赁的平均月租',
|
||||
|
|
@ -220,12 +222,12 @@ const descriptions: Record<string, Record<string, string>> = {
|
|||
'Outstanding secondary schools within 2km': 'Ofsted评为优秀的2公里内中学',
|
||||
'Outstanding primary schools within 5km': 'Ofsted评为优秀的5公里内小学',
|
||||
'Outstanding secondary schools within 5km': 'Ofsted评为优秀的5公里内中学',
|
||||
'Education, Skills and Training Score': '当地教育质量得分(越高越好)',
|
||||
'Income Score': '收入贫困率,反向指标(越高越不贫困)',
|
||||
'Employment Score': '就业贫困率,反向指标(越高越不贫困)',
|
||||
'Health Deprivation and Disability Score': '健康与残障得分(越高健康状况越好)',
|
||||
'Housing Conditions Score': '住房质量和状况(越高越好)',
|
||||
'Air Quality and Road Safety Score': '空气质量和道路安全(越高越好)',
|
||||
'Education, Skills and Training Score': '教育与技能贫困百分位(越高越不贫困)',
|
||||
'Income Score': '收入贫困百分位(越高越不贫困)',
|
||||
'Employment Score': '就业贫困百分位(越高越不贫困)',
|
||||
'Health Deprivation and Disability Score': '健康与残障贫困百分位(越高结果越好)',
|
||||
'Housing Conditions Score': '住房条件百分位(越高条件越好)',
|
||||
'Air Quality and Road Safety Score': '空气质量和道路安全百分位(越高条件越好)',
|
||||
'Serious crime per 1k residents (avg/yr)': '每千人每年严重犯罪率',
|
||||
'Minor crime per 1k residents (avg/yr)': '每千人每年轻微犯罪率',
|
||||
'Serious crime (avg/yr)': '严重犯罪类别年度总计',
|
||||
|
|
@ -269,7 +271,7 @@ const descriptions: Record<string, Record<string, string>> = {
|
|||
'Property type': 'संपत्ति प्रकार: अलग, अर्ध-स्वतंत्र, कतारबद्ध, फ्लैट या अन्य',
|
||||
'Leasehold/Freehold': 'बताता है कि संपत्ति लीजहोल्ड है या फ्रीहोल्ड',
|
||||
'Last known price': 'Land Registry में दर्ज अंतिम बिक्री कीमत',
|
||||
'Estimated current price': 'महंगाई और स्थानीय कीमत बदलाव से समायोजित मौजूदा अनुमानित मूल्य',
|
||||
'Estimated current price': 'मॉडल से अनुमानित मौजूदा मूल्य',
|
||||
'Price per sqm': 'बिक्री कीमत को कुल फर्श क्षेत्र से विभाजित किया गया',
|
||||
'Est. price per sqm': 'मौजूदा अनुमानित कीमत को कुल फर्श क्षेत्र से विभाजित किया गया',
|
||||
'Estimated monthly rent': 'क्षेत्र का औसत निजी मासिक किराया',
|
||||
|
|
@ -292,12 +294,14 @@ const descriptions: Record<string, Record<string, string>> = {
|
|||
'Outstanding secondary schools within 2km': '2 किमी के भीतर Ofsted Outstanding सेकेंडरी स्कूल',
|
||||
'Outstanding primary schools within 5km': '5 किमी के भीतर Ofsted Outstanding प्राइमरी स्कूल',
|
||||
'Outstanding secondary schools within 5km': '5 किमी के भीतर Ofsted Outstanding सेकेंडरी स्कूल',
|
||||
'Education, Skills and Training Score': 'स्थानीय शिक्षा गुणवत्ता स्कोर (अधिक = बेहतर)',
|
||||
'Income Score': 'आय वंचना दर, उलटी की गई (अधिक = कम वंचना)',
|
||||
'Employment Score': 'रोजगार वंचना दर, उलटी की गई (अधिक = कम वंचना)',
|
||||
'Health Deprivation and Disability Score': 'स्वास्थ्य और विकलांगता स्कोर (अधिक = बेहतर परिणाम)',
|
||||
'Housing Conditions Score': 'आवास गुणवत्ता और स्थिति (अधिक = बेहतर)',
|
||||
'Air Quality and Road Safety Score': 'हवा की गुणवत्ता और सड़क सुरक्षा (अधिक = बेहतर)',
|
||||
'Education, Skills and Training Score': 'शिक्षा और कौशल वंचना percentile (अधिक = कम वंचना)',
|
||||
'Income Score': 'आय वंचना percentile (अधिक = कम वंचना)',
|
||||
'Employment Score': 'रोजगार वंचना percentile (अधिक = कम वंचना)',
|
||||
'Health Deprivation and Disability Score':
|
||||
'स्वास्थ्य और विकलांगता वंचना percentile (अधिक = बेहतर परिणाम)',
|
||||
'Housing Conditions Score': 'आवास स्थिति percentile (अधिक = बेहतर स्थिति)',
|
||||
'Air Quality and Road Safety Score':
|
||||
'हवा की गुणवत्ता और सड़क सुरक्षा percentile (अधिक = बेहतर स्थिति)',
|
||||
'Serious crime per 1k residents (avg/yr)': 'प्रति 1,000 निवासियों सालाना गंभीर अपराध दर',
|
||||
'Minor crime per 1k residents (avg/yr)': 'प्रति 1,000 निवासियों सालाना मामूली अपराध दर',
|
||||
'Serious crime (avg/yr)': 'गंभीर अपराध श्रेणियों का सालाना कुल',
|
||||
|
|
@ -342,7 +346,7 @@ const descriptions: Record<string, Record<string, string>> = {
|
|||
'Property type': 'Ingatlantípus: különálló, ikerház, sorház, lakás vagy egyéb',
|
||||
'Leasehold/Freehold': 'Az ingatlan bérleti jogú vagy teljes tulajdonú',
|
||||
'Last known price': 'A Land Registry-ben rögzített utolsó eladási ár',
|
||||
'Estimated current price': 'Inflációval korrigált becsült jelenlegi érték',
|
||||
'Estimated current price': 'Modellezett becsült jelenlegi érték',
|
||||
'Price per sqm': 'Eladási ár osztva az összes alapterülettel',
|
||||
'Est. price per sqm': 'Becsült jelenlegi ár osztva az összes alapterülettel',
|
||||
'Estimated monthly rent': 'A környék átlagos havi magánbérleti díja',
|
||||
|
|
@ -374,14 +378,14 @@ const descriptions: Record<string, Record<string, string>> = {
|
|||
'Outstanding secondary schools within 5km':
|
||||
'Ofsted által Kiváló minősítésű középiskolák 5 km-en belül',
|
||||
'Education, Skills and Training Score':
|
||||
'A környék oktatási minőségi pontszáma (magasabb = jobb)',
|
||||
'Income Score': 'Jövedelmi deprivációs ráta, invertálva (magasabb = kevésbé hátrányos)',
|
||||
'Employment Score':
|
||||
'Foglalkoztatási deprivációs ráta, invertálva (magasabb = kevésbé hátrányos)',
|
||||
'Oktatási és készségbeli deprivációs percentilis (magasabb = kevésbé hátrányos)',
|
||||
'Income Score': 'Jövedelmi deprivációs percentilis (magasabb = kevésbé hátrányos)',
|
||||
'Employment Score': 'Foglalkoztatási deprivációs percentilis (magasabb = kevésbé hátrányos)',
|
||||
'Health Deprivation and Disability Score':
|
||||
'Egészségügyi és fogyatékossági pontszám (magasabb = jobb eredmények)',
|
||||
'Housing Conditions Score': 'Lakásminőség és állapot (magasabb = jobb)',
|
||||
'Air Quality and Road Safety Score': 'Levegőminőség és közlekedésbiztonság (magasabb = jobb)',
|
||||
'Egészségügyi és fogyatékossági deprivációs percentilis (magasabb = jobb eredmények)',
|
||||
'Housing Conditions Score': 'Lakáskörülmények percentilise (magasabb = jobb körülmények)',
|
||||
'Air Quality and Road Safety Score':
|
||||
'Levegőminőség és közlekedésbiztonság percentilise (magasabb = jobb körülmények)',
|
||||
'Serious crime per 1k residents (avg/yr)': 'Súlyos bűncselekmények aránya 1000 lakosra évente',
|
||||
'Minor crime per 1k residents (avg/yr)': 'Kisebb bűncselekmények aránya 1000 lakosra évente',
|
||||
'Serious crime (avg/yr)': 'Súlyos bűncselekményi kategóriák éves összesítése',
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
@config "../tailwind.config.js";
|
||||
|
||||
@import "tailwindcss";
|
||||
@import 'tailwindcss';
|
||||
|
||||
html,
|
||||
body,
|
||||
|
|
@ -66,7 +66,7 @@ h3 {
|
|||
}
|
||||
|
||||
.home-content-surface {
|
||||
--home-hex-pattern: url("/home-hex-pattern.svg");
|
||||
--home-hex-pattern: url('/home-hex-pattern.svg');
|
||||
--home-pointer-active: 0;
|
||||
--home-pointer-x: 50%;
|
||||
--home-pointer-y: 50%;
|
||||
|
|
@ -124,7 +124,7 @@ h3 {
|
|||
}
|
||||
|
||||
.dark .home-content-surface {
|
||||
--home-hex-pattern: url("/home-hex-pattern-dark.svg");
|
||||
--home-hex-pattern: url('/home-hex-pattern-dark.svg');
|
||||
background: linear-gradient(180deg, #121827 0%, #0a0e1a 42%, #10211f 100%);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import {
|
|||
ZOOM_TO_RESOLUTION_THRESHOLDS,
|
||||
TWEMOJI_BASE,
|
||||
BUFFER_MULTIPLIER,
|
||||
ENUM_PALETTE,
|
||||
POI_CATEGORY_LOGOS,
|
||||
type GradientStop,
|
||||
} from './consts';
|
||||
|
|
@ -78,19 +77,21 @@ export function getMapStyle(theme: 'light' | 'dark'): StyleSpecification {
|
|||
// In dark mode, make all text white with dark outline
|
||||
const modifiedLayers = baseLayers
|
||||
.filter((layer) => !layer.id.includes('buildings'))
|
||||
.map((layer) => {
|
||||
.map((original) => {
|
||||
let layer = original;
|
||||
|
||||
// Modify road opacity
|
||||
if (layer.id.includes('roads_') || layer.id.includes('road_')) {
|
||||
if (layer.type === 'line') {
|
||||
return { ...layer, paint: { ...layer.paint, 'line-opacity': ROAD_OPACITY } };
|
||||
layer = { ...layer, paint: { ...layer.paint, 'line-opacity': ROAD_OPACITY } };
|
||||
} else if (layer.type === 'fill') {
|
||||
return { ...layer, paint: { ...layer.paint, 'fill-opacity': ROAD_OPACITY } };
|
||||
layer = { ...layer, paint: { ...layer.paint, 'fill-opacity': ROAD_OPACITY } };
|
||||
}
|
||||
}
|
||||
|
||||
// Modify text colors in dark mode
|
||||
if (isDark && layer.type === 'symbol' && layer.paint?.['text-color']) {
|
||||
return {
|
||||
layer = {
|
||||
...layer,
|
||||
paint: {
|
||||
...layer.paint,
|
||||
|
|
@ -234,9 +235,32 @@ export function getBoundsFromViewState(
|
|||
return { south, west, north, east };
|
||||
}
|
||||
|
||||
export function getLatitudeAtVerticalPixelOffset(
|
||||
latitude: number,
|
||||
zoom: number,
|
||||
pixelOffsetY: number
|
||||
): number {
|
||||
const worldSize = TILE_SIZE * Math.pow(2, zoom);
|
||||
const pixelY = latitudeToWorldY(latitude, worldSize) + pixelOffsetY;
|
||||
return worldYToLatitude(pixelY, worldSize);
|
||||
}
|
||||
|
||||
export function getBoundsWithBottomScreenInset(
|
||||
bounds: [number, number, number, number],
|
||||
zoom: number,
|
||||
bottomInsetPx: number
|
||||
): [number, number, number, number] {
|
||||
if (bottomInsetPx <= 0) return bounds;
|
||||
|
||||
const [west, south, east, north] = bounds;
|
||||
return [west, getLatitudeAtVerticalPixelOffset(south, zoom, bottomInsetPx), east, north];
|
||||
}
|
||||
|
||||
export function emojiToTwemojiUrl(emoji: string): string {
|
||||
const codePoint = emoji.codePointAt(0);
|
||||
if (!codePoint) return `${TWEMOJI_BASE}1f4cd.png`;
|
||||
if (!codePoint) {
|
||||
throw new Error('Cannot build a Twemoji URL without an emoji');
|
||||
}
|
||||
const hex = codePoint.toString(16);
|
||||
return `${TWEMOJI_BASE}${hex}.png`;
|
||||
}
|
||||
|
|
@ -287,7 +311,7 @@ function inferPoiIconCategory(category: string, name?: string): string | undefin
|
|||
|
||||
export function getPoiIconUrl(
|
||||
category: string,
|
||||
emoji: string,
|
||||
_emoji: string,
|
||||
iconCategory?: string,
|
||||
name?: string
|
||||
): string {
|
||||
|
|
@ -295,13 +319,17 @@ export function getPoiIconUrl(
|
|||
if (resolvedIconCategory && POI_CATEGORY_LOGOS[resolvedIconCategory]) {
|
||||
return POI_CATEGORY_LOGOS[resolvedIconCategory];
|
||||
}
|
||||
return POI_CATEGORY_LOGOS[category] ?? emojiToTwemojiUrl(emoji);
|
||||
const categoryLogo = POI_CATEGORY_LOGOS[category];
|
||||
if (!categoryLogo) {
|
||||
throw new Error(`Missing POI icon for category '${category}'`);
|
||||
}
|
||||
return categoryLogo;
|
||||
}
|
||||
|
||||
/** Look up a discrete color from the enum palette by index (wraps if > palette size). */
|
||||
export function enumIndexToColor(
|
||||
index: number,
|
||||
palette: [number, number, number][] = ENUM_PALETTE
|
||||
palette: [number, number, number][]
|
||||
): [number, number, number] {
|
||||
const i = Math.round(Math.max(0, index)) % palette.length;
|
||||
return palette[i];
|
||||
|
|
@ -324,7 +352,7 @@ export function getFeatureFillColor(
|
|||
isDark: boolean,
|
||||
alpha: number,
|
||||
enumCount: number = 0,
|
||||
enumPalette?: [number, number, number][],
|
||||
enumPalette?: [number, number, number][] | null,
|
||||
featureGradient: GradientStop[] = FEATURE_GRADIENT
|
||||
): [number, number, number, number] {
|
||||
if (colorRange) {
|
||||
|
|
@ -343,6 +371,9 @@ export function getFeatureFillColor(
|
|||
|
||||
// Discrete coloring for enum features (used as base; PieHexExtension overrides when active)
|
||||
if (enumCount > 0) {
|
||||
if (!enumPalette) {
|
||||
throw new Error('Enum feature fill requested without an enum color palette');
|
||||
}
|
||||
const rgb = enumIndexToColor(Math.round(value as number), enumPalette);
|
||||
return [...rgb, alpha] as [number, number, number, number];
|
||||
}
|
||||
|
|
|
|||
|
|
@ -77,10 +77,6 @@ const POI_FILTER_CONFIGS: Record<
|
|||
},
|
||||
};
|
||||
|
||||
function isPoiFilterNameValue(name: string): name is PoiFilterName {
|
||||
return POI_FILTER_NAMES.includes(name as PoiFilterName);
|
||||
}
|
||||
|
||||
function getConfig(filterName: PoiFilterName) {
|
||||
return POI_FILTER_CONFIGS[filterName];
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import {
|
|||
type TravelTimeEntry,
|
||||
type TravelTimeInitial,
|
||||
} from '../hooks/useTravelTime';
|
||||
import { INITIAL_VIEW_STATE } from './consts';
|
||||
import {
|
||||
SCHOOL_FILTER_NAME,
|
||||
createSchoolFilterKey,
|
||||
|
|
@ -21,13 +22,56 @@ import {
|
|||
isSpecificCrimeFeatureName,
|
||||
isSpecificCrimeFilterName,
|
||||
} from './crime-filter';
|
||||
import {
|
||||
ETHNICITIES_FILTER_NAME,
|
||||
createEthnicityFilterKey,
|
||||
getEthnicityFeatureName,
|
||||
isEthnicityFeatureName,
|
||||
isEthnicityFilterName,
|
||||
} from './ethnicity-filter';
|
||||
import {
|
||||
POI_DISTANCE_FILTER_NAME,
|
||||
POI_COUNT_2KM_FILTER_NAME,
|
||||
POI_COUNT_5KM_FILTER_NAME,
|
||||
createPoiFilterKey,
|
||||
createPoiDistanceFilterKey,
|
||||
getPoiDistanceFeatureName,
|
||||
getPoiFilterName,
|
||||
isPoiDistanceFeatureName,
|
||||
isPoiDistanceFilterName,
|
||||
type PoiFilterName,
|
||||
} from './poi-distance-filter';
|
||||
|
||||
function parseFilters(params: URLSearchParams): FeatureFilters | undefined {
|
||||
const POI_NONE_PARAM = '__none';
|
||||
|
||||
export interface UrlState {
|
||||
viewState: ViewState;
|
||||
filters: FeatureFilters;
|
||||
poiCategories: Set<string>;
|
||||
tab: 'properties' | 'area';
|
||||
travelTime?: TravelTimeInitial;
|
||||
postcode?: string;
|
||||
share?: string;
|
||||
}
|
||||
|
||||
function parseFilters(params: URLSearchParams): FeatureFilters {
|
||||
const filterParams = params.getAll('filter');
|
||||
const schoolParams = params.getAll('school');
|
||||
const crimeParams = params.getAll('crime');
|
||||
if (filterParams.length === 0 && schoolParams.length === 0 && crimeParams.length === 0) {
|
||||
return undefined;
|
||||
const ethnicityParams = params.getAll('ethnicity');
|
||||
const poiDistanceParams = params.getAll('poiDistance');
|
||||
const poiCount2KmParams = params.getAll('poiCount2km');
|
||||
const poiCount5KmParams = params.getAll('poiCount5km');
|
||||
if (
|
||||
filterParams.length === 0 &&
|
||||
schoolParams.length === 0 &&
|
||||
crimeParams.length === 0 &&
|
||||
ethnicityParams.length === 0 &&
|
||||
poiDistanceParams.length === 0 &&
|
||||
poiCount2KmParams.length === 0 &&
|
||||
poiCount5KmParams.length === 0
|
||||
) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const filters: FeatureFilters = {};
|
||||
|
|
@ -82,20 +126,65 @@ function parseFilters(params: URLSearchParams): FeatureFilters | undefined {
|
|||
filters[createSpecificCrimeFilterKey(featureName, index)] = [min, max];
|
||||
});
|
||||
|
||||
return Object.keys(filters).length > 0 ? filters : undefined;
|
||||
ethnicityParams.forEach((entry, index) => {
|
||||
const parts = entry.split(':');
|
||||
if (parts.length < 3) return;
|
||||
const featureName = parts.slice(0, -2).join(':');
|
||||
const min = Number(parts[parts.length - 2]);
|
||||
const max = Number(parts[parts.length - 1]);
|
||||
if (!isEthnicityFeatureName(featureName) || isNaN(min) || isNaN(max)) {
|
||||
return;
|
||||
}
|
||||
filters[createEthnicityFilterKey(featureName, index)] = [min, max];
|
||||
});
|
||||
|
||||
poiDistanceParams.forEach((entry, index) => {
|
||||
const parts = entry.split(':');
|
||||
if (parts.length < 3) return;
|
||||
const featureName = decodeURIComponent(parts.slice(0, -2).join(':'));
|
||||
const min = Number(parts[parts.length - 2]);
|
||||
const max = Number(parts[parts.length - 1]);
|
||||
if (!isPoiDistanceFeatureName(featureName) || isNaN(min) || isNaN(max)) {
|
||||
return;
|
||||
}
|
||||
filters[createPoiDistanceFilterKey(featureName, index)] = [min, max];
|
||||
});
|
||||
|
||||
const parsePoiCountParams = (
|
||||
entries: string[],
|
||||
filterName: PoiFilterName,
|
||||
startIndex: number
|
||||
) => {
|
||||
entries.forEach((entry, index) => {
|
||||
const parts = entry.split(':');
|
||||
if (parts.length < 3) return;
|
||||
const featureName = decodeURIComponent(parts.slice(0, -2).join(':'));
|
||||
const min = Number(parts[parts.length - 2]);
|
||||
const max = Number(parts[parts.length - 1]);
|
||||
if (getPoiFilterName(featureName) !== filterName || isNaN(min) || isNaN(max)) {
|
||||
return;
|
||||
}
|
||||
filters[createPoiFilterKey(filterName, featureName, startIndex + index)] = [min, max];
|
||||
});
|
||||
};
|
||||
parsePoiCountParams(poiCount2KmParams, POI_COUNT_2KM_FILTER_NAME, poiDistanceParams.length);
|
||||
parsePoiCountParams(
|
||||
poiCount5KmParams,
|
||||
POI_COUNT_5KM_FILTER_NAME,
|
||||
poiDistanceParams.length + poiCount2KmParams.length
|
||||
);
|
||||
|
||||
return filters;
|
||||
}
|
||||
|
||||
export function parseUrlState(): {
|
||||
viewState?: ViewState;
|
||||
filters?: FeatureFilters;
|
||||
poiCategories?: Set<string>;
|
||||
tab?: 'properties' | 'area';
|
||||
travelTime?: TravelTimeInitial;
|
||||
postcode?: string;
|
||||
share?: string;
|
||||
} {
|
||||
export function parseUrlState(): UrlState {
|
||||
const params = new URLSearchParams(window.location.search);
|
||||
const result: ReturnType<typeof parseUrlState> = {};
|
||||
const result: UrlState = {
|
||||
viewState: INITIAL_VIEW_STATE,
|
||||
filters: parseFilters(params),
|
||||
poiCategories: new Set(),
|
||||
tab: 'area',
|
||||
};
|
||||
|
||||
// Share-link code: grants bbox-scoped access to the area the link references
|
||||
// even for unlicensed users. The backend looks the code up against PocketBase.
|
||||
|
|
@ -117,13 +206,16 @@ export function parseUrlState(): {
|
|||
}
|
||||
}
|
||||
|
||||
// Filters: repeated `filter` params
|
||||
result.filters = parseFilters(params);
|
||||
|
||||
// POI categories: repeated `poi` params
|
||||
const poiParams = params.getAll('poi');
|
||||
if (poiParams.length > 0) {
|
||||
result.poiCategories = new Set(poiParams.filter(Boolean));
|
||||
if (poiParams.includes(POI_NONE_PARAM)) {
|
||||
result.poiCategories = new Set();
|
||||
} else {
|
||||
result.poiCategories = new Set(
|
||||
poiParams.filter((value) => value && value !== POI_NONE_PARAM)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Tab: full name
|
||||
|
|
@ -209,6 +301,27 @@ export function stateToParams(
|
|||
continue;
|
||||
}
|
||||
|
||||
const ethnicityFeatureName = getEthnicityFeatureName(name);
|
||||
if (ethnicityFeatureName && isEthnicityFilterName(name)) {
|
||||
const [min, max] = value as [number, number];
|
||||
params.append('ethnicity', `${ethnicityFeatureName}:${min}:${max}`);
|
||||
continue;
|
||||
}
|
||||
|
||||
const poiDistanceFeatureName = getPoiDistanceFeatureName(name);
|
||||
if (poiDistanceFeatureName && isPoiDistanceFilterName(name)) {
|
||||
const [min, max] = value as [number, number];
|
||||
const filterName = getPoiFilterName(name);
|
||||
const paramName =
|
||||
filterName === POI_COUNT_2KM_FILTER_NAME
|
||||
? 'poiCount2km'
|
||||
: filterName === POI_COUNT_5KM_FILTER_NAME
|
||||
? 'poiCount5km'
|
||||
: 'poiDistance';
|
||||
params.append(paramName, `${encodeURIComponent(poiDistanceFeatureName)}:${min}:${max}`);
|
||||
continue;
|
||||
}
|
||||
|
||||
const meta = features.find((f) => f.name === name);
|
||||
if (meta?.type === 'enum') {
|
||||
params.append('filter', `${name}:${(value as string[]).join('|')}`);
|
||||
|
|
@ -218,8 +331,12 @@ export function stateToParams(
|
|||
}
|
||||
}
|
||||
|
||||
for (const category of selectedPOICategories) {
|
||||
params.append('poi', category);
|
||||
if (selectedPOICategories.size === 0) {
|
||||
params.append('poi', POI_NONE_PARAM);
|
||||
} else {
|
||||
for (const category of selectedPOICategories) {
|
||||
params.append('poi', category);
|
||||
}
|
||||
}
|
||||
|
||||
if (rightPaneTab === 'properties') {
|
||||
|
|
@ -255,18 +372,45 @@ export function summarizeParams(queryString: string): string {
|
|||
const filterParams = params.getAll('filter');
|
||||
const schoolParams = params.getAll('school');
|
||||
const crimeParams = params.getAll('crime');
|
||||
if (filterParams.length > 0 || schoolParams.length > 0 || crimeParams.length > 0) {
|
||||
const ethnicityParams = params.getAll('ethnicity');
|
||||
const poiDistanceParams = params.getAll('poiDistance');
|
||||
const poiCount2KmParams = params.getAll('poiCount2km');
|
||||
const poiCount5KmParams = params.getAll('poiCount5km');
|
||||
if (
|
||||
filterParams.length > 0 ||
|
||||
schoolParams.length > 0 ||
|
||||
crimeParams.length > 0 ||
|
||||
ethnicityParams.length > 0 ||
|
||||
poiDistanceParams.length > 0 ||
|
||||
poiCount2KmParams.length > 0 ||
|
||||
poiCount5KmParams.length > 0
|
||||
) {
|
||||
const filterNames = filterParams
|
||||
.map((entry) => {
|
||||
const colonIdx = entry.indexOf(':');
|
||||
const name = colonIdx > 0 ? entry.substring(0, colonIdx) : entry;
|
||||
return isSpecificCrimeFeatureName(name) ? SPECIFIC_CRIMES_FILTER_NAME : name;
|
||||
if (isSpecificCrimeFeatureName(name)) return SPECIFIC_CRIMES_FILTER_NAME;
|
||||
if (isEthnicityFeatureName(name)) return ETHNICITIES_FILTER_NAME;
|
||||
if (isPoiDistanceFeatureName(name)) return POI_DISTANCE_FILTER_NAME;
|
||||
return name;
|
||||
})
|
||||
.filter((n) => n);
|
||||
for (let i = 0; i < schoolParams.length; i++) filterNames.push(SCHOOL_FILTER_NAME);
|
||||
for (let i = 0; i < crimeParams.length; i++) {
|
||||
filterNames.push(SPECIFIC_CRIMES_FILTER_NAME);
|
||||
}
|
||||
for (let i = 0; i < ethnicityParams.length; i++) {
|
||||
filterNames.push(ETHNICITIES_FILTER_NAME);
|
||||
}
|
||||
for (let i = 0; i < poiDistanceParams.length; i++) {
|
||||
filterNames.push(POI_DISTANCE_FILTER_NAME);
|
||||
}
|
||||
for (let i = 0; i < poiCount2KmParams.length; i++) {
|
||||
filterNames.push(POI_COUNT_2KM_FILTER_NAME);
|
||||
}
|
||||
for (let i = 0; i < poiCount5KmParams.length; i++) {
|
||||
filterNames.push(POI_COUNT_5KM_FILTER_NAME);
|
||||
}
|
||||
if (filterNames.length > 0) {
|
||||
parts.push(
|
||||
filterNames.length <= 2
|
||||
|
|
@ -278,7 +422,7 @@ export function summarizeParams(queryString: string): string {
|
|||
|
||||
const poiParams = params.getAll('poi');
|
||||
if (poiParams.length > 0) {
|
||||
const count = poiParams.filter(Boolean).length;
|
||||
const count = poiParams.filter((value) => value && value !== POI_NONE_PARAM).length;
|
||||
if (count > 0) {
|
||||
parts.push(
|
||||
count === 1
|
||||
|
|
|
|||
|
|
@ -46,7 +46,14 @@ module.exports = (env, argv) => {
|
|||
test: /\.css$/,
|
||||
use: [
|
||||
isProduction ? MiniCssExtractPlugin.loader : 'style-loader',
|
||||
'css-loader',
|
||||
{
|
||||
loader: 'css-loader',
|
||||
options: {
|
||||
url: {
|
||||
filter: (url) => !url.startsWith('/'),
|
||||
},
|
||||
},
|
||||
},
|
||||
'postcss-loader',
|
||||
],
|
||||
},
|
||||
|
|
|
|||
|
|
@ -1,9 +1,15 @@
|
|||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import urllib.request
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from pipeline.transform.transform_poi import NAPTAN_EMOJIS, _CATEGORIES
|
||||
|
||||
GLYPHS_BASE = "https://protomaps.github.io/basemaps-assets/fonts"
|
||||
|
|
@ -14,53 +20,80 @@ POI_ICON_BASE = "https://geolytix.github.io/MapIcons"
|
|||
# Font stacks used by @protomaps/basemaps with lang='en'
|
||||
FONT_STACKS = ["Noto Sans Regular", "Noto Sans Italic", "Noto Sans Medium"]
|
||||
|
||||
# Fallback emoji not in any category
|
||||
_FALLBACK_EMOJIS = ["📍"]
|
||||
|
||||
POI_ICON_PATHS = [
|
||||
"asda/asda_express_24px.svg",
|
||||
"asda/asda_green_basket_24px.svg",
|
||||
"asda/asda_green_trolley_24px.svg",
|
||||
"asda/asda_living_24px.svg",
|
||||
"asda/asda_pfs_24px.svg",
|
||||
"asda/asda_primary.svg",
|
||||
"asda/asda_superstore_green_trolley_24px.svg",
|
||||
"brands/aldi_24px.svg",
|
||||
"brands/amazon_fresh_alt_24px.svg",
|
||||
"brands/booths_24px.svg",
|
||||
"brands/budgens_24px.svg",
|
||||
"brands/centra_24px.svg",
|
||||
"brands/cook.svg",
|
||||
"brands/coop_24px.svg",
|
||||
"brands/costco_24px.svg",
|
||||
"brands/dunnes_stores_24px.svg",
|
||||
"brands/farmfoods_updated_24px.svg",
|
||||
"brands/heron_24px.svg",
|
||||
"brands/iceland_24px.svg",
|
||||
"brands/iceland_food_warehouse_24px.svg",
|
||||
"brands/lidl_24px.svg",
|
||||
"brands/little_waitrose_24px.svg",
|
||||
"brands/makro_24px.svg",
|
||||
"brands/mns_24px.svg",
|
||||
"brands/mns_food_24px.svg",
|
||||
"brands/mns_high_street_24px.svg",
|
||||
"brands/mns_hospital_24px.svg",
|
||||
"brands/mns_moto_24px.svg",
|
||||
"brands/mns_outlet_24px.svg",
|
||||
"brands/morrisons_24px.svg",
|
||||
"brands/morrisons_daily_24px.svg",
|
||||
"brands/sainsburys_24px.svg",
|
||||
"brands/sainsburys_local_24px.svg",
|
||||
"brands/spar_24px.svg",
|
||||
"brands/tesco_24px.svg",
|
||||
"brands/tesco_express_24px.svg",
|
||||
"brands/tesco_extra_24px.svg",
|
||||
"brands/waitrose_24px.svg",
|
||||
"brands/wholefoods_24px.svg",
|
||||
"logos/planet_organic_24px.svg",
|
||||
"brands_2023/supermarkets/farmfoods.svg",
|
||||
"brands_2023/supermarkets/heron_foods.svg",
|
||||
"brands_2023/supermarkets/little_waitrose.svg",
|
||||
"brands_2024/amazon_fresh.svg",
|
||||
"brands_2024/booths.svg",
|
||||
"brands_2024/budgens.svg",
|
||||
"brands_2024/cook.svg",
|
||||
"brands_2024/dunnes_stores.svg",
|
||||
"brands_2024/iceland.svg",
|
||||
"brands_2024/makro.svg",
|
||||
"brands_2024/mns.svg",
|
||||
"brands_2024/morrisons_daily.svg",
|
||||
"brands_2024/sainsburys_local.svg",
|
||||
"brands_2024/wholefoods.svg",
|
||||
"logos/aldi.svg",
|
||||
"logos/asda.svg",
|
||||
"logos/centra.svg",
|
||||
"logos/coop.svg",
|
||||
"logos/lidl.svg",
|
||||
"logos/morrisons.svg",
|
||||
"logos/planet_organic.svg",
|
||||
"logos/sainsburys.svg",
|
||||
"logos/spar.svg",
|
||||
"logos/tesco.svg",
|
||||
"logos/tesco_express.svg",
|
||||
"logos/tesco_extra.svg",
|
||||
"logos/waitrose.svg",
|
||||
"public_transport/london_tube.svg",
|
||||
"visuals/mns.svg",
|
||||
]
|
||||
|
||||
DERIVED_POI_ICON_PATHS = [
|
||||
("costco_logo", "brands/costco.svg", "logos/costco.svg"),
|
||||
(
|
||||
"embedded_png",
|
||||
"brands/iceland_food_warehouse_24px.svg",
|
||||
"logos/the_food_warehouse.png",
|
||||
),
|
||||
]
|
||||
|
||||
POI_ICON_SVG_CROPS = {
|
||||
"brands_2023/supermarkets/farmfoods.svg": (1.293, 7.314, 15.48, 3.293),
|
||||
"brands_2023/supermarkets/heron_foods.svg": (0.062, 6.68, 17.995, 5.325),
|
||||
"brands_2023/supermarkets/little_waitrose.svg": (0.916, 5.645, 16.365, 6.719),
|
||||
"brands_2024/amazon_fresh.svg": (3.817, 1.646, 16.367, 16.358),
|
||||
"brands_2024/booths.svg": (1.456, 7.143, 15.313, 3.512),
|
||||
"brands_2024/budgens.svg": (2.251, 2.278, 13.6, 13.612),
|
||||
"brands_2024/cook.svg": (5.028, 5.493, 13.945, 9.648),
|
||||
"brands_2024/dunnes_stores.svg": (4.375, 7.732, 15.249, 5.055),
|
||||
"brands_2024/iceland.svg": (1.136, 6.823, 16.067, 4.302),
|
||||
"brands_2024/makro.svg": (4.411, 6.098, 16.397, 5.428),
|
||||
"brands_2024/mns.svg": (4.042, 6.986, 16.171, 6.724),
|
||||
"brands_2024/morrisons_daily.svg": (3.341, 4.414, 17.317, 8.248),
|
||||
"brands_2024/sainsburys_local.svg": (4.58, 1.61, 14.84, 14.849),
|
||||
"brands_2024/wholefoods.svg": (4.17, 2.193, 15.659, 15.668),
|
||||
"logos/aldi.svg": (4.813, 2.563, 14.374, 14.383),
|
||||
"logos/asda.svg": (3.91, 7.135, 16.181, 5.442),
|
||||
"logos/centra.svg": (3.36, 7.35, 17.28, 4.651),
|
||||
"logos/coop.svg": (6.407, 4.658, 11.187, 11.793),
|
||||
"logos/costco.svg": (70.61, 144.908, 256.67, 85.825),
|
||||
"logos/lidl.svg": (4.938, 2.973, 13.985, 13.985),
|
||||
"logos/morrisons.svg": (5.231, 2.985, 13.538, 13.398),
|
||||
"logos/planet_organic.svg": (5.528, 3.564, 12.943, 12.943),
|
||||
"logos/sainsburys.svg": (7.502, 3.572, 8.996, 12.646),
|
||||
"logos/spar.svg": (4.933, 2.968, 14.133, 13.853),
|
||||
"logos/tesco.svg": (4.338, 6.865, 15.324, 5.359),
|
||||
"logos/tesco_express.svg": (5.231, 5.933, 13.538, 8.345),
|
||||
"logos/tesco_extra.svg": (4.933, 5.775, 14.133, 8.519),
|
||||
"logos/waitrose.svg": (5.528, 6.09, 12.943, 9.855),
|
||||
}
|
||||
|
||||
POI_ICON_SVG_INTRINSIC_MAX = 512
|
||||
|
||||
|
||||
def collect_twemoji_codes() -> list[str]:
|
||||
"""Derive twemoji hex codes from transform_poi categories.
|
||||
|
|
@ -76,9 +109,6 @@ def collect_twemoji_codes() -> list[str]:
|
|||
for emoji in NAPTAN_EMOJIS.values():
|
||||
emojis.add(emoji)
|
||||
|
||||
for emoji in _FALLBACK_EMOJIS:
|
||||
emojis.add(emoji)
|
||||
|
||||
# First codepoint hex, matching frontend logic
|
||||
return sorted({f"{ord(e[0]):x}" for e in emojis})
|
||||
|
||||
|
|
@ -97,6 +127,214 @@ def download_file(url: str, dest: Path) -> tuple[bool, str]:
|
|||
return False, url
|
||||
|
||||
|
||||
def download_text(url: str) -> str:
|
||||
with urllib.request.urlopen(url) as response:
|
||||
return response.read().decode("utf-8")
|
||||
|
||||
|
||||
def build_costco_logo(marker_svg: str) -> str:
|
||||
start = marker_svg.find('<g><path d=" M 316.312')
|
||||
end = marker_svg.rfind("</g></g></svg>")
|
||||
if start < 0 or end < 0:
|
||||
raise ValueError("Costco marker SVG layout changed")
|
||||
|
||||
logo_group = marker_svg[start : end + 4]
|
||||
return (
|
||||
'<?xml version="1.0" encoding="UTF-8"?>\n'
|
||||
'<svg xmlns="http://www.w3.org/2000/svg" viewBox="70 145 260 90" '
|
||||
'width="260pt" height="90pt" preserveAspectRatio="xMidYMid meet">\n'
|
||||
f"{logo_group}\n"
|
||||
"</svg>\n"
|
||||
)
|
||||
|
||||
|
||||
def trim_white_png(png_bytes: bytes) -> bytes:
|
||||
image = Image.open(BytesIO(png_bytes)).convert("RGBA")
|
||||
pixels = image.load()
|
||||
|
||||
for y in range(image.height):
|
||||
for x in range(image.width):
|
||||
red, green, blue, alpha = pixels[x, y]
|
||||
if red > 245 and green > 245 and blue > 245:
|
||||
pixels[x, y] = (red, green, blue, 0)
|
||||
|
||||
alpha_box = image.getchannel("A").getbbox()
|
||||
if alpha_box:
|
||||
image = image.crop(alpha_box)
|
||||
|
||||
out = BytesIO()
|
||||
image.save(out, format="PNG")
|
||||
return out.getvalue()
|
||||
|
||||
|
||||
def extract_embedded_png(marker_svg: str) -> bytes:
|
||||
match = re.search(r"base64,([^\"']+)", marker_svg)
|
||||
if not match:
|
||||
raise ValueError("POI marker SVG did not contain an embedded PNG")
|
||||
return trim_white_png(base64.b64decode(match.group(1)))
|
||||
|
||||
|
||||
def svg_intrinsic_size(width: float, height: float) -> tuple[int, int]:
|
||||
if width <= 0 or height <= 0:
|
||||
return (POI_ICON_SVG_INTRINSIC_MAX, POI_ICON_SVG_INTRINSIC_MAX)
|
||||
if width >= height:
|
||||
return (
|
||||
POI_ICON_SVG_INTRINSIC_MAX,
|
||||
max(1, round(POI_ICON_SVG_INTRINSIC_MAX * height / width)),
|
||||
)
|
||||
return (
|
||||
max(1, round(POI_ICON_SVG_INTRINSIC_MAX * width / height)),
|
||||
POI_ICON_SVG_INTRINSIC_MAX,
|
||||
)
|
||||
|
||||
|
||||
def set_svg_geometry(svg_text: str, crop: tuple[float, float, float, float]) -> str:
|
||||
x, y, width, height = crop
|
||||
view_box = f"{x:g} {y:g} {width:g} {height:g}"
|
||||
intrinsic_width, intrinsic_height = svg_intrinsic_size(width, height)
|
||||
|
||||
svg_text = re.sub(r'viewBox="[^"]+"', f'viewBox="{view_box}"', svg_text, count=1)
|
||||
if 'viewBox="' not in svg_text:
|
||||
svg_text = re.sub(r"<svg\b", f'<svg viewBox="{view_box}"', svg_text, count=1)
|
||||
|
||||
svg_text = re.sub(r'width="[^"]+"', f'width="{intrinsic_width}"', svg_text, count=1)
|
||||
if 'width="' not in svg_text:
|
||||
svg_text = re.sub(
|
||||
r"<svg\b", f'<svg width="{intrinsic_width}"', svg_text, count=1
|
||||
)
|
||||
|
||||
svg_text = re.sub(
|
||||
r'height="[^"]+"', f'height="{intrinsic_height}"', svg_text, count=1
|
||||
)
|
||||
if 'height="' not in svg_text:
|
||||
svg_text = re.sub(
|
||||
r"<svg\b", f'<svg height="{intrinsic_height}"', svg_text, count=1
|
||||
)
|
||||
|
||||
return svg_text
|
||||
|
||||
|
||||
def get_svg_view_box(svg_text: str) -> tuple[float, float, float, float] | None:
|
||||
match = re.search(r'viewBox="([^"]+)"', svg_text)
|
||||
if not match:
|
||||
return None
|
||||
parts = [
|
||||
float(part) for part in re.split(r"[\s,]+", match.group(1).strip()) if part
|
||||
]
|
||||
if len(parts) != 4:
|
||||
return None
|
||||
return (parts[0], parts[1], parts[2], parts[3])
|
||||
|
||||
|
||||
def crop_poi_svg_icons(poi_icons_dir: Path) -> None:
|
||||
for icon_path, crop in POI_ICON_SVG_CROPS.items():
|
||||
dest = poi_icons_dir / icon_path
|
||||
if not dest.exists():
|
||||
continue
|
||||
svg_text = dest.read_text(encoding="utf-8")
|
||||
if icon_path == "brands_2024/dunnes_stores.svg":
|
||||
svg_text = svg_text.replace('fill="#fffcfc"', 'fill="#111111"')
|
||||
svg_text = svg_text.replace('fill="#fcfcfc"', 'fill="#111111"')
|
||||
dest.write_text(set_svg_geometry(svg_text, crop), encoding="utf-8")
|
||||
|
||||
for dest in poi_icons_dir.rglob("*.svg"):
|
||||
svg_text = dest.read_text(encoding="utf-8")
|
||||
view_box = get_svg_view_box(svg_text)
|
||||
if view_box:
|
||||
dest.write_text(set_svg_geometry(svg_text, view_box), encoding="utf-8")
|
||||
|
||||
|
||||
def download_derived_poi_icon(
|
||||
kind: str, source_path: str, dest: Path
|
||||
) -> tuple[bool, str]:
|
||||
url = f"{POI_ICON_BASE}/{source_path}"
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
source = download_text(url)
|
||||
if kind == "costco_logo":
|
||||
dest.write_text(build_costco_logo(source), encoding="utf-8")
|
||||
elif kind == "embedded_png":
|
||||
dest.write_bytes(extract_embedded_png(source))
|
||||
else:
|
||||
raise ValueError(f"Unknown derived POI icon kind: {kind}")
|
||||
return True, url
|
||||
except urllib.error.HTTPError as e:
|
||||
print(f" {e.code} {url}", file=sys.stderr)
|
||||
return False, url
|
||||
except Exception as e:
|
||||
print(f" ERROR {url}: {e}", file=sys.stderr)
|
||||
return False, url
|
||||
|
||||
|
||||
# Slategray accent used by civic POI icons (school, library, building, …) in
|
||||
# protomaps' v4 sprite. We match it so the townhall blends in with its peers.
|
||||
_TOWNHALL_COLOR = {
|
||||
"light": (135, 128, 171),
|
||||
"dark": (118, 118, 127),
|
||||
}
|
||||
_TOWNHALL_LOGICAL_SIZE = 17
|
||||
|
||||
|
||||
def _render_townhall_glyph(size_px: int, color: tuple[int, int, int]) -> Image.Image:
|
||||
# Draw at 8× resolution and downsample with Lanczos so the pediment's
|
||||
# diagonals come out anti-aliased; PIL's polygon fill is otherwise aliased.
|
||||
super_factor = 8
|
||||
canvas = size_px * super_factor
|
||||
img = Image.new("RGBA", (canvas, canvas), (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(img)
|
||||
fill = (*color, 255)
|
||||
|
||||
def s(v: float) -> float:
|
||||
return v * canvas / _TOWNHALL_LOGICAL_SIZE
|
||||
|
||||
draw.polygon([(s(8.5), s(1)), (s(15), s(6.5)), (s(2), s(6.5))], fill=fill)
|
||||
draw.rectangle([(s(1), s(6.5)), (s(16), s(8.5))], fill=fill)
|
||||
for column_x in (3, 8, 13):
|
||||
draw.rectangle([(s(column_x), s(8.5)), (s(column_x + 1.5), s(14))], fill=fill)
|
||||
draw.rectangle([(s(0), s(14)), (s(17), s(15.5))], fill=fill)
|
||||
|
||||
return img.resize((size_px, size_px), Image.LANCZOS)
|
||||
|
||||
|
||||
def inject_townhall_sprite(sprites_dir: Path) -> None:
|
||||
"""Append a townhall glyph to each downloaded sprite sheet.
|
||||
|
||||
Protomaps' v4 sprite omits `townhall` even though the basemap style
|
||||
references it; we add the icon here so MapLibre can resolve the name
|
||||
natively at runtime.
|
||||
"""
|
||||
for theme in ("light", "dark"):
|
||||
color = _TOWNHALL_COLOR[theme]
|
||||
for suffix, scale in (("", 1), ("@2x", 2)):
|
||||
json_path = sprites_dir / f"{theme}{suffix}.json"
|
||||
png_path = sprites_dir / f"{theme}{suffix}.png"
|
||||
if not json_path.exists() or not png_path.exists():
|
||||
continue
|
||||
|
||||
manifest = json.loads(json_path.read_text())
|
||||
sheet = Image.open(png_path).convert("RGBA")
|
||||
|
||||
glyph_size = _TOWNHALL_LOGICAL_SIZE * scale
|
||||
glyph = _render_townhall_glyph(glyph_size, color)
|
||||
|
||||
new_width = max(sheet.width, glyph_size)
|
||||
new_height = sheet.height + glyph_size
|
||||
extended = Image.new("RGBA", (new_width, new_height), (0, 0, 0, 0))
|
||||
extended.paste(sheet, (0, 0))
|
||||
extended.paste(glyph, (0, sheet.height))
|
||||
extended.save(png_path, optimize=True)
|
||||
|
||||
manifest["townhall"] = {
|
||||
"x": 0,
|
||||
"y": sheet.height,
|
||||
"width": glyph_size,
|
||||
"height": glyph_size,
|
||||
"pixelRatio": scale,
|
||||
}
|
||||
json_path.write_text(json.dumps(manifest))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
|
|
@ -147,7 +385,7 @@ def main():
|
|||
# Skip already-downloaded files
|
||||
remaining = [(url, dest) for url, dest in tasks]
|
||||
|
||||
print(f"Downloading {len(remaining)} assets")
|
||||
print(f"Downloading {len(remaining) + len(DERIVED_POI_ICON_PATHS)} assets")
|
||||
|
||||
ok = 0
|
||||
fail = 0
|
||||
|
|
@ -162,6 +400,18 @@ def main():
|
|||
else:
|
||||
fail += 1
|
||||
|
||||
for kind, source_path, dest_path in DERIVED_POI_ICON_PATHS:
|
||||
success, _url = download_derived_poi_icon(
|
||||
kind, source_path, poi_icons_dir / dest_path
|
||||
)
|
||||
if success:
|
||||
ok += 1
|
||||
else:
|
||||
fail += 1
|
||||
|
||||
crop_poi_svg_icons(poi_icons_dir)
|
||||
inject_townhall_sprite(sprites_dir)
|
||||
|
||||
print(f"Done: {ok} downloaded, {fail} failed")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ Reuses the same england-latest.osm.pbf as pois.py.
|
|||
"""
|
||||
|
||||
import argparse
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import osmium
|
||||
|
|
@ -44,11 +45,37 @@ _STATION_STRIP = (
|
|||
" underground station",
|
||||
" railway station",
|
||||
" dlr station",
|
||||
" station dlr",
|
||||
" dlr",
|
||||
" overground station",
|
||||
" tram stop",
|
||||
" station",
|
||||
)
|
||||
|
||||
_DLR_CODE_RE = re.compile(r"ZZDL([A-Z0-9]{3})")
|
||||
|
||||
|
||||
def _is_dlr_station(tags: dict[str, str]) -> bool:
|
||||
name = tags.get("name", "").lower()
|
||||
network = tags.get("network", "").lower()
|
||||
operator = tags.get("operator", "").lower()
|
||||
return (
|
||||
"docklands" in network
|
||||
or "dlr" in network
|
||||
or "docklands" in operator
|
||||
or "dlr" in operator
|
||||
or name.endswith(" dlr")
|
||||
or " dlr " in name
|
||||
)
|
||||
|
||||
|
||||
def _is_tram_station(tags: dict[str, str]) -> bool:
|
||||
if _is_dlr_station(tags):
|
||||
return False
|
||||
station_tag = tags.get("station", "")
|
||||
network = tags.get("network", "").lower()
|
||||
return station_tag == "light_rail" or "tramlink" in network or "tram" in network
|
||||
|
||||
|
||||
def _station_display_name(name: str, tags: dict[str, str]) -> str:
|
||||
"""Build a descriptive station name like 'Bank tube station'."""
|
||||
|
|
@ -78,6 +105,96 @@ def _station_display_name(name: str, tags: dict[str, str]) -> str:
|
|||
return f"{name} {suffix}"
|
||||
|
||||
|
||||
def _station_name_score(name: str) -> tuple[int, int]:
|
||||
lower = name.lower()
|
||||
suffix_penalty = int(
|
||||
lower.endswith(
|
||||
(
|
||||
" underground station",
|
||||
" tube station",
|
||||
" dlr station",
|
||||
" railway station",
|
||||
" rail station",
|
||||
" station dlr",
|
||||
" station",
|
||||
)
|
||||
)
|
||||
or lower.endswith(" dlr")
|
||||
)
|
||||
return (suffix_penalty, len(name))
|
||||
|
||||
|
||||
def _naptan_dlr_stations(naptan_path: Path) -> list[dict]:
|
||||
"""Extract station-level DLR destinations from NaPTAN access nodes."""
|
||||
df = pl.read_parquet(naptan_path)
|
||||
required = {"id", "name", "category", "lat", "lng"}
|
||||
missing = required - set(df.columns)
|
||||
if missing:
|
||||
raise ValueError(f"NaPTAN file is missing columns: {sorted(missing)}")
|
||||
|
||||
rows: dict[str, dict] = {}
|
||||
for row in df.iter_rows(named=True):
|
||||
atco_id = str(row["id"] or "")
|
||||
match = _DLR_CODE_RE.search(atco_id)
|
||||
if not match:
|
||||
continue
|
||||
if row["category"] not in {"Tube station", "Rail station"}:
|
||||
continue
|
||||
|
||||
code = match.group(1)
|
||||
raw_name = str(row["name"] or "")
|
||||
if not raw_name:
|
||||
continue
|
||||
|
||||
lat = float(row["lat"])
|
||||
lon = float(row["lng"])
|
||||
current = rows.get(code)
|
||||
if current is None:
|
||||
rows[code] = {
|
||||
"raw_name": raw_name,
|
||||
"lat_sum": lat,
|
||||
"lon_sum": lon,
|
||||
"count": 1,
|
||||
}
|
||||
continue
|
||||
|
||||
current["lat_sum"] += lat
|
||||
current["lon_sum"] += lon
|
||||
current["count"] += 1
|
||||
if _station_name_score(raw_name) < _station_name_score(current["raw_name"]):
|
||||
current["raw_name"] = raw_name
|
||||
|
||||
stations = []
|
||||
for station in rows.values():
|
||||
count = station["count"]
|
||||
display_name = _station_display_name(station["raw_name"], {"network": "DLR"})
|
||||
stations.append(
|
||||
{
|
||||
"name": display_name,
|
||||
"place_type": "station",
|
||||
"lat": station["lat_sum"] / count,
|
||||
"lon": station["lon_sum"] / count,
|
||||
"population": 0,
|
||||
"travel_destination": True,
|
||||
}
|
||||
)
|
||||
|
||||
return sorted(stations, key=lambda station: station["name"])
|
||||
|
||||
|
||||
def _append_naptan_dlr_stations(places: list[dict], naptan_path: Path) -> int:
|
||||
existing_names = {str(place["name"]).casefold() for place in places}
|
||||
added = 0
|
||||
for station in _naptan_dlr_stations(naptan_path):
|
||||
key = station["name"].casefold()
|
||||
if key in existing_names:
|
||||
continue
|
||||
places.append(station)
|
||||
existing_names.add(key)
|
||||
added += 1
|
||||
return added
|
||||
|
||||
|
||||
class PlaceHandler(osmium.SimpleHandler):
|
||||
def __init__(self, progress: tqdm, england_polygon) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -145,14 +262,7 @@ class PlaceHandler(osmium.SimpleHandler):
|
|||
# Railway stations (tube, national rail, DLR, overground, Elizabeth line)
|
||||
if n.tags.get("railway") == "station":
|
||||
tags = dict(n.tags)
|
||||
station_tag = tags.get("station", "")
|
||||
network = tags.get("network", "").lower()
|
||||
# Skip tram stops
|
||||
if (
|
||||
station_tag == "light_rail"
|
||||
or "tramlink" in network
|
||||
or "tram" in network
|
||||
):
|
||||
if _is_tram_station(tags):
|
||||
return
|
||||
display_name = _station_display_name(name, tags)
|
||||
self._add(
|
||||
|
|
@ -178,6 +288,11 @@ def main() -> None:
|
|||
required=True,
|
||||
help="England boundary GeoJSON file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--naptan",
|
||||
type=Path,
|
||||
help="Optional NaPTAN parquet file used to add DLR station destinations",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
pbf_file = args.pbf
|
||||
|
|
@ -195,6 +310,9 @@ def main() -> None:
|
|||
handler.apply_file(str(pbf_file), locations=True)
|
||||
|
||||
print(f"Extracted {len(handler.places):,} place nodes")
|
||||
if args.naptan:
|
||||
added = _append_naptan_dlr_stations(handler.places, args.naptan)
|
||||
print(f"Added {added:,} DLR station destinations from NaPTAN")
|
||||
|
||||
if handler.places:
|
||||
df = pl.DataFrame(handler.places)
|
||||
|
|
|
|||
81
pipeline/download/test_places.py
Normal file
81
pipeline/download/test_places.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
import polars as pl
|
||||
|
||||
from pipeline.download.places import (
|
||||
_is_dlr_station,
|
||||
_is_tram_station,
|
||||
_naptan_dlr_stations,
|
||||
_station_display_name,
|
||||
)
|
||||
|
||||
|
||||
def test_dlr_light_rail_is_not_treated_as_tram():
|
||||
dlr_tags = {
|
||||
"name": "Lewisham DLR",
|
||||
"railway": "station",
|
||||
"station": "light_rail",
|
||||
"network": "Docklands Light Railway",
|
||||
}
|
||||
|
||||
assert _is_dlr_station(dlr_tags)
|
||||
assert not _is_tram_station(dlr_tags)
|
||||
assert _station_display_name("Lewisham DLR", dlr_tags) == "Lewisham DLR station"
|
||||
assert (
|
||||
_station_display_name("Tower Gateway Station DLR", dlr_tags)
|
||||
== "Tower Gateway DLR station"
|
||||
)
|
||||
|
||||
|
||||
def test_tram_light_rail_is_still_excluded():
|
||||
tram_tags = {
|
||||
"name": "East Croydon",
|
||||
"railway": "station",
|
||||
"station": "light_rail",
|
||||
"network": "London Trams",
|
||||
}
|
||||
|
||||
assert not _is_dlr_station(tram_tags)
|
||||
assert _is_tram_station(tram_tags)
|
||||
|
||||
|
||||
def test_naptan_dlr_stations_are_deduplicated_by_atco_code(tmp_path):
|
||||
naptan = tmp_path / "naptan.parquet"
|
||||
pl.DataFrame(
|
||||
{
|
||||
"id": [
|
||||
"4900ZZDLSHA3",
|
||||
"9400ZZDLSHA",
|
||||
"4900ZZDLGRE1",
|
||||
"490002076RV",
|
||||
"4900ZZLUBNK",
|
||||
],
|
||||
"name": [
|
||||
"Shadwell DLR",
|
||||
"Shadwell DLR Station",
|
||||
"Greenwich Station",
|
||||
"Tower Gateway Station DLR",
|
||||
"Bank",
|
||||
],
|
||||
"category": [
|
||||
"Tube station",
|
||||
"Tube station",
|
||||
"Rail station",
|
||||
"Bus stop",
|
||||
"Tube station",
|
||||
],
|
||||
"lat": [51.51156, 51.511693, 51.47794, 51.510575, 51.5131],
|
||||
"lng": [-0.055595, -0.056643, -0.01442, -0.07514, -0.0894],
|
||||
}
|
||||
).write_parquet(naptan)
|
||||
|
||||
stations = _naptan_dlr_stations(naptan)
|
||||
|
||||
assert [station["name"] for station in stations] == [
|
||||
"Greenwich DLR station",
|
||||
"Shadwell DLR station",
|
||||
]
|
||||
shadwell = next(
|
||||
station for station in stations if station["name"].startswith("Shadwell")
|
||||
)
|
||||
assert shadwell["lat"] == (51.51156 + 51.511693) / 2
|
||||
assert shadwell["place_type"] == "station"
|
||||
assert shadwell["travel_destination"] is True
|
||||
|
|
@ -56,6 +56,7 @@ NR_AUTH_URL = "https://opendata.nationalrail.co.uk/authenticate"
|
|||
NR_TIMETABLE_URL = "https://opendata.nationalrail.co.uk/api/staticfeeds/3.0/timetable"
|
||||
|
||||
USER_AGENT = "property-map-pipeline/1.0 (https://github.com)"
|
||||
TRANSXCHANGE2GTFS_PACKAGE = "transxchange2gtfs@1.12.0"
|
||||
|
||||
|
||||
def _download_http(
|
||||
|
|
@ -473,10 +474,50 @@ def convert_tfl_to_gtfs(raw_dir: Path, output_dir: Path) -> Path:
|
|||
download_naptan()
|
||||
|
||||
print("Converting TfL TransXChange → GTFS...")
|
||||
# The shim patches known packaging/runtime issues in the pinned npm package
|
||||
# before loading its CLI from npx's temporary install.
|
||||
shim_path = Path(__file__).with_name("transxchange2gtfs_shim.js")
|
||||
subprocess.run(
|
||||
["npx", "--yes", "transxchange2gtfs", str(txc_path), str(dest)],
|
||||
[
|
||||
"npx",
|
||||
"--yes",
|
||||
"--package",
|
||||
TRANSXCHANGE2GTFS_PACKAGE,
|
||||
"sh",
|
||||
"-c",
|
||||
"\n".join(
|
||||
[
|
||||
'bin="$(command -v transxchange2gtfs)"',
|
||||
'script="$(readlink -f "$bin")"',
|
||||
'pkg_dir="$(dirname "$(dirname "$script")")"',
|
||||
'shim="$1"',
|
||||
"shift",
|
||||
'exec node "$shim" "$pkg_dir" "$@"',
|
||||
]
|
||||
),
|
||||
"transxchange2gtfs",
|
||||
str(shim_path.resolve()),
|
||||
str(txc_path.resolve()),
|
||||
str(dest.resolve()),
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
required_files = {
|
||||
"agency.txt",
|
||||
"calendar.txt",
|
||||
"calendar_dates.txt",
|
||||
"routes.txt",
|
||||
"stop_times.txt",
|
||||
"stops.txt",
|
||||
"trips.txt",
|
||||
}
|
||||
if not dest.exists() or not zipfile.is_zipfile(dest):
|
||||
raise RuntimeError(f"transxchange2gtfs did not create a valid GTFS zip: {dest}")
|
||||
with zipfile.ZipFile(dest) as z:
|
||||
missing = required_files - set(z.namelist())
|
||||
if missing:
|
||||
missing_str = ", ".join(sorted(missing))
|
||||
raise RuntimeError(f"TfL GTFS zip is missing required files: {missing_str}")
|
||||
size_mb = dest.stat().st_size / (1024 * 1024)
|
||||
print(f" Saved to {dest} ({size_mb:.1f} MB)")
|
||||
return dest
|
||||
|
|
|
|||
76
pipeline/download/transxchange2gtfs_shim.js
Normal file
76
pipeline/download/transxchange2gtfs_shim.js
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
#!/usr/bin/env node
|
||||
"use strict";
|
||||
|
||||
const fs = require("fs");
|
||||
const path = require("path");
|
||||
const { createRequire } = require("module");
|
||||
|
||||
const [pkgDirArg, ...converterArgs] = process.argv.slice(2);
|
||||
|
||||
if (!pkgDirArg || converterArgs.length < 2) {
|
||||
console.error(
|
||||
"Usage: transxchange2gtfs_shim.js <package-dir> <input...> <output>",
|
||||
);
|
||||
process.exit(2);
|
||||
}
|
||||
|
||||
const pkgDir = path.resolve(pkgDirArg);
|
||||
|
||||
function replaceOnce(relativePath, before, after) {
|
||||
const file = path.join(pkgDir, relativePath);
|
||||
const original = fs.readFileSync(file, "utf8");
|
||||
if (original.includes(before)) {
|
||||
fs.writeFileSync(file, original.replace(before, after));
|
||||
} else if (original.includes(after)) {
|
||||
return;
|
||||
} else {
|
||||
throw new Error(`Could not patch ${relativePath}: expected text not found`);
|
||||
}
|
||||
}
|
||||
|
||||
// The published 1.12.0 package has a few compatibility issues with current
|
||||
// TfL TransXChange exports:
|
||||
// - the bin script points at dist/src/cli.js, but the package ships dist/cli.js
|
||||
// - the compiled date-holidays import expects a synthetic default export
|
||||
// - some TfL journeys reference timing links without matching route-link geometry
|
||||
//
|
||||
// GTFS shapes are optional for R5 routing. Clear shape references and omit
|
||||
// shapes.txt so missing route geometry does not drop otherwise usable trips.
|
||||
function patchPackage() {
|
||||
replaceOnce(
|
||||
"dist/transxchange/TransXChangeJourneyStream.js",
|
||||
"distanceSoFarM += routeLink.Distance;",
|
||||
"distanceSoFarM += routeLink ? routeLink.Distance : 0;",
|
||||
);
|
||||
replaceOnce(
|
||||
"dist/gtfs/TripsStream.js",
|
||||
"(0, crypto_1.createHash)('md5').update(JSON.stringify({ routeId: journey.route, routeLinkSeq: journey.routeLinkIds })).digest(\"hex\"));",
|
||||
"\"\");",
|
||||
);
|
||||
replaceOnce(
|
||||
"dist/gtfs/StopTimesStream.js",
|
||||
"stop.shapeDistTraveled, stop.exactTime ? \"1\" : \"0\");",
|
||||
"\"\", stop.exactTime ? \"1\" : \"0\");",
|
||||
);
|
||||
replaceOnce(
|
||||
"dist/Container.js",
|
||||
"\"stops.txt\": transxchange.pipe(new StopsStream_1.StopsStream(naptanIndex)),\n \"shapes.txt\": journeyStream.pipe(new ShapesStream_1.ShapesStream())",
|
||||
"\"stops.txt\": transxchange.pipe(new StopsStream_1.StopsStream(naptanIndex))",
|
||||
);
|
||||
replaceOnce(
|
||||
"dist/Container.js",
|
||||
"\"routes.txt\": transxchange.pipe(new RoutesStream_1.RoutesStream()),\n \"transfers.txt\": transxchange.pipe(new TransfersStream_1.TransfersStream(naptanIndex, locationIndex)),\n \"stops.txt\": transxchange.pipe(new StopsStream_1.StopsStream(naptanIndex))",
|
||||
"\"routes.txt\": transxchange.pipe(new RoutesStream_1.RoutesStream()),\n \"stops.txt\": transxchange.pipe(new StopsStream_1.StopsStream(naptanIndex))",
|
||||
);
|
||||
}
|
||||
|
||||
patchPackage();
|
||||
|
||||
const pkgRequire = createRequire(path.join(pkgDir, "package.json"));
|
||||
const Holidays = pkgRequire("date-holidays");
|
||||
if (!Holidays.default) {
|
||||
Holidays.default = Holidays;
|
||||
}
|
||||
|
||||
process.argv = [process.argv[0], "transxchange2gtfs", ...converterArgs];
|
||||
require(path.join(pkgDir, "dist", "cli.js"));
|
||||
|
|
@ -7,6 +7,15 @@ from pipeline.utils.postcode_mapping import build_postcode_mapping
|
|||
|
||||
MIN_FLOOR_AREA_M2 = 10
|
||||
|
||||
_IOD_PERCENTILE_COLUMNS = [
|
||||
"Education, Skills and Training Score",
|
||||
"Income Score (rate)",
|
||||
"Employment Score (rate)",
|
||||
"Health Deprivation and Disability Score",
|
||||
"Indoors Sub-domain Score",
|
||||
"Outdoors Sub-domain Score",
|
||||
]
|
||||
|
||||
|
||||
_AREA_COLUMNS = [
|
||||
"Postcode",
|
||||
|
|
@ -51,6 +60,14 @@ _AREA_COLUMNS = [
|
|||
"Number of parks within 1km",
|
||||
"Distance to nearest train or tube station (km)",
|
||||
"Distance to nearest park (km)",
|
||||
"Distance to nearest grocery store (km)",
|
||||
"Distance to nearest tube station (km)",
|
||||
"Distance to nearest rail station (km)",
|
||||
"Distance to nearest Waitrose (km)",
|
||||
"Distance to nearest Tesco (km)",
|
||||
"Distance to nearest cafe (km)",
|
||||
"Distance to nearest pub (km)",
|
||||
"Distance to nearest restaurant (km)",
|
||||
# Environment
|
||||
"Noise (dB)",
|
||||
"Max available download speed (Mbps)",
|
||||
|
|
@ -76,6 +93,34 @@ _AREA_COLUMNS = [
|
|||
]
|
||||
|
||||
|
||||
def _is_dynamic_poi_metric_column(column: str) -> bool:
|
||||
return (
|
||||
column.startswith("Distance to nearest ")
|
||||
and column.endswith(" POI (km)")
|
||||
) or (
|
||||
column.startswith("Number of ")
|
||||
and (column.endswith(" POIs within 2km") or column.endswith(" POIs within 5km"))
|
||||
)
|
||||
|
||||
|
||||
def _less_deprived_percentile_expr(column: str) -> pl.Expr:
|
||||
"""Convert an IoD deprivation score to a 0-100 less-deprived percentile."""
|
||||
non_null_count = pl.col(column).count()
|
||||
descending_rank = pl.col(column).rank("average", descending=True)
|
||||
return (
|
||||
pl.when(pl.col(column).is_null())
|
||||
.then(None)
|
||||
.when(pl.col(column) == pl.col(column).min())
|
||||
.then(100.0)
|
||||
.when(pl.col(column) == pl.col(column).max())
|
||||
.then(0.0)
|
||||
.when(non_null_count > 1)
|
||||
.then(((descending_rank - 1) / (non_null_count - 1) * 100).round(1))
|
||||
.otherwise(100.0)
|
||||
.alias(column)
|
||||
)
|
||||
|
||||
|
||||
def _build(
|
||||
epc_pp_path: Path,
|
||||
arcgis_path: Path,
|
||||
|
|
@ -134,20 +179,11 @@ def _build(
|
|||
)
|
||||
wide = wide.join(arcgis, on="postcode", how="left")
|
||||
|
||||
iod = pl.scan_parquet(iod_path)
|
||||
iod = pl.scan_parquet(iod_path).with_columns(
|
||||
*(_less_deprived_percentile_expr(c) for c in _IOD_PERCENTILE_COLUMNS)
|
||||
)
|
||||
wide = wide.join(iod, left_on="lsoa21", right_on="LSOA code (2021)", how="left")
|
||||
|
||||
# Invert deprivation scores so that higher values = less deprived (better)
|
||||
iod_score_cols = [
|
||||
"Education, Skills and Training Score",
|
||||
"Income Score (rate)",
|
||||
"Employment Score (rate)",
|
||||
"Health Deprivation and Disability Score",
|
||||
"Indoors Sub-domain Score",
|
||||
"Outdoors Sub-domain Score",
|
||||
]
|
||||
wide = wide.with_columns(*(pl.col(c).max() - pl.col(c) for c in iod_score_cols))
|
||||
|
||||
ethnicity = pl.scan_parquet(ethnicity_path)
|
||||
wide = wide.join(
|
||||
ethnicity,
|
||||
|
|
@ -351,6 +387,14 @@ def _build(
|
|||
"parks_1km": "Number of parks within 1km",
|
||||
"train_tube_nearest_km": "Distance to nearest train or tube station (km)",
|
||||
"parks_nearest_km": "Distance to nearest park (km)",
|
||||
"grocery_store_nearest_km": "Distance to nearest grocery store (km)",
|
||||
"tube_station_nearest_km": "Distance to nearest tube station (km)",
|
||||
"rail_station_nearest_km": "Distance to nearest rail station (km)",
|
||||
"waitrose_nearest_km": "Distance to nearest Waitrose (km)",
|
||||
"tesco_nearest_km": "Distance to nearest Tesco (km)",
|
||||
"cafe_nearest_km": "Distance to nearest cafe (km)",
|
||||
"pub_nearest_km": "Distance to nearest pub (km)",
|
||||
"restaurant_nearest_km": "Distance to nearest restaurant (km)",
|
||||
"latest_price": "Last known price",
|
||||
"number_habitable_rooms": "Number of bedrooms & living rooms",
|
||||
"noise_lden_db": "Noise (dB)",
|
||||
|
|
@ -381,10 +425,14 @@ def _build(
|
|||
|
||||
# Split into postcode-level and property-level dataframes
|
||||
area_cols = [c for c in _AREA_COLUMNS if c in df.columns]
|
||||
area_cols.extend(
|
||||
c for c in df.columns if _is_dynamic_poi_metric_column(c) and c not in area_cols
|
||||
)
|
||||
area_col_set = set(area_cols)
|
||||
postcode_df = df.select(area_cols).group_by("Postcode").first()
|
||||
print(f"Postcode rows: {postcode_df.height} (unique postcodes)")
|
||||
|
||||
property_cols = [c for c in df.columns if c not in _AREA_COLUMNS or c == "Postcode"]
|
||||
property_cols = [c for c in df.columns if c not in area_col_set or c == "Postcode"]
|
||||
properties_df = df.select(property_cols)
|
||||
print(f"Property rows: {properties_df.height}")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
"""Compute POI proximity counts and distances per postcode from ArcGIS + filtered POIs."""
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import unicodedata
|
||||
from pathlib import Path
|
||||
|
||||
import polars as pl
|
||||
|
|
@ -15,9 +17,25 @@ POI_GROUPS_2KM = {
|
|||
"groceries": ["Greengrocer", "Supermarket", "Convenience Store"],
|
||||
}
|
||||
|
||||
# Groups for which to compute distance to nearest POI (from filtered POIs)
|
||||
# Groups for which to compute distance to nearest POI (from filtered POIs).
|
||||
# Keep `train_tube` for the existing backend feature; the individual POI
|
||||
# distance filters below power the frontend dropdown.
|
||||
DISTANCE_GROUPS = {
|
||||
"train_tube": ["Tube station", "Rail station"],
|
||||
"grocery_store": [
|
||||
"Greengrocer",
|
||||
"Supermarket",
|
||||
"Convenience Store",
|
||||
"Waitrose",
|
||||
"Tesco",
|
||||
],
|
||||
"tube_station": ["Tube station"],
|
||||
"rail_station": ["Rail station"],
|
||||
"waitrose": ["Waitrose"],
|
||||
"tesco": ["Tesco"],
|
||||
"cafe": ["Café"],
|
||||
"pub": ["Pub"],
|
||||
"restaurant": ["Restaurant"],
|
||||
}
|
||||
|
||||
# OS Open Greenspace function types used for park counts and distance calculation.
|
||||
|
|
@ -27,6 +45,69 @@ GREENSPACE_PARK_FUNCTIONS = {
|
|||
"parks": ["Public Park Or Garden", "Playing Field", "Play Space"],
|
||||
}
|
||||
|
||||
GROCERY_DYNAMIC_FILTER_MIN_POIS = 100
|
||||
DYNAMIC_FILTER_ALL_GROUPS = {"Public Transport", "Leisure"}
|
||||
DYNAMIC_FILTER_COUNT_THRESHOLD_GROUPS = {"Groceries"}
|
||||
|
||||
|
||||
def _poi_category_slug(category: str) -> str:
|
||||
ascii_text = (
|
||||
unicodedata.normalize("NFKD", category)
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii")
|
||||
.lower()
|
||||
)
|
||||
slug = re.sub(r"[^a-z0-9]+", "_", ascii_text).strip("_")
|
||||
return slug or "poi"
|
||||
|
||||
|
||||
def _build_poi_category_groups(
|
||||
pois: pl.DataFrame,
|
||||
) -> tuple[dict[str, list[str]], dict[str, str]]:
|
||||
"""Build one proximity group for each POI category selected for filters."""
|
||||
if "group" not in pois.columns:
|
||||
raise ValueError("POI dataframe must include a 'group' column")
|
||||
|
||||
categories = (
|
||||
pois.group_by("group", "category")
|
||||
.len()
|
||||
.filter(
|
||||
pl.col("group").is_in(list(DYNAMIC_FILTER_ALL_GROUPS))
|
||||
| (
|
||||
pl.col("group").is_in(list(DYNAMIC_FILTER_COUNT_THRESHOLD_GROUPS))
|
||||
& (pl.col("len") > GROCERY_DYNAMIC_FILTER_MIN_POIS)
|
||||
)
|
||||
)
|
||||
.select("category")
|
||||
.sort("category")
|
||||
.to_series()
|
||||
.to_list()
|
||||
)
|
||||
used_slugs: dict[str, int] = {}
|
||||
groups: dict[str, list[str]] = {}
|
||||
display_names: dict[str, str] = {}
|
||||
|
||||
for category in categories:
|
||||
if not isinstance(category, str) or not category:
|
||||
continue
|
||||
base_slug = f"poi_{_poi_category_slug(category)}"
|
||||
slug_count = used_slugs.get(base_slug, 0)
|
||||
used_slugs[base_slug] = slug_count + 1
|
||||
group_key = base_slug if slug_count == 0 else f"{base_slug}_{slug_count + 1}"
|
||||
groups[group_key] = [category]
|
||||
display_names[group_key] = category
|
||||
|
||||
return groups, display_names
|
||||
|
||||
|
||||
def _dynamic_poi_metric_renames(display_names: dict[str, str]) -> dict[str, str]:
|
||||
renames: dict[str, str] = {}
|
||||
for group_key, category in display_names.items():
|
||||
renames[f"{group_key}_nearest_km"] = f"Distance to nearest {category} POI (km)"
|
||||
renames[f"{group_key}_2km"] = f"Number of {category} POIs within 2km"
|
||||
renames[f"{group_key}_5km"] = f"Number of {category} POIs within 5km"
|
||||
return renames
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
|
@ -56,12 +137,35 @@ def main():
|
|||
)
|
||||
|
||||
pois = pl.read_parquet(args.pois)
|
||||
poi_category_groups, poi_display_names = _build_poi_category_groups(pois)
|
||||
|
||||
# Count amenity POIs within 2km
|
||||
counts_2km = count_pois_per_postcode(
|
||||
postcodes, pois, groups=POI_GROUPS_2KM, radius_km=2
|
||||
)
|
||||
|
||||
# Dynamic POI filters: nearest distance plus counts within 2km and 5km for
|
||||
# the selected public transport, grocery, and leisure categories.
|
||||
dynamic_counts_2km = count_pois_per_postcode(
|
||||
postcodes, pois, groups=poi_category_groups, radius_km=2
|
||||
)
|
||||
dynamic_counts_5km = count_pois_per_postcode(
|
||||
postcodes, pois, groups=poi_category_groups, radius_km=5
|
||||
)
|
||||
dynamic_distances = min_distance_per_postcode(
|
||||
postcodes, pois, groups=poi_category_groups
|
||||
)
|
||||
dynamic_renames = _dynamic_poi_metric_renames(poi_display_names)
|
||||
dynamic_counts_2km = dynamic_counts_2km.rename(
|
||||
{k: v for k, v in dynamic_renames.items() if k in dynamic_counts_2km.columns}
|
||||
)
|
||||
dynamic_counts_5km = dynamic_counts_5km.rename(
|
||||
{k: v for k, v in dynamic_renames.items() if k in dynamic_counts_5km.columns}
|
||||
)
|
||||
dynamic_distances = dynamic_distances.rename(
|
||||
{k: v for k, v in dynamic_renames.items() if k in dynamic_distances.columns}
|
||||
)
|
||||
|
||||
# Distance to nearest train/tube station (from filtered POIs)
|
||||
distances = min_distance_per_postcode(postcodes, pois, groups=DISTANCE_GROUPS)
|
||||
|
||||
|
|
@ -77,6 +181,9 @@ def main():
|
|||
# Join all results on postcode
|
||||
result = (
|
||||
counts_2km.join(distances, on="postcode")
|
||||
.join(dynamic_counts_2km, on="postcode")
|
||||
.join(dynamic_counts_5km, on="postcode")
|
||||
.join(dynamic_distances, on="postcode")
|
||||
.join(park_counts_1km, on="postcode")
|
||||
.join(park_distances, on="postcode")
|
||||
)
|
||||
|
|
|
|||
33
pipeline/transform/test_merge.py
Normal file
33
pipeline/transform/test_merge.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
import polars as pl
|
||||
|
||||
from pipeline.transform.merge import (
|
||||
_is_dynamic_poi_metric_column,
|
||||
_less_deprived_percentile_expr,
|
||||
)
|
||||
|
||||
|
||||
def test_less_deprived_percentile_expr_preserves_direction_and_nulls() -> None:
|
||||
df = pl.DataFrame({"Income Score (rate)": [1.0, 2.0, 3.0, None]})
|
||||
|
||||
result = df.lazy().with_columns(
|
||||
_less_deprived_percentile_expr("Income Score (rate)")
|
||||
).collect()
|
||||
|
||||
assert result["Income Score (rate)"].to_list() == [100.0, 50.0, 0.0, None]
|
||||
|
||||
|
||||
def test_less_deprived_percentile_expr_uses_exact_scale_endpoints() -> None:
|
||||
df = pl.DataFrame({"Income Score (rate)": [1.0, 1.0, 2.0, 3.0, 3.0]})
|
||||
|
||||
result = df.lazy().with_columns(
|
||||
_less_deprived_percentile_expr("Income Score (rate)")
|
||||
).collect()
|
||||
|
||||
assert result["Income Score (rate)"].to_list() == [100.0, 100.0, 50.0, 0.0, 0.0]
|
||||
|
||||
|
||||
def test_dynamic_poi_metric_columns_are_area_level() -> None:
|
||||
assert _is_dynamic_poi_metric_column("Distance to nearest Cafe POI (km)")
|
||||
assert _is_dynamic_poi_metric_column("Number of Cafe POIs within 2km")
|
||||
assert _is_dynamic_poi_metric_column("Number of Cafe POIs within 5km")
|
||||
assert not _is_dynamic_poi_metric_column("Number of restaurants within 2km")
|
||||
41
pipeline/transform/test_poi_proximity.py
Normal file
41
pipeline/transform/test_poi_proximity.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
import polars as pl
|
||||
|
||||
from pipeline.transform.poi_proximity import _build_poi_category_groups
|
||||
|
||||
|
||||
def test_dynamic_poi_groups_include_requested_categories_only() -> None:
|
||||
pois = pl.DataFrame(
|
||||
{
|
||||
"group": (
|
||||
["Public Transport"] * 2
|
||||
+ ["Leisure"] * 2
|
||||
+ ["Groceries"] * 101
|
||||
+ ["Groceries"] * 100
|
||||
+ ["Education"] * 200
|
||||
+ ["Health"] * 200
|
||||
),
|
||||
"category": (
|
||||
["Rail station", "Bus stop"]
|
||||
+ ["Café", "Restaurant"]
|
||||
+ ["Tesco"] * 101
|
||||
+ ["Waitrose"] * 100
|
||||
+ ["School"] * 200
|
||||
+ ["Pharmacy"] * 200
|
||||
),
|
||||
"lat": [51.5] * 605,
|
||||
"lng": [-0.1] * 605,
|
||||
}
|
||||
)
|
||||
|
||||
groups, display_names = _build_poi_category_groups(pois)
|
||||
|
||||
assert set(display_names.values()) == {
|
||||
"Bus stop",
|
||||
"Café",
|
||||
"Rail station",
|
||||
"Restaurant",
|
||||
"Tesco",
|
||||
}
|
||||
assert "poi_waitrose" not in groups
|
||||
assert "poi_school" not in groups
|
||||
assert "poi_pharmacy" not in groups
|
||||
|
|
@ -1128,12 +1128,18 @@ GROCERY_FASCIA_ICON_NAMES: dict[str, str] = {
|
|||
def normalize_grocery_retailer(retailer: str | None) -> str:
|
||||
if retailer is None:
|
||||
return ""
|
||||
return GROCERY_RETAILER_DISPLAY_NAMES.get(retailer, retailer)
|
||||
display_name = GROCERY_RETAILER_DISPLAY_NAMES.get(retailer)
|
||||
if display_name is None:
|
||||
raise ValueError(f"Missing grocery retailer display name for {retailer!r}")
|
||||
return display_name
|
||||
|
||||
|
||||
def normalize_grocery_icon_category(fascia: str | None, retailer: str | None) -> str:
|
||||
if fascia:
|
||||
return GROCERY_FASCIA_ICON_NAMES.get(fascia, normalize_grocery_retailer(fascia))
|
||||
icon_name = GROCERY_FASCIA_ICON_NAMES.get(fascia)
|
||||
if icon_name is None:
|
||||
raise ValueError(f"Missing grocery fascia icon name for {fascia!r}")
|
||||
return icon_name
|
||||
return normalize_grocery_retailer(retailer)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,9 +2,12 @@
|
|||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
from scipy.spatial import cKDTree
|
||||
|
||||
from .haversine import haversine_km
|
||||
|
||||
EARTH_RADIUS_KM = 6371.0088
|
||||
|
||||
|
||||
def _build_poi_grid(
|
||||
pois: pl.DataFrame, grid_size: float = 0.05
|
||||
|
|
@ -49,6 +52,21 @@ def _get_nearby_indices(
|
|||
return np.concatenate(nearby_indices)
|
||||
|
||||
|
||||
def _project_lat_lng_km(
|
||||
lats: np.ndarray, lngs: np.ndarray, origin_lat: float
|
||||
) -> np.ndarray:
|
||||
"""Project WGS84 coordinates to local km coordinates for nearest-neighbour lookup."""
|
||||
lat_rad = np.radians(lats)
|
||||
lng_rad = np.radians(lngs)
|
||||
origin_lat_rad = np.radians(origin_lat)
|
||||
return np.column_stack(
|
||||
(
|
||||
EARTH_RADIUS_KM * lng_rad * np.cos(origin_lat_rad),
|
||||
EARTH_RADIUS_KM * lat_rad,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def count_pois_per_postcode(
|
||||
postcodes_df: pl.DataFrame,
|
||||
pois: pl.DataFrame,
|
||||
|
|
@ -136,7 +154,7 @@ def min_distance_per_postcode(
|
|||
) -> pl.DataFrame:
|
||||
"""
|
||||
For each postcode, compute the distance (km) to the closest POI per group.
|
||||
Returns NaN where no POI of that group exists within the grid search range (~5.5km).
|
||||
Returns NaN where no POI of that group exists.
|
||||
"""
|
||||
print("Computing minimum POI distances per postcode...")
|
||||
|
||||
|
|
@ -144,51 +162,84 @@ def min_distance_per_postcode(
|
|||
n_pois = len(pois)
|
||||
print(f" {n_postcodes:,} postcodes, {n_pois:,} POIs")
|
||||
|
||||
grid_size = 0.05
|
||||
print(" Building POI spatial grid...")
|
||||
poi_lats, poi_lngs, poi_cats, poi_grid = _build_poi_grid(pois, grid_size)
|
||||
print(f" POI grid has {len(poi_grid):,} occupied cells")
|
||||
|
||||
category_masks = {}
|
||||
for group, categories in groups.items():
|
||||
mask = np.isin(poi_cats, categories)
|
||||
category_masks[group] = mask
|
||||
print(f" {group}: {mask.sum():,} POIs")
|
||||
|
||||
pc_lats = postcodes_df["lat"].to_numpy()
|
||||
pc_lons = postcodes_df["lon"].to_numpy()
|
||||
pc_codes = postcodes_df["postcode"].to_list()
|
||||
valid_pc_mask = np.isfinite(pc_lats) & np.isfinite(pc_lons)
|
||||
valid_pc_indices = np.flatnonzero(valid_pc_mask)
|
||||
|
||||
result_min_dist = {
|
||||
group: np.full(n_postcodes, np.nan, dtype=np.float32) for group in groups
|
||||
}
|
||||
|
||||
batch_size = 50000
|
||||
n_batches = (n_postcodes + batch_size - 1) // batch_size
|
||||
print(f" Processing {n_postcodes:,} postcodes in {n_batches} batches...")
|
||||
if n_pois == 0 or len(valid_pc_indices) == 0:
|
||||
print(" No valid postcode/POI coordinates; returning NaN distances")
|
||||
return pl.DataFrame(
|
||||
{
|
||||
"postcode": pc_codes,
|
||||
**{
|
||||
f"{group}_nearest_km": values
|
||||
for group, values in result_min_dist.items()
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
for batch_idx in range(n_batches):
|
||||
start_idx = batch_idx * batch_size
|
||||
end_idx = min(start_idx + batch_size, n_postcodes)
|
||||
poi_lats = pois["lat"].to_numpy()
|
||||
poi_lngs = pois["lng"].to_numpy()
|
||||
poi_cats = pois["category"].to_numpy()
|
||||
valid_poi_mask = np.isfinite(poi_lats) & np.isfinite(poi_lngs)
|
||||
origin_lat = float(np.nanmean(pc_lats[valid_pc_mask]))
|
||||
query_xy = _project_lat_lng_km(
|
||||
pc_lats[valid_pc_indices], pc_lons[valid_pc_indices], origin_lat
|
||||
)
|
||||
|
||||
if batch_idx % 5 == 0:
|
||||
print(
|
||||
f" Batch {batch_idx + 1}/{n_batches}: postcodes {start_idx:,} - {end_idx:,}"
|
||||
)
|
||||
batch_size = 200_000
|
||||
n_batches = (len(valid_pc_indices) + batch_size - 1) // batch_size
|
||||
|
||||
for i in range(start_idx, end_idx):
|
||||
nearby = _get_nearby_indices(pc_lats[i], pc_lons[i], poi_grid, grid_size)
|
||||
if nearby is None:
|
||||
continue
|
||||
for group, categories in groups.items():
|
||||
group_indices = np.flatnonzero(valid_poi_mask & np.isin(poi_cats, categories))
|
||||
print(f" {group}: {len(group_indices):,} POIs")
|
||||
if len(group_indices) == 0:
|
||||
continue
|
||||
|
||||
distances = haversine_km(
|
||||
poi_lats[nearby], poi_lngs[nearby], pc_lats[i], pc_lons[i]
|
||||
)
|
||||
poi_xy = _project_lat_lng_km(
|
||||
poi_lats[group_indices], poi_lngs[group_indices], origin_lat
|
||||
)
|
||||
tree = cKDTree(poi_xy)
|
||||
k = min(8, len(group_indices))
|
||||
|
||||
for group, cat_mask in category_masks.items():
|
||||
group_mask = cat_mask[nearby]
|
||||
if group_mask.any():
|
||||
result_min_dist[group][i] = distances[group_mask].min()
|
||||
for batch_idx in range(n_batches):
|
||||
start_idx = batch_idx * batch_size
|
||||
end_idx = min(start_idx + batch_size, len(valid_pc_indices))
|
||||
batch_pc_indices = valid_pc_indices[start_idx:end_idx]
|
||||
batch_xy = query_xy[start_idx:end_idx]
|
||||
|
||||
if batch_idx == 0 or (batch_idx + 1) % 5 == 0:
|
||||
print(
|
||||
f" Batch {batch_idx + 1}/{n_batches}: postcodes {start_idx:,} - {end_idx:,}"
|
||||
)
|
||||
|
||||
_, nearest = tree.query(batch_xy, k=k)
|
||||
nearest = np.asarray(nearest)
|
||||
|
||||
if k == 1:
|
||||
candidate_indices = group_indices[nearest]
|
||||
distances = haversine_km(
|
||||
poi_lats[candidate_indices],
|
||||
poi_lngs[candidate_indices],
|
||||
pc_lats[batch_pc_indices],
|
||||
pc_lons[batch_pc_indices],
|
||||
)
|
||||
else:
|
||||
candidate_indices = group_indices[nearest]
|
||||
distances = haversine_km(
|
||||
poi_lats[candidate_indices],
|
||||
poi_lngs[candidate_indices],
|
||||
pc_lats[batch_pc_indices, None],
|
||||
pc_lons[batch_pc_indices, None],
|
||||
).min(axis=1)
|
||||
|
||||
result_min_dist[group][batch_pc_indices] = distances.astype(np.float32)
|
||||
|
||||
result_data = {"postcode": pc_codes}
|
||||
for group in groups:
|
||||
|
|
|
|||
|
|
@ -113,9 +113,9 @@ def test_min_distance_finds_nearest(postcodes, pois):
|
|||
# Restaurant is co-located — distance ~0
|
||||
assert ec1a["restaurants_nearest_km"][0] < 0.01
|
||||
|
||||
# Far-away postcode should have NaN (no POIs within grid range)
|
||||
# Far-away postcode should still get the global nearest distance.
|
||||
zz99 = result.filter(pl.col("postcode") == "ZZ99 9ZZ")
|
||||
assert np.isnan(zz99["train_tube_nearest_km"][0])
|
||||
assert zz99["train_tube_nearest_km"][0] > 300
|
||||
|
||||
|
||||
def test_min_distance_no_pois_returns_nan(postcodes):
|
||||
|
|
|
|||
|
|
@ -111,20 +111,23 @@ fi
|
|||
# R5 writes .mapdb temp files next to OSM/GTFS files during network construction.
|
||||
# Copy source data to a writable build dir to avoid polluting the originals.
|
||||
mkdir -p "$NETWORK_DIR"
|
||||
OSM_PBF="property-data/england-latest.osm.pbf"
|
||||
TRANSIT_SRC="property-data/transit"
|
||||
NETWORK_DATA_DIR="$TRANSIT_SRC"
|
||||
NETWORK_DATA_DIR="$NETWORK_DIR/build"
|
||||
|
||||
if [ ! -f "$NETWORK_DIR/network.dat" ]; then
|
||||
BUILD_DIR="$NETWORK_DIR/build"
|
||||
echo "--- No cached network — copying transit data to build dir ---"
|
||||
mkdir -p "$BUILD_DIR"
|
||||
if ! cp "$TRANSIT_SRC"/raw/*.osm.pbf "$BUILD_DIR/" 2>/dev/null; then
|
||||
echo "Warning: no .osm.pbf files found in $TRANSIT_SRC/raw/"
|
||||
if [ ! -f "$OSM_PBF" ]; then
|
||||
echo "Error: OSM PBF not found at $OSM_PBF"
|
||||
echo "Download it from https://download.geofabrik.de/europe/united-kingdom/england-latest.osm.pbf"
|
||||
exit 1
|
||||
fi
|
||||
cp "$OSM_PBF" "$BUILD_DIR/"
|
||||
if ! cp "$TRANSIT_SRC"/*.zip "$BUILD_DIR/" 2>/dev/null; then
|
||||
echo "Warning: no .zip files found in $TRANSIT_SRC/"
|
||||
echo "Warning: no GTFS .zip files found in $TRANSIT_SRC/ — transit routing will be unavailable"
|
||||
fi
|
||||
NETWORK_DATA_DIR="$BUILD_DIR"
|
||||
fi
|
||||
|
||||
# --- Step 5: Run batch ---
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use crate::consts::NAN_U16;
|
||||
use crate::data::QuantRef;
|
||||
use crate::data::{PostcodePoiMetrics, QuantRef};
|
||||
|
||||
/// Optional per-enum-value distribution tracking for a single feature.
|
||||
/// Counts how many rows have each enum value (by raw u16 index).
|
||||
|
|
@ -21,6 +21,69 @@ pub struct Aggregator {
|
|||
pub enum_dist: Option<EnumDist>,
|
||||
}
|
||||
|
||||
/// Accumulator for postcode-level POI metrics stored outside `feature_data`.
|
||||
/// Only constructed when a request selects POI metric fields.
|
||||
pub struct PoiAggregator {
|
||||
pub mins: Box<[f32]>,
|
||||
pub maxs: Box<[f32]>,
|
||||
pub sums: Box<[f64]>,
|
||||
pub counts: Box<[u32]>,
|
||||
}
|
||||
|
||||
impl PoiAggregator {
|
||||
pub fn new(num_features: usize) -> Self {
|
||||
Self {
|
||||
mins: vec![f32::INFINITY; num_features].into_boxed_slice(),
|
||||
maxs: vec![f32::NEG_INFINITY; num_features].into_boxed_slice(),
|
||||
sums: vec![0.0f64; num_features].into_boxed_slice(),
|
||||
counts: vec![0u32; num_features].into_boxed_slice(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn add_row_selective(
|
||||
&mut self,
|
||||
poi_metrics: &PostcodePoiMetrics,
|
||||
row: usize,
|
||||
indices: &[usize],
|
||||
) {
|
||||
let Some(metric_row) = poi_metrics.metric_row_for_property(row) else {
|
||||
return;
|
||||
};
|
||||
for &metric_idx in indices {
|
||||
let raw = poi_metrics.raw_for_metric_row(metric_row, metric_idx);
|
||||
if raw == NAN_U16 {
|
||||
continue;
|
||||
}
|
||||
let value = poi_metrics.decode_raw(metric_idx, raw);
|
||||
if value < self.mins[metric_idx] {
|
||||
self.mins[metric_idx] = value;
|
||||
}
|
||||
if value > self.maxs[metric_idx] {
|
||||
self.maxs[metric_idx] = value;
|
||||
}
|
||||
self.sums[metric_idx] += value as f64;
|
||||
self.counts[metric_idx] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn merge(&mut self, other: &PoiAggregator) {
|
||||
for i in 0..self.counts.len() {
|
||||
if other.counts[i] == 0 {
|
||||
continue;
|
||||
}
|
||||
if other.mins[i] < self.mins[i] {
|
||||
self.mins[i] = other.mins[i];
|
||||
}
|
||||
if other.maxs[i] > self.maxs[i] {
|
||||
self.maxs[i] = other.maxs[i];
|
||||
}
|
||||
self.sums[i] += other.sums[i];
|
||||
self.counts[i] += other.counts[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for enum distribution tracking, passed to Aggregator::new.
|
||||
/// (feature_index, number_of_enum_values)
|
||||
pub type EnumDistConfig = Option<(usize, usize)>;
|
||||
|
|
|
|||
807
server-rs/src/checkout_sessions.rs
Normal file
807
server-rs/src/checkout_sessions.rs
Normal file
|
|
@ -0,0 +1,807 @@
|
|||
use std::sync::LazyLock;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::auth::PocketBaseUser;
|
||||
use crate::pocketbase::get_superuser_token;
|
||||
use crate::pocketbase_locks::acquire_pocketbase_lock;
|
||||
use crate::routes::pricing::{count_licensed_users, price_for_count};
|
||||
use crate::state::AppState;
|
||||
|
||||
pub const CHECKOUT_CURRENCY: &str = "gbp";
|
||||
|
||||
const CHECKOUT_SESSION_TTL_SECS: u64 = 31 * 60;
|
||||
const CHECKOUT_PRODUCT_NAME: &str = "Perfect Postcodes Lifetime License";
|
||||
const CHECKOUT_COLLECTION: &str = "checkout_sessions";
|
||||
const CHECKOUT_PRICING_LOCK_NAME: &str = "checkout:pricing";
|
||||
const CHECKOUT_PRICING_LOCK_TTL_SECS: u64 = 5 * 60;
|
||||
const REFERRAL_DISCOUNT_PERCENT: u64 = 30;
|
||||
|
||||
static CHECKOUT_RESERVATION_LOCK: LazyLock<Mutex<()>> = LazyLock::new(|| Mutex::new(()));
|
||||
|
||||
pub enum CheckoutStart {
|
||||
Free,
|
||||
Stripe { url: String },
|
||||
}
|
||||
|
||||
pub enum CheckoutCompletion {
|
||||
Grant(VerifiedCheckout),
|
||||
AlreadyHandled,
|
||||
Rejected(String),
|
||||
}
|
||||
|
||||
pub struct VerifiedCheckout {
|
||||
pub reservation_id: String,
|
||||
pub user_id: String,
|
||||
pub paid_amount_pence: u64,
|
||||
pub referral_invite_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PendingCheckout {
|
||||
id: String,
|
||||
user_id: String,
|
||||
stripe_session_id: String,
|
||||
checkout_url: String,
|
||||
amount_pence: u64,
|
||||
expected_total_pence: u64,
|
||||
currency: String,
|
||||
referral_invite_id: String,
|
||||
status: String,
|
||||
}
|
||||
|
||||
pub fn now_unix_secs() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
pub async fn start_license_checkout(
|
||||
state: &AppState,
|
||||
user: &PocketBaseUser,
|
||||
success_url: &str,
|
||||
cancel_url: &str,
|
||||
discount_coupon_id: Option<&str>,
|
||||
referral_invite_id: Option<&str>,
|
||||
) -> anyhow::Result<CheckoutStart> {
|
||||
let _guard = CHECKOUT_RESERVATION_LOCK.lock().await;
|
||||
let pricing_lock = acquire_pocketbase_lock(
|
||||
state,
|
||||
CHECKOUT_PRICING_LOCK_NAME,
|
||||
CHECKOUT_PRICING_LOCK_TTL_SECS,
|
||||
)
|
||||
.await?;
|
||||
let result = start_license_checkout_locked(
|
||||
state,
|
||||
user,
|
||||
success_url,
|
||||
cancel_url,
|
||||
discount_coupon_id,
|
||||
referral_invite_id,
|
||||
)
|
||||
.await;
|
||||
if let Err(err) = pricing_lock.release().await {
|
||||
warn!("Failed to release checkout pricing lock: {err}");
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
async fn start_license_checkout_locked(
|
||||
state: &AppState,
|
||||
user: &PocketBaseUser,
|
||||
success_url: &str,
|
||||
cancel_url: &str,
|
||||
discount_coupon_id: Option<&str>,
|
||||
referral_invite_id: Option<&str>,
|
||||
) -> anyhow::Result<CheckoutStart> {
|
||||
let now = now_unix_secs();
|
||||
expire_stale_pending_checkouts(state, now).await?;
|
||||
|
||||
if let Some(existing) = find_active_checkout_for_user(
|
||||
state,
|
||||
&user.id,
|
||||
discount_coupon_id.unwrap_or_default(),
|
||||
referral_invite_id.unwrap_or_default(),
|
||||
now,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
if !existing.checkout_url.is_empty() {
|
||||
return Ok(CheckoutStart::Stripe {
|
||||
url: existing.checkout_url,
|
||||
});
|
||||
}
|
||||
if let Err(err) = mark_checkout_status(state, &existing.id, "failed").await {
|
||||
warn!(
|
||||
reservation_id = %existing.id,
|
||||
"Failed to fail incomplete checkout reservation: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let licensed_count = count_licensed_users(state).await?;
|
||||
let pending_count = count_active_pending_checkouts(state, now).await?;
|
||||
let price_pence = price_for_count(licensed_count + pending_count);
|
||||
|
||||
if price_pence == 0 {
|
||||
grant_license(state, &user.id).await?;
|
||||
return Ok(CheckoutStart::Free);
|
||||
}
|
||||
|
||||
let expires_at_unix = now + CHECKOUT_SESSION_TTL_SECS;
|
||||
let expected_total_pence = expected_total_for_checkout(price_pence, discount_coupon_id);
|
||||
let reservation_id = create_pending_checkout(
|
||||
state,
|
||||
PendingCheckoutInput {
|
||||
user_id: &user.id,
|
||||
amount_pence: price_pence,
|
||||
expected_total_pence,
|
||||
currency: CHECKOUT_CURRENCY,
|
||||
discount_coupon_id: discount_coupon_id.unwrap_or_default(),
|
||||
referral_invite_id: referral_invite_id.unwrap_or_default(),
|
||||
expires_at_unix,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let stripe_result = create_stripe_session(
|
||||
state,
|
||||
user,
|
||||
&reservation_id,
|
||||
price_pence,
|
||||
success_url,
|
||||
cancel_url,
|
||||
expires_at_unix,
|
||||
discount_coupon_id,
|
||||
)
|
||||
.await;
|
||||
|
||||
let (stripe_session_id, url) = match stripe_result {
|
||||
Ok(session) => session,
|
||||
Err(err) => {
|
||||
if let Err(mark_err) = mark_checkout_status(state, &reservation_id, "failed").await {
|
||||
warn!(
|
||||
reservation_id,
|
||||
"Failed to mark checkout reservation failed: {mark_err}"
|
||||
);
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(err) = attach_stripe_session(state, &reservation_id, &stripe_session_id, &url).await
|
||||
{
|
||||
if let Err(mark_err) = mark_checkout_status(state, &reservation_id, "failed").await {
|
||||
warn!(
|
||||
reservation_id,
|
||||
"Failed to mark checkout reservation failed: {mark_err}"
|
||||
);
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
Ok(CheckoutStart::Stripe { url })
|
||||
}
|
||||
|
||||
pub async fn verify_checkout_completion(
|
||||
state: &AppState,
|
||||
session: &Value,
|
||||
) -> anyhow::Result<CheckoutCompletion> {
|
||||
let session_id = match session["id"].as_str() {
|
||||
Some(id) if is_safe_stripe_session_id(id) => id,
|
||||
_ => {
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"missing or invalid session id".into(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let checkout = match find_checkout_by_stripe_session(state, session_id).await? {
|
||||
Some(checkout) => checkout,
|
||||
None => {
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"checkout session has no reservation".into(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
if checkout.status == "completed" {
|
||||
return Ok(CheckoutCompletion::AlreadyHandled);
|
||||
}
|
||||
if checkout.status != "pending" && checkout.status != "expired" {
|
||||
return Ok(CheckoutCompletion::Rejected(format!(
|
||||
"checkout reservation is {}",
|
||||
checkout.status
|
||||
)));
|
||||
}
|
||||
if checkout.stripe_session_id != session_id {
|
||||
mark_checkout_status(state, &checkout.id, "invalid").await?;
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"checkout reservation session id mismatch".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let client_reference_id = session["client_reference_id"].as_str().unwrap_or_default();
|
||||
if client_reference_id != checkout.user_id {
|
||||
mark_checkout_status(state, &checkout.id, "invalid").await?;
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"checkout client_reference_id mismatch".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let payment_status = session["payment_status"].as_str().unwrap_or_default();
|
||||
if payment_status != "paid" {
|
||||
return Ok(CheckoutCompletion::Rejected(format!(
|
||||
"checkout payment_status is {payment_status}"
|
||||
)));
|
||||
}
|
||||
|
||||
let currency = session["currency"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_ascii_lowercase();
|
||||
if currency != checkout.currency {
|
||||
mark_checkout_status(state, &checkout.id, "invalid").await?;
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"checkout currency mismatch".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let amount_subtotal = match number_field(session, "amount_subtotal") {
|
||||
Some(amount) => amount,
|
||||
None => {
|
||||
mark_checkout_status(state, &checkout.id, "invalid").await?;
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"checkout amount_subtotal missing".into(),
|
||||
));
|
||||
}
|
||||
};
|
||||
if amount_subtotal != checkout.amount_pence {
|
||||
mark_checkout_status(state, &checkout.id, "invalid").await?;
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"checkout amount_subtotal mismatch".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let amount_total = match number_field(session, "amount_total") {
|
||||
Some(amount) => amount,
|
||||
None => {
|
||||
mark_checkout_status(state, &checkout.id, "invalid").await?;
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"checkout amount_total missing".into(),
|
||||
));
|
||||
}
|
||||
};
|
||||
if amount_total != checkout.expected_total_pence {
|
||||
mark_checkout_status(state, &checkout.id, "invalid").await?;
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"checkout amount_total mismatch".into(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(CheckoutCompletion::Grant(VerifiedCheckout {
|
||||
reservation_id: checkout.id,
|
||||
user_id: checkout.user_id,
|
||||
paid_amount_pence: amount_total,
|
||||
referral_invite_id: checkout.referral_invite_id,
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn mark_checkout_completed(
|
||||
state: &AppState,
|
||||
reservation_id: &str,
|
||||
paid_amount_pence: u64,
|
||||
) -> anyhow::Result<()> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"status": "completed",
|
||||
"paid_amount_pence": paid_amount_pence,
|
||||
"completed_at_unix": now_unix_secs().to_string(),
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp)
|
||||
.await
|
||||
.context("PocketBase checkout completion update failed")
|
||||
}
|
||||
|
||||
pub async fn grant_license(state: &AppState, user_id: &str) -> anyhow::Result<()> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({ "subscription": "licensed" }))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp)
|
||||
.await
|
||||
.context("PocketBase license update failed")?;
|
||||
|
||||
state.token_cache.invalidate_by_user_id(user_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn mark_referral_invite_used(
|
||||
state: &AppState,
|
||||
invite_id: &str,
|
||||
user_id: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
if invite_id.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
if !is_safe_pocketbase_id(invite_id) || !is_safe_pocketbase_id(user_id) {
|
||||
return Err(anyhow!("invalid PocketBase id"));
|
||||
}
|
||||
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let existing_used_by = fetch_invite_used_by(state, pb_url, &token, invite_id).await?;
|
||||
if existing_used_by == user_id {
|
||||
return Ok(());
|
||||
}
|
||||
if !existing_used_by.is_empty() {
|
||||
return Err(anyhow!("referral invite already used by another account"));
|
||||
}
|
||||
|
||||
let url = format!("{pb_url}/api/collections/invites/records/{invite_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"used_by_id": user_id,
|
||||
"used_at": now_unix_secs().to_string(),
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp)
|
||||
.await
|
||||
.context("PocketBase invite usage update failed")
|
||||
}
|
||||
|
||||
async fn fetch_invite_used_by(
|
||||
state: &AppState,
|
||||
pb_url: &str,
|
||||
token: &str,
|
||||
invite_id: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
let url = format!("{pb_url}/api/collections/invites/records/{invite_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
let body: Value = resp.json().await?;
|
||||
Ok(body["used_by_id"].as_str().unwrap_or_default().to_string())
|
||||
}
|
||||
|
||||
pub async fn active_referral_checkout_user(
|
||||
state: &AppState,
|
||||
invite_id: &str,
|
||||
) -> anyhow::Result<Option<String>> {
|
||||
if !is_safe_pocketbase_id(invite_id) {
|
||||
return Err(anyhow!("invalid PocketBase invite id"));
|
||||
}
|
||||
|
||||
let now = now_unix_secs();
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let filter = format!(
|
||||
"status=\"pending\" && expires_at_unix>={now} && referral_invite_id=\"{}\"",
|
||||
invite_id
|
||||
);
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
let body: Value = resp.json().await?;
|
||||
Ok(body["items"]
|
||||
.as_array()
|
||||
.and_then(|items| items.first())
|
||||
.and_then(|item| item["user"].as_str())
|
||||
.map(str::to_string))
|
||||
}
|
||||
|
||||
async fn count_active_pending_checkouts(state: &AppState, now: u64) -> anyhow::Result<u64> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let filter = format!("status=\"pending\" && expires_at_unix>={now}");
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
let body: Value = resp.json().await?;
|
||||
Ok(body["totalItems"].as_u64().unwrap_or(0))
|
||||
}
|
||||
|
||||
async fn find_active_checkout_for_user(
|
||||
state: &AppState,
|
||||
user_id: &str,
|
||||
discount_coupon_id: &str,
|
||||
referral_invite_id: &str,
|
||||
now: u64,
|
||||
) -> anyhow::Result<Option<PendingCheckout>> {
|
||||
if !is_safe_pocketbase_id(user_id) {
|
||||
return Err(anyhow!("invalid PocketBase user id"));
|
||||
}
|
||||
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let filter = format!(
|
||||
"status=\"pending\" && expires_at_unix>={now} && user=\"{}\" && discount_coupon_id=\"{}\" && referral_invite_id=\"{}\"",
|
||||
user_id, discount_coupon_id, referral_invite_id
|
||||
);
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
let body: Value = resp.json().await?;
|
||||
let item = body["items"]
|
||||
.as_array()
|
||||
.and_then(|items| items.first())
|
||||
.cloned();
|
||||
|
||||
item.map(parse_pending_checkout).transpose()
|
||||
}
|
||||
|
||||
async fn expire_stale_pending_checkouts(state: &AppState, now: u64) -> anyhow::Result<()> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let filter = format!("status=\"pending\" && expires_at_unix<{now}");
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=50",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
let body: Value = resp.json().await?;
|
||||
let Some(items) = body["items"].as_array() else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
for id in items.iter().filter_map(|item| item["id"].as_str()) {
|
||||
if let Err(err) = mark_checkout_status(state, id, "expired").await {
|
||||
warn!(
|
||||
reservation_id = id,
|
||||
"Failed to expire checkout reservation: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct PendingCheckoutInput<'a> {
|
||||
user_id: &'a str,
|
||||
amount_pence: u64,
|
||||
expected_total_pence: u64,
|
||||
currency: &'a str,
|
||||
discount_coupon_id: &'a str,
|
||||
referral_invite_id: &'a str,
|
||||
expires_at_unix: u64,
|
||||
}
|
||||
|
||||
async fn create_pending_checkout(
|
||||
state: &AppState,
|
||||
input: PendingCheckoutInput<'_>,
|
||||
) -> anyhow::Result<String> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records");
|
||||
let resp = state
|
||||
.http_client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"user": input.user_id,
|
||||
"stripe_session_id": "",
|
||||
"checkout_url": "",
|
||||
"amount_pence": input.amount_pence,
|
||||
"expected_total_pence": input.expected_total_pence,
|
||||
"currency": input.currency,
|
||||
"discount_coupon_id": input.discount_coupon_id,
|
||||
"referral_invite_id": input.referral_invite_id,
|
||||
"status": "pending",
|
||||
"expires_at_unix": input.expires_at_unix,
|
||||
"paid_amount_pence": 0,
|
||||
"completed_at_unix": "",
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
let body: Value = resp.json().await?;
|
||||
body["id"]
|
||||
.as_str()
|
||||
.map(str::to_string)
|
||||
.ok_or_else(|| anyhow!("PocketBase checkout reservation missing id"))
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn create_stripe_session(
|
||||
state: &AppState,
|
||||
user: &PocketBaseUser,
|
||||
reservation_id: &str,
|
||||
price_pence: u64,
|
||||
success_url: &str,
|
||||
cancel_url: &str,
|
||||
expires_at_unix: u64,
|
||||
discount_coupon_id: Option<&str>,
|
||||
) -> anyhow::Result<(String, String)> {
|
||||
let mut form_params = vec![
|
||||
("mode", "payment".to_string()),
|
||||
("payment_method_types[0]", "card".to_string()),
|
||||
(
|
||||
"line_items[0][price_data][unit_amount]",
|
||||
price_pence.to_string(),
|
||||
),
|
||||
(
|
||||
"line_items[0][price_data][currency]",
|
||||
CHECKOUT_CURRENCY.to_string(),
|
||||
),
|
||||
(
|
||||
"line_items[0][price_data][product_data][name]",
|
||||
CHECKOUT_PRODUCT_NAME.to_string(),
|
||||
),
|
||||
("line_items[0][quantity]", "1".to_string()),
|
||||
("success_url", success_url.to_string()),
|
||||
("cancel_url", cancel_url.to_string()),
|
||||
("expires_at", expires_at_unix.to_string()),
|
||||
("client_reference_id", user.id.clone()),
|
||||
("customer_email", user.email.clone()),
|
||||
("metadata[pending_checkout_id]", reservation_id.to_string()),
|
||||
("metadata[expected_amount_pence]", price_pence.to_string()),
|
||||
(
|
||||
"metadata[expected_total_pence]",
|
||||
expected_total_for_checkout(price_pence, discount_coupon_id).to_string(),
|
||||
),
|
||||
("metadata[expected_currency]", CHECKOUT_CURRENCY.to_string()),
|
||||
];
|
||||
|
||||
if let Some(coupon_id) = discount_coupon_id.filter(|id| !id.is_empty()) {
|
||||
form_params.push(("discounts[0][coupon]", coupon_id.to_string()));
|
||||
form_params.push(("metadata[discount_coupon_id]", coupon_id.to_string()));
|
||||
}
|
||||
|
||||
let resp = state
|
||||
.http_client
|
||||
.post("https://api.stripe.com/v1/checkout/sessions")
|
||||
.basic_auth(&state.stripe_secret_key, None::<&str>)
|
||||
.form(&form_params)
|
||||
.send()
|
||||
.await
|
||||
.context("Stripe checkout request failed")?;
|
||||
|
||||
ensure_success_ref(&resp)
|
||||
.await
|
||||
.context("Stripe checkout failed")?;
|
||||
|
||||
let body: Value = resp
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Stripe response")?;
|
||||
let session_id = body["id"]
|
||||
.as_str()
|
||||
.filter(|id| is_safe_stripe_session_id(id))
|
||||
.map(str::to_string)
|
||||
.ok_or_else(|| anyhow!("Stripe session missing valid id"))?;
|
||||
let url = body["url"]
|
||||
.as_str()
|
||||
.map(str::to_string)
|
||||
.filter(|url| !url.is_empty())
|
||||
.ok_or_else(|| anyhow!("Stripe session missing URL"))?;
|
||||
|
||||
Ok((session_id, url))
|
||||
}
|
||||
|
||||
async fn attach_stripe_session(
|
||||
state: &AppState,
|
||||
reservation_id: &str,
|
||||
stripe_session_id: &str,
|
||||
checkout_url: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"stripe_session_id": stripe_session_id,
|
||||
"checkout_url": checkout_url,
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp)
|
||||
.await
|
||||
.context("PocketBase checkout session attach failed")
|
||||
}
|
||||
|
||||
async fn mark_checkout_status(
|
||||
state: &AppState,
|
||||
reservation_id: &str,
|
||||
status: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({ "status": status }))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp)
|
||||
.await
|
||||
.with_context(|| format!("PocketBase checkout status update failed for {reservation_id}"))
|
||||
}
|
||||
|
||||
async fn find_checkout_by_stripe_session(
|
||||
state: &AppState,
|
||||
stripe_session_id: &str,
|
||||
) -> anyhow::Result<Option<PendingCheckout>> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let filter = format!("stripe_session_id=\"{}\"", stripe_session_id);
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
let body: Value = resp.json().await?;
|
||||
let item = body["items"]
|
||||
.as_array()
|
||||
.and_then(|items| items.first())
|
||||
.cloned();
|
||||
|
||||
item.map(parse_pending_checkout).transpose()
|
||||
}
|
||||
|
||||
fn parse_pending_checkout(item: Value) -> anyhow::Result<PendingCheckout> {
|
||||
Ok(PendingCheckout {
|
||||
id: item["id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow!("checkout reservation missing id"))?
|
||||
.to_string(),
|
||||
user_id: item["user"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow!("checkout reservation missing user"))?
|
||||
.to_string(),
|
||||
stripe_session_id: item["stripe_session_id"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
checkout_url: item["checkout_url"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
amount_pence: number_field(&item, "amount_pence")
|
||||
.ok_or_else(|| anyhow!("checkout reservation missing amount_pence"))?,
|
||||
expected_total_pence: number_field(&item, "expected_total_pence")
|
||||
.ok_or_else(|| anyhow!("checkout reservation missing expected_total_pence"))?,
|
||||
currency: item["currency"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_ascii_lowercase(),
|
||||
referral_invite_id: item["referral_invite_id"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
status: item["status"].as_str().unwrap_or_default().to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn expected_total_for_checkout(amount_pence: u64, discount_coupon_id: Option<&str>) -> u64 {
|
||||
if discount_coupon_id.is_some_and(|id| !id.is_empty()) {
|
||||
return ((amount_pence * (100 - REFERRAL_DISCOUNT_PERCENT)) / 100).max(1);
|
||||
}
|
||||
amount_pence
|
||||
}
|
||||
|
||||
fn number_field(value: &Value, field: &str) -> Option<u64> {
|
||||
value[field].as_u64().or_else(|| {
|
||||
value[field]
|
||||
.as_f64()
|
||||
.filter(|n| n.is_finite() && *n >= 0.0 && n.fract() == 0.0)
|
||||
.map(|n| n as u64)
|
||||
})
|
||||
}
|
||||
|
||||
fn is_safe_stripe_session_id(id: &str) -> bool {
|
||||
!id.is_empty()
|
||||
&& id.len() <= 128
|
||||
&& id
|
||||
.bytes()
|
||||
.all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'-')
|
||||
}
|
||||
|
||||
fn is_safe_pocketbase_id(id: &str) -> bool {
|
||||
!id.is_empty() && id.len() <= 32 && id.bytes().all(|b| b.is_ascii_alphanumeric())
|
||||
}
|
||||
|
||||
async fn ensure_success(resp: reqwest::Response) -> anyhow::Result<()> {
|
||||
if resp.status().is_success() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
Err(anyhow!("upstream returned {status}: {text}"))
|
||||
}
|
||||
|
||||
async fn ensure_success_ref(resp: &reqwest::Response) -> anyhow::Result<()> {
|
||||
if resp.status().is_success() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
Err(anyhow!("upstream returned {}", resp.status()))
|
||||
}
|
||||
|
|
@ -97,7 +97,7 @@ fn build_search_text(name: &str, place_type: &str) -> String {
|
|||
}
|
||||
|
||||
if place_type == "station" {
|
||||
let suffix_aliases: [(&str, &[&str]); 5] = [
|
||||
let suffix_aliases: [(&str, &[&str]); 6] = [
|
||||
(
|
||||
" tube station",
|
||||
&[" underground station", " station", " tube", " underground"],
|
||||
|
|
@ -118,6 +118,7 @@ fn build_search_text(name: &str, place_type: &str) -> String {
|
|||
" elizabeth line station",
|
||||
&[" station", " elizabeth line", " crossrail station"],
|
||||
),
|
||||
(" dlr station", &[" station", " dlr"]),
|
||||
];
|
||||
|
||||
for (suffix, replacements) in suffix_aliases {
|
||||
|
|
@ -139,10 +140,15 @@ fn extract_str_col(df: &DataFrame, name: &str) -> anyhow::Result<Vec<String>> {
|
|||
let string_column = column
|
||||
.str()
|
||||
.with_context(|| format!("Column '{name}' is not a string column"))?;
|
||||
Ok(string_column
|
||||
string_column
|
||||
.into_iter()
|
||||
.map(|value| value.unwrap_or("").to_string())
|
||||
.collect())
|
||||
.enumerate()
|
||||
.map(|(row, value)| {
|
||||
value
|
||||
.map(ToString::to_string)
|
||||
.with_context(|| format!("Column '{name}' has null at row {row}"))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn extract_f32_col(df: &DataFrame, name: &str) -> anyhow::Result<Vec<f32>> {
|
||||
|
|
@ -155,33 +161,37 @@ fn extract_f32_col(df: &DataFrame, name: &str) -> anyhow::Result<Vec<f32>> {
|
|||
let float_column = cast
|
||||
.f32()
|
||||
.with_context(|| format!("Column '{name}' is not a float32 column"))?;
|
||||
Ok(float_column
|
||||
float_column
|
||||
.into_iter()
|
||||
.map(|value| value.unwrap_or(0.0))
|
||||
.collect())
|
||||
.enumerate()
|
||||
.map(|(row, value)| value.with_context(|| format!("Column '{name}' has null at row {row}")))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn extract_bool_col_or_default(
|
||||
df: &DataFrame,
|
||||
name: &str,
|
||||
default_value: bool,
|
||||
) -> anyhow::Result<Vec<bool>> {
|
||||
let Ok(column) = df.column(name) else {
|
||||
return Ok(vec![default_value; df.height()]);
|
||||
};
|
||||
fn extract_bool_col(df: &DataFrame, name: &str) -> anyhow::Result<Vec<bool>> {
|
||||
let column = df
|
||||
.column(name)
|
||||
.with_context(|| format!("Missing column '{name}' in places data"))?;
|
||||
let bool_column = column
|
||||
.bool()
|
||||
.with_context(|| format!("Column '{name}' is not a boolean column"))?;
|
||||
Ok(bool_column
|
||||
bool_column
|
||||
.into_iter()
|
||||
.map(|value| value.unwrap_or(default_value))
|
||||
.collect())
|
||||
.enumerate()
|
||||
.map(|(row, value)| value.with_context(|| format!("Column '{name}' has null at row {row}")))
|
||||
.collect()
|
||||
}
|
||||
|
||||
impl PlaceData {
|
||||
pub fn load(parquet_path: &Path) -> anyhow::Result<Self> {
|
||||
super::run_polars_io(|| Self::load_inner(parquet_path))
|
||||
}
|
||||
|
||||
fn load_inner(parquet_path: &Path) -> anyhow::Result<Self> {
|
||||
info!("Loading place data from {:?}...", parquet_path);
|
||||
|
||||
let parquet_path = PlRefPath::try_from_path(parquet_path)
|
||||
.context("Failed to normalize places parquet path")?;
|
||||
let df = LazyFrame::scan_parquet(parquet_path, Default::default())
|
||||
.context("Failed to scan places parquet")?
|
||||
.collect()
|
||||
|
|
@ -210,7 +220,7 @@ impl PlaceData {
|
|||
let type_rank_vec: Vec<u8> = place_type_raw.iter().map(|pt| type_rank(pt)).collect();
|
||||
let place_type = InternedColumn::build(&place_type_raw);
|
||||
let travel_destination = if df.column("travel_destination").is_ok() {
|
||||
extract_bool_col_or_default(&df, "travel_destination", true)?
|
||||
extract_bool_col(&df, "travel_destination")?
|
||||
} else {
|
||||
place_type_raw
|
||||
.iter()
|
||||
|
|
@ -296,6 +306,7 @@ mod tests {
|
|||
assert!(build_search_text("King's Cross tube station", "station")
|
||||
.contains("kings cross underground"));
|
||||
assert!(build_search_text("St Albans", "city").contains("saint albans"));
|
||||
assert!(build_search_text("Shadwell DLR station", "station").contains("shadwell station"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ use anyhow::{bail, Context};
|
|||
use polars::frame::DataFrame;
|
||||
use polars::lazy::frame::LazyFrame;
|
||||
use polars::prelude::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use serde::Serialize;
|
||||
use tracing::info;
|
||||
|
||||
|
|
@ -17,6 +18,94 @@ pub struct POICategoryGroup {
|
|||
pub categories: Vec<String>,
|
||||
}
|
||||
|
||||
const GROCERY_DASHBOARD_CATEGORIES: &[&str] = &[
|
||||
"Supermarket",
|
||||
"Convenience Store",
|
||||
"Bakery",
|
||||
"Greengrocer",
|
||||
"Aldi",
|
||||
"Amazon",
|
||||
"Asda",
|
||||
"Booths",
|
||||
"Budgens",
|
||||
"Centra",
|
||||
"Co-op",
|
||||
"COOK",
|
||||
"Costco",
|
||||
"Dunnes Stores",
|
||||
"Farmfoods",
|
||||
"Heron Foods",
|
||||
"Iceland",
|
||||
"Lidl",
|
||||
"Makro",
|
||||
"M&S",
|
||||
"Morrisons",
|
||||
"Planet Organic",
|
||||
"Sainsbury's",
|
||||
"Spar",
|
||||
"Tesco",
|
||||
"The Food Warehouse",
|
||||
"Waitrose",
|
||||
"Whole Foods Market",
|
||||
];
|
||||
|
||||
const DASHBOARD_POI_GROUPS: &[(&str, &[&str])] = &[
|
||||
(
|
||||
"Public Transport",
|
||||
&[
|
||||
"Rail station",
|
||||
"Tube station",
|
||||
"Bus station",
|
||||
"Bus stop",
|
||||
"Airport",
|
||||
],
|
||||
),
|
||||
("Groceries", GROCERY_DASHBOARD_CATEGORIES),
|
||||
("Food & Drink", &["Café", "Restaurant", "Pub", "Fast Food"]),
|
||||
("Green Space", &["Park", "Playground"]),
|
||||
("Education", &["School"]),
|
||||
(
|
||||
"Health",
|
||||
&["GP Surgery", "Pharmacy", "Dentist", "Hospital & Clinic"],
|
||||
),
|
||||
(
|
||||
"Leisure",
|
||||
&[
|
||||
"Gym & Fitness",
|
||||
"Sports Centre",
|
||||
"Cinema",
|
||||
"Theatre",
|
||||
"Library",
|
||||
],
|
||||
),
|
||||
(
|
||||
"Practical",
|
||||
&["Post Office", "Bank", "EV Charging", "Fuel Station"],
|
||||
),
|
||||
];
|
||||
|
||||
fn add_category_filter_index(
|
||||
category_values: &[String],
|
||||
category: &str,
|
||||
selected: &mut FxHashSet<u16>,
|
||||
) {
|
||||
if let Some(pos) = category_values.iter().position(|value| value == category) {
|
||||
selected.insert(pos as u16);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn resolve_poi_category_filter(category_values: &[String], categories: &str) -> FxHashSet<u16> {
|
||||
let mut selected = FxHashSet::default();
|
||||
for part in categories.split(',') {
|
||||
let category = part.trim();
|
||||
if category.is_empty() {
|
||||
continue;
|
||||
}
|
||||
add_category_filter_index(category_values, category, &mut selected);
|
||||
}
|
||||
selected
|
||||
}
|
||||
|
||||
pub struct POIData {
|
||||
/// Contiguous buffer holding all POI ID strings end-to-end.
|
||||
id_buffer: String,
|
||||
|
|
@ -53,13 +142,18 @@ fn extract_str_col(df: &DataFrame, name: &str) -> anyhow::Result<Vec<String>> {
|
|||
let string_column = column
|
||||
.str()
|
||||
.with_context(|| format!("Column '{name}' is not a string column"))?;
|
||||
Ok(string_column
|
||||
string_column
|
||||
.into_iter()
|
||||
.map(|value| value.unwrap_or("").to_string())
|
||||
.collect())
|
||||
.enumerate()
|
||||
.map(|(row, value)| {
|
||||
value
|
||||
.map(ToString::to_string)
|
||||
.with_context(|| format!("Column '{name}' has null at row {row}"))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn extract_f32_col(df: &DataFrame, name: &str, default: f32) -> anyhow::Result<Vec<f32>> {
|
||||
fn extract_f32_col(df: &DataFrame, name: &str) -> anyhow::Result<Vec<f32>> {
|
||||
let column = df
|
||||
.column(name)
|
||||
.with_context(|| format!("Missing column '{name}' in POI data"))?;
|
||||
|
|
@ -69,16 +163,23 @@ fn extract_f32_col(df: &DataFrame, name: &str, default: f32) -> anyhow::Result<V
|
|||
let float_column = cast
|
||||
.f32()
|
||||
.with_context(|| format!("Column '{name}' is not a float32 column"))?;
|
||||
Ok(float_column
|
||||
float_column
|
||||
.into_iter()
|
||||
.map(|value| value.unwrap_or(default))
|
||||
.collect())
|
||||
.enumerate()
|
||||
.map(|(row, value)| value.with_context(|| format!("Column '{name}' has null at row {row}")))
|
||||
.collect()
|
||||
}
|
||||
|
||||
impl POIData {
|
||||
pub fn load(parquet_path: &Path) -> anyhow::Result<Self> {
|
||||
super::run_polars_io(|| Self::load_inner(parquet_path))
|
||||
}
|
||||
|
||||
fn load_inner(parquet_path: &Path) -> anyhow::Result<Self> {
|
||||
info!("Loading POI data from {:?}...", parquet_path);
|
||||
|
||||
let parquet_path = PlRefPath::try_from_path(parquet_path)
|
||||
.context("Failed to normalize POI parquet path")?;
|
||||
let df = LazyFrame::scan_parquet(parquet_path, Default::default())
|
||||
.context("Failed to scan POI parquet")?
|
||||
.collect()
|
||||
|
|
@ -91,18 +192,10 @@ impl POIData {
|
|||
let name = extract_str_col(&df, "name")?;
|
||||
let category_raw = extract_str_col(&df, "category")?;
|
||||
let group_raw = extract_str_col(&df, "group")?;
|
||||
let lat = extract_f32_col(&df, "lat", 0.0)?;
|
||||
let lng = extract_f32_col(&df, "lng", 0.0)?;
|
||||
let lat = extract_f32_col(&df, "lat")?;
|
||||
let lng = extract_f32_col(&df, "lng")?;
|
||||
let emoji_raw = extract_str_col(&df, "emoji")?;
|
||||
let icon_category_raw = if df
|
||||
.get_column_names()
|
||||
.iter()
|
||||
.any(|name| name.as_str() == "icon_category")
|
||||
{
|
||||
extract_str_col(&df, "icon_category")?
|
||||
} else {
|
||||
category_raw.clone()
|
||||
};
|
||||
let icon_category_raw = extract_str_col(&df, "icon_category")?;
|
||||
|
||||
// Pack POI IDs into a contiguous buffer
|
||||
let total_id_bytes: usize = id_raw.iter().map(|s| s.len()).sum();
|
||||
|
|
@ -152,7 +245,7 @@ impl POIData {
|
|||
})
|
||||
}
|
||||
|
||||
/// Build category groups from the loaded POI data, validated against POI_GROUP_ORDER.
|
||||
/// Build dashboard category groups from every category present in the loaded POI data.
|
||||
pub fn category_groups(&self) -> anyhow::Result<Vec<POICategoryGroup>> {
|
||||
let mut group_cats: HashMap<String, HashSet<String>> = HashMap::new();
|
||||
let num_pois = self.category.indices.len();
|
||||
|
|
@ -174,18 +267,78 @@ impl POIData {
|
|||
);
|
||||
}
|
||||
|
||||
POI_GROUP_ORDER
|
||||
let preferred_order: HashMap<&str, HashMap<&str, usize>> = DASHBOARD_POI_GROUPS
|
||||
.iter()
|
||||
.map(|group_name| {
|
||||
let name = group_name.to_string();
|
||||
let mut categories: Vec<String> = group_cats
|
||||
.remove(&name)
|
||||
.context("POI group validated but missing from map")?
|
||||
.into_iter()
|
||||
.collect();
|
||||
categories.sort();
|
||||
Ok(POICategoryGroup { name, categories })
|
||||
.map(|(group, categories)| {
|
||||
(
|
||||
*group,
|
||||
categories
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, category)| (*category, idx))
|
||||
.collect(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
.collect();
|
||||
|
||||
let groups: Vec<POICategoryGroup> = POI_GROUP_ORDER
|
||||
.iter()
|
||||
.filter_map(|group_name| {
|
||||
let mut categories: Vec<String> = group_cats
|
||||
.get(*group_name)
|
||||
.map(|categories| categories.iter().cloned().collect())
|
||||
.unwrap_or_default();
|
||||
if categories.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let group_order = preferred_order.get(*group_name);
|
||||
categories.sort_by(|a, b| {
|
||||
let a_order = group_order.and_then(|order| order.get(a.as_str())).copied();
|
||||
let b_order = group_order.and_then(|order| order.get(b.as_str())).copied();
|
||||
match (a_order, b_order) {
|
||||
(Some(left), Some(right)) => left.cmp(&right),
|
||||
(Some(_), None) => std::cmp::Ordering::Less,
|
||||
(None, Some(_)) => std::cmp::Ordering::Greater,
|
||||
(None, None) => a.cmp(b),
|
||||
}
|
||||
});
|
||||
Some(POICategoryGroup {
|
||||
name: (*group_name).to_string(),
|
||||
categories,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(groups)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn category_filter_matches_exact_present_categories() {
|
||||
let values = vec![
|
||||
"Supermarket".to_string(),
|
||||
"Tesco".to_string(),
|
||||
"Aldi".to_string(),
|
||||
"Rail station".to_string(),
|
||||
];
|
||||
|
||||
let selected = resolve_poi_category_filter(&values, "Supermarket,Rail station");
|
||||
|
||||
assert!(selected.contains(&0));
|
||||
assert!(selected.contains(&3));
|
||||
assert_eq!(selected.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_category_filter_matches_nothing() {
|
||||
let values = vec!["Supermarket".to_string()];
|
||||
|
||||
let selected = resolve_poi_category_filter(&values, "Unknown");
|
||||
|
||||
assert!(selected.is_empty());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -195,33 +195,38 @@ impl PostcodeData {
|
|||
|
||||
// Extract all outer rings from the geometry
|
||||
let rings: Vec<Vec<[f32; 2]>> = match feature.geometry {
|
||||
Geometry::Polygon { coordinates } => coordinates
|
||||
.first()
|
||||
.map(|ring| {
|
||||
vec![ring
|
||||
.iter()
|
||||
.map(|[lon, lat]| [*lon as f32, *lat as f32])
|
||||
.collect()]
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
Geometry::Polygon { coordinates } => {
|
||||
let ring = coordinates.first().with_context(|| {
|
||||
format!("Postcode '{postcode}' polygon has no outer ring")
|
||||
})?;
|
||||
vec![ring
|
||||
.iter()
|
||||
.map(|[lon, lat]| [*lon as f32, *lat as f32])
|
||||
.collect()]
|
||||
}
|
||||
Geometry::MultiPolygon { coordinates } => coordinates
|
||||
.iter()
|
||||
.filter_map(|poly| {
|
||||
poly.first().map(|ring| {
|
||||
ring.iter()
|
||||
.map(|[lon, lat]| [*lon as f32, *lat as f32])
|
||||
.collect()
|
||||
})
|
||||
.enumerate()
|
||||
.map(|(idx, poly)| {
|
||||
let ring = poly.first().with_context(|| {
|
||||
format!(
|
||||
"Postcode '{postcode}' multipolygon part {idx} has no outer ring"
|
||||
)
|
||||
})?;
|
||||
Ok(ring
|
||||
.iter()
|
||||
.map(|[lon, lat]| [*lon as f32, *lat as f32])
|
||||
.collect())
|
||||
})
|
||||
.collect(),
|
||||
.collect::<anyhow::Result<Vec<_>>>()?,
|
||||
};
|
||||
|
||||
// Compute centroid across all vertices from all rings
|
||||
let total_vertices: usize = rings.iter().map(|ring| ring.len()).sum();
|
||||
let centroid = if total_vertices == 0 {
|
||||
tracing::warn!(postcode = %postcode, "Postcode polygon has zero vertices, defaulting centroid to (0,0)");
|
||||
(0.0, 0.0)
|
||||
} else {
|
||||
if total_vertices == 0 {
|
||||
anyhow::bail!("Postcode '{postcode}' polygon has zero vertices");
|
||||
}
|
||||
let centroid = {
|
||||
let mut sum_lat: f32 = 0.0;
|
||||
let mut sum_lon: f32 = 0.0;
|
||||
for ring in &rings {
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ const ADDRESS_SEARCH_CANDIDATE_LIMIT: usize = 50_000;
|
|||
const ADDRESS_SEARCH_MAX_POSTINGS_PER_TOKEN: usize = 250_000;
|
||||
const ADDRESS_SEARCH_PREFIX_MIN_LEN: usize = 4;
|
||||
const ADDRESS_SEARCH_PREFIX_MAX_LEN: usize = 8;
|
||||
const NO_POI_METRIC_ROW: u32 = u32::MAX;
|
||||
|
||||
fn is_numeric_dtype(dtype: &DataType) -> bool {
|
||||
matches!(
|
||||
|
|
@ -495,6 +496,187 @@ impl QuantRef<'_> {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct PostcodePoiMetrics {
|
||||
pub feature_names: Vec<String>,
|
||||
pub name_to_index: FxHashMap<String, usize>,
|
||||
/// Metric-major storage: columns[metric_idx][postcode_metric_idx].
|
||||
pub columns: Vec<Vec<u16>>,
|
||||
pub feature_stats: Vec<FeatureStats>,
|
||||
/// Per-property row lookup into the postcode metric table.
|
||||
row_to_metric_idx: Vec<u32>,
|
||||
dequant_a: Vec<f32>,
|
||||
quant_min: Vec<f32>,
|
||||
quant_range: Vec<f32>,
|
||||
}
|
||||
|
||||
impl PostcodePoiMetrics {
|
||||
fn empty(row_count: usize) -> Self {
|
||||
Self {
|
||||
feature_names: Vec::new(),
|
||||
name_to_index: FxHashMap::default(),
|
||||
columns: Vec::new(),
|
||||
feature_stats: Vec::new(),
|
||||
row_to_metric_idx: vec![NO_POI_METRIC_ROW; row_count],
|
||||
dequant_a: Vec::new(),
|
||||
quant_min: Vec::new(),
|
||||
quant_range: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn from_postcode_df(df: &DataFrame, feature_names: Vec<String>) -> anyhow::Result<Self> {
|
||||
if feature_names.is_empty() {
|
||||
return Ok(Self::empty(0));
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
metrics = feature_names.len(),
|
||||
postcodes = df.height(),
|
||||
"Building postcode POI metric side table"
|
||||
);
|
||||
|
||||
let col_major: Vec<Vec<f32>> = feature_names
|
||||
.par_iter()
|
||||
.map(|name| {
|
||||
let column = df
|
||||
.column(name.as_str())
|
||||
.with_context(|| format!("Missing POI metric column '{name}'"))?;
|
||||
column_to_f32_vec(column)
|
||||
})
|
||||
.collect::<anyhow::Result<Vec<_>>>()?;
|
||||
|
||||
let feature_stats: Vec<FeatureStats> = col_major
|
||||
.par_iter()
|
||||
.enumerate()
|
||||
.map(|(metric_idx, vals)| {
|
||||
let name = feature_names[metric_idx].as_str();
|
||||
let bounds = features::bounds_for(name)
|
||||
.with_context(|| format!("No bounds config for POI metric '{name}'"))?;
|
||||
Ok(compute_feature_stats(
|
||||
vals,
|
||||
&bounds,
|
||||
features::has_integer_bins(name),
|
||||
))
|
||||
})
|
||||
.collect::<anyhow::Result<Vec<_>>>()?;
|
||||
|
||||
let mut quant_min = Vec::with_capacity(feature_names.len());
|
||||
let mut quant_range = Vec::with_capacity(feature_names.len());
|
||||
for (metric_idx, stats) in feature_stats.iter().enumerate() {
|
||||
let (min, max) = match features::bounds_for(feature_names[metric_idx].as_str()) {
|
||||
Some(Bounds::Fixed { min, max }) => (min, max),
|
||||
_ => (stats.histogram.min, stats.histogram.max),
|
||||
};
|
||||
quant_min.push(min);
|
||||
quant_range.push(if max > min { max - min } else { 0.0 });
|
||||
}
|
||||
let dequant_a: Vec<f32> = quant_range
|
||||
.iter()
|
||||
.map(|&range| {
|
||||
if range > 0.0 {
|
||||
range / QUANT_SCALE
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let columns: Vec<Vec<u16>> = col_major
|
||||
.par_iter()
|
||||
.enumerate()
|
||||
.map(|(metric_idx, vals)| {
|
||||
let range = quant_range[metric_idx];
|
||||
let min = quant_min[metric_idx];
|
||||
vals.iter()
|
||||
.map(|&value| {
|
||||
if !value.is_finite() {
|
||||
NAN_U16
|
||||
} else if range > 0.0 {
|
||||
let normalized = (value - min) / range;
|
||||
(normalized * QUANT_SCALE).round().clamp(0.0, QUANT_SCALE) as u16
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let name_to_index = feature_names
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, name)| (name.clone(), idx))
|
||||
.collect();
|
||||
|
||||
Ok(Self {
|
||||
feature_names,
|
||||
name_to_index,
|
||||
columns,
|
||||
feature_stats,
|
||||
row_to_metric_idx: Vec::new(),
|
||||
dequant_a,
|
||||
quant_min,
|
||||
quant_range,
|
||||
})
|
||||
}
|
||||
|
||||
fn set_row_mapping(&mut self, row_to_metric_idx: Vec<u32>) {
|
||||
self.row_to_metric_idx = row_to_metric_idx;
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.feature_names.is_empty()
|
||||
}
|
||||
|
||||
pub fn num_features(&self) -> usize {
|
||||
self.feature_names.len()
|
||||
}
|
||||
|
||||
pub fn quant_ref(&self) -> QuantRef<'_> {
|
||||
QuantRef {
|
||||
dequant_a: &self.dequant_a,
|
||||
quant_min: &self.quant_min,
|
||||
quant_range: &self.quant_range,
|
||||
num_numeric: self.feature_names.len(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn metric_row_for_property(&self, row: usize) -> Option<usize> {
|
||||
self.row_to_metric_idx
|
||||
.get(row)
|
||||
.copied()
|
||||
.filter(|&idx| idx != NO_POI_METRIC_ROW)
|
||||
.map(|idx| idx as usize)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn raw_for_metric_row(&self, metric_row: usize, metric_idx: usize) -> u16 {
|
||||
self.columns[metric_idx][metric_row]
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn raw_for_property_row(&self, row: usize, metric_idx: usize) -> u16 {
|
||||
let Some(metric_row) = self.metric_row_for_property(row) else {
|
||||
return NAN_U16;
|
||||
};
|
||||
self.raw_for_metric_row(metric_row, metric_idx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn decode_raw(&self, metric_idx: usize, raw: u16) -> f32 {
|
||||
if raw == NAN_U16 {
|
||||
f32::NAN
|
||||
} else {
|
||||
raw as f32 * self.dequant_a[metric_idx] + self.quant_min[metric_idx]
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn get_for_property_row(&self, row: usize, metric_idx: usize) -> f32 {
|
||||
self.decode_raw(metric_idx, self.raw_for_property_row(row, metric_idx))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PropertyData {
|
||||
pub lat: Vec<f32>,
|
||||
pub lon: Vec<f32>,
|
||||
|
|
@ -514,6 +696,7 @@ pub struct PropertyData {
|
|||
/// Per-feature: max - min (for encoding filter bounds).
|
||||
quant_range: Vec<f32>,
|
||||
pub feature_stats: Vec<FeatureStats>,
|
||||
pub poi_metrics: PostcodePoiMetrics,
|
||||
/// Unquantized last sale price used by the price-history chart.
|
||||
last_known_price_raw: Vec<f32>,
|
||||
/// Contiguous buffer holding all address strings end-to-end.
|
||||
|
|
@ -1055,19 +1238,54 @@ pub fn precompute_h3(lat: &[f32], lon: &[f32]) -> anyhow::Result<Vec<u64>> {
|
|||
|
||||
impl PropertyData {
|
||||
pub fn load(properties_path: &Path, postcode_features_path: &Path) -> anyhow::Result<Self> {
|
||||
super::run_polars_io(|| Self::load_inner(properties_path, postcode_features_path))
|
||||
}
|
||||
|
||||
fn load_inner(properties_path: &Path, postcode_features_path: &Path) -> anyhow::Result<Self> {
|
||||
// Load postcode.parquet
|
||||
tracing::info!(
|
||||
"Loading postcode features from {:?}",
|
||||
postcode_features_path
|
||||
);
|
||||
let postcode_features_path = PlRefPath::try_from_path(postcode_features_path)
|
||||
.context("Failed to normalize postcode parquet path")?;
|
||||
let postcode_df = LazyFrame::scan_parquet(postcode_features_path, Default::default())
|
||||
.context("Failed to scan postcode parquet")?
|
||||
.collect()
|
||||
.context("Failed to read postcode parquet")?;
|
||||
tracing::info!(rows = postcode_df.height(), "Postcode features loaded");
|
||||
|
||||
let mut poi_metric_names: Vec<String> = postcode_df
|
||||
.get_column_names()
|
||||
.iter()
|
||||
.map(|name| name.as_str())
|
||||
.filter(|&name| features::is_dynamic_poi_feature(name))
|
||||
.map(str::to_string)
|
||||
.collect();
|
||||
poi_metric_names.sort_by_key(|name| features::dynamic_poi_feature_sort_key(name));
|
||||
|
||||
let poi_metric_by_postcode: FxHashMap<String, u32> = if poi_metric_names.is_empty() {
|
||||
FxHashMap::default()
|
||||
} else {
|
||||
let postcode_column = postcode_df
|
||||
.column("Postcode")
|
||||
.context("Postcode feature parquet missing 'Postcode' column")?
|
||||
.str()
|
||||
.context("'Postcode' column in postcode feature parquet is not a string")?;
|
||||
postcode_column
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, postcode)| {
|
||||
postcode.map(|postcode| (postcode.to_string(), idx as u32))
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
let mut poi_metrics = PostcodePoiMetrics::from_postcode_df(&postcode_df, poi_metric_names)?;
|
||||
|
||||
// Load properties.parquet and join with postcode data for lat/lon + area features
|
||||
tracing::info!("Loading properties from {:?}", properties_path);
|
||||
let properties_path = PlRefPath::try_from_path(properties_path)
|
||||
.context("Failed to normalize properties parquet path")?;
|
||||
let properties_lf = LazyFrame::scan_parquet(properties_path, Default::default())
|
||||
.context("Failed to scan properties parquet")?;
|
||||
let combined = properties_lf
|
||||
|
|
@ -1082,14 +1300,20 @@ impl PropertyData {
|
|||
let total_rows = combined.height();
|
||||
tracing::info!(rows = total_rows, "Properties joined with postcodes");
|
||||
|
||||
// Get configured feature/enum names in config order
|
||||
let numeric_names = features::all_numeric_feature_names();
|
||||
// Get configured feature/enum names in config order. Dynamic POI
|
||||
// metrics live in a postcode-level side table so they do not widen the
|
||||
// hot row-major property feature matrix.
|
||||
let configured_numeric_names = features::all_numeric_feature_names();
|
||||
let enum_names = features::all_enum_feature_names();
|
||||
|
||||
let schema = combined.schema();
|
||||
let numeric_names: Vec<String> = configured_numeric_names
|
||||
.iter()
|
||||
.map(|name| (*name).to_string())
|
||||
.collect();
|
||||
|
||||
for name in &numeric_names {
|
||||
match schema.get(name) {
|
||||
match schema.get(name.as_str()) {
|
||||
Some(dtype) if is_numeric_dtype(dtype) => {}
|
||||
Some(dtype) => bail!(
|
||||
"Configured numeric feature '{}' has non-numeric type {:?}",
|
||||
|
|
@ -1120,8 +1344,8 @@ impl PropertyData {
|
|||
// Combine numeric and enum feature names (numeric first, then enum)
|
||||
let feature_names: Vec<String> = numeric_names
|
||||
.iter()
|
||||
.chain(enum_names.iter())
|
||||
.map(|name| name.to_string())
|
||||
.chain(enum_names.iter().map(|name| name.to_string()))
|
||||
.collect();
|
||||
let num_features = feature_names.len();
|
||||
let num_numeric = numeric_names.len();
|
||||
|
|
@ -1138,16 +1362,16 @@ impl PropertyData {
|
|||
select_exprs.push(col("lon").cast(DataType::Float32));
|
||||
|
||||
// Select numeric features as Float32 (datetime columns → fractional year)
|
||||
for &name in &numeric_names {
|
||||
if is_datetime_dtype(schema.get(name).unwrap()) {
|
||||
for name in &numeric_names {
|
||||
if is_datetime_dtype(schema.get(name.as_str()).unwrap()) {
|
||||
select_exprs.push(
|
||||
(col(name).dt().year().cast(DataType::Float32)
|
||||
+ (col(name).dt().month().cast(DataType::Float32) - lit(1.0f32))
|
||||
(col(name.as_str()).dt().year().cast(DataType::Float32)
|
||||
+ (col(name.as_str()).dt().month().cast(DataType::Float32) - lit(1.0f32))
|
||||
/ lit(12.0f32))
|
||||
.alias(name),
|
||||
.alias(name.as_str()),
|
||||
);
|
||||
} else {
|
||||
select_exprs.push(col(name).cast(DataType::Float32));
|
||||
select_exprs.push(col(name.as_str()).cast(DataType::Float32));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1233,7 +1457,7 @@ impl PropertyData {
|
|||
.par_iter()
|
||||
.map(|name| {
|
||||
let column = df
|
||||
.column(name)
|
||||
.column(name.as_str())
|
||||
.with_context(|| format!("Missing feature column '{name}'"))?;
|
||||
column_to_f32_vec(column)
|
||||
})
|
||||
|
|
@ -1244,10 +1468,10 @@ impl PropertyData {
|
|||
.par_iter()
|
||||
.enumerate()
|
||||
.map(|(feat_index, vals)| {
|
||||
let name = numeric_names[feat_index];
|
||||
let name = numeric_names[feat_index].as_str();
|
||||
let bounds = features::bounds_for(name)
|
||||
.with_context(|| format!("No bounds config for feature '{}'", name))?;
|
||||
let stats = compute_feature_stats(vals, bounds, features::has_integer_bins(name));
|
||||
let stats = compute_feature_stats(vals, &bounds, features::has_integer_bins(name));
|
||||
tracing::debug!(
|
||||
feature = %name,
|
||||
slider_min = format_args!("{:.2}", stats.slider_min),
|
||||
|
|
@ -1268,8 +1492,8 @@ impl PropertyData {
|
|||
let mut quant_min = Vec::with_capacity(num_features);
|
||||
let mut quant_range = Vec::with_capacity(num_features);
|
||||
for (feat_idx, stats) in numeric_feature_stats.iter().enumerate() {
|
||||
let (min, max) = match features::bounds_for(numeric_names[feat_idx]) {
|
||||
Some(Bounds::Fixed { min, max }) => (*min, *max),
|
||||
let (min, max) = match features::bounds_for(numeric_names[feat_idx].as_str()) {
|
||||
Some(Bounds::Fixed { min, max }) => (min, max),
|
||||
_ => (stats.histogram.min, stats.histogram.max),
|
||||
};
|
||||
quant_min.push(min);
|
||||
|
|
@ -1284,10 +1508,15 @@ impl PropertyData {
|
|||
let string_column = column
|
||||
.str()
|
||||
.with_context(|| format!("Column '{name}' is not a string column"))?;
|
||||
Ok(string_column
|
||||
string_column
|
||||
.into_iter()
|
||||
.map(|value| value.unwrap_or("").to_string())
|
||||
.collect())
|
||||
.enumerate()
|
||||
.map(|(row, value)| {
|
||||
value
|
||||
.map(ToString::to_string)
|
||||
.with_context(|| format!("Required column '{name}' has null at row {row}"))
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
|
||||
let address_raw = extract_string_col(&df, "Address per Property Register")?;
|
||||
|
|
@ -1325,18 +1554,18 @@ impl PropertyData {
|
|||
// enum_col_major: Vec<(values_list, encoded_as_f32)>
|
||||
let enum_col_major: Vec<(Vec<String>, Vec<f32>)> = enum_names
|
||||
.par_iter()
|
||||
.filter_map(|&name| {
|
||||
let column_data = df.column(name).ok()?;
|
||||
let string_column = column_data.str().ok()?;
|
||||
.map(|&name| -> anyhow::Result<(Vec<String>, Vec<f32>)> {
|
||||
let column_data = df
|
||||
.column(name)
|
||||
.with_context(|| format!("Required enum column '{name}' not found"))?;
|
||||
let string_column = column_data
|
||||
.str()
|
||||
.with_context(|| format!("Enum column '{name}' is not a string column"))?;
|
||||
let unique_set: std::collections::HashSet<String> = string_column
|
||||
.into_iter()
|
||||
.filter_map(|value| {
|
||||
let text = value.unwrap_or("");
|
||||
if text.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(text.to_string())
|
||||
}
|
||||
let text = value?.trim();
|
||||
(!text.is_empty()).then(|| text.to_string())
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
|
@ -1373,20 +1602,22 @@ impl PropertyData {
|
|||
|
||||
let encoded: Vec<f32> = string_column
|
||||
.into_iter()
|
||||
.map(|value| {
|
||||
let text = value.unwrap_or("");
|
||||
if text.is_empty() {
|
||||
f32::NAN
|
||||
} else {
|
||||
*value_to_idx.get(text).unwrap_or(&f32::NAN)
|
||||
}
|
||||
.enumerate()
|
||||
.map(|(row, value)| {
|
||||
let Some(text) = value.map(str::trim).filter(|text| !text.is_empty())
|
||||
else {
|
||||
return Ok(f32::NAN);
|
||||
};
|
||||
value_to_idx.get(text).copied().with_context(|| {
|
||||
format!("Enum column '{name}' has unknown value '{text}' at row {row}")
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
.collect::<anyhow::Result<Vec<_>>>()?;
|
||||
|
||||
tracing::debug!(column = %name, unique_values = unique.len(), "Enum feature encoded as f32");
|
||||
Some((unique, encoded))
|
||||
Ok((unique, encoded))
|
||||
})
|
||||
.collect();
|
||||
.collect::<anyhow::Result<Vec<_>>>()?;
|
||||
|
||||
// Extract is_approx_build_date: 0.0 = exact, anything else (1.0/NaN) = approximate
|
||||
let is_approx_build_date_raw: Vec<bool> = if has_approx_col {
|
||||
|
|
@ -1487,13 +1718,13 @@ impl PropertyData {
|
|||
.collect();
|
||||
let last_known_price_raw: Vec<f32> = numeric_names
|
||||
.iter()
|
||||
.position(|&name| name == "Last known price")
|
||||
.position(|name| name == "Last known price")
|
||||
.map(|price_idx| {
|
||||
perm.iter()
|
||||
.map(|&perm_index| numeric_col_major[price_idx][perm_index as usize])
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_else(|| vec![f32::NAN; row_count]);
|
||||
.context("Required numeric column 'Last known price' not configured")?;
|
||||
|
||||
// Build contiguous address buffer and address search index (permuted)
|
||||
tracing::info!("Building interned strings");
|
||||
|
|
@ -1561,6 +1792,20 @@ impl PropertyData {
|
|||
}
|
||||
let postcode_interner = postcode_rodeo.into_reader();
|
||||
|
||||
let row_to_poi_metric_idx: Vec<u32> = if poi_metrics.is_empty() {
|
||||
vec![NO_POI_METRIC_ROW; row_count]
|
||||
} else {
|
||||
perm.iter()
|
||||
.map(|&old_row| {
|
||||
poi_metric_by_postcode
|
||||
.get(postcode_raw[old_row as usize].as_str())
|
||||
.copied()
|
||||
.unwrap_or(NO_POI_METRIC_ROW)
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
poi_metrics.set_row_mapping(row_to_poi_metric_idx);
|
||||
|
||||
// Pack is_approx_build_date into a bitvec (8 bools per byte)
|
||||
let num_bytes = row_count.div_ceil(8);
|
||||
let mut approx_build_date_bits = vec![0u8; num_bytes];
|
||||
|
|
@ -1697,6 +1942,7 @@ impl PropertyData {
|
|||
quant_min,
|
||||
quant_range,
|
||||
feature_stats,
|
||||
poi_metrics,
|
||||
last_known_price_raw,
|
||||
address_buffer,
|
||||
address_offsets,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ use std::sync::Arc;
|
|||
use anyhow::Context;
|
||||
use parking_lot::Mutex;
|
||||
use polars::lazy::frame::LazyFrame;
|
||||
use polars::prelude::PlRefPath;
|
||||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
use tracing::info;
|
||||
|
||||
|
|
@ -155,15 +156,23 @@ impl TravelTimeStore {
|
|||
/// Returns a cached or freshly-loaded postcode → travel_minutes mapping.
|
||||
pub fn get(&self, mode: &str, slug: &str) -> anyhow::Result<TravelData> {
|
||||
let key = (mode.to_string(), slug.to_string());
|
||||
|
||||
// Check cache first
|
||||
{
|
||||
let mut cache = self.cache.lock();
|
||||
if let Some(data) = cache.get(&key) {
|
||||
return Ok(data);
|
||||
}
|
||||
if let Some(data) = self.get_cached(&key) {
|
||||
return Ok(data);
|
||||
}
|
||||
|
||||
super::run_polars_io(|| self.load_uncached(key))
|
||||
}
|
||||
|
||||
fn get_cached(&self, key: &(String, String)) -> Option<TravelData> {
|
||||
let mut cache = self.cache.lock();
|
||||
cache.get(key)
|
||||
}
|
||||
|
||||
fn load_uncached(&self, key: (String, String)) -> anyhow::Result<TravelData> {
|
||||
if let Some(data) = self.get_cached(&key) {
|
||||
return Ok(data);
|
||||
}
|
||||
let (mode, slug) = &key;
|
||||
// Resolve slug to actual filename (may have numeric prefix).
|
||||
// Reject unknown slugs rather than falling back to raw input to prevent path traversal.
|
||||
let file_stem = self
|
||||
|
|
@ -175,7 +184,9 @@ impl TravelTimeStore {
|
|||
.join(mode)
|
||||
.join(format!("{}.parquet", file_stem));
|
||||
|
||||
let df = LazyFrame::scan_parquet(&path, Default::default())
|
||||
let parquet_path = PlRefPath::try_from_path(&path)
|
||||
.with_context(|| format!("Failed to normalize path: {}", path.display()))?;
|
||||
let df = LazyFrame::scan_parquet(parquet_path, Default::default())
|
||||
.with_context(|| format!("Failed to scan: {}", path.display()))?
|
||||
.collect()
|
||||
.with_context(|| format!("Failed to read: {}", path.display()))?;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
//! Static feature configuration. Every numeric and enum column in wide.parquet
|
||||
//! must be declared here. Unknown columns cause a startup panic.
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum Bounds {
|
||||
/// Fixed min/max values for the slider
|
||||
Fixed { min: f32, max: f32 },
|
||||
|
|
@ -61,6 +62,26 @@ pub struct FeatureGroup {
|
|||
}
|
||||
|
||||
pub static FEATURE_GROUPS: &[FeatureGroup] = &[
|
||||
FeatureGroup {
|
||||
name: "Transport",
|
||||
features: &[
|
||||
Feature::Numeric(FeatureConfig {
|
||||
name: "Distance to nearest train or tube station (km)",
|
||||
bounds: Bounds::Percentile {
|
||||
low: 2.0,
|
||||
high: 98.0,
|
||||
},
|
||||
step: 0.1,
|
||||
description: "Distance to the closest train or tube station",
|
||||
detail: "Straight-line distance in kilometres from the postcode to the nearest rail station or Tube/metro/tram stop.",
|
||||
source: "naptan",
|
||||
prefix: "",
|
||||
suffix: " km",
|
||||
raw: false,
|
||||
absolute: false,
|
||||
}),
|
||||
],
|
||||
},
|
||||
FeatureGroup {
|
||||
name: "Properties",
|
||||
features: &[
|
||||
|
|
@ -78,6 +99,21 @@ pub static FEATURE_GROUPS: &[FeatureGroup] = &[
|
|||
detail: "From HM Land Registry Price Paid data. Freehold means you own the building and the land it stands on. Leasehold means you own the building but not the land: you have a lease from the freeholder for a set number of years.",
|
||||
source: "price-paid",
|
||||
}),
|
||||
Feature::Numeric(FeatureConfig {
|
||||
name: "Estimated current price",
|
||||
bounds: Bounds::Fixed {
|
||||
min: 0.0,
|
||||
max: 2_500_000.0,
|
||||
},
|
||||
step: 10000.0,
|
||||
description: "Modelled estimate of the current property value",
|
||||
detail: "Based on the last sale price, local repeat-sales price movement, and nearby recently sold properties. The repeat-sales index is tracked by postcode sector and property type, with smoothing and neighbour blending where data is sparse. Recent sales stay close to the recorded price; older sales depend more on the model.",
|
||||
source: "price-paid",
|
||||
prefix: "£",
|
||||
suffix: "",
|
||||
raw: false,
|
||||
absolute: true,
|
||||
}),
|
||||
Feature::Numeric(FeatureConfig {
|
||||
name: "Last known price",
|
||||
bounds: Bounds::Fixed {
|
||||
|
|
@ -94,19 +130,19 @@ pub static FEATURE_GROUPS: &[FeatureGroup] = &[
|
|||
absolute: true,
|
||||
}),
|
||||
Feature::Numeric(FeatureConfig {
|
||||
name: "Estimated current price",
|
||||
bounds: Bounds::Fixed {
|
||||
min: 0.0,
|
||||
max: 2_500_000.0,
|
||||
name: "Est. price per sqm",
|
||||
bounds: Bounds::Percentile {
|
||||
low: 0.0,
|
||||
high: 98.0,
|
||||
},
|
||||
step: 10000.0,
|
||||
description: "Inflation-adjusted estimate of the current property value",
|
||||
detail: "Based on the last sale price, adjusted for local price changes over time using a repeat-sales index (tracked per postcode sector and property type). If post-sale improvements are detected from EPC records, a renovation premium is added. Recent sales will be close to the original price; older sales are adjusted more.",
|
||||
step: 100.0,
|
||||
description: "Estimated current price divided by total floor area",
|
||||
detail: "Calculated by dividing the modelled estimated current price by the total floor area from the EPC certificate. Provides a more up-to-date price-per-area comparison than the historical sale price per sqm.",
|
||||
source: "price-paid",
|
||||
prefix: "£",
|
||||
suffix: "",
|
||||
raw: false,
|
||||
absolute: true,
|
||||
absolute: false,
|
||||
}),
|
||||
Feature::Numeric(FeatureConfig {
|
||||
name: "Price per sqm",
|
||||
|
|
@ -123,21 +159,6 @@ pub static FEATURE_GROUPS: &[FeatureGroup] = &[
|
|||
raw: false,
|
||||
absolute: false,
|
||||
}),
|
||||
Feature::Numeric(FeatureConfig {
|
||||
name: "Est. price per sqm",
|
||||
bounds: Bounds::Percentile {
|
||||
low: 0.0,
|
||||
high: 98.0,
|
||||
},
|
||||
step: 100.0,
|
||||
description: "Estimated current price divided by total floor area",
|
||||
detail: "Calculated by dividing the inflation-adjusted estimated current price (including any renovation premium) by the total floor area from the EPC certificate. Provides a more up-to-date price-per-area comparison than the historical sale price per sqm.",
|
||||
source: "price-paid",
|
||||
prefix: "£",
|
||||
suffix: "",
|
||||
raw: false,
|
||||
absolute: false,
|
||||
}),
|
||||
Feature::Numeric(FeatureConfig {
|
||||
name: "Estimated monthly rent",
|
||||
bounds: Bounds::Percentile { low: 2.0, high: 98.0 },
|
||||
|
|
@ -248,26 +269,6 @@ pub static FEATURE_GROUPS: &[FeatureGroup] = &[
|
|||
}),
|
||||
],
|
||||
},
|
||||
FeatureGroup {
|
||||
name: "Transport",
|
||||
features: &[
|
||||
Feature::Numeric(FeatureConfig {
|
||||
name: "Distance to nearest train or tube station (km)",
|
||||
bounds: Bounds::Percentile {
|
||||
low: 2.0,
|
||||
high: 98.0,
|
||||
},
|
||||
step: 0.1,
|
||||
description: "Distance to the closest train or tube station",
|
||||
detail: "Straight-line distance in kilometres from the postcode to the nearest rail station or Tube/metro/tram stop.",
|
||||
source: "naptan",
|
||||
prefix: "",
|
||||
suffix: " km",
|
||||
raw: false,
|
||||
absolute: false,
|
||||
}),
|
||||
],
|
||||
},
|
||||
FeatureGroup {
|
||||
name: "Education",
|
||||
features: &[
|
||||
|
|
@ -393,18 +394,18 @@ pub static FEATURE_GROUPS: &[FeatureGroup] = &[
|
|||
}),
|
||||
Feature::Numeric(FeatureConfig {
|
||||
name: "Education, Skills and Training Score",
|
||||
bounds: Bounds::Percentile {
|
||||
low: 2.0,
|
||||
high: 98.0,
|
||||
bounds: Bounds::Fixed {
|
||||
min: 0.0,
|
||||
max: 100.0,
|
||||
},
|
||||
step: 0.1,
|
||||
description: "Education quality score for the local area (higher = better)",
|
||||
detail: "From the English Indices of Deprivation (inverted so higher = better). Covers school attainment, entry to higher education, adult qualifications, and English language proficiency. Higher scores indicate less deprivation.",
|
||||
step: 1.0,
|
||||
description: "Education and skills deprivation percentile (higher = less deprived)",
|
||||
detail: "From the English Indices of Deprivation, converted to a national percentile where 0% is most deprived and 100% is least deprived. Covers school attainment, entry to higher education, adult qualifications, and English language proficiency.",
|
||||
source: "iod",
|
||||
prefix: "",
|
||||
suffix: "",
|
||||
raw: false,
|
||||
absolute: false,
|
||||
suffix: "%",
|
||||
raw: true,
|
||||
absolute: true,
|
||||
}),
|
||||
],
|
||||
},
|
||||
|
|
@ -413,72 +414,78 @@ pub static FEATURE_GROUPS: &[FeatureGroup] = &[
|
|||
features: &[
|
||||
Feature::Numeric(FeatureConfig {
|
||||
name: "Income Score",
|
||||
bounds: Bounds::Fixed { min: 0.0, max: 1.0 },
|
||||
step: 0.01,
|
||||
description: "Income deprivation rate, inverted (higher = less deprived)",
|
||||
detail: "From the English Indices of Deprivation (inverted so higher = better). Higher values indicate less income deprivation. Based on Income Support, income-based Jobseeker's Allowance, income-based Employment and Support Allowance, Pension Credit, Working Tax Credit and Child Tax Credit, Universal Credit, and asylum seekers.",
|
||||
bounds: Bounds::Fixed {
|
||||
min: 0.0,
|
||||
max: 100.0,
|
||||
},
|
||||
step: 1.0,
|
||||
description: "Income deprivation percentile (higher = less deprived)",
|
||||
detail: "From the English Indices of Deprivation, converted to a national percentile where 0% is most income deprived and 100% is least income deprived. Based on Income Support, income-based Jobseeker's Allowance, income-based Employment and Support Allowance, Pension Credit, Working Tax Credit and Child Tax Credit, Universal Credit, and asylum seekers.",
|
||||
source: "iod",
|
||||
prefix: "",
|
||||
suffix: "",
|
||||
raw: false,
|
||||
absolute: false,
|
||||
suffix: "%",
|
||||
raw: true,
|
||||
absolute: true,
|
||||
}),
|
||||
Feature::Numeric(FeatureConfig {
|
||||
name: "Employment Score",
|
||||
bounds: Bounds::Fixed { min: 0.0, max: 1.0 },
|
||||
step: 0.01,
|
||||
description: "Employment deprivation rate, inverted (higher = less deprived)",
|
||||
detail: "From the English Indices of Deprivation (inverted so higher = better). Higher values indicate less employment deprivation. Based on claimants of Jobseeker's Allowance, Employment and Support Allowance, Incapacity Benefit, Severe Disablement Allowance, Carer's Allowance, and relevant Universal Credit claimants.",
|
||||
bounds: Bounds::Fixed {
|
||||
min: 0.0,
|
||||
max: 100.0,
|
||||
},
|
||||
step: 1.0,
|
||||
description: "Employment deprivation percentile (higher = less deprived)",
|
||||
detail: "From the English Indices of Deprivation, converted to a national percentile where 0% is most employment deprived and 100% is least employment deprived. Based on claimants of Jobseeker's Allowance, Employment and Support Allowance, Incapacity Benefit, Severe Disablement Allowance, Carer's Allowance, and relevant Universal Credit claimants.",
|
||||
source: "iod",
|
||||
prefix: "",
|
||||
suffix: "",
|
||||
raw: false,
|
||||
absolute: false,
|
||||
suffix: "%",
|
||||
raw: true,
|
||||
absolute: true,
|
||||
}),
|
||||
Feature::Numeric(FeatureConfig {
|
||||
name: "Health Deprivation and Disability Score",
|
||||
bounds: Bounds::Percentile {
|
||||
low: 2.0,
|
||||
high: 98.0,
|
||||
bounds: Bounds::Fixed {
|
||||
min: 0.0,
|
||||
max: 100.0,
|
||||
},
|
||||
step: 0.1,
|
||||
description: "Health and disability score (higher = better health outcomes)",
|
||||
detail: "From the English Indices of Deprivation (inverted so higher = better). Higher scores indicate lower risk of premature death and better quality of life. Derived from years of potential life lost, comparative illness and disability ratio, acute morbidity, and mood and anxiety disorders.",
|
||||
step: 1.0,
|
||||
description: "Health and disability deprivation percentile (higher = better outcomes)",
|
||||
detail: "From the English Indices of Deprivation, converted to a national percentile where 0% is most health deprived and 100% is least health deprived. Derived from years of potential life lost, comparative illness and disability ratio, acute morbidity, and mood and anxiety disorders.",
|
||||
source: "iod",
|
||||
prefix: "",
|
||||
suffix: "",
|
||||
raw: false,
|
||||
absolute: false,
|
||||
suffix: "%",
|
||||
raw: true,
|
||||
absolute: true,
|
||||
}),
|
||||
Feature::Numeric(FeatureConfig {
|
||||
name: "Housing Conditions Score",
|
||||
bounds: Bounds::Percentile {
|
||||
low: 2.0,
|
||||
high: 98.0,
|
||||
bounds: Bounds::Fixed {
|
||||
min: 0.0,
|
||||
max: 100.0,
|
||||
},
|
||||
step: 0.1,
|
||||
description: "Housing quality and conditions (higher = better)",
|
||||
detail: "From the English Indices of Deprivation, Living Environment domain (inverted so higher = better). Measures the quality of housing stock: central heating availability, housing condition, and Decent Homes standards. Higher scores indicate better housing conditions.",
|
||||
step: 1.0,
|
||||
description: "Housing conditions percentile (higher = better conditions)",
|
||||
detail: "From the English Indices of Deprivation, Living Environment domain, converted to a national percentile where 0% is most deprived and 100% is least deprived. Measures the quality of housing stock: central heating availability, housing condition, and Decent Homes standards.",
|
||||
source: "iod",
|
||||
prefix: "",
|
||||
suffix: "",
|
||||
raw: false,
|
||||
absolute: false,
|
||||
suffix: "%",
|
||||
raw: true,
|
||||
absolute: true,
|
||||
}),
|
||||
Feature::Numeric(FeatureConfig {
|
||||
name: "Air Quality and Road Safety Score",
|
||||
bounds: Bounds::Percentile {
|
||||
low: 2.0,
|
||||
high: 98.0,
|
||||
bounds: Bounds::Fixed {
|
||||
min: 0.0,
|
||||
max: 100.0,
|
||||
},
|
||||
step: 0.1,
|
||||
description: "Air quality and road safety (higher = better)",
|
||||
detail: "From the English Indices of Deprivation, Living Environment domain (inverted so higher = better). Measures the outdoor living environment quality through air quality indicators and road traffic accident casualties involving pedestrians and cyclists. Higher scores indicate better outdoor environments.",
|
||||
step: 1.0,
|
||||
description: "Air quality and road safety percentile (higher = better conditions)",
|
||||
detail: "From the English Indices of Deprivation, Living Environment domain, converted to a national percentile where 0% is most deprived and 100% is least deprived. Measures the outdoor living environment through air quality indicators and road traffic accident casualties involving pedestrians and cyclists.",
|
||||
source: "iod",
|
||||
prefix: "",
|
||||
suffix: "",
|
||||
raw: false,
|
||||
absolute: false,
|
||||
suffix: "%",
|
||||
raw: true,
|
||||
absolute: true,
|
||||
}),
|
||||
],
|
||||
},
|
||||
|
|
@ -996,6 +1003,126 @@ pub static FEATURE_GROUPS: &[FeatureGroup] = &[
|
|||
raw: false,
|
||||
absolute: false,
|
||||
}),
|
||||
Feature::Numeric(FeatureConfig {
|
||||
name: "Distance to nearest grocery store (km)",
|
||||
bounds: Bounds::Percentile {
|
||||
low: 2.0,
|
||||
high: 98.0,
|
||||
},
|
||||
step: 0.1,
|
||||
description: "Distance to the closest grocery shop or supermarket",
|
||||
detail: "Straight-line distance in kilometres from the postcode to the nearest grocery shop, supermarket, or convenience store. Uses OpenStreetMap POIs, with Waitrose and Tesco coverage from GEOLYTIX retail points.",
|
||||
source: "osm-pois",
|
||||
prefix: "",
|
||||
suffix: " km",
|
||||
raw: false,
|
||||
absolute: false,
|
||||
}),
|
||||
Feature::Numeric(FeatureConfig {
|
||||
name: "Distance to nearest tube station (km)",
|
||||
bounds: Bounds::Percentile {
|
||||
low: 2.0,
|
||||
high: 98.0,
|
||||
},
|
||||
step: 0.1,
|
||||
description: "Distance to the closest Tube, metro, tram, or DLR stop",
|
||||
detail: "Straight-line distance in kilometres from the postcode to the nearest NaPTAN station classified as Tube, metro, tram, or DLR.",
|
||||
source: "naptan",
|
||||
prefix: "",
|
||||
suffix: " km",
|
||||
raw: false,
|
||||
absolute: false,
|
||||
}),
|
||||
Feature::Numeric(FeatureConfig {
|
||||
name: "Distance to nearest rail station (km)",
|
||||
bounds: Bounds::Percentile {
|
||||
low: 2.0,
|
||||
high: 98.0,
|
||||
},
|
||||
step: 0.1,
|
||||
description: "Distance to the closest National Rail station",
|
||||
detail: "Straight-line distance in kilometres from the postcode to the nearest NaPTAN railway station.",
|
||||
source: "naptan",
|
||||
prefix: "",
|
||||
suffix: " km",
|
||||
raw: false,
|
||||
absolute: false,
|
||||
}),
|
||||
Feature::Numeric(FeatureConfig {
|
||||
name: "Distance to nearest Waitrose (km)",
|
||||
bounds: Bounds::Percentile {
|
||||
low: 2.0,
|
||||
high: 98.0,
|
||||
},
|
||||
step: 0.1,
|
||||
description: "Distance to the closest Waitrose store",
|
||||
detail: "Straight-line distance in kilometres from the postcode to the nearest Waitrose or Little Waitrose store in the GEOLYTIX Grocery Retail Points dataset.",
|
||||
source: "geolytix-retail-points",
|
||||
prefix: "",
|
||||
suffix: " km",
|
||||
raw: false,
|
||||
absolute: false,
|
||||
}),
|
||||
Feature::Numeric(FeatureConfig {
|
||||
name: "Distance to nearest Tesco (km)",
|
||||
bounds: Bounds::Percentile {
|
||||
low: 2.0,
|
||||
high: 98.0,
|
||||
},
|
||||
step: 0.1,
|
||||
description: "Distance to the closest Tesco store",
|
||||
detail: "Straight-line distance in kilometres from the postcode to the nearest Tesco store in the GEOLYTIX Grocery Retail Points dataset.",
|
||||
source: "geolytix-retail-points",
|
||||
prefix: "",
|
||||
suffix: " km",
|
||||
raw: false,
|
||||
absolute: false,
|
||||
}),
|
||||
Feature::Numeric(FeatureConfig {
|
||||
name: "Distance to nearest cafe (km)",
|
||||
bounds: Bounds::Percentile {
|
||||
low: 2.0,
|
||||
high: 98.0,
|
||||
},
|
||||
step: 0.1,
|
||||
description: "Distance to the closest cafe",
|
||||
detail: "Straight-line distance in kilometres from the postcode to the nearest cafe, ice-cream shop, or internet cafe mapped in OpenStreetMap.",
|
||||
source: "osm-pois",
|
||||
prefix: "",
|
||||
suffix: " km",
|
||||
raw: false,
|
||||
absolute: false,
|
||||
}),
|
||||
Feature::Numeric(FeatureConfig {
|
||||
name: "Distance to nearest pub (km)",
|
||||
bounds: Bounds::Percentile {
|
||||
low: 2.0,
|
||||
high: 98.0,
|
||||
},
|
||||
step: 0.1,
|
||||
description: "Distance to the closest pub",
|
||||
detail: "Straight-line distance in kilometres from the postcode to the nearest pub, social club, brewery, distillery, or winery mapped in OpenStreetMap.",
|
||||
source: "osm-pois",
|
||||
prefix: "",
|
||||
suffix: " km",
|
||||
raw: false,
|
||||
absolute: false,
|
||||
}),
|
||||
Feature::Numeric(FeatureConfig {
|
||||
name: "Distance to nearest restaurant (km)",
|
||||
bounds: Bounds::Percentile {
|
||||
low: 2.0,
|
||||
high: 98.0,
|
||||
},
|
||||
step: 0.1,
|
||||
description: "Distance to the closest restaurant",
|
||||
detail: "Straight-line distance in kilometres from the postcode to the nearest restaurant or food court mapped in OpenStreetMap.",
|
||||
source: "osm-pois",
|
||||
prefix: "",
|
||||
suffix: " km",
|
||||
raw: false,
|
||||
absolute: false,
|
||||
}),
|
||||
Feature::Numeric(FeatureConfig {
|
||||
name: "Number of parks within 1km",
|
||||
bounds: Bounds::Percentile {
|
||||
|
|
@ -1105,20 +1232,76 @@ pub fn order_for(name: &str) -> Option<&'static [&'static str]> {
|
|||
|
||||
/// Whether this feature should use integer-width histogram bins.
|
||||
pub fn has_integer_bins(name: &str) -> bool {
|
||||
INTEGER_BIN_FEATURES.contains(&name)
|
||||
INTEGER_BIN_FEATURES.contains(&name) || dynamic_poi_count_radius(name).is_some()
|
||||
}
|
||||
|
||||
/// Look up the Bounds config for a numeric feature by name.
|
||||
pub fn bounds_for(name: &str) -> Option<&'static Bounds> {
|
||||
pub fn bounds_for(name: &str) -> Option<Bounds> {
|
||||
if dynamic_poi_distance_category(name).is_some() {
|
||||
return Some(Bounds::Percentile {
|
||||
low: 2.0,
|
||||
high: 98.0,
|
||||
});
|
||||
}
|
||||
if dynamic_poi_count_radius(name).is_some() {
|
||||
return Some(Bounds::Percentile {
|
||||
low: 5.0,
|
||||
high: 95.0,
|
||||
});
|
||||
}
|
||||
|
||||
FEATURE_GROUPS
|
||||
.iter()
|
||||
.flat_map(|group| group.features.iter())
|
||||
.find_map(|feature| match feature {
|
||||
Feature::Numeric(c) if c.name == name => Some(&c.bounds),
|
||||
Feature::Numeric(c) if c.name == name => Some(c.bounds),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn dynamic_poi_distance_category(name: &str) -> Option<&str> {
|
||||
name.strip_prefix("Distance to nearest ")
|
||||
.and_then(|rest| rest.strip_suffix(" POI (km)"))
|
||||
.filter(|category| !category.is_empty())
|
||||
}
|
||||
|
||||
pub fn dynamic_poi_count_radius(name: &str) -> Option<u8> {
|
||||
let rest = name.strip_prefix("Number of ")?;
|
||||
let (_category, suffix) = rest.rsplit_once(" POIs within ")?;
|
||||
match suffix {
|
||||
"2km" => Some(2),
|
||||
"5km" => Some(5),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dynamic_poi_count_category(name: &str) -> Option<&str> {
|
||||
let rest = name.strip_prefix("Number of ")?;
|
||||
let (category, suffix) = rest.rsplit_once(" POIs within ")?;
|
||||
matches!(suffix, "2km" | "5km")
|
||||
.then_some(category)
|
||||
.filter(|category| !category.is_empty())
|
||||
}
|
||||
|
||||
pub fn is_dynamic_poi_feature(name: &str) -> bool {
|
||||
dynamic_poi_distance_category(name).is_some() || dynamic_poi_count_category(name).is_some()
|
||||
}
|
||||
|
||||
pub fn dynamic_poi_feature_sort_key(name: &str) -> (u8, String) {
|
||||
if let Some(category) = dynamic_poi_distance_category(name) {
|
||||
return (0, category.to_ascii_lowercase());
|
||||
}
|
||||
if let Some(category) = dynamic_poi_count_category(name) {
|
||||
let metric_order = match dynamic_poi_count_radius(name) {
|
||||
Some(2) => 1,
|
||||
Some(5) => 2,
|
||||
_ => 3,
|
||||
};
|
||||
return (metric_order, category.to_ascii_lowercase());
|
||||
}
|
||||
(9, name.to_ascii_lowercase())
|
||||
}
|
||||
|
||||
/// Canonical display order for POI category groups.
|
||||
/// The server will panic at startup if the data contains groups not in this list or vice versa.
|
||||
pub const POI_GROUP_ORDER: &[&str] = &[
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
mod aggregation;
|
||||
mod auth;
|
||||
mod checkout_sessions;
|
||||
mod consts;
|
||||
mod data;
|
||||
mod features;
|
||||
|
|
@ -10,6 +11,7 @@ mod metrics;
|
|||
mod og_middleware;
|
||||
pub mod parsing;
|
||||
mod pocketbase;
|
||||
mod pocketbase_locks;
|
||||
mod routes;
|
||||
mod state;
|
||||
pub mod utils;
|
||||
|
|
|
|||
|
|
@ -4,8 +4,11 @@ mod filters;
|
|||
mod h3;
|
||||
|
||||
pub use bounds::{bounds_intersect, h3_cell_bounds, parse_bounds, require_bounds};
|
||||
pub use fields::{parse_enum_dist, parse_field_indices, parse_field_set};
|
||||
pub use fields::{
|
||||
parse_enum_dist, parse_field_indices, parse_field_indices_with_poi, parse_field_set,
|
||||
};
|
||||
pub use filters::{
|
||||
count_filter_impacts, parse_filters, row_passes_filters, ParsedEnumFilter, ParsedFilter,
|
||||
count_filter_impacts, parse_filters, parse_filters_with_poi, row_passes_filters,
|
||||
row_passes_poi_filters, ParsedEnumFilter, ParsedFilter, ParsedPoiFilter,
|
||||
};
|
||||
pub use h3::{cell_for_row, cell_for_row_cached, needs_parent, validate_h3_resolution};
|
||||
|
|
|
|||
|
|
@ -31,6 +31,55 @@ pub fn parse_field_indices(
|
|||
Ok(Some(indices))
|
||||
}
|
||||
|
||||
pub struct ParsedFieldIndices {
|
||||
/// None means no `fields` param was supplied, so normal aggregation keeps
|
||||
/// its existing "all configured features" behavior.
|
||||
pub normal: Option<Vec<usize>>,
|
||||
pub poi: Vec<usize>,
|
||||
}
|
||||
|
||||
/// Parse `?fields=` against both the row-major feature matrix and the
|
||||
/// postcode-level POI side table.
|
||||
pub fn parse_field_indices_with_poi(
|
||||
fields: Option<&str>,
|
||||
name_to_index: &FxHashMap<String, usize>,
|
||||
poi_name_to_index: &FxHashMap<String, usize>,
|
||||
) -> Result<ParsedFieldIndices, (StatusCode, String)> {
|
||||
let Some(fields_str) = fields else {
|
||||
return Ok(ParsedFieldIndices {
|
||||
normal: None,
|
||||
poi: Vec::new(),
|
||||
});
|
||||
};
|
||||
if fields_str.is_empty() {
|
||||
return Ok(ParsedFieldIndices {
|
||||
normal: Some(Vec::new()),
|
||||
poi: Vec::new(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut normal = Vec::new();
|
||||
let mut poi = Vec::new();
|
||||
for name in fields_str.split(";;") {
|
||||
let name = name.trim();
|
||||
if name.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Some(&idx) = name_to_index.get(name) {
|
||||
normal.push(idx);
|
||||
} else if let Some(&idx) = poi_name_to_index.get(name) {
|
||||
poi.push(idx);
|
||||
} else {
|
||||
return Err((StatusCode::BAD_REQUEST, format!("Unknown field: {}", name)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ParsedFieldIndices {
|
||||
normal: Some(normal),
|
||||
poi,
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse an optional `?enum_dist=` query param into (feature_index, num_values) for
|
||||
/// per-value distribution counting. Returns None if not requested.
|
||||
/// Returns 400 if the feature name is unknown or not an enum feature.
|
||||
|
|
@ -73,3 +122,28 @@ pub fn parse_field_set(fields: Option<&str>) -> (bool, HashSet<String>) {
|
|||
.unwrap_or_default();
|
||||
(fields_specified, field_set)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_field_indices_with_poi_splits_normal_and_side_fields() {
|
||||
let normal: FxHashMap<String, usize> = [("Price".to_string(), 0), ("Area".to_string(), 1)]
|
||||
.into_iter()
|
||||
.collect();
|
||||
let poi: FxHashMap<String, usize> = [("Distance to nearest cafe POI (km)".to_string(), 2)]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let parsed = parse_field_indices_with_poi(
|
||||
Some("Price;;Distance to nearest cafe POI (km)"),
|
||||
&normal,
|
||||
&poi,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(parsed.normal, Some(vec![0]));
|
||||
assert_eq!(parsed.poi, vec![2]);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
|
||||
use crate::consts::NAN_U16;
|
||||
use crate::data::QuantRef;
|
||||
use crate::data::{PostcodePoiMetrics, QuantRef};
|
||||
|
||||
/// Filter for numeric features: value must be in [min_u16, max_u16] range (quantized).
|
||||
#[derive(Debug)]
|
||||
|
|
@ -19,6 +19,20 @@ pub struct ParsedEnumFilter {
|
|||
pub allowed: FxHashSet<u16>,
|
||||
}
|
||||
|
||||
/// Filter for postcode-level POI metrics stored in the side table.
|
||||
#[derive(Debug)]
|
||||
pub struct ParsedPoiFilter {
|
||||
pub metric_idx: usize,
|
||||
pub min_u16: u16,
|
||||
pub max_u16: u16,
|
||||
}
|
||||
|
||||
pub type ParsedFiltersWithPoi = (
|
||||
Vec<ParsedFilter>,
|
||||
Vec<ParsedEnumFilter>,
|
||||
Vec<ParsedPoiFilter>,
|
||||
);
|
||||
|
||||
/// Parse `;;`-separated filter string into numeric and enum filters.
|
||||
/// Numeric format: `name:min:max`
|
||||
/// Enum format: `name:val1|val2|val3` (pipe-separated string values)
|
||||
|
|
@ -110,6 +124,101 @@ pub fn parse_filters(
|
|||
Ok((numeric, enums))
|
||||
}
|
||||
|
||||
/// Parse filters while allowing dynamic POI metric names that live outside the
|
||||
/// row-major property feature matrix.
|
||||
pub fn parse_filters_with_poi(
|
||||
filter_str: Option<&str>,
|
||||
feature_name_to_index: &FxHashMap<String, usize>,
|
||||
enum_values: &FxHashMap<usize, Vec<String>>,
|
||||
quant: &QuantRef,
|
||||
poi_name_to_index: &FxHashMap<String, usize>,
|
||||
poi_quant: &QuantRef,
|
||||
) -> Result<ParsedFiltersWithPoi, String> {
|
||||
let mut numeric = Vec::new();
|
||||
let mut enums = Vec::new();
|
||||
let mut poi = Vec::new();
|
||||
|
||||
let input = match filter_str.filter(|text| !text.is_empty()) {
|
||||
Some(text) => text,
|
||||
None => return Ok((numeric, enums, poi)),
|
||||
};
|
||||
|
||||
for entry in input.split(";;") {
|
||||
let parts: Vec<&str> = entry.splitn(2, ':').collect();
|
||||
if parts.len() != 2 {
|
||||
return Err(format!("Malformed filter entry (missing ':'): '{entry}'"));
|
||||
}
|
||||
let name = parts[0].trim();
|
||||
let rest = parts[1].trim();
|
||||
|
||||
if let Some(&feat_idx) = feature_name_to_index.get(name) {
|
||||
if let Some(values) = enum_values.get(&feat_idx) {
|
||||
let mut allowed: FxHashSet<u16> = FxHashSet::default();
|
||||
for value in rest.split('|') {
|
||||
let value = value.trim();
|
||||
match values.iter().position(|existing| existing == value) {
|
||||
Some(position) => {
|
||||
allowed.insert(position as u16);
|
||||
}
|
||||
None => {
|
||||
return Err(format!(
|
||||
"Unknown value '{}' for enum feature '{}'. Valid values: {:?}",
|
||||
value, name, values
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
enums.push(ParsedEnumFilter { feat_idx, allowed });
|
||||
} else {
|
||||
let (min, max) = parse_numeric_filter_bounds(name, rest, entry)?;
|
||||
numeric.push(ParsedFilter {
|
||||
feat_idx,
|
||||
min_u16: quant.encode_min(feat_idx, min),
|
||||
max_u16: quant.encode_max(feat_idx, max),
|
||||
});
|
||||
}
|
||||
} else if let Some(&metric_idx) = poi_name_to_index.get(name) {
|
||||
let (min, max) = parse_numeric_filter_bounds(name, rest, entry)?;
|
||||
poi.push(ParsedPoiFilter {
|
||||
metric_idx,
|
||||
min_u16: poi_quant.encode_min(metric_idx, min),
|
||||
max_u16: poi_quant.encode_max(metric_idx, max),
|
||||
});
|
||||
} else {
|
||||
return Err(format!("Unknown feature in filter: '{name}'"));
|
||||
}
|
||||
}
|
||||
|
||||
numeric.sort_unstable_by_key(|f| f.max_u16.saturating_sub(f.min_u16));
|
||||
enums.sort_unstable_by_key(|f| f.allowed.len());
|
||||
poi.sort_unstable_by_key(|f| f.max_u16.saturating_sub(f.min_u16));
|
||||
|
||||
Ok((numeric, enums, poi))
|
||||
}
|
||||
|
||||
fn parse_numeric_filter_bounds(name: &str, rest: &str, entry: &str) -> Result<(f32, f32), String> {
|
||||
let num_parts: Vec<&str> = rest.splitn(2, ':').collect();
|
||||
if num_parts.len() != 2 {
|
||||
return Err(format!(
|
||||
"Numeric filter '{name}' must have format 'name:min:max', got '{entry}'"
|
||||
));
|
||||
}
|
||||
let min = num_parts[0]
|
||||
.trim()
|
||||
.parse::<f32>()
|
||||
.map_err(|err| format!("Invalid min value in filter '{name}': {err}"))?;
|
||||
let max = num_parts[1]
|
||||
.trim()
|
||||
.parse::<f32>()
|
||||
.map_err(|err| format!("Invalid max value in filter '{name}': {err}"))?;
|
||||
if min.is_finite() && max.is_finite() && min > max {
|
||||
return Err(format!(
|
||||
"Numeric filter '{name}' has inverted range: min ({min}) > max ({max})"
|
||||
));
|
||||
}
|
||||
Ok((min, max))
|
||||
}
|
||||
|
||||
/// Check if a row passes all filters.
|
||||
/// All features (numeric and enum) are stored in feature_data as quantized u16.
|
||||
pub fn row_passes_filters(
|
||||
|
|
@ -130,6 +239,18 @@ pub fn row_passes_filters(
|
|||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn row_passes_poi_filters(
|
||||
row: usize,
|
||||
filters: &[ParsedPoiFilter],
|
||||
poi_metrics: &PostcodePoiMetrics,
|
||||
) -> bool {
|
||||
filters.iter().all(|filter| {
|
||||
let raw = poi_metrics.raw_for_property_row(row, filter.metric_idx);
|
||||
raw != NAN_U16 && raw >= filter.min_u16 && raw <= filter.max_u16
|
||||
})
|
||||
}
|
||||
|
||||
/// Single-pass marginal impact counting.
|
||||
///
|
||||
/// Returns `(total_passing, impacts)` where `impacts[i]` is how many MORE rows
|
||||
|
|
@ -330,6 +451,35 @@ mod tests {
|
|||
assert_eq!(enums[0].allowed.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_filters_with_poi_splits_side_table_filters() {
|
||||
let tq = test_quant(3, 2);
|
||||
let poi_tq = test_quant(2, 2);
|
||||
let poi_map: FxHashMap<String, usize> = [
|
||||
("Distance to nearest cafe POI (km)".into(), 0),
|
||||
("Number of cafe POIs within 2km".into(), 1),
|
||||
]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let (numeric, enums, poi) = parse_filters_with_poi(
|
||||
Some("price:100:500;;rating:A;;Distance to nearest cafe POI (km):0:1.5"),
|
||||
&feature_name_to_index(),
|
||||
&enum_values(),
|
||||
&tq.as_ref(),
|
||||
&poi_map,
|
||||
&poi_tq.as_ref(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(numeric.len(), 1);
|
||||
assert_eq!(enums.len(), 1);
|
||||
assert_eq!(poi.len(), 1);
|
||||
assert_eq!(poi[0].metric_idx, 0);
|
||||
assert_eq!(poi[0].min_u16, 0);
|
||||
assert_eq!(poi[0].max_u16, 99);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_filters_empty() {
|
||||
let tq = test_quant(3, 2);
|
||||
|
|
|
|||
|
|
@ -88,6 +88,8 @@ struct CreateCollection {
|
|||
update_rule: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
delete_rule: Option<String>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
indexes: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
|
|
@ -308,12 +310,13 @@ async fn ensure_user_fields(client: &Client, base_url: &str, token: &str) -> any
|
|||
let has_ai_tokens_used = fields.iter().any(|f| f["name"] == "ai_tokens_used");
|
||||
let has_ai_tokens_week = fields.iter().any(|f| f["name"] == "ai_tokens_week");
|
||||
|
||||
if has_is_admin
|
||||
let has_all_required_fields = has_is_admin
|
||||
&& has_subscription
|
||||
&& has_newsletter
|
||||
&& has_ai_tokens_used
|
||||
&& has_ai_tokens_week
|
||||
{
|
||||
&& has_ai_tokens_week;
|
||||
|
||||
if has_all_required_fields {
|
||||
info!("PocketBase users collection already has all required fields");
|
||||
return Ok(());
|
||||
}
|
||||
|
|
@ -372,6 +375,52 @@ async fn ensure_user_fields(client: &Client, base_url: &str, token: &str) -> any
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Ensure clients can manage normal account data but cannot self-grant paid or
|
||||
/// admin-only state. Superuser writes from the Rust API bypass these rules.
|
||||
async fn ensure_user_auth_rules(
|
||||
client: &Client,
|
||||
base_url: &str,
|
||||
token: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let url = format!("{base_url}/api/collections/users");
|
||||
let self_only = "id = @request.auth.id";
|
||||
let protected_fields_absent = concat!(
|
||||
"@request.body.subscription:isset = false",
|
||||
" && @request.body.is_admin:isset = false",
|
||||
" && @request.body.ai_tokens_used:isset = false",
|
||||
" && @request.body.ai_tokens_week:isset = false"
|
||||
);
|
||||
let protected_fields_unchanged = concat!(
|
||||
"@request.body.subscription:changed = false",
|
||||
" && @request.body.is_admin:changed = false",
|
||||
" && @request.body.ai_tokens_used:changed = false",
|
||||
" && @request.body.ai_tokens_week:changed = false"
|
||||
);
|
||||
let update_rule = format!("{self_only} && {protected_fields_unchanged}");
|
||||
|
||||
let resp = client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"listRule": self_only,
|
||||
"viewRule": self_only,
|
||||
"createRule": protected_fields_absent,
|
||||
"updateRule": update_rule,
|
||||
"deleteRule": self_only,
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Failed to update users collection API rules ({status}): {text}");
|
||||
}
|
||||
|
||||
info!("PocketBase users collection API rules hardened");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Ensure a collection has API rules allowing users to manage their own records.
|
||||
async fn ensure_user_owned_rules(
|
||||
client: &Client,
|
||||
|
|
@ -404,6 +453,263 @@ async fn ensure_user_owned_rules(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Ensure a collection is accessible only via server-side superuser calls.
|
||||
async fn ensure_server_only_rules(
|
||||
client: &Client,
|
||||
base_url: &str,
|
||||
token: &str,
|
||||
collection_name: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let url = format!("{base_url}/api/collections/{collection_name}");
|
||||
let resp = client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"listRule": serde_json::Value::Null,
|
||||
"viewRule": serde_json::Value::Null,
|
||||
"createRule": serde_json::Value::Null,
|
||||
"updateRule": serde_json::Value::Null,
|
||||
"deleteRule": serde_json::Value::Null,
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Failed to lock {collection_name} API rules ({status}): {text}");
|
||||
}
|
||||
|
||||
info!("PocketBase collection '{collection_name}' locked to superuser access");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ensure_checkout_sessions_fields(
|
||||
client: &Client,
|
||||
base_url: &str,
|
||||
token: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let url = format!("{base_url}/api/collections/checkout_sessions");
|
||||
let resp = client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Failed to fetch checkout_sessions collection ({status}): {text}");
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let fields = body["fields"]
|
||||
.as_array()
|
||||
.ok_or_else(|| anyhow::anyhow!("checkout_sessions collection has no fields array"))?;
|
||||
let users_id = find_users_collection_id(client, base_url, token).await?;
|
||||
|
||||
let mut new_fields = fields.clone();
|
||||
let mut add_field = |name: &str, field: serde_json::Value| {
|
||||
if !fields.iter().any(|f| f["name"] == name) {
|
||||
new_fields.push(field);
|
||||
}
|
||||
};
|
||||
|
||||
add_field(
|
||||
"user",
|
||||
serde_json::json!({
|
||||
"name": "user",
|
||||
"type": "relation",
|
||||
"required": true,
|
||||
"maxSelect": 1,
|
||||
"collectionId": users_id,
|
||||
}),
|
||||
);
|
||||
add_field(
|
||||
"stripe_session_id",
|
||||
serde_json::json!({ "name": "stripe_session_id", "type": "text", "required": false }),
|
||||
);
|
||||
add_field(
|
||||
"checkout_url",
|
||||
serde_json::json!({ "name": "checkout_url", "type": "text", "required": false }),
|
||||
);
|
||||
add_field(
|
||||
"amount_pence",
|
||||
serde_json::json!({ "name": "amount_pence", "type": "number" }),
|
||||
);
|
||||
add_field(
|
||||
"expected_total_pence",
|
||||
serde_json::json!({ "name": "expected_total_pence", "type": "number" }),
|
||||
);
|
||||
add_field(
|
||||
"currency",
|
||||
serde_json::json!({ "name": "currency", "type": "text", "required": true }),
|
||||
);
|
||||
add_field(
|
||||
"discount_coupon_id",
|
||||
serde_json::json!({ "name": "discount_coupon_id", "type": "text", "required": false }),
|
||||
);
|
||||
add_field(
|
||||
"referral_invite_id",
|
||||
serde_json::json!({ "name": "referral_invite_id", "type": "text", "required": false }),
|
||||
);
|
||||
add_field(
|
||||
"status",
|
||||
serde_json::json!({ "name": "status", "type": "text", "required": true }),
|
||||
);
|
||||
add_field(
|
||||
"expires_at_unix",
|
||||
serde_json::json!({ "name": "expires_at_unix", "type": "number" }),
|
||||
);
|
||||
add_field(
|
||||
"paid_amount_pence",
|
||||
serde_json::json!({ "name": "paid_amount_pence", "type": "number" }),
|
||||
);
|
||||
add_field(
|
||||
"completed_at_unix",
|
||||
serde_json::json!({ "name": "completed_at_unix", "type": "text", "required": false }),
|
||||
);
|
||||
|
||||
if new_fields.len() == fields.len() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let patch_resp = client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({ "fields": new_fields }))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !patch_resp.status().is_success() {
|
||||
let status = patch_resp.status();
|
||||
let text = patch_resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Failed to patch checkout_sessions fields ({status}): {text}");
|
||||
}
|
||||
|
||||
info!("PocketBase checkout_sessions collection fields updated");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ensure_checkout_locks_fields(
|
||||
client: &Client,
|
||||
base_url: &str,
|
||||
token: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let url = format!("{base_url}/api/collections/checkout_locks");
|
||||
let resp = client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Failed to fetch checkout_locks collection ({status}): {text}");
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let fields = body["fields"]
|
||||
.as_array()
|
||||
.ok_or_else(|| anyhow::anyhow!("checkout_locks collection has no fields array"))?;
|
||||
|
||||
let mut new_fields = fields.clone();
|
||||
let mut add_field = |name: &str, field: serde_json::Value| {
|
||||
if !fields.iter().any(|f| f["name"] == name) {
|
||||
new_fields.push(field);
|
||||
}
|
||||
};
|
||||
|
||||
add_field(
|
||||
"name",
|
||||
serde_json::json!({ "name": "name", "type": "text", "required": true }),
|
||||
);
|
||||
add_field(
|
||||
"owner",
|
||||
serde_json::json!({ "name": "owner", "type": "text", "required": true }),
|
||||
);
|
||||
add_field(
|
||||
"expires_at_unix",
|
||||
serde_json::json!({ "name": "expires_at_unix", "type": "number" }),
|
||||
);
|
||||
|
||||
if new_fields.len() == fields.len() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let patch_resp = client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({ "fields": new_fields }))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !patch_resp.status().is_success() {
|
||||
let status = patch_resp.status();
|
||||
let text = patch_resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Failed to patch checkout_locks fields ({status}): {text}");
|
||||
}
|
||||
|
||||
info!("PocketBase checkout_locks collection fields updated");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ensure_collection_indexes(
|
||||
client: &Client,
|
||||
base_url: &str,
|
||||
token: &str,
|
||||
collection_name: &str,
|
||||
required_indexes: &[(&str, &str)],
|
||||
) -> anyhow::Result<()> {
|
||||
let url = format!("{base_url}/api/collections/{collection_name}");
|
||||
let resp = client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Failed to fetch {collection_name} collection ({status}): {text}");
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let indexes = body["indexes"].as_array().cloned().unwrap_or_default();
|
||||
let mut new_indexes = indexes.clone();
|
||||
|
||||
for (index_name, create_sql) in required_indexes {
|
||||
let exists = indexes
|
||||
.iter()
|
||||
.filter_map(|idx| idx.as_str())
|
||||
.any(|idx| idx.contains(index_name));
|
||||
if !exists {
|
||||
new_indexes.push(serde_json::Value::String((*create_sql).to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
if new_indexes.len() == indexes.len() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let patch_resp = client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({ "indexes": new_indexes }))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !patch_resp.status().is_success() {
|
||||
let status = patch_resp.status();
|
||||
let text = patch_resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Failed to patch {collection_name} indexes ({status}): {text}");
|
||||
}
|
||||
|
||||
info!("PocketBase collection '{collection_name}' indexes updated");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Ensure the `saved_searches` collection has API rules allowing users to manage their own records.
|
||||
async fn ensure_saved_searches_rules(
|
||||
client: &Client,
|
||||
|
|
@ -608,6 +914,7 @@ pub async fn ensure_collections(
|
|||
let existing = list_collections(client, base_url, &token).await?;
|
||||
|
||||
ensure_user_fields(client, base_url, &token).await?;
|
||||
ensure_user_auth_rules(client, base_url, &token).await?;
|
||||
|
||||
if !existing.iter().any(|n| n == "saved_searches") {
|
||||
let users_id = find_users_collection_id(client, base_url, &token).await?;
|
||||
|
|
@ -633,6 +940,7 @@ pub async fn ensure_collections(
|
|||
create_rule: user_only.clone(),
|
||||
update_rule: user_only.clone(),
|
||||
delete_rule: user_only,
|
||||
indexes: Vec::new(),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
|
@ -667,6 +975,7 @@ pub async fn ensure_collections(
|
|||
create_rule: user_only.clone(),
|
||||
update_rule: user_only.clone(),
|
||||
delete_rule: user_only,
|
||||
indexes: Vec::new(),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
|
@ -698,6 +1007,7 @@ pub async fn ensure_collections(
|
|||
create_rule: None,
|
||||
update_rule: None,
|
||||
delete_rule: None,
|
||||
indexes: Vec::new(),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
|
@ -705,6 +1015,86 @@ pub async fn ensure_collections(
|
|||
ensure_autodate_fields(client, base_url, &token, "invites").await?;
|
||||
}
|
||||
|
||||
if !existing.iter().any(|n| n == "checkout_sessions") {
|
||||
let users_id = find_users_collection_id(client, base_url, &token).await?;
|
||||
create_collection(
|
||||
client,
|
||||
base_url,
|
||||
&token,
|
||||
CreateCollection {
|
||||
name: "checkout_sessions".to_string(),
|
||||
r#type: "base".to_string(),
|
||||
fields: vec![
|
||||
Field::relation("user", &users_id),
|
||||
Field::text("stripe_session_id", false),
|
||||
Field::text("checkout_url", false),
|
||||
Field::number("amount_pence"),
|
||||
Field::number("expected_total_pence"),
|
||||
Field::text("currency", true),
|
||||
Field::text("discount_coupon_id", false),
|
||||
Field::text("referral_invite_id", false),
|
||||
Field::text("status", true),
|
||||
Field::number("expires_at_unix"),
|
||||
Field::number("paid_amount_pence"),
|
||||
Field::text("completed_at_unix", false),
|
||||
Field::autodate("created", true, false),
|
||||
Field::autodate("updated", true, true),
|
||||
],
|
||||
list_rule: None,
|
||||
view_rule: None,
|
||||
create_rule: None,
|
||||
update_rule: None,
|
||||
delete_rule: None,
|
||||
indexes: Vec::new(),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
} else {
|
||||
ensure_server_only_rules(client, base_url, &token, "checkout_sessions").await?;
|
||||
ensure_checkout_sessions_fields(client, base_url, &token).await?;
|
||||
ensure_autodate_fields(client, base_url, &token, "checkout_sessions").await?;
|
||||
}
|
||||
|
||||
let checkout_locks_name_index =
|
||||
"CREATE UNIQUE INDEX idx_checkout_locks_name ON checkout_locks (name)";
|
||||
if !existing.iter().any(|n| n == "checkout_locks") {
|
||||
create_collection(
|
||||
client,
|
||||
base_url,
|
||||
&token,
|
||||
CreateCollection {
|
||||
name: "checkout_locks".to_string(),
|
||||
r#type: "base".to_string(),
|
||||
fields: vec![
|
||||
Field::text("name", true),
|
||||
Field::text("owner", true),
|
||||
Field::number("expires_at_unix"),
|
||||
Field::autodate("created", true, false),
|
||||
Field::autodate("updated", true, true),
|
||||
],
|
||||
list_rule: None,
|
||||
view_rule: None,
|
||||
create_rule: None,
|
||||
update_rule: None,
|
||||
delete_rule: None,
|
||||
indexes: vec![checkout_locks_name_index.to_string()],
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
} else {
|
||||
ensure_server_only_rules(client, base_url, &token, "checkout_locks").await?;
|
||||
ensure_checkout_locks_fields(client, base_url, &token).await?;
|
||||
ensure_autodate_fields(client, base_url, &token, "checkout_locks").await?;
|
||||
ensure_collection_indexes(
|
||||
client,
|
||||
base_url,
|
||||
&token,
|
||||
"checkout_locks",
|
||||
&[("idx_checkout_locks_name", checkout_locks_name_index)],
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if !existing.iter().any(|n| n == "short_urls") {
|
||||
create_collection(
|
||||
client,
|
||||
|
|
@ -724,6 +1114,7 @@ pub async fn ensure_collections(
|
|||
create_rule: None,
|
||||
update_rule: None,
|
||||
delete_rule: None,
|
||||
indexes: Vec::new(),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
|
@ -753,6 +1144,7 @@ pub async fn ensure_collections(
|
|||
create_rule: None,
|
||||
update_rule: None,
|
||||
delete_rule: None,
|
||||
indexes: Vec::new(),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
|
@ -785,6 +1177,7 @@ pub async fn ensure_collections(
|
|||
create_rule: None,
|
||||
update_rule: None,
|
||||
delete_rule: None,
|
||||
indexes: Vec::new(),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
|
|
|||
264
server-rs/src/pocketbase_locks.rs
Normal file
264
server-rs/src/pocketbase_locks.rs
Normal file
|
|
@ -0,0 +1,264 @@
|
|||
use std::time::{Duration, Instant};
|
||||
|
||||
use anyhow::{anyhow, bail, Context};
|
||||
use rand::RngExt;
|
||||
use serde_json::Value;
|
||||
use tokio::time::sleep;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::pocketbase::get_superuser_token;
|
||||
use crate::state::AppState;
|
||||
|
||||
const LOCK_COLLECTION: &str = "checkout_locks";
|
||||
const LOCK_ACQUIRE_TIMEOUT_SECS: u64 = 10;
|
||||
const LOCK_RETRY_DELAY_MS: u64 = 100;
|
||||
|
||||
pub struct PocketBaseLock {
|
||||
client: reqwest::Client,
|
||||
pb_url: String,
|
||||
token: String,
|
||||
record_id: Option<String>,
|
||||
name: String,
|
||||
}
|
||||
|
||||
struct ExistingLock {
|
||||
id: String,
|
||||
expires_at_unix: u64,
|
||||
}
|
||||
|
||||
pub async fn acquire_pocketbase_lock(
|
||||
state: &AppState,
|
||||
name: &str,
|
||||
ttl_secs: u64,
|
||||
) -> anyhow::Result<PocketBaseLock> {
|
||||
validate_lock_name(name)?;
|
||||
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/').to_string();
|
||||
let owner = random_owner();
|
||||
let deadline = Instant::now() + Duration::from_secs(LOCK_ACQUIRE_TIMEOUT_SECS);
|
||||
|
||||
loop {
|
||||
let now = now_unix_secs();
|
||||
if let Some(record_id) =
|
||||
try_create_lock(state, &pb_url, &token, name, &owner, now + ttl_secs).await?
|
||||
{
|
||||
return Ok(PocketBaseLock {
|
||||
client: state.http_client.clone(),
|
||||
pb_url,
|
||||
token,
|
||||
record_id: Some(record_id),
|
||||
name: name.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(existing) = find_lock(state, &pb_url, &token, name).await? {
|
||||
if existing.expires_at_unix <= now {
|
||||
if let Err(err) = delete_lock_record(state, &pb_url, &token, &existing.id).await {
|
||||
warn!(
|
||||
lock_name = name,
|
||||
lock_id = %existing.id,
|
||||
"Failed to delete stale PocketBase lock: {err}"
|
||||
);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if Instant::now() >= deadline {
|
||||
bail!("Timed out acquiring PocketBase lock '{name}'");
|
||||
}
|
||||
|
||||
sleep(Duration::from_millis(LOCK_RETRY_DELAY_MS)).await;
|
||||
}
|
||||
}
|
||||
|
||||
impl PocketBaseLock {
|
||||
pub async fn release(mut self) -> anyhow::Result<()> {
|
||||
let Some(record_id) = self.record_id.take() else {
|
||||
return Ok(());
|
||||
};
|
||||
release_lock_record(&self.client, &self.pb_url, &self.token, &record_id)
|
||||
.await
|
||||
.with_context(|| format!("Failed to release PocketBase lock '{}'", self.name))
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for PocketBaseLock {
|
||||
fn drop(&mut self) {
|
||||
let Some(record_id) = self.record_id.take() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let client = self.client.clone();
|
||||
let pb_url = self.pb_url.clone();
|
||||
let token = self.token.clone();
|
||||
let name = self.name.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = release_lock_record(&client, &pb_url, &token, &record_id).await {
|
||||
warn!(
|
||||
lock_name = %name,
|
||||
lock_id = %record_id,
|
||||
"Failed to release PocketBase lock on drop: {err}"
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async fn try_create_lock(
|
||||
state: &AppState,
|
||||
pb_url: &str,
|
||||
token: &str,
|
||||
name: &str,
|
||||
owner: &str,
|
||||
expires_at_unix: u64,
|
||||
) -> anyhow::Result<Option<String>> {
|
||||
let url = format!("{pb_url}/api/collections/{LOCK_COLLECTION}/records");
|
||||
let resp = state
|
||||
.http_client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"name": name,
|
||||
"owner": owner,
|
||||
"expires_at_unix": expires_at_unix,
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if resp.status().is_success() {
|
||||
let body: Value = resp.json().await?;
|
||||
return body["id"]
|
||||
.as_str()
|
||||
.map(str::to_string)
|
||||
.map(Some)
|
||||
.ok_or_else(|| anyhow!("PocketBase lock record missing id"));
|
||||
}
|
||||
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
if status.is_client_error() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Err(anyhow!("PocketBase lock create failed ({status}): {text}"))
|
||||
}
|
||||
|
||||
async fn find_lock(
|
||||
state: &AppState,
|
||||
pb_url: &str,
|
||||
token: &str,
|
||||
name: &str,
|
||||
) -> anyhow::Result<Option<ExistingLock>> {
|
||||
let filter = format!("name=\"{}\"", name);
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/{LOCK_COLLECTION}/records?filter={}&perPage=1",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
let body: Value = resp.json().await?;
|
||||
let Some(item) = body["items"].as_array().and_then(|items| items.first()) else {
|
||||
return Ok(None);
|
||||
};
|
||||
let id = item["id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow!("PocketBase lock missing id"))?
|
||||
.to_string();
|
||||
let expires_at_unix = number_field(item, "expires_at_unix").unwrap_or(0);
|
||||
|
||||
Ok(Some(ExistingLock {
|
||||
id,
|
||||
expires_at_unix,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn delete_lock_record(
|
||||
state: &AppState,
|
||||
pb_url: &str,
|
||||
token: &str,
|
||||
record_id: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
release_lock_record(&state.http_client, pb_url, token, record_id).await
|
||||
}
|
||||
|
||||
async fn release_lock_record(
|
||||
client: &reqwest::Client,
|
||||
pb_url: &str,
|
||||
token: &str,
|
||||
record_id: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let url = format!("{pb_url}/api/collections/{LOCK_COLLECTION}/records/{record_id}");
|
||||
let resp = client
|
||||
.delete(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if resp.status().is_success() || resp.status() == reqwest::StatusCode::NOT_FOUND {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
Err(anyhow!("PocketBase lock delete failed ({status}): {text}"))
|
||||
}
|
||||
|
||||
fn validate_lock_name(name: &str) -> anyhow::Result<()> {
|
||||
if name.is_empty() || name.len() > 80 {
|
||||
bail!("invalid PocketBase lock name length");
|
||||
}
|
||||
if !name
|
||||
.bytes()
|
||||
.all(|b| b.is_ascii_alphanumeric() || b == b':' || b == b'_' || b == b'-')
|
||||
{
|
||||
bail!("invalid PocketBase lock name characters");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn random_owner() -> String {
|
||||
let mut rng = rand::rng();
|
||||
(0..24)
|
||||
.map(|_| {
|
||||
let idx: u8 = rng.random_range(0..36);
|
||||
if idx < 10 {
|
||||
(b'0' + idx) as char
|
||||
} else {
|
||||
(b'a' + idx - 10) as char
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn now_unix_secs() -> u64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
fn number_field(value: &Value, field: &str) -> Option<u64> {
|
||||
value[field].as_u64().or_else(|| {
|
||||
value[field]
|
||||
.as_f64()
|
||||
.filter(|n| n.is_finite() && *n >= 0.0 && n.fract() == 0.0)
|
||||
.map(|n| n as u64)
|
||||
})
|
||||
}
|
||||
|
||||
async fn ensure_success_ref(resp: &reqwest::Response) -> anyhow::Result<()> {
|
||||
if resp.status().is_success() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
Err(anyhow!("upstream returned {}", resp.status()))
|
||||
}
|
||||
|
|
@ -8,10 +8,8 @@ use serde::{Deserialize, Serialize};
|
|||
use tracing::{info, warn};
|
||||
|
||||
use crate::auth::OptionalUser;
|
||||
use crate::pocketbase::get_superuser_token;
|
||||
use crate::state::{AppState, SharedState};
|
||||
|
||||
use super::pricing::{count_licensed_users, price_for_count};
|
||||
use crate::checkout_sessions::{start_license_checkout, CheckoutStart};
|
||||
use crate::state::SharedState;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct CheckoutRequest {
|
||||
|
|
@ -23,8 +21,8 @@ struct CheckoutResponse {
|
|||
url: String,
|
||||
}
|
||||
|
||||
/// Create a Stripe Checkout session for the lifetime license (or grant for free if in free tier).
|
||||
/// Requires authentication. Optionally accepts a referral code to apply a coupon.
|
||||
/// Create a reserved Stripe Checkout session for the lifetime license.
|
||||
/// Requires authentication. Referral discounts are issued via invite redemption.
|
||||
pub async fn post_checkout(
|
||||
State(shared): State<Arc<SharedState>>,
|
||||
Extension(user): Extension<OptionalUser>,
|
||||
|
|
@ -36,147 +34,27 @@ pub async fn post_checkout(
|
|||
None => return StatusCode::UNAUTHORIZED.into_response(),
|
||||
};
|
||||
|
||||
let count = match count_licensed_users(&state).await {
|
||||
Ok(c) => c,
|
||||
Err(err) => {
|
||||
warn!("Failed to count licensed users at checkout: {err}");
|
||||
return StatusCode::SERVICE_UNAVAILABLE.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let price_pence = price_for_count(count);
|
||||
let public_url = &state.public_url;
|
||||
let success_url = format!("{public_url}/pricing?license_success=1");
|
||||
|
||||
// Free tier — grant license directly without Stripe
|
||||
if price_pence == 0 {
|
||||
if let Err(err) = grant_license(&state, &user.id).await {
|
||||
warn!(user_id = %user.id, "Failed to grant free license: {err}");
|
||||
return StatusCode::BAD_GATEWAY.into_response();
|
||||
}
|
||||
info!(user_id = %user.id, "Granted free early-bird license");
|
||||
return Json(CheckoutResponse { url: success_url }).into_response();
|
||||
}
|
||||
|
||||
// Paid tier — create Stripe checkout with dynamic price
|
||||
let secret_key = &state.stripe_secret_key;
|
||||
let cancel_url = format!("{public_url}/pricing");
|
||||
|
||||
let mut form_params = vec![
|
||||
("mode", "payment".to_string()),
|
||||
(
|
||||
"line_items[0][price_data][unit_amount]",
|
||||
price_pence.to_string(),
|
||||
),
|
||||
("line_items[0][price_data][currency]", "gbp".to_string()),
|
||||
(
|
||||
"line_items[0][price_data][product_data][name]",
|
||||
"Perfect Postcodes Lifetime License".to_string(),
|
||||
),
|
||||
("line_items[0][quantity]", "1".to_string()),
|
||||
("success_url", success_url),
|
||||
("cancel_url", cancel_url),
|
||||
("client_reference_id", user.id.clone()),
|
||||
("customer_email", user.email.clone()),
|
||||
];
|
||||
|
||||
// If a referral code is provided and valid, look it up and apply the coupon
|
||||
if let Some(ref code) = req.referral_code {
|
||||
if validate_referral_invite(&state, code).await {
|
||||
form_params.push((
|
||||
"discounts[0][coupon]",
|
||||
state.stripe_referral_coupon_id.clone(),
|
||||
));
|
||||
info!(code = %code, "Applying referral coupon to checkout");
|
||||
} else {
|
||||
warn!(code = %code, "Referral code validation failed, proceeding without discount");
|
||||
}
|
||||
if req.referral_code.is_some() {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Referral codes must be redeemed from the invite link",
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let res = state
|
||||
.http_client
|
||||
.post("https://api.stripe.com/v1/checkout/sessions")
|
||||
.basic_auth(secret_key, None::<&str>)
|
||||
.form(&form_params)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let body: serde_json::Value = match resp.json().await {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
warn!("Failed to parse Stripe response: {err}");
|
||||
return StatusCode::BAD_GATEWAY.into_response();
|
||||
}
|
||||
};
|
||||
let url = body["url"].as_str().unwrap_or_default().to_string();
|
||||
if url.is_empty() {
|
||||
warn!("Stripe session missing URL");
|
||||
return StatusCode::BAD_GATEWAY.into_response();
|
||||
}
|
||||
info!(user_id = %user.id, price_pence, "Created Stripe checkout session");
|
||||
Json(CheckoutResponse { url }).into_response()
|
||||
}
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
warn!("Stripe checkout failed ({status}): {text}");
|
||||
StatusCode::BAD_GATEWAY.into_response()
|
||||
match start_license_checkout(&state, &user, &success_url, &cancel_url, None, None).await {
|
||||
Ok(CheckoutStart::Free) => {
|
||||
info!(user_id = %user.id, "Granted free early-bird license");
|
||||
Json(CheckoutResponse { url: success_url }).into_response()
|
||||
}
|
||||
Ok(CheckoutStart::Stripe { url }) => Json(CheckoutResponse { url }).into_response(),
|
||||
Err(err) => {
|
||||
warn!("Stripe request error: {err}");
|
||||
warn!(user_id = %user.id, "Failed to start checkout: {err:?}");
|
||||
StatusCode::BAD_GATEWAY.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Grant a license by updating the user's subscription to "licensed" in PocketBase.
|
||||
async fn grant_license(state: &AppState, user_id: &str) -> anyhow::Result<()> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({ "subscription": "licensed" }))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("PocketBase update failed ({status}): {text}");
|
||||
}
|
||||
|
||||
state.token_cache.invalidate_by_user_id(user_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a referral invite code exists and is unused.
|
||||
async fn validate_referral_invite(state: &AppState, code: &str) -> bool {
|
||||
// Only allow alphanumeric codes to prevent PocketBase filter injection
|
||||
if code.is_empty() || code.len() > 20 || !code.bytes().all(|b| b.is_ascii_alphanumeric()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let filter = format!(
|
||||
"code=\"{}\" && invite_type=\"referral\" && used_by_id=\"\"",
|
||||
code
|
||||
);
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/invites/records?filter={}&perPage=1",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
|
||||
match state.http_client.get(&url).send().await {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let body: serde_json::Value = resp.json().await.unwrap_or_default();
|
||||
body["totalItems"].as_u64().unwrap_or(0) > 0
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::extract::{Query, State};
|
||||
use axum::http::{header, HeaderMap, StatusCode};
|
||||
|
|
@ -13,14 +14,18 @@ use tracing::{info, warn};
|
|||
|
||||
use crate::auth::OptionalUser;
|
||||
use crate::consts::NAN_U16;
|
||||
use crate::data::QuantRef;
|
||||
use crate::features::INTEGER_BIN_FEATURES;
|
||||
use crate::data::{PostcodePoiMetrics, QuantRef};
|
||||
use crate::features;
|
||||
use crate::licensing::check_license_bounds;
|
||||
use crate::parsing::{parse_field_indices, parse_filters, require_bounds, row_passes_filters};
|
||||
use crate::parsing::{
|
||||
parse_field_indices_with_poi, parse_filters_with_poi, require_bounds, row_passes_filters,
|
||||
row_passes_poi_filters,
|
||||
};
|
||||
use crate::routes::{fetch_screenshot_bytes, FeatureInfo};
|
||||
use crate::state::SharedState;
|
||||
|
||||
const MAX_EXPORT_POSTCODES: usize = 250;
|
||||
const EXPORT_SCREENSHOT_TIMEOUT_SECS: u64 = 12;
|
||||
/// Height (in pixels) reserved for the screenshot row
|
||||
const IMAGE_ROW_HEIGHT: f64 = 225.0;
|
||||
|
||||
|
|
@ -41,11 +46,11 @@ struct PostcodeExportAgg {
|
|||
}
|
||||
|
||||
impl PostcodeExportAgg {
|
||||
fn new(num_features: usize) -> Self {
|
||||
fn new(total_features: usize) -> Self {
|
||||
Self {
|
||||
count: 0,
|
||||
sums: vec![0.0; num_features],
|
||||
finite_counts: vec![0; num_features],
|
||||
sums: vec![0.0; total_features],
|
||||
finite_counts: vec![0; total_features],
|
||||
enum_freqs: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
|
|
@ -58,6 +63,7 @@ impl PostcodeExportAgg {
|
|||
num_features: usize,
|
||||
enum_indices: &FxHashMap<usize, ()>,
|
||||
quant: &QuantRef,
|
||||
poi_metrics: &PostcodePoiMetrics,
|
||||
) {
|
||||
self.count += 1;
|
||||
let base = row * num_features;
|
||||
|
|
@ -79,6 +85,18 @@ impl PostcodeExportAgg {
|
|||
self.finite_counts[feat_idx] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let poi_offset = num_features;
|
||||
for metric_idx in 0..poi_metrics.num_features() {
|
||||
let raw = poi_metrics.raw_for_property_row(row, metric_idx);
|
||||
if raw == NAN_U16 {
|
||||
continue;
|
||||
}
|
||||
let value = poi_metrics.decode_raw(metric_idx, raw);
|
||||
let out_idx = poi_offset + metric_idx;
|
||||
self.sums[out_idx] += value as f64;
|
||||
self.finite_counts[out_idx] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -138,13 +156,17 @@ pub async fn get_export(
|
|||
check_license_bounds(&user.0, (south, west, north, east), None)?;
|
||||
|
||||
let quant = state.data.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters) = parse_filters(
|
||||
let poi_quant = state.data.poi_metrics.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
|
||||
params.filters.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
&quant,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
&poi_quant,
|
||||
)
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
let has_poi_filters = !parsed_poi_filters.is_empty();
|
||||
let filters_str = params.filters;
|
||||
let fields_str = params.fields;
|
||||
|
||||
|
|
@ -164,16 +186,28 @@ pub async fn get_export(
|
|||
|
||||
// Fetch screenshot (async, before spawn_blocking)
|
||||
let auth_header = headers.get(header::AUTHORIZATION);
|
||||
let screenshot_bytes = match fetch_screenshot_bytes(&state, &frontend_params, auth_header).await
|
||||
let screenshot_fetch = fetch_screenshot_bytes(&state, &frontend_params, auth_header);
|
||||
let screenshot_bytes = match tokio::time::timeout(
|
||||
Duration::from_secs(EXPORT_SCREENSHOT_TIMEOUT_SECS),
|
||||
screenshot_fetch,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(bytes) => {
|
||||
Ok(Ok(bytes)) => {
|
||||
info!(bytes = bytes.len(), "Fetched screenshot for export");
|
||||
Some(bytes)
|
||||
}
|
||||
Err(err) => {
|
||||
Ok(Err(err)) => {
|
||||
warn!("Screenshot failed for export: {err}");
|
||||
None
|
||||
}
|
||||
Err(_) => {
|
||||
warn!(
|
||||
timeout_secs = EXPORT_SCREENSHOT_TIMEOUT_SECS,
|
||||
"Screenshot timed out for export"
|
||||
);
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
// Build feature name → description map from the precomputed features response
|
||||
|
|
@ -200,6 +234,9 @@ pub async fn get_export(
|
|||
let feature_names = &state.data.feature_names;
|
||||
let enum_values = &state.data.enum_values;
|
||||
let postcode_data = &state.postcode_data;
|
||||
let poi_metrics = &state.data.poi_metrics;
|
||||
let poi_offset = num_features;
|
||||
let total_export_features = num_features + poi_metrics.num_features();
|
||||
|
||||
// Build set of enum feature indices for quick lookup
|
||||
let enum_indices: FxHashMap<usize, ()> = enum_values.keys().map(|&idx| (idx, ())).collect();
|
||||
|
|
@ -219,6 +256,10 @@ pub async fn get_export(
|
|||
) {
|
||||
return;
|
||||
}
|
||||
if has_poi_filters && !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics)
|
||||
{
|
||||
return;
|
||||
}
|
||||
let postcode = state.data.postcode(row);
|
||||
if let Some(&pc_idx) = postcode_data.postcode_to_idx.get(postcode) {
|
||||
postcode_rows.entry(pc_idx).or_default().push(row);
|
||||
|
|
@ -229,9 +270,16 @@ pub async fn get_export(
|
|||
let mut postcode_aggs: Vec<(usize, PostcodeExportAgg)> =
|
||||
Vec::with_capacity(postcode_rows.len());
|
||||
for (pc_idx, rows) in postcode_rows {
|
||||
let mut agg = PostcodeExportAgg::new(num_features);
|
||||
let mut agg = PostcodeExportAgg::new(total_export_features);
|
||||
for &row in &rows {
|
||||
agg.add_row(feature_data, row, num_features, &enum_indices, &quant);
|
||||
agg.add_row(
|
||||
feature_data,
|
||||
row,
|
||||
num_features,
|
||||
&enum_indices,
|
||||
&quant,
|
||||
poi_metrics,
|
||||
);
|
||||
}
|
||||
if agg.count > 0 {
|
||||
postcode_aggs.push((pc_idx, agg));
|
||||
|
|
@ -265,14 +313,19 @@ pub async fn get_export(
|
|||
// Determine column order: filter features first, then remaining
|
||||
let filter_feature_names = extract_filter_feature_names(filters_str.as_deref());
|
||||
|
||||
let field_indices =
|
||||
parse_field_indices(fields_str.as_deref(), &state.feature_name_to_index)
|
||||
.map_err(|err| err.1)?;
|
||||
let field_indices = parse_field_indices_with_poi(
|
||||
fields_str.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
)
|
||||
.map_err(|err| err.1)?;
|
||||
|
||||
let all_feature_indices: Vec<usize> = if let Some(ref indices) = field_indices {
|
||||
indices.clone()
|
||||
let all_feature_indices: Vec<usize> = if let Some(ref indices) = field_indices.normal {
|
||||
let mut selected = indices.clone();
|
||||
selected.extend(field_indices.poi.iter().map(|idx| poi_offset + *idx));
|
||||
selected
|
||||
} else {
|
||||
let mut ordered = Vec::with_capacity(num_features);
|
||||
let mut ordered = Vec::with_capacity(total_export_features);
|
||||
let mut used = FxHashSet::default();
|
||||
|
||||
for name in &filter_feature_names {
|
||||
|
|
@ -280,6 +333,11 @@ pub async fn get_export(
|
|||
if used.insert(idx) {
|
||||
ordered.push(idx);
|
||||
}
|
||||
} else if let Some(&idx) = state.data.poi_metrics.name_to_index.get(name.as_str()) {
|
||||
let virtual_idx = poi_offset + idx;
|
||||
if used.insert(virtual_idx) {
|
||||
ordered.push(virtual_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
for idx in 0..num_features {
|
||||
|
|
@ -287,15 +345,42 @@ pub async fn get_export(
|
|||
ordered.push(idx);
|
||||
}
|
||||
}
|
||||
for idx in 0..poi_metrics.num_features() {
|
||||
let virtual_idx = poi_offset + idx;
|
||||
if used.insert(virtual_idx) {
|
||||
ordered.push(virtual_idx);
|
||||
}
|
||||
}
|
||||
ordered
|
||||
};
|
||||
|
||||
// Filter-only feature indices for the Selected sheet
|
||||
let filter_feature_indices: Vec<usize> = filter_feature_names
|
||||
.iter()
|
||||
.filter_map(|name| state.feature_name_to_index.get(name.as_str()).copied())
|
||||
.filter_map(|name| {
|
||||
state
|
||||
.feature_name_to_index
|
||||
.get(name.as_str())
|
||||
.copied()
|
||||
.or_else(|| {
|
||||
state
|
||||
.data
|
||||
.poi_metrics
|
||||
.name_to_index
|
||||
.get(name.as_str())
|
||||
.map(|idx| poi_offset + *idx)
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let feature_name_for_idx = |idx: usize| -> &str {
|
||||
if idx < num_features {
|
||||
&feature_names[idx]
|
||||
} else {
|
||||
&poi_metrics.feature_names[idx - poi_offset]
|
||||
}
|
||||
};
|
||||
|
||||
// Build feature unit map (feat_idx → (prefix, suffix)) for number formatting
|
||||
let feature_units: FxHashMap<usize, (&str, &str)> = state
|
||||
.features_response
|
||||
|
|
@ -309,16 +394,25 @@ pub async fn get_export(
|
|||
suffix,
|
||||
..
|
||||
} => {
|
||||
let idx = state.feature_name_to_index.get(name.as_str())?;
|
||||
Some((*idx, (*prefix, *suffix)))
|
||||
if let Some(&idx) = state.feature_name_to_index.get(name.as_str()) {
|
||||
Some((idx, (*prefix, *suffix)))
|
||||
} else {
|
||||
state
|
||||
.data
|
||||
.poi_metrics
|
||||
.name_to_index
|
||||
.get(name.as_str())
|
||||
.map(|idx| (poi_offset + *idx, (*prefix, *suffix)))
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let integer_feature_indices: FxHashSet<usize> = INTEGER_BIN_FEATURES
|
||||
let integer_feature_indices: FxHashSet<usize> = all_feature_indices
|
||||
.iter()
|
||||
.filter_map(|name| state.feature_name_to_index.get(*name).copied())
|
||||
.copied()
|
||||
.filter(|&idx| features::has_integer_bins(feature_name_for_idx(idx)))
|
||||
.collect();
|
||||
|
||||
// Build Excel number formats per feature index for unit display
|
||||
|
|
@ -435,7 +529,7 @@ pub async fn get_export(
|
|||
.write_string_with_format(
|
||||
header_row,
|
||||
col,
|
||||
&feature_names[feat_idx],
|
||||
feature_name_for_idx(feat_idx),
|
||||
&header_fmt,
|
||||
)
|
||||
.map_err(|e| format!("Failed to write header: {e}"))?;
|
||||
|
|
@ -453,7 +547,7 @@ pub async fn get_export(
|
|||
for (col_offset, &feat_idx) in feat_indices.iter().enumerate() {
|
||||
let col = (col_offset + 2) as u16;
|
||||
let desc = feature_descriptions
|
||||
.get(&feature_names[feat_idx])
|
||||
.get(feature_name_for_idx(feat_idx))
|
||||
.map(String::as_str)
|
||||
.unwrap_or("");
|
||||
sheet
|
||||
|
|
@ -477,7 +571,7 @@ pub async fn get_export(
|
|||
for (col_offset, &feat_idx) in feat_indices.iter().enumerate() {
|
||||
let col = (col_offset + 2) as u16;
|
||||
|
||||
if enum_indices.contains_key(&feat_idx) {
|
||||
if feat_idx < num_features && enum_indices.contains_key(&feat_idx) {
|
||||
if let Some(freqs) = agg.enum_freqs.get(&feat_idx) {
|
||||
if let Some((&mode_bits, _)) =
|
||||
freqs.iter().max_by_key(|(_, &count)| count)
|
||||
|
|
@ -543,7 +637,7 @@ pub async fn get_export(
|
|||
.map_err(|e| format!("Failed to set column width: {e}"))?;
|
||||
for col_offset in 0..feat_indices.len() {
|
||||
let col = (col_offset + 2) as u16;
|
||||
let feat_name = &feature_names[feat_indices[col_offset]];
|
||||
let feat_name = feature_name_for_idx(feat_indices[col_offset]);
|
||||
let width = (feat_name.len() as f64 * 1.1).clamp(10.0, 30.0);
|
||||
sheet
|
||||
.set_column_width(col, width)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ use serde::Serialize;
|
|||
use tracing::info;
|
||||
|
||||
use crate::data::{Histogram, PropertyData};
|
||||
use crate::features::{Feature, FEATURE_GROUPS};
|
||||
use crate::features::{self, Feature, FEATURE_GROUPS};
|
||||
use crate::state::SharedState;
|
||||
|
||||
fn is_empty(val: &str) -> bool {
|
||||
|
|
@ -28,9 +28,9 @@ pub enum FeatureInfo {
|
|||
max: f32,
|
||||
step: f32,
|
||||
histogram: Histogram,
|
||||
description: &'static str,
|
||||
detail: &'static str,
|
||||
source: &'static str,
|
||||
description: String,
|
||||
detail: String,
|
||||
source: String,
|
||||
#[serde(skip_serializing_if = "is_empty")]
|
||||
prefix: &'static str,
|
||||
#[serde(skip_serializing_if = "is_empty")]
|
||||
|
|
@ -45,9 +45,9 @@ pub enum FeatureInfo {
|
|||
name: String,
|
||||
values: Vec<String>,
|
||||
counts: HashMap<String, u64>,
|
||||
description: &'static str,
|
||||
detail: &'static str,
|
||||
source: &'static str,
|
||||
description: String,
|
||||
detail: String,
|
||||
source: String,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -85,9 +85,9 @@ pub fn build_features_response(data: &PropertyData) -> FeaturesResponse {
|
|||
max: stats.slider_max,
|
||||
step: config.step,
|
||||
histogram: stats.histogram.clone(),
|
||||
description: config.description,
|
||||
detail: config.detail,
|
||||
source: config.source,
|
||||
description: config.description.to_string(),
|
||||
detail: config.detail.to_string(),
|
||||
source: config.source.to_string(),
|
||||
prefix: config.prefix,
|
||||
suffix: config.suffix,
|
||||
raw: config.raw,
|
||||
|
|
@ -118,9 +118,9 @@ pub fn build_features_response(data: &PropertyData) -> FeaturesResponse {
|
|||
name: config.name.to_string(),
|
||||
values: values.clone(),
|
||||
counts,
|
||||
description: config.description,
|
||||
detail: config.detail,
|
||||
source: config.source,
|
||||
description: config.description.to_string(),
|
||||
detail: config.detail.to_string(),
|
||||
source: config.source.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -136,6 +136,58 @@ pub fn build_features_response(data: &PropertyData) -> FeaturesResponse {
|
|||
}
|
||||
}
|
||||
|
||||
let mut dynamic_poi_features = Vec::new();
|
||||
for (feat_idx, name) in data.poi_metrics.feature_names.iter().enumerate() {
|
||||
if let Some(category) = features::dynamic_poi_distance_category(name) {
|
||||
let stats = &data.poi_metrics.feature_stats[feat_idx];
|
||||
dynamic_poi_features.push(FeatureInfo::Numeric {
|
||||
name: name.clone(),
|
||||
min: stats.slider_min,
|
||||
max: stats.slider_max,
|
||||
step: 0.1,
|
||||
histogram: stats.histogram.clone(),
|
||||
description: format!("Distance to the closest {category} POI"),
|
||||
detail: format!(
|
||||
"Straight-line distance in kilometres from the postcode to the nearest {category} point of interest in the POI dataset."
|
||||
),
|
||||
source: "osm-pois".to_string(),
|
||||
prefix: "",
|
||||
suffix: " km",
|
||||
raw: false,
|
||||
absolute: false,
|
||||
});
|
||||
} else if let Some(category) = features::dynamic_poi_count_category(name) {
|
||||
let stats = &data.poi_metrics.feature_stats[feat_idx];
|
||||
let radius = features::dynamic_poi_count_radius(name).unwrap_or(0);
|
||||
dynamic_poi_features.push(FeatureInfo::Numeric {
|
||||
name: name.clone(),
|
||||
min: stats.slider_min,
|
||||
max: stats.slider_max,
|
||||
step: 1.0,
|
||||
histogram: stats.histogram.clone(),
|
||||
description: format!("Number of {category} POIs within {radius}km"),
|
||||
detail: format!(
|
||||
"Count of {category} points of interest within a {radius}km radius of the property's postcode centroid."
|
||||
),
|
||||
source: "osm-pois".to_string(),
|
||||
prefix: "",
|
||||
suffix: "",
|
||||
raw: false,
|
||||
absolute: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
if !dynamic_poi_features.is_empty() {
|
||||
dynamic_poi_features.sort_by_key(|feature| match feature {
|
||||
FeatureInfo::Numeric { name, .. } => features::dynamic_poi_feature_sort_key(name),
|
||||
FeatureInfo::Enum { name, .. } => features::dynamic_poi_feature_sort_key(name),
|
||||
});
|
||||
groups.push(FeatureGroupResponse {
|
||||
name: "Nearby POIs".to_string(),
|
||||
features: dynamic_poi_features,
|
||||
});
|
||||
}
|
||||
|
||||
FeaturesResponse { groups }
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ use tracing::info;
|
|||
|
||||
use crate::consts::NAN_U16;
|
||||
use crate::data::travel_time::TravelData;
|
||||
use crate::parsing::{parse_filters, require_bounds};
|
||||
use crate::parsing::{parse_filters_with_poi, require_bounds};
|
||||
use crate::routes::travel_time::parse_optional_travel;
|
||||
use crate::state::SharedState;
|
||||
|
||||
|
|
@ -36,18 +36,21 @@ pub async fn get_filter_counts(
|
|||
require_bounds(params.bounds).map_err(IntoResponse::into_response)?;
|
||||
|
||||
let quant = state.data.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters) = parse_filters(
|
||||
let poi_quant = state.data.poi_metrics.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
|
||||
params.filters.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
&quant,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
&poi_quant,
|
||||
)
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
|
||||
let travel_entries = parse_optional_travel(params.travel.as_deref())
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
|
||||
let num_regular = parsed_filters.len() + parsed_enum_filters.len();
|
||||
let num_regular = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
|
||||
// Only travel entries with a filter range count as filters for impact tracking
|
||||
let travel_filter_indices: Vec<usize> = travel_entries
|
||||
.iter()
|
||||
|
|
@ -65,6 +68,7 @@ pub async fn get_filter_counts(
|
|||
}
|
||||
|
||||
let filters_str = params.filters;
|
||||
let has_poi_filters = !parsed_poi_filters.is_empty();
|
||||
|
||||
let response = tokio::task::spawn_blocking(move || -> Result<FilterCountsResponse, String> {
|
||||
let t0 = std::time::Instant::now();
|
||||
|
|
@ -124,6 +128,23 @@ pub async fn get_filter_counts(
|
|||
}
|
||||
}
|
||||
|
||||
// Test travel time filters
|
||||
if fail_count <= 1 && has_poi_filters {
|
||||
for (i, f) in parsed_poi_filters.iter().enumerate() {
|
||||
let raw = state
|
||||
.data
|
||||
.poi_metrics
|
||||
.raw_for_property_row(row, f.metric_idx);
|
||||
if raw == NAN_U16 || raw < f.min_u16 || raw > f.max_u16 {
|
||||
fail_count += 1;
|
||||
fail_index = parsed_filters.len() + parsed_enum_filters.len() + i;
|
||||
if fail_count > 1 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test travel time filters
|
||||
if fail_count <= 1 && has_travel {
|
||||
let postcode = pc_interner.resolve(&pc_keys[row]);
|
||||
|
|
@ -169,8 +190,15 @@ pub async fn get_filter_counts(
|
|||
let name = if i < parsed_filters.len() {
|
||||
state.data.feature_names[parsed_filters[i].feat_idx].clone()
|
||||
} else if i < num_regular {
|
||||
let ei = i - parsed_filters.len();
|
||||
state.data.feature_names[parsed_enum_filters[ei].feat_idx].clone()
|
||||
let enum_start = parsed_filters.len();
|
||||
let poi_start = enum_start + parsed_enum_filters.len();
|
||||
if i < poi_start {
|
||||
let ei = i - enum_start;
|
||||
state.data.feature_names[parsed_enum_filters[ei].feat_idx].clone()
|
||||
} else {
|
||||
let pi = i - poi_start;
|
||||
state.data.poi_metrics.feature_names[parsed_poi_filters[pi].metric_idx].clone()
|
||||
}
|
||||
} else {
|
||||
let slot = i - num_regular;
|
||||
let ti = travel_filter_indices[slot];
|
||||
|
|
|
|||
|
|
@ -13,8 +13,8 @@ use tracing::{info, warn};
|
|||
use crate::auth::OptionalUser;
|
||||
use crate::licensing::{check_license_bounds, resolve_share_code};
|
||||
use crate::parsing::{
|
||||
cell_for_row_cached, h3_cell_bounds, needs_parent, parse_field_set, parse_filters,
|
||||
row_passes_filters, validate_h3_resolution,
|
||||
cell_for_row_cached, h3_cell_bounds, needs_parent, parse_field_set, parse_filters_with_poi,
|
||||
row_passes_filters, row_passes_poi_filters, validate_h3_resolution,
|
||||
};
|
||||
use crate::state::SharedState;
|
||||
|
||||
|
|
@ -110,15 +110,19 @@ pub async fn get_hexagon_stats(
|
|||
|
||||
let h3_str = params.h3;
|
||||
let quant = state.data.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters) = parse_filters(
|
||||
let poi_quant = state.data.poi_metrics.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
|
||||
params.filters.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
&quant,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
&poi_quant,
|
||||
)
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
|
||||
let filters_str = params.filters;
|
||||
let has_poi_filters = !parsed_poi_filters.is_empty();
|
||||
|
||||
let (fields_specified, field_set) = parse_field_set(params.fields.as_deref());
|
||||
|
||||
|
|
@ -161,6 +165,12 @@ pub async fn get_hexagon_stats(
|
|||
feature_data,
|
||||
num_features,
|
||||
)
|
||||
&& (!has_poi_filters
|
||||
|| row_passes_poi_filters(
|
||||
row,
|
||||
&parsed_poi_filters,
|
||||
&state.data.poi_metrics,
|
||||
))
|
||||
{
|
||||
if has_travel {
|
||||
let postcode = state.data.postcode(row);
|
||||
|
|
@ -233,7 +243,7 @@ pub async fn get_hexagon_stats(
|
|||
let price_history =
|
||||
stats::extract_price_history(&matching_rows, &state.data, &state.feature_name_to_index);
|
||||
|
||||
let (numeric_features, enum_features_out) = stats::compute_feature_stats(
|
||||
let (mut numeric_features, enum_features_out) = stats::compute_feature_stats(
|
||||
&matching_rows,
|
||||
&state.data,
|
||||
&state.data.feature_names,
|
||||
|
|
@ -242,6 +252,12 @@ pub async fn get_hexagon_stats(
|
|||
fields_specified,
|
||||
&field_set,
|
||||
);
|
||||
numeric_features.extend(stats::compute_poi_feature_stats(
|
||||
&matching_rows,
|
||||
&state.data.poi_metrics,
|
||||
fields_specified,
|
||||
&field_set,
|
||||
));
|
||||
|
||||
let elapsed = start_time.elapsed();
|
||||
info!(
|
||||
|
|
|
|||
|
|
@ -11,14 +11,15 @@ use serde::{Deserialize, Serialize};
|
|||
use serde_json::{Map, Value};
|
||||
use tracing::info;
|
||||
|
||||
use crate::aggregation::{Aggregator, EnumDistConfig};
|
||||
use crate::aggregation::{Aggregator, EnumDistConfig, PoiAggregator};
|
||||
use crate::auth::OptionalUser;
|
||||
use crate::consts::MAX_CELLS_PER_REQUEST;
|
||||
use crate::data::travel_time::TravelData;
|
||||
use crate::licensing::{check_license_bounds, resolve_share_code};
|
||||
use crate::parsing::{
|
||||
cell_for_row_cached, needs_parent, parse_enum_dist, parse_field_indices, parse_filters,
|
||||
require_bounds, row_passes_filters, validate_h3_resolution,
|
||||
cell_for_row_cached, needs_parent, parse_enum_dist, parse_field_indices_with_poi,
|
||||
parse_filters_with_poi, require_bounds, row_passes_filters, row_passes_poi_filters,
|
||||
validate_h3_resolution,
|
||||
};
|
||||
use crate::routes::travel_time::{parse_optional_travel, TravelTimeAgg};
|
||||
use crate::state::SharedState;
|
||||
|
|
@ -29,6 +30,7 @@ const PARALLEL_THRESHOLD: usize = 50_000;
|
|||
/// Per-thread aggregation result: feature accumulators + travel time accumulators.
|
||||
type ChunkResult = (
|
||||
FxHashMap<u64, Aggregator>,
|
||||
FxHashMap<u64, PoiAggregator>,
|
||||
Vec<FxHashMap<u64, TravelTimeAgg>>,
|
||||
);
|
||||
|
||||
|
|
@ -79,11 +81,14 @@ pub struct HexagonParams {
|
|||
#[allow(clippy::too_many_arguments)]
|
||||
fn build_feature_maps(
|
||||
groups: &FxHashMap<u64, Aggregator>,
|
||||
poi_groups: &FxHashMap<u64, PoiAggregator>,
|
||||
min_keys: &[String],
|
||||
max_keys: &[String],
|
||||
avg_keys: &[String],
|
||||
num_features: usize,
|
||||
indices: Option<&[usize]>,
|
||||
poi_feature_names: &[String],
|
||||
poi_indices: &[usize],
|
||||
query_bounds: (f64, f64, f64, f64),
|
||||
resolution: h3o::Resolution,
|
||||
travel_aggs: &[FxHashMap<u64, TravelTimeAgg>],
|
||||
|
|
@ -163,6 +168,25 @@ fn build_feature_maps(
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(poi_aggregation) = poi_groups.get(&cell_id) {
|
||||
for &metric_idx in poi_indices {
|
||||
if poi_aggregation.counts[metric_idx] > 0 {
|
||||
let avg = poi_aggregation.sums[metric_idx]
|
||||
/ poi_aggregation.counts[metric_idx] as f64;
|
||||
if let (Some(min_num), Some(max_num), Some(avg_num)) = (
|
||||
serde_json::Number::from_f64(poi_aggregation.mins[metric_idx] as f64),
|
||||
serde_json::Number::from_f64(poi_aggregation.maxs[metric_idx] as f64),
|
||||
serde_json::Number::from_f64(avg),
|
||||
) {
|
||||
let name = &poi_feature_names[metric_idx];
|
||||
map.insert(format!("min_{name}"), Value::Number(min_num));
|
||||
map.insert(format!("max_{name}"), Value::Number(max_num));
|
||||
map.insert(format!("avg_{name}"), Value::Number(avg_num));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add travel time aggregation fields (using pre-computed key strings)
|
||||
for (ti, agg_map) in travel_aggs.iter().enumerate() {
|
||||
if let Some(agg) = agg_map.get(&cell_id) {
|
||||
|
|
@ -209,18 +233,25 @@ pub async fn get_hexagons(
|
|||
check_license_bounds(&user.0, (south, west, north, east), share_bounds)?;
|
||||
|
||||
let quant = state.data.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters) = parse_filters(
|
||||
let poi_quant = state.data.poi_metrics.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
|
||||
params.filters.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
&quant,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
&poi_quant,
|
||||
)
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
|
||||
let filters_str = params.filters;
|
||||
|
||||
let field_indices = parse_field_indices(params.fields.as_deref(), &state.feature_name_to_index)
|
||||
.map_err(|err| (err.0, err.1).into_response())?;
|
||||
let field_indices = parse_field_indices_with_poi(
|
||||
params.fields.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
)
|
||||
.map_err(|err| (err.0, err.1).into_response())?;
|
||||
|
||||
let travel_entries = parse_optional_travel(params.travel.as_deref())
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
|
|
@ -269,6 +300,11 @@ pub async fn get_hexagons(
|
|||
let min_keys = &state.min_keys;
|
||||
let max_keys = &state.max_keys;
|
||||
let avg_keys = &state.avg_keys;
|
||||
let poi_metrics = &state.data.poi_metrics;
|
||||
let poi_field_indices = field_indices.poi.as_slice();
|
||||
let has_poi_fields = !poi_field_indices.is_empty();
|
||||
let has_poi_filters = !parsed_poi_filters.is_empty();
|
||||
let poi_num_features = poi_metrics.num_features();
|
||||
|
||||
let h3_res = h3o::Resolution::try_from(resolution)
|
||||
.map_err(|error| format!("Invalid H3 resolution {}: {}", resolution, error))?;
|
||||
|
|
@ -276,6 +312,7 @@ pub async fn get_hexagons(
|
|||
let need_parent = needs_parent(resolution);
|
||||
|
||||
let mut groups: FxHashMap<u64, Aggregator> = FxHashMap::default();
|
||||
let mut poi_groups: FxHashMap<u64, PoiAggregator> = FxHashMap::default();
|
||||
let mut travel_aggs: Vec<FxHashMap<u64, TravelTimeAgg>> = (0..travel_entries.len())
|
||||
.map(|_| FxHashMap::default())
|
||||
.collect();
|
||||
|
|
@ -296,6 +333,7 @@ pub async fn get_hexagons(
|
|||
.par_chunks(chunk_size)
|
||||
.map(|chunk| {
|
||||
let mut local_groups: FxHashMap<u64, Aggregator> = FxHashMap::default();
|
||||
let mut local_poi_groups: FxHashMap<u64, PoiAggregator> = FxHashMap::default();
|
||||
let mut local_travel_aggs: Vec<FxHashMap<u64, TravelTimeAgg>> = (0
|
||||
..travel_entries.len())
|
||||
.map(|_| FxHashMap::default())
|
||||
|
|
@ -315,6 +353,11 @@ pub async fn get_hexagons(
|
|||
) {
|
||||
continue;
|
||||
}
|
||||
if has_poi_filters
|
||||
&& !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if has_travel {
|
||||
travel_minutes.clear();
|
||||
|
|
@ -352,7 +395,7 @@ pub async fn get_hexagons(
|
|||
let agg = local_groups
|
||||
.entry(cell_id)
|
||||
.or_insert_with(|| Aggregator::new(num_features, enum_dist_config));
|
||||
if let Some(sel_indices) = field_indices.as_deref() {
|
||||
if let Some(sel_indices) = field_indices.normal.as_deref() {
|
||||
agg.add_row_selective(
|
||||
feature_data,
|
||||
row,
|
||||
|
|
@ -364,6 +407,13 @@ pub async fn get_hexagons(
|
|||
agg.add_row(feature_data, row, num_features, &quant);
|
||||
}
|
||||
|
||||
if has_poi_fields {
|
||||
local_poi_groups
|
||||
.entry(cell_id)
|
||||
.or_insert_with(|| PoiAggregator::new(poi_num_features))
|
||||
.add_row_selective(poi_metrics, row, poi_field_indices);
|
||||
}
|
||||
|
||||
for (ti, minutes) in travel_minutes.iter().enumerate() {
|
||||
if let Some(mins) = minutes {
|
||||
let tagg = local_travel_aggs[ti]
|
||||
|
|
@ -374,18 +424,24 @@ pub async fn get_hexagons(
|
|||
}
|
||||
}
|
||||
|
||||
(local_groups, local_travel_aggs)
|
||||
(local_groups, local_poi_groups, local_travel_aggs)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Merge thread-local results into the main accumulators
|
||||
for (local_groups, local_travel) in thread_results {
|
||||
for (local_groups, local_poi_groups, local_travel) in thread_results {
|
||||
for (cell_id, local_agg) in local_groups {
|
||||
groups
|
||||
.entry(cell_id)
|
||||
.or_insert_with(|| Aggregator::new(num_features, enum_dist_config))
|
||||
.merge(&local_agg);
|
||||
}
|
||||
for (cell_id, local_agg) in local_poi_groups {
|
||||
poi_groups
|
||||
.entry(cell_id)
|
||||
.or_insert_with(|| PoiAggregator::new(poi_num_features))
|
||||
.merge(&local_agg);
|
||||
}
|
||||
for (ti, local_ta) in local_travel.into_iter().enumerate() {
|
||||
for (cell_id, local_tt) in local_ta {
|
||||
travel_aggs[ti]
|
||||
|
|
@ -414,6 +470,11 @@ pub async fn get_hexagons(
|
|||
) {
|
||||
return;
|
||||
}
|
||||
if has_poi_filters
|
||||
&& !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if has_travel {
|
||||
travel_minutes.clear();
|
||||
|
|
@ -444,7 +505,7 @@ pub async fn get_hexagons(
|
|||
let aggregation = groups
|
||||
.entry(cell_id)
|
||||
.or_insert_with(|| Aggregator::new(num_features, enum_dist_config));
|
||||
if let Some(sel_indices) = field_indices.as_deref() {
|
||||
if let Some(sel_indices) = field_indices.normal.as_deref() {
|
||||
aggregation.add_row_selective(
|
||||
feature_data,
|
||||
row,
|
||||
|
|
@ -456,6 +517,13 @@ pub async fn get_hexagons(
|
|||
aggregation.add_row(feature_data, row, num_features, &quant);
|
||||
}
|
||||
|
||||
if has_poi_fields {
|
||||
poi_groups
|
||||
.entry(cell_id)
|
||||
.or_insert_with(|| PoiAggregator::new(poi_num_features))
|
||||
.add_row_selective(poi_metrics, row, poi_field_indices);
|
||||
}
|
||||
|
||||
for (ti, minutes) in travel_minutes.iter().enumerate() {
|
||||
if let Some(mins) = minutes {
|
||||
let agg = travel_aggs[ti]
|
||||
|
|
@ -471,11 +539,14 @@ pub async fn get_hexagons(
|
|||
|
||||
let mut features = build_feature_maps(
|
||||
&groups,
|
||||
&poi_groups,
|
||||
min_keys,
|
||||
max_keys,
|
||||
avg_keys,
|
||||
num_features,
|
||||
field_indices.as_deref(),
|
||||
field_indices.normal.as_deref(),
|
||||
&poi_metrics.feature_names,
|
||||
poi_field_indices,
|
||||
(south, west, north, east),
|
||||
h3_res,
|
||||
&travel_aggs,
|
||||
|
|
@ -499,7 +570,11 @@ pub async fn get_hexagons(
|
|||
bounds = format_args!("{:.4},{:.4},{:.4},{:.4}", south, west, north, east),
|
||||
filters = num_filters,
|
||||
filters_raw = filters_str.as_deref().unwrap_or("-"),
|
||||
fields = field_indices.as_ref().map(|v| v.len() as i32).unwrap_or(-1),
|
||||
fields = field_indices
|
||||
.normal
|
||||
.as_ref()
|
||||
.map(|v| (v.len() + poi_field_indices.len()) as i32)
|
||||
.unwrap_or(-1),
|
||||
travel_entries = travel_entries.len(),
|
||||
grid_ms = format_args!("{:.1}", t_grid.as_secs_f64() * 1000.0),
|
||||
agg_ms = format_args!("{:.1}", (t_agg - t_grid).as_secs_f64() * 1000.0),
|
||||
|
|
|
|||
|
|
@ -9,11 +9,16 @@ use serde::{Deserialize, Serialize};
|
|||
use tracing::{info, warn};
|
||||
|
||||
use crate::auth::{OptionalUser, PocketBaseUser};
|
||||
use crate::checkout_sessions::{
|
||||
active_referral_checkout_user, start_license_checkout, CheckoutStart,
|
||||
};
|
||||
use crate::pocketbase::get_superuser_token;
|
||||
use crate::pocketbase_locks::acquire_pocketbase_lock;
|
||||
use crate::state::{AppState, SharedState};
|
||||
|
||||
static INVITE_REDEMPTIONS_IN_PROGRESS: LazyLock<Mutex<HashSet<String>>> =
|
||||
LazyLock::new(|| Mutex::new(HashSet::new()));
|
||||
const INVITE_REDEMPTION_LOCK_TTL_SECS: u64 = 5 * 60;
|
||||
|
||||
struct InviteRedemptionGuard {
|
||||
code: String,
|
||||
|
|
@ -103,7 +108,7 @@ fn validate_invite_code(code: &str) -> Result<(), &'static str> {
|
|||
}
|
||||
|
||||
fn generate_invite_code() -> String {
|
||||
use rand::Rng;
|
||||
use rand::RngExt;
|
||||
let mut rng = rand::rng();
|
||||
let chars: Vec<char> = (0..12)
|
||||
.map(|_| {
|
||||
|
|
@ -246,74 +251,26 @@ async fn grant_license_for_invite(
|
|||
async fn create_referral_checkout(
|
||||
state: &AppState,
|
||||
user: &PocketBaseUser,
|
||||
invite_id: &str,
|
||||
) -> Result<String, Response> {
|
||||
let count = match super::pricing::count_licensed_users(state).await {
|
||||
Ok(count) => count,
|
||||
Err(err) => {
|
||||
warn!("Failed to count licensed users for invite checkout: {err}");
|
||||
return Err(StatusCode::SERVICE_UNAVAILABLE.into_response());
|
||||
}
|
||||
};
|
||||
let price_pence = super::pricing::price_for_count(count);
|
||||
|
||||
let public_url = &state.public_url;
|
||||
let success_url = format!("{public_url}/pricing?license_success=1");
|
||||
let cancel_url = format!("{public_url}/pricing");
|
||||
|
||||
let form_params = vec![
|
||||
("mode", "payment".to_string()),
|
||||
(
|
||||
"line_items[0][price_data][unit_amount]",
|
||||
price_pence.to_string(),
|
||||
),
|
||||
("line_items[0][price_data][currency]", "gbp".to_string()),
|
||||
(
|
||||
"line_items[0][price_data][product_data][name]",
|
||||
"Perfect Postcodes Lifetime License".to_string(),
|
||||
),
|
||||
("line_items[0][quantity]", "1".to_string()),
|
||||
("success_url", success_url),
|
||||
("cancel_url", cancel_url),
|
||||
("client_reference_id", user.id.clone()),
|
||||
("customer_email", user.email.clone()),
|
||||
(
|
||||
"discounts[0][coupon]",
|
||||
state.stripe_referral_coupon_id.clone(),
|
||||
),
|
||||
];
|
||||
|
||||
let stripe_res = state
|
||||
.http_client
|
||||
.post("https://api.stripe.com/v1/checkout/sessions")
|
||||
.basic_auth(&state.stripe_secret_key, None::<&str>)
|
||||
.form(&form_params)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match stripe_res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let stripe_body: serde_json::Value = match resp.json().await {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
warn!("Failed to parse Stripe checkout response: {err}");
|
||||
return Err(StatusCode::BAD_GATEWAY.into_response());
|
||||
}
|
||||
};
|
||||
let checkout_url = stripe_body["url"].as_str().unwrap_or_default().to_string();
|
||||
if checkout_url.is_empty() {
|
||||
warn!("Stripe checkout response did not include a URL");
|
||||
return Err(StatusCode::BAD_GATEWAY.into_response());
|
||||
}
|
||||
Ok(checkout_url)
|
||||
}
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
warn!("Failed to create Stripe checkout for referral invite ({status}): {text}");
|
||||
Err(StatusCode::BAD_GATEWAY.into_response())
|
||||
}
|
||||
match start_license_checkout(
|
||||
state,
|
||||
user,
|
||||
&success_url,
|
||||
&cancel_url,
|
||||
Some(&state.stripe_referral_coupon_id),
|
||||
Some(invite_id),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(CheckoutStart::Free) => Ok(success_url),
|
||||
Ok(CheckoutStart::Stripe { url }) => Ok(url),
|
||||
Err(err) => {
|
||||
warn!("Stripe request error for referral invite: {err}");
|
||||
warn!("Failed to create reserved Stripe checkout for referral invite: {err:?}");
|
||||
Err(StatusCode::BAD_GATEWAY.into_response())
|
||||
}
|
||||
}
|
||||
|
|
@ -541,6 +498,10 @@ pub async fn post_redeem_invite(
|
|||
.into_response();
|
||||
}
|
||||
|
||||
if user.is_admin || user.subscription == "licensed" {
|
||||
return (StatusCode::CONFLICT, "Account already has full access").into_response();
|
||||
}
|
||||
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
|
||||
let token = match get_superuser_token(&state).await {
|
||||
|
|
@ -561,6 +522,19 @@ pub async fn post_redeem_invite(
|
|||
.into_response()
|
||||
}
|
||||
};
|
||||
let lock_name = format!("invite:{}", req.code);
|
||||
let _distributed_redemption_guard =
|
||||
match acquire_pocketbase_lock(&state, &lock_name, INVITE_REDEMPTION_LOCK_TTL_SECS).await {
|
||||
Ok(guard) => guard,
|
||||
Err(err) => {
|
||||
warn!(code = %req.code, "Failed to acquire invite redemption lock: {err}");
|
||||
return (
|
||||
StatusCode::CONFLICT,
|
||||
"Invite redemption is already in progress",
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let invite = match lookup_unused_invite(&state, pb_url, &token, &req.code).await {
|
||||
Ok(Some(invite)) => invite,
|
||||
|
|
@ -591,11 +565,11 @@ pub async fn post_redeem_invite(
|
|||
};
|
||||
|
||||
if invite_type == "admin" {
|
||||
if let Err(response) = grant_license_for_invite(&state, pb_url, &token, &user.id).await {
|
||||
if let Err(response) = mark_invite_used(&state, pb_url, &token, invite_id, &user.id).await {
|
||||
return response;
|
||||
}
|
||||
|
||||
if let Err(response) = mark_invite_used(&state, pb_url, &token, invite_id, &user.id).await {
|
||||
if let Err(response) = grant_license_for_invite(&state, pb_url, &token, &user.id).await {
|
||||
return response;
|
||||
}
|
||||
|
||||
|
|
@ -607,15 +581,26 @@ pub async fn post_redeem_invite(
|
|||
.into_response();
|
||||
}
|
||||
|
||||
let checkout_url = match create_referral_checkout(&state, &user).await {
|
||||
match active_referral_checkout_user(&state, invite_id).await {
|
||||
Ok(Some(active_user_id)) if active_user_id != user.id => {
|
||||
return (
|
||||
StatusCode::CONFLICT,
|
||||
"Invite checkout is already in progress",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
Ok(_) => {}
|
||||
Err(err) => {
|
||||
warn!(code = %req.code, "Failed to check active referral checkout: {err}");
|
||||
return StatusCode::BAD_GATEWAY.into_response();
|
||||
}
|
||||
}
|
||||
|
||||
let checkout_url = match create_referral_checkout(&state, &user, invite_id).await {
|
||||
Ok(url) => url,
|
||||
Err(response) => return response,
|
||||
};
|
||||
|
||||
if let Err(response) = mark_invite_used(&state, pb_url, &token, invite_id, &user.id).await {
|
||||
return response;
|
||||
}
|
||||
|
||||
info!(user_id = %user.id, code = %req.code, "Referral invite redeemed; checkout created");
|
||||
Json(RedeemResponse {
|
||||
result: "checkout".to_string(),
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
|
|||
use tracing::info;
|
||||
|
||||
use crate::consts::MAX_POIS_PER_REQUEST;
|
||||
use crate::data::POICategoryGroup;
|
||||
use crate::data::{resolve_poi_category_filter, POICategoryGroup};
|
||||
use crate::parsing::require_bounds;
|
||||
use crate::state::SharedState;
|
||||
|
||||
|
|
@ -47,20 +47,7 @@ pub async fn get_pois(
|
|||
.categories
|
||||
.as_deref()
|
||||
.filter(|text| !text.is_empty())
|
||||
.map(|text| {
|
||||
text.split(',')
|
||||
.filter_map(|part| {
|
||||
let name = part.trim();
|
||||
state
|
||||
.poi_data
|
||||
.category
|
||||
.values
|
||||
.iter()
|
||||
.position(|v| v == name)
|
||||
.map(|pos| pos as u16)
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
.map(|text| resolve_poi_category_filter(&state.poi_data.category.values, text));
|
||||
let categories_raw = params.categories;
|
||||
|
||||
let num_categories = category_filter.as_ref().map(|cats| cats.len()).unwrap_or(0);
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ use tracing::{info, warn};
|
|||
use crate::auth::OptionalUser;
|
||||
use crate::consts::{DEFAULT_PROPERTIES_LIMIT, MAX_PROPERTIES_LIMIT, POSTCODE_SEARCH_OFFSET};
|
||||
use crate::licensing::{check_license_point, resolve_share_code};
|
||||
use crate::parsing::{parse_filters, row_passes_filters};
|
||||
use crate::parsing::{parse_filters_with_poi, row_passes_filters, row_passes_poi_filters};
|
||||
use crate::state::SharedState;
|
||||
use crate::utils::normalize_postcode;
|
||||
|
||||
|
|
@ -62,15 +62,19 @@ pub async fn get_postcode_properties(
|
|||
)?;
|
||||
|
||||
let quant = state.data.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters) = parse_filters(
|
||||
let poi_quant = state.data.poi_metrics.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
|
||||
params.filters.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
&quant,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
&poi_quant,
|
||||
)
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
|
||||
let filters_str = params.filters;
|
||||
let has_poi_filters = !parsed_poi_filters.is_empty();
|
||||
let travel_entries = parse_optional_travel(params.travel.as_deref())
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
|
||||
|
|
@ -111,6 +115,12 @@ pub async fn get_postcode_properties(
|
|||
feature_data,
|
||||
num_features,
|
||||
)
|
||||
&& (!has_poi_filters
|
||||
|| row_passes_poi_filters(
|
||||
row,
|
||||
&parsed_poi_filters,
|
||||
&state.data.poi_metrics,
|
||||
))
|
||||
{
|
||||
if has_travel
|
||||
&& !row_passes_travel_filters(
|
||||
|
|
|
|||
|
|
@ -10,7 +10,9 @@ use tracing::{info, warn};
|
|||
use crate::auth::OptionalUser;
|
||||
use crate::consts::POSTCODE_SEARCH_OFFSET;
|
||||
use crate::licensing::{check_license_point, resolve_share_code};
|
||||
use crate::parsing::{parse_field_set, parse_filters, row_passes_filters};
|
||||
use crate::parsing::{
|
||||
parse_field_set, parse_filters_with_poi, row_passes_filters, row_passes_poi_filters,
|
||||
};
|
||||
use crate::state::SharedState;
|
||||
use crate::utils::normalize_postcode;
|
||||
|
||||
|
|
@ -64,15 +66,19 @@ pub async fn get_postcode_stats(
|
|||
)?;
|
||||
|
||||
let quant = state.data.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters) = parse_filters(
|
||||
let poi_quant = state.data.poi_metrics.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
|
||||
params.filters.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
&quant,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
&poi_quant,
|
||||
)
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
|
||||
let filters_str = params.filters;
|
||||
let has_poi_filters = !parsed_poi_filters.is_empty();
|
||||
|
||||
let (fields_specified, field_set) = parse_field_set(params.fields.as_deref());
|
||||
let travel_entries = parse_optional_travel(params.travel.as_deref())
|
||||
|
|
@ -108,6 +114,12 @@ pub async fn get_postcode_stats(
|
|||
feature_data,
|
||||
num_features,
|
||||
)
|
||||
&& (!has_poi_filters
|
||||
|| row_passes_poi_filters(
|
||||
row,
|
||||
&parsed_poi_filters,
|
||||
&state.data.poi_metrics,
|
||||
))
|
||||
{
|
||||
if has_travel
|
||||
&& !row_passes_travel_filters(row_postcode, &travel_entries, &travel_data)
|
||||
|
|
@ -123,7 +135,7 @@ pub async fn get_postcode_stats(
|
|||
let price_history =
|
||||
stats::extract_price_history(&matching_rows, &state.data, &state.feature_name_to_index);
|
||||
|
||||
let (numeric_features, enum_features_out) = stats::compute_feature_stats(
|
||||
let (mut numeric_features, enum_features_out) = stats::compute_feature_stats(
|
||||
&matching_rows,
|
||||
&state.data,
|
||||
&state.data.feature_names,
|
||||
|
|
@ -132,6 +144,12 @@ pub async fn get_postcode_stats(
|
|||
fields_specified,
|
||||
&field_set,
|
||||
);
|
||||
numeric_features.extend(stats::compute_poi_feature_stats(
|
||||
&matching_rows,
|
||||
&state.data.poi_metrics,
|
||||
fields_specified,
|
||||
&field_set,
|
||||
));
|
||||
|
||||
let elapsed = start_time.elapsed();
|
||||
info!(
|
||||
|
|
|
|||
|
|
@ -10,14 +10,14 @@ use serde::{Deserialize, Serialize};
|
|||
use serde_json::{Map, Value};
|
||||
use tracing::info;
|
||||
|
||||
use crate::aggregation::{Aggregator, EnumDistConfig};
|
||||
use crate::aggregation::{Aggregator, EnumDistConfig, PoiAggregator};
|
||||
use crate::auth::OptionalUser;
|
||||
use crate::consts::MAX_CELLS_PER_REQUEST;
|
||||
use crate::data::travel_time::TravelData;
|
||||
use crate::licensing::{check_license_bounds, resolve_share_code};
|
||||
use crate::parsing::{
|
||||
bounds_intersect, parse_enum_dist, parse_field_indices, parse_filters, require_bounds,
|
||||
row_passes_filters,
|
||||
bounds_intersect, parse_enum_dist, parse_field_indices_with_poi, parse_filters_with_poi,
|
||||
require_bounds, row_passes_filters, row_passes_poi_filters,
|
||||
};
|
||||
use crate::pocketbase::log_user_location;
|
||||
use crate::routes::travel_time::{parse_optional_travel, TravelTimeAgg};
|
||||
|
|
@ -64,18 +64,25 @@ pub async fn get_postcodes(
|
|||
check_license_bounds(&user.0, (south, west, north, east), share_bounds)?;
|
||||
|
||||
let quant = state.data.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters) = parse_filters(
|
||||
let poi_quant = state.data.poi_metrics.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
|
||||
params.filters.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
&quant,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
&poi_quant,
|
||||
)
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
|
||||
let filters_str = params.filters;
|
||||
|
||||
let field_indices = parse_field_indices(params.fields.as_deref(), &state.feature_name_to_index)
|
||||
.map_err(|err| (err.0, err.1).into_response())?;
|
||||
let field_indices = parse_field_indices_with_poi(
|
||||
params.fields.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
)
|
||||
.map_err(|err| (err.0, err.1).into_response())?;
|
||||
|
||||
let travel_entries = parse_optional_travel(params.travel.as_deref())
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
|
|
@ -123,12 +130,18 @@ pub async fn get_postcodes(
|
|||
let min_keys = &state.min_keys;
|
||||
let max_keys = &state.max_keys;
|
||||
let avg_keys = &state.avg_keys;
|
||||
let poi_metrics = &state.data.poi_metrics;
|
||||
let poi_field_indices = field_indices.poi.as_slice();
|
||||
let has_poi_fields = !poi_field_indices.is_empty();
|
||||
let has_poi_filters = !parsed_poi_filters.is_empty();
|
||||
let poi_num_features = poi_metrics.num_features();
|
||||
|
||||
let has_selective = field_indices.is_some();
|
||||
let sel_indices = field_indices.as_deref().unwrap_or(&[]);
|
||||
let has_selective = field_indices.normal.is_some();
|
||||
let sel_indices = field_indices.normal.as_deref().unwrap_or(&[]);
|
||||
|
||||
// Single-pass: aggregate directly into postcode_aggs while iterating properties in bounds
|
||||
let mut postcode_aggs: FxHashMap<usize, Aggregator> = FxHashMap::default();
|
||||
let mut poi_aggs: FxHashMap<usize, PoiAggregator> = FxHashMap::default();
|
||||
|
||||
state
|
||||
.grid
|
||||
|
|
@ -143,6 +156,10 @@ pub async fn get_postcodes(
|
|||
) {
|
||||
return;
|
||||
}
|
||||
if has_poi_filters && !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
let postcode = state.data.postcode(row);
|
||||
if let Some(&pc_idx) = postcode_data.postcode_to_idx.get(postcode) {
|
||||
|
|
@ -154,6 +171,12 @@ pub async fn get_postcodes(
|
|||
} else {
|
||||
agg.add_row(feature_data, row, num_features, &quant);
|
||||
}
|
||||
if has_poi_fields {
|
||||
poi_aggs
|
||||
.entry(pc_idx)
|
||||
.or_insert_with(|| PoiAggregator::new(poi_num_features))
|
||||
.add_row_selective(poi_metrics, row, poi_field_indices);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -250,11 +273,12 @@ pub async fn get_postcodes(
|
|||
]),
|
||||
);
|
||||
|
||||
let iter: Box<dyn Iterator<Item = usize>> = if let Some(idx) = field_indices.as_ref() {
|
||||
Box::new(idx.iter().copied())
|
||||
} else {
|
||||
Box::new(0..num_features)
|
||||
};
|
||||
let iter: Box<dyn Iterator<Item = usize>> =
|
||||
if let Some(idx) = field_indices.normal.as_ref() {
|
||||
Box::new(idx.iter().copied())
|
||||
} else {
|
||||
Box::new(0..num_features)
|
||||
};
|
||||
|
||||
for feat_index in iter {
|
||||
if aggregation.feat_counts[feat_index] > 0 {
|
||||
|
|
@ -272,6 +296,25 @@ pub async fn get_postcodes(
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(poi_aggregation) = poi_aggs.get(&pc_idx) {
|
||||
for &metric_idx in poi_field_indices {
|
||||
if poi_aggregation.counts[metric_idx] > 0 {
|
||||
let avg = poi_aggregation.sums[metric_idx]
|
||||
/ poi_aggregation.counts[metric_idx] as f64;
|
||||
if let (Some(min_num), Some(max_num), Some(avg_num)) = (
|
||||
serde_json::Number::from_f64(poi_aggregation.mins[metric_idx] as f64),
|
||||
serde_json::Number::from_f64(poi_aggregation.maxs[metric_idx] as f64),
|
||||
serde_json::Number::from_f64(avg),
|
||||
) {
|
||||
let name = &poi_metrics.feature_names[metric_idx];
|
||||
props.insert(format!("min_{name}"), Value::Number(min_num));
|
||||
props.insert(format!("max_{name}"), Value::Number(max_num));
|
||||
props.insert(format!("avg_{name}"), Value::Number(avg_num));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add travel time aggregation fields
|
||||
if let Some(tt_aggs) = travel_aggs.get(&pc_idx) {
|
||||
for (ti, agg) in tt_aggs.iter().enumerate() {
|
||||
|
|
@ -322,7 +365,11 @@ pub async fn get_postcodes(
|
|||
bounds = format_args!("{:.6},{:.6},{:.6},{:.6}", south, west, north, east),
|
||||
filters = num_filters,
|
||||
filters_raw = filters_str.as_deref().unwrap_or("-"),
|
||||
fields = field_indices.as_ref().map(|v| v.len() as i32).unwrap_or(-1),
|
||||
fields = field_indices
|
||||
.normal
|
||||
.as_ref()
|
||||
.map(|v| (v.len() + poi_field_indices.len()) as i32)
|
||||
.unwrap_or(-1),
|
||||
travel_entries = travel_entries.len(),
|
||||
agg_ms = format_args!("{:.1}", t_agg.as_secs_f64() * 1000.0),
|
||||
json_ms = format_args!("{:.1}", (t_total - t_agg).as_secs_f64() * 1000.0),
|
||||
|
|
|
|||
|
|
@ -14,8 +14,8 @@ use crate::consts::{DEFAULT_PROPERTIES_LIMIT, MAX_PROPERTIES_LIMIT};
|
|||
use crate::data::RenovationEvent;
|
||||
use crate::licensing::{check_license_bounds, resolve_share_code};
|
||||
use crate::parsing::{
|
||||
cell_for_row_cached, h3_cell_bounds, needs_parent, parse_filters, row_passes_filters,
|
||||
validate_h3_resolution,
|
||||
cell_for_row_cached, h3_cell_bounds, needs_parent, parse_filters_with_poi, row_passes_filters,
|
||||
row_passes_poi_filters, validate_h3_resolution,
|
||||
};
|
||||
use crate::state::{AppState, SharedState};
|
||||
|
||||
|
|
@ -117,6 +117,12 @@ pub fn build_property(
|
|||
features.insert(feat_name.clone(), value);
|
||||
}
|
||||
}
|
||||
for (metric_idx, metric_name) in state.data.poi_metrics.feature_names.iter().enumerate() {
|
||||
let value = state.data.poi_metrics.get_for_property_row(row, metric_idx);
|
||||
if value.is_finite() {
|
||||
features.insert(metric_name.clone(), value);
|
||||
}
|
||||
}
|
||||
|
||||
Property {
|
||||
address: non_empty_string(state.data.address(row)),
|
||||
|
|
@ -199,15 +205,19 @@ pub async fn get_hexagon_properties(
|
|||
|
||||
let h3_str = params.h3;
|
||||
let quant = state.data.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters) = parse_filters(
|
||||
let poi_quant = state.data.poi_metrics.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
|
||||
params.filters.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
&quant,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
&poi_quant,
|
||||
)
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
|
||||
let filters_str = params.filters;
|
||||
let has_poi_filters = !parsed_poi_filters.is_empty();
|
||||
let travel_entries = parse_optional_travel(params.travel.as_deref())
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
|
||||
|
|
@ -242,6 +252,12 @@ pub async fn get_hexagon_properties(
|
|||
feature_data,
|
||||
num_features,
|
||||
)
|
||||
&& (!has_poi_filters
|
||||
|| row_passes_poi_filters(
|
||||
row,
|
||||
&parsed_poi_filters,
|
||||
&state.data.poi_metrics,
|
||||
))
|
||||
{
|
||||
if has_travel {
|
||||
let postcode = state.data.postcode(row);
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use rustc_hash::FxHashMap;
|
|||
use tracing::warn;
|
||||
|
||||
use crate::consts::MAX_PRICE_HISTORY_POINTS;
|
||||
use crate::data::{FeatureStats, PropertyData};
|
||||
use crate::data::{FeatureStats, PostcodePoiMetrics, PropertyData};
|
||||
|
||||
use super::hexagon_stats::{EnumFeatureStats, HistogramStats, NumericFeatureStats, PricePoint};
|
||||
|
||||
|
|
@ -243,3 +243,80 @@ pub fn compute_feature_stats(
|
|||
|
||||
(numeric_features, enum_features_out)
|
||||
}
|
||||
|
||||
pub fn compute_poi_feature_stats(
|
||||
matching_rows: &[usize],
|
||||
poi_metrics: &PostcodePoiMetrics,
|
||||
fields_specified: bool,
|
||||
field_set: &HashSet<String>,
|
||||
) -> Vec<NumericFeatureStats> {
|
||||
let mut out = Vec::new();
|
||||
for (metric_idx, name) in poi_metrics.feature_names.iter().enumerate() {
|
||||
if fields_specified && !field_set.contains(name.as_str()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let global_hist = &poi_metrics.feature_stats[metric_idx].histogram;
|
||||
let p1 = global_hist.p1;
|
||||
let p99 = global_hist.p99;
|
||||
let num_bins = global_hist.counts.len();
|
||||
let middle_bins = num_bins.saturating_sub(2);
|
||||
let middle_width = if middle_bins > 0 && p99 > p1 {
|
||||
(p99 - p1) / middle_bins as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let mut count = 0usize;
|
||||
let mut min_value = f32::INFINITY;
|
||||
let mut max_value = f32::NEG_INFINITY;
|
||||
let mut sum = 0.0f64;
|
||||
let mut bins = vec![0u64; num_bins];
|
||||
|
||||
for &row in matching_rows {
|
||||
let value = poi_metrics.get_for_property_row(row, metric_idx);
|
||||
if !value.is_finite() {
|
||||
continue;
|
||||
}
|
||||
count += 1;
|
||||
if value < min_value {
|
||||
min_value = value;
|
||||
}
|
||||
if value > max_value {
|
||||
max_value = value;
|
||||
}
|
||||
sum += value as f64;
|
||||
|
||||
let bin = if value < p1 {
|
||||
0
|
||||
} else if value >= p99 {
|
||||
num_bins - 1
|
||||
} else if middle_width > 0.0 {
|
||||
let middle_bin = ((value - p1) / middle_width) as usize;
|
||||
(1 + middle_bin).min(num_bins - 2)
|
||||
} else {
|
||||
num_bins / 2
|
||||
};
|
||||
bins[bin] += 1;
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
out.push(NumericFeatureStats {
|
||||
name: name.clone(),
|
||||
count,
|
||||
min: min_value as f64,
|
||||
max: max_value as f64,
|
||||
mean: sum / count as f64,
|
||||
histogram: HistogramStats {
|
||||
min: global_hist.min as f64,
|
||||
max: global_hist.max as f64,
|
||||
p1: p1 as f64,
|
||||
p99: p99 as f64,
|
||||
counts: bins,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
out
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,78 +1,40 @@
|
|||
use std::collections::VecDeque;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::body::Bytes;
|
||||
use axum::extract::State;
|
||||
use axum::http::{HeaderMap, StatusCode};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use hmac::{Hmac, Mac};
|
||||
use parking_lot::Mutex;
|
||||
use rustc_hash::FxHashSet;
|
||||
use hmac::{Hmac, KeyInit, Mac};
|
||||
use sha2::Sha256;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::pocketbase::get_superuser_token;
|
||||
use crate::checkout_sessions::{
|
||||
grant_license, mark_checkout_completed, mark_referral_invite_used, verify_checkout_completion,
|
||||
CheckoutCompletion,
|
||||
};
|
||||
use crate::state::SharedState;
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
/// Process-local LRU of recently processed Stripe event IDs.
|
||||
/// Stripe retries deliver the same event ID; we drop duplicates so we don't
|
||||
/// re-run side effects (subscription writes, token cache invalidation, logs).
|
||||
/// Capacity is intentionally generous: at typical webhook volumes this covers
|
||||
/// far more than Stripe's retry window.
|
||||
struct EventDedup {
|
||||
seen: FxHashSet<String>,
|
||||
queue: VecDeque<String>,
|
||||
capacity: usize,
|
||||
}
|
||||
|
||||
impl EventDedup {
|
||||
fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
seen: FxHashSet::default(),
|
||||
queue: VecDeque::with_capacity(capacity),
|
||||
capacity,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` if this event ID is new (and records it),
|
||||
/// `false` if it was already seen recently.
|
||||
fn check_and_insert(&mut self, id: &str) -> bool {
|
||||
if self.seen.contains(id) {
|
||||
return false;
|
||||
}
|
||||
self.seen.insert(id.to_string());
|
||||
self.queue.push_back(id.to_string());
|
||||
if self.queue.len() > self.capacity {
|
||||
if let Some(old) = self.queue.pop_front() {
|
||||
self.seen.remove(&old);
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
static EVENT_DEDUP: LazyLock<Mutex<EventDedup>> =
|
||||
LazyLock::new(|| Mutex::new(EventDedup::new(1024)));
|
||||
|
||||
/// Verify Stripe webhook signature (v1 scheme).
|
||||
fn verify_signature(payload: &[u8], sig_header: &str, secret: &str) -> bool {
|
||||
// Parse timestamp and signature from header: "t=TIMESTAMP,v1=SIGNATURE"
|
||||
let mut timestamp = None;
|
||||
let mut signature = None;
|
||||
let mut signatures = Vec::new();
|
||||
for part in sig_header.split(',') {
|
||||
if let Some(ts) = part.strip_prefix("t=") {
|
||||
timestamp = Some(ts);
|
||||
} else if let Some(sig) = part.strip_prefix("v1=") {
|
||||
signature = Some(sig);
|
||||
signatures.push(sig);
|
||||
}
|
||||
}
|
||||
|
||||
let (ts, sig_hex) = match (timestamp, signature) {
|
||||
(Some(t), Some(s)) => (t, s),
|
||||
_ => return false,
|
||||
let Some(ts) = timestamp else {
|
||||
return false;
|
||||
};
|
||||
if signatures.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Reject webhooks older than 5 minutes to prevent replay attacks
|
||||
if let Ok(ts_secs) = ts.parse::<i64>() {
|
||||
|
|
@ -87,20 +49,21 @@ fn verify_signature(payload: &[u8], sig_header: &str, secret: &str) -> bool {
|
|||
return false;
|
||||
}
|
||||
|
||||
// Compute expected signature: HMAC-SHA256(secret, "TIMESTAMP.PAYLOAD")
|
||||
let signed_payload = format!("{ts}.{}", String::from_utf8_lossy(payload));
|
||||
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) {
|
||||
Ok(m) => m,
|
||||
Err(_) => return false,
|
||||
};
|
||||
mac.update(signed_payload.as_bytes());
|
||||
let mut signed_payload = Vec::with_capacity(ts.len() + 1 + payload.len());
|
||||
signed_payload.extend_from_slice(ts.as_bytes());
|
||||
signed_payload.push(b'.');
|
||||
signed_payload.extend_from_slice(payload);
|
||||
|
||||
// Decode the provided hex signature and verify with constant-time comparison
|
||||
let sig_bytes = match hex::decode(sig_hex) {
|
||||
Ok(bytes) => bytes,
|
||||
Err(_) => return false,
|
||||
};
|
||||
mac.verify_slice(&sig_bytes).is_ok()
|
||||
signatures.into_iter().any(|sig_hex| {
|
||||
let Ok(sig_bytes) = hex::decode(sig_hex) else {
|
||||
return false;
|
||||
};
|
||||
let Ok(mut mac) = HmacSha256::new_from_slice(secret.as_bytes()) else {
|
||||
return false;
|
||||
};
|
||||
mac.update(&signed_payload);
|
||||
mac.verify_slice(&sig_bytes).is_ok()
|
||||
})
|
||||
}
|
||||
|
||||
/// Handle Stripe webhook events.
|
||||
|
|
@ -140,65 +103,64 @@ pub async fn post_stripe_webhook(
|
|||
let event_type = event["type"].as_str().unwrap_or("");
|
||||
let event_id = event["id"].as_str().unwrap_or("");
|
||||
|
||||
// Idempotency: drop replays/retries of an already-processed event.
|
||||
// We always answer 200 so Stripe stops retrying.
|
||||
if !event_id.is_empty() && !EVENT_DEDUP.lock().check_and_insert(event_id) {
|
||||
info!(event_id, event_type, "Dropping duplicate Stripe webhook");
|
||||
return StatusCode::OK.into_response();
|
||||
}
|
||||
|
||||
info!(event_id, event_type, "Received Stripe webhook");
|
||||
|
||||
if event_type == "checkout.session.completed" {
|
||||
let user_id = event["data"]["object"]["client_reference_id"]
|
||||
.as_str()
|
||||
.unwrap_or("");
|
||||
if user_id.is_empty() {
|
||||
warn!("checkout.session.completed missing client_reference_id");
|
||||
return StatusCode::OK.into_response();
|
||||
}
|
||||
if !user_id.bytes().all(|b| b.is_ascii_alphanumeric()) || user_id.len() > 20 {
|
||||
warn!(user_id, "Invalid client_reference_id format in webhook");
|
||||
return StatusCode::BAD_REQUEST.into_response();
|
||||
}
|
||||
|
||||
// Update user subscription to "licensed" via PocketBase superuser auth
|
||||
let token = match get_superuser_token(&state).await {
|
||||
Ok(t) => t,
|
||||
Err(err) => {
|
||||
warn!("Failed to auth as PocketBase superuser in webhook: {err}");
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
|
||||
let res = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({ "subscription": "licensed" }))
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
state.token_cache.invalidate_by_user_id(user_id);
|
||||
let session = &event["data"]["object"];
|
||||
match verify_checkout_completion(&state, session).await {
|
||||
Ok(CheckoutCompletion::Grant(checkout)) => {
|
||||
if let Err(err) = mark_referral_invite_used(
|
||||
&state,
|
||||
&checkout.referral_invite_id,
|
||||
&checkout.user_id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!(
|
||||
user_id = %checkout.user_id,
|
||||
reservation_id = %checkout.reservation_id,
|
||||
referral_invite_id = %checkout.referral_invite_id,
|
||||
"Failed to mark referral invite used after Stripe checkout: {err:?}"
|
||||
);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
if let Err(err) = grant_license(&state, &checkout.user_id).await {
|
||||
warn!(
|
||||
user_id = %checkout.user_id,
|
||||
reservation_id = %checkout.reservation_id,
|
||||
"Failed to grant license after Stripe checkout: {err:?}"
|
||||
);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
if let Err(err) = mark_checkout_completed(
|
||||
&state,
|
||||
&checkout.reservation_id,
|
||||
checkout.paid_amount_pence,
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!(
|
||||
user_id = %checkout.user_id,
|
||||
reservation_id = %checkout.reservation_id,
|
||||
"Failed to mark checkout completed after license grant: {err:?}"
|
||||
);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
info!(
|
||||
user_id,
|
||||
"User subscription updated to licensed via Stripe webhook"
|
||||
user_id = %checkout.user_id,
|
||||
reservation_id = %checkout.reservation_id,
|
||||
"User subscription updated to licensed via verified Stripe checkout"
|
||||
);
|
||||
}
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
warn!(
|
||||
user_id,
|
||||
"Failed to update user subscription ({status}): {text}"
|
||||
);
|
||||
Ok(CheckoutCompletion::AlreadyHandled) => {
|
||||
info!("Stripe checkout session was already handled");
|
||||
}
|
||||
Ok(CheckoutCompletion::Rejected(reason)) => {
|
||||
warn!("Rejecting Stripe checkout completion: {reason}");
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(user_id, "PocketBase request error in webhook: {err}");
|
||||
warn!("Failed to verify Stripe checkout completion: {err:?}");
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue