Test changes
Some checks failed
Build and publish Docker image / build-and-push (push) Failing after 8m20s
CI / Check (push) Failing after 10m40s

This commit is contained in:
Andras Schmelczer 2026-05-09 11:35:38 +01:00
parent 4c95815dc8
commit be02fc16bb
41 changed files with 4224 additions and 759 deletions

View file

@ -1,9 +1,15 @@
import argparse import argparse
import base64
import json
import re
import sys import sys
import urllib.request import urllib.request
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from io import BytesIO
from pathlib import Path from pathlib import Path
from PIL import Image, ImageDraw
from pipeline.transform.transform_poi import NAPTAN_EMOJIS, _CATEGORIES from pipeline.transform.transform_poi import NAPTAN_EMOJIS, _CATEGORIES
GLYPHS_BASE = "https://protomaps.github.io/basemaps-assets/fonts" GLYPHS_BASE = "https://protomaps.github.io/basemaps-assets/fonts"
@ -14,53 +20,80 @@ POI_ICON_BASE = "https://geolytix.github.io/MapIcons"
# Font stacks used by @protomaps/basemaps with lang='en' # Font stacks used by @protomaps/basemaps with lang='en'
FONT_STACKS = ["Noto Sans Regular", "Noto Sans Italic", "Noto Sans Medium"] FONT_STACKS = ["Noto Sans Regular", "Noto Sans Italic", "Noto Sans Medium"]
# Fallback emoji not in any category
_FALLBACK_EMOJIS = ["📍"]
POI_ICON_PATHS = [ POI_ICON_PATHS = [
"asda/asda_express_24px.svg", "brands_2023/supermarkets/farmfoods.svg",
"asda/asda_green_basket_24px.svg", "brands_2023/supermarkets/heron_foods.svg",
"asda/asda_green_trolley_24px.svg", "brands_2023/supermarkets/little_waitrose.svg",
"asda/asda_living_24px.svg", "brands_2024/amazon_fresh.svg",
"asda/asda_pfs_24px.svg", "brands_2024/booths.svg",
"asda/asda_primary.svg", "brands_2024/budgens.svg",
"asda/asda_superstore_green_trolley_24px.svg", "brands_2024/cook.svg",
"brands/aldi_24px.svg", "brands_2024/dunnes_stores.svg",
"brands/amazon_fresh_alt_24px.svg", "brands_2024/iceland.svg",
"brands/booths_24px.svg", "brands_2024/makro.svg",
"brands/budgens_24px.svg", "brands_2024/mns.svg",
"brands/centra_24px.svg", "brands_2024/morrisons_daily.svg",
"brands/cook.svg", "brands_2024/sainsburys_local.svg",
"brands/coop_24px.svg", "brands_2024/wholefoods.svg",
"brands/costco_24px.svg", "logos/aldi.svg",
"brands/dunnes_stores_24px.svg", "logos/asda.svg",
"brands/farmfoods_updated_24px.svg", "logos/centra.svg",
"brands/heron_24px.svg", "logos/coop.svg",
"brands/iceland_24px.svg", "logos/lidl.svg",
"brands/iceland_food_warehouse_24px.svg", "logos/morrisons.svg",
"brands/lidl_24px.svg", "logos/planet_organic.svg",
"brands/little_waitrose_24px.svg", "logos/sainsburys.svg",
"brands/makro_24px.svg", "logos/spar.svg",
"brands/mns_24px.svg", "logos/tesco.svg",
"brands/mns_food_24px.svg", "logos/tesco_express.svg",
"brands/mns_high_street_24px.svg", "logos/tesco_extra.svg",
"brands/mns_hospital_24px.svg", "logos/waitrose.svg",
"brands/mns_moto_24px.svg",
"brands/mns_outlet_24px.svg",
"brands/morrisons_24px.svg",
"brands/morrisons_daily_24px.svg",
"brands/sainsburys_24px.svg",
"brands/sainsburys_local_24px.svg",
"brands/spar_24px.svg",
"brands/tesco_24px.svg",
"brands/tesco_express_24px.svg",
"brands/tesco_extra_24px.svg",
"brands/waitrose_24px.svg",
"brands/wholefoods_24px.svg",
"logos/planet_organic_24px.svg",
"public_transport/london_tube.svg", "public_transport/london_tube.svg",
"visuals/mns.svg",
] ]
DERIVED_POI_ICON_PATHS = [
("costco_logo", "brands/costco.svg", "logos/costco.svg"),
(
"embedded_png",
"brands/iceland_food_warehouse_24px.svg",
"logos/the_food_warehouse.png",
),
]
POI_ICON_SVG_CROPS = {
"brands_2023/supermarkets/farmfoods.svg": (1.293, 7.314, 15.48, 3.293),
"brands_2023/supermarkets/heron_foods.svg": (0.062, 6.68, 17.995, 5.325),
"brands_2023/supermarkets/little_waitrose.svg": (0.916, 5.645, 16.365, 6.719),
"brands_2024/amazon_fresh.svg": (3.817, 1.646, 16.367, 16.358),
"brands_2024/booths.svg": (1.456, 7.143, 15.313, 3.512),
"brands_2024/budgens.svg": (2.251, 2.278, 13.6, 13.612),
"brands_2024/cook.svg": (5.028, 5.493, 13.945, 9.648),
"brands_2024/dunnes_stores.svg": (4.375, 7.732, 15.249, 5.055),
"brands_2024/iceland.svg": (1.136, 6.823, 16.067, 4.302),
"brands_2024/makro.svg": (4.411, 6.098, 16.397, 5.428),
"brands_2024/mns.svg": (4.042, 6.986, 16.171, 6.724),
"brands_2024/morrisons_daily.svg": (3.341, 4.414, 17.317, 8.248),
"brands_2024/sainsburys_local.svg": (4.58, 1.61, 14.84, 14.849),
"brands_2024/wholefoods.svg": (4.17, 2.193, 15.659, 15.668),
"logos/aldi.svg": (4.813, 2.563, 14.374, 14.383),
"logos/asda.svg": (3.91, 7.135, 16.181, 5.442),
"logos/centra.svg": (3.36, 7.35, 17.28, 4.651),
"logos/coop.svg": (6.407, 4.658, 11.187, 11.793),
"logos/costco.svg": (70.61, 144.908, 256.67, 85.825),
"logos/lidl.svg": (4.938, 2.973, 13.985, 13.985),
"logos/morrisons.svg": (5.231, 2.985, 13.538, 13.398),
"logos/planet_organic.svg": (5.528, 3.564, 12.943, 12.943),
"logos/sainsburys.svg": (7.502, 3.572, 8.996, 12.646),
"logos/spar.svg": (4.933, 2.968, 14.133, 13.853),
"logos/tesco.svg": (4.338, 6.865, 15.324, 5.359),
"logos/tesco_express.svg": (5.231, 5.933, 13.538, 8.345),
"logos/tesco_extra.svg": (4.933, 5.775, 14.133, 8.519),
"logos/waitrose.svg": (5.528, 6.09, 12.943, 9.855),
}
POI_ICON_SVG_INTRINSIC_MAX = 512
def collect_twemoji_codes() -> list[str]: def collect_twemoji_codes() -> list[str]:
"""Derive twemoji hex codes from transform_poi categories. """Derive twemoji hex codes from transform_poi categories.
@ -76,9 +109,6 @@ def collect_twemoji_codes() -> list[str]:
for emoji in NAPTAN_EMOJIS.values(): for emoji in NAPTAN_EMOJIS.values():
emojis.add(emoji) emojis.add(emoji)
for emoji in _FALLBACK_EMOJIS:
emojis.add(emoji)
# First codepoint hex, matching frontend logic # First codepoint hex, matching frontend logic
return sorted({f"{ord(e[0]):x}" for e in emojis}) return sorted({f"{ord(e[0]):x}" for e in emojis})
@ -97,6 +127,214 @@ def download_file(url: str, dest: Path) -> tuple[bool, str]:
return False, url return False, url
def download_text(url: str) -> str:
with urllib.request.urlopen(url) as response:
return response.read().decode("utf-8")
def build_costco_logo(marker_svg: str) -> str:
start = marker_svg.find('<g><path d=" M 316.312')
end = marker_svg.rfind("</g></g></svg>")
if start < 0 or end < 0:
raise ValueError("Costco marker SVG layout changed")
logo_group = marker_svg[start : end + 4]
return (
'<?xml version="1.0" encoding="UTF-8"?>\n'
'<svg xmlns="http://www.w3.org/2000/svg" viewBox="70 145 260 90" '
'width="260pt" height="90pt" preserveAspectRatio="xMidYMid meet">\n'
f"{logo_group}\n"
"</svg>\n"
)
def trim_white_png(png_bytes: bytes) -> bytes:
image = Image.open(BytesIO(png_bytes)).convert("RGBA")
pixels = image.load()
for y in range(image.height):
for x in range(image.width):
red, green, blue, alpha = pixels[x, y]
if red > 245 and green > 245 and blue > 245:
pixels[x, y] = (red, green, blue, 0)
alpha_box = image.getchannel("A").getbbox()
if alpha_box:
image = image.crop(alpha_box)
out = BytesIO()
image.save(out, format="PNG")
return out.getvalue()
def extract_embedded_png(marker_svg: str) -> bytes:
match = re.search(r"base64,([^\"']+)", marker_svg)
if not match:
raise ValueError("POI marker SVG did not contain an embedded PNG")
return trim_white_png(base64.b64decode(match.group(1)))
def svg_intrinsic_size(width: float, height: float) -> tuple[int, int]:
if width <= 0 or height <= 0:
return (POI_ICON_SVG_INTRINSIC_MAX, POI_ICON_SVG_INTRINSIC_MAX)
if width >= height:
return (
POI_ICON_SVG_INTRINSIC_MAX,
max(1, round(POI_ICON_SVG_INTRINSIC_MAX * height / width)),
)
return (
max(1, round(POI_ICON_SVG_INTRINSIC_MAX * width / height)),
POI_ICON_SVG_INTRINSIC_MAX,
)
def set_svg_geometry(svg_text: str, crop: tuple[float, float, float, float]) -> str:
x, y, width, height = crop
view_box = f"{x:g} {y:g} {width:g} {height:g}"
intrinsic_width, intrinsic_height = svg_intrinsic_size(width, height)
svg_text = re.sub(r'viewBox="[^"]+"', f'viewBox="{view_box}"', svg_text, count=1)
if 'viewBox="' not in svg_text:
svg_text = re.sub(r"<svg\b", f'<svg viewBox="{view_box}"', svg_text, count=1)
svg_text = re.sub(r'width="[^"]+"', f'width="{intrinsic_width}"', svg_text, count=1)
if 'width="' not in svg_text:
svg_text = re.sub(
r"<svg\b", f'<svg width="{intrinsic_width}"', svg_text, count=1
)
svg_text = re.sub(
r'height="[^"]+"', f'height="{intrinsic_height}"', svg_text, count=1
)
if 'height="' not in svg_text:
svg_text = re.sub(
r"<svg\b", f'<svg height="{intrinsic_height}"', svg_text, count=1
)
return svg_text
def get_svg_view_box(svg_text: str) -> tuple[float, float, float, float] | None:
match = re.search(r'viewBox="([^"]+)"', svg_text)
if not match:
return None
parts = [
float(part) for part in re.split(r"[\s,]+", match.group(1).strip()) if part
]
if len(parts) != 4:
return None
return (parts[0], parts[1], parts[2], parts[3])
def crop_poi_svg_icons(poi_icons_dir: Path) -> None:
for icon_path, crop in POI_ICON_SVG_CROPS.items():
dest = poi_icons_dir / icon_path
if not dest.exists():
continue
svg_text = dest.read_text(encoding="utf-8")
if icon_path == "brands_2024/dunnes_stores.svg":
svg_text = svg_text.replace('fill="#fffcfc"', 'fill="#111111"')
svg_text = svg_text.replace('fill="#fcfcfc"', 'fill="#111111"')
dest.write_text(set_svg_geometry(svg_text, crop), encoding="utf-8")
for dest in poi_icons_dir.rglob("*.svg"):
svg_text = dest.read_text(encoding="utf-8")
view_box = get_svg_view_box(svg_text)
if view_box:
dest.write_text(set_svg_geometry(svg_text, view_box), encoding="utf-8")
def download_derived_poi_icon(
kind: str, source_path: str, dest: Path
) -> tuple[bool, str]:
url = f"{POI_ICON_BASE}/{source_path}"
dest.parent.mkdir(parents=True, exist_ok=True)
try:
source = download_text(url)
if kind == "costco_logo":
dest.write_text(build_costco_logo(source), encoding="utf-8")
elif kind == "embedded_png":
dest.write_bytes(extract_embedded_png(source))
else:
raise ValueError(f"Unknown derived POI icon kind: {kind}")
return True, url
except urllib.error.HTTPError as e:
print(f" {e.code} {url}", file=sys.stderr)
return False, url
except Exception as e:
print(f" ERROR {url}: {e}", file=sys.stderr)
return False, url
# Slategray accent used by civic POI icons (school, library, building, …) in
# protomaps' v4 sprite. We match it so the townhall blends in with its peers.
_TOWNHALL_COLOR = {
"light": (135, 128, 171),
"dark": (118, 118, 127),
}
_TOWNHALL_LOGICAL_SIZE = 17
def _render_townhall_glyph(size_px: int, color: tuple[int, int, int]) -> Image.Image:
# Draw at 8× resolution and downsample with Lanczos so the pediment's
# diagonals come out anti-aliased; PIL's polygon fill is otherwise aliased.
super_factor = 8
canvas = size_px * super_factor
img = Image.new("RGBA", (canvas, canvas), (0, 0, 0, 0))
draw = ImageDraw.Draw(img)
fill = (*color, 255)
def s(v: float) -> float:
return v * canvas / _TOWNHALL_LOGICAL_SIZE
draw.polygon([(s(8.5), s(1)), (s(15), s(6.5)), (s(2), s(6.5))], fill=fill)
draw.rectangle([(s(1), s(6.5)), (s(16), s(8.5))], fill=fill)
for column_x in (3, 8, 13):
draw.rectangle([(s(column_x), s(8.5)), (s(column_x + 1.5), s(14))], fill=fill)
draw.rectangle([(s(0), s(14)), (s(17), s(15.5))], fill=fill)
return img.resize((size_px, size_px), Image.LANCZOS)
def inject_townhall_sprite(sprites_dir: Path) -> None:
"""Append a townhall glyph to each downloaded sprite sheet.
Protomaps' v4 sprite omits `townhall` even though the basemap style
references it; we add the icon here so MapLibre can resolve the name
natively at runtime.
"""
for theme in ("light", "dark"):
color = _TOWNHALL_COLOR[theme]
for suffix, scale in (("", 1), ("@2x", 2)):
json_path = sprites_dir / f"{theme}{suffix}.json"
png_path = sprites_dir / f"{theme}{suffix}.png"
if not json_path.exists() or not png_path.exists():
continue
manifest = json.loads(json_path.read_text())
sheet = Image.open(png_path).convert("RGBA")
glyph_size = _TOWNHALL_LOGICAL_SIZE * scale
glyph = _render_townhall_glyph(glyph_size, color)
new_width = max(sheet.width, glyph_size)
new_height = sheet.height + glyph_size
extended = Image.new("RGBA", (new_width, new_height), (0, 0, 0, 0))
extended.paste(sheet, (0, 0))
extended.paste(glyph, (0, sheet.height))
extended.save(png_path, optimize=True)
manifest["townhall"] = {
"x": 0,
"y": sheet.height,
"width": glyph_size,
"height": glyph_size,
"pixelRatio": scale,
}
json_path.write_text(json.dumps(manifest))
def main(): def main():
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
@ -147,7 +385,7 @@ def main():
# Skip already-downloaded files # Skip already-downloaded files
remaining = [(url, dest) for url, dest in tasks] remaining = [(url, dest) for url, dest in tasks]
print(f"Downloading {len(remaining)} assets") print(f"Downloading {len(remaining) + len(DERIVED_POI_ICON_PATHS)} assets")
ok = 0 ok = 0
fail = 0 fail = 0
@ -162,6 +400,18 @@ def main():
else: else:
fail += 1 fail += 1
for kind, source_path, dest_path in DERIVED_POI_ICON_PATHS:
success, _url = download_derived_poi_icon(
kind, source_path, poi_icons_dir / dest_path
)
if success:
ok += 1
else:
fail += 1
crop_poi_svg_icons(poi_icons_dir)
inject_townhall_sprite(sprites_dir)
print(f"Done: {ok} downloaded, {fail} failed") print(f"Done: {ok} downloaded, {fail} failed")

View file

@ -6,6 +6,7 @@ Reuses the same england-latest.osm.pbf as pois.py.
""" """
import argparse import argparse
import re
from pathlib import Path from pathlib import Path
import osmium import osmium
@ -44,11 +45,37 @@ _STATION_STRIP = (
" underground station", " underground station",
" railway station", " railway station",
" dlr station", " dlr station",
" station dlr",
" dlr",
" overground station", " overground station",
" tram stop", " tram stop",
" station", " station",
) )
_DLR_CODE_RE = re.compile(r"ZZDL([A-Z0-9]{3})")
def _is_dlr_station(tags: dict[str, str]) -> bool:
name = tags.get("name", "").lower()
network = tags.get("network", "").lower()
operator = tags.get("operator", "").lower()
return (
"docklands" in network
or "dlr" in network
or "docklands" in operator
or "dlr" in operator
or name.endswith(" dlr")
or " dlr " in name
)
def _is_tram_station(tags: dict[str, str]) -> bool:
if _is_dlr_station(tags):
return False
station_tag = tags.get("station", "")
network = tags.get("network", "").lower()
return station_tag == "light_rail" or "tramlink" in network or "tram" in network
def _station_display_name(name: str, tags: dict[str, str]) -> str: def _station_display_name(name: str, tags: dict[str, str]) -> str:
"""Build a descriptive station name like 'Bank tube station'.""" """Build a descriptive station name like 'Bank tube station'."""
@ -78,6 +105,96 @@ def _station_display_name(name: str, tags: dict[str, str]) -> str:
return f"{name} {suffix}" return f"{name} {suffix}"
def _station_name_score(name: str) -> tuple[int, int]:
lower = name.lower()
suffix_penalty = int(
lower.endswith(
(
" underground station",
" tube station",
" dlr station",
" railway station",
" rail station",
" station dlr",
" station",
)
)
or lower.endswith(" dlr")
)
return (suffix_penalty, len(name))
def _naptan_dlr_stations(naptan_path: Path) -> list[dict]:
"""Extract station-level DLR destinations from NaPTAN access nodes."""
df = pl.read_parquet(naptan_path)
required = {"id", "name", "category", "lat", "lng"}
missing = required - set(df.columns)
if missing:
raise ValueError(f"NaPTAN file is missing columns: {sorted(missing)}")
rows: dict[str, dict] = {}
for row in df.iter_rows(named=True):
atco_id = str(row["id"] or "")
match = _DLR_CODE_RE.search(atco_id)
if not match:
continue
if row["category"] not in {"Tube station", "Rail station"}:
continue
code = match.group(1)
raw_name = str(row["name"] or "")
if not raw_name:
continue
lat = float(row["lat"])
lon = float(row["lng"])
current = rows.get(code)
if current is None:
rows[code] = {
"raw_name": raw_name,
"lat_sum": lat,
"lon_sum": lon,
"count": 1,
}
continue
current["lat_sum"] += lat
current["lon_sum"] += lon
current["count"] += 1
if _station_name_score(raw_name) < _station_name_score(current["raw_name"]):
current["raw_name"] = raw_name
stations = []
for station in rows.values():
count = station["count"]
display_name = _station_display_name(station["raw_name"], {"network": "DLR"})
stations.append(
{
"name": display_name,
"place_type": "station",
"lat": station["lat_sum"] / count,
"lon": station["lon_sum"] / count,
"population": 0,
"travel_destination": True,
}
)
return sorted(stations, key=lambda station: station["name"])
def _append_naptan_dlr_stations(places: list[dict], naptan_path: Path) -> int:
existing_names = {str(place["name"]).casefold() for place in places}
added = 0
for station in _naptan_dlr_stations(naptan_path):
key = station["name"].casefold()
if key in existing_names:
continue
places.append(station)
existing_names.add(key)
added += 1
return added
class PlaceHandler(osmium.SimpleHandler): class PlaceHandler(osmium.SimpleHandler):
def __init__(self, progress: tqdm, england_polygon) -> None: def __init__(self, progress: tqdm, england_polygon) -> None:
super().__init__() super().__init__()
@ -145,14 +262,7 @@ class PlaceHandler(osmium.SimpleHandler):
# Railway stations (tube, national rail, DLR, overground, Elizabeth line) # Railway stations (tube, national rail, DLR, overground, Elizabeth line)
if n.tags.get("railway") == "station": if n.tags.get("railway") == "station":
tags = dict(n.tags) tags = dict(n.tags)
station_tag = tags.get("station", "") if _is_tram_station(tags):
network = tags.get("network", "").lower()
# Skip tram stops
if (
station_tag == "light_rail"
or "tramlink" in network
or "tram" in network
):
return return
display_name = _station_display_name(name, tags) display_name = _station_display_name(name, tags)
self._add( self._add(
@ -178,6 +288,11 @@ def main() -> None:
required=True, required=True,
help="England boundary GeoJSON file", help="England boundary GeoJSON file",
) )
parser.add_argument(
"--naptan",
type=Path,
help="Optional NaPTAN parquet file used to add DLR station destinations",
)
args = parser.parse_args() args = parser.parse_args()
pbf_file = args.pbf pbf_file = args.pbf
@ -195,6 +310,9 @@ def main() -> None:
handler.apply_file(str(pbf_file), locations=True) handler.apply_file(str(pbf_file), locations=True)
print(f"Extracted {len(handler.places):,} place nodes") print(f"Extracted {len(handler.places):,} place nodes")
if args.naptan:
added = _append_naptan_dlr_stations(handler.places, args.naptan)
print(f"Added {added:,} DLR station destinations from NaPTAN")
if handler.places: if handler.places:
df = pl.DataFrame(handler.places) df = pl.DataFrame(handler.places)

View file

@ -0,0 +1,81 @@
import polars as pl
from pipeline.download.places import (
_is_dlr_station,
_is_tram_station,
_naptan_dlr_stations,
_station_display_name,
)
def test_dlr_light_rail_is_not_treated_as_tram():
dlr_tags = {
"name": "Lewisham DLR",
"railway": "station",
"station": "light_rail",
"network": "Docklands Light Railway",
}
assert _is_dlr_station(dlr_tags)
assert not _is_tram_station(dlr_tags)
assert _station_display_name("Lewisham DLR", dlr_tags) == "Lewisham DLR station"
assert (
_station_display_name("Tower Gateway Station DLR", dlr_tags)
== "Tower Gateway DLR station"
)
def test_tram_light_rail_is_still_excluded():
tram_tags = {
"name": "East Croydon",
"railway": "station",
"station": "light_rail",
"network": "London Trams",
}
assert not _is_dlr_station(tram_tags)
assert _is_tram_station(tram_tags)
def test_naptan_dlr_stations_are_deduplicated_by_atco_code(tmp_path):
naptan = tmp_path / "naptan.parquet"
pl.DataFrame(
{
"id": [
"4900ZZDLSHA3",
"9400ZZDLSHA",
"4900ZZDLGRE1",
"490002076RV",
"4900ZZLUBNK",
],
"name": [
"Shadwell DLR",
"Shadwell DLR Station",
"Greenwich Station",
"Tower Gateway Station DLR",
"Bank",
],
"category": [
"Tube station",
"Tube station",
"Rail station",
"Bus stop",
"Tube station",
],
"lat": [51.51156, 51.511693, 51.47794, 51.510575, 51.5131],
"lng": [-0.055595, -0.056643, -0.01442, -0.07514, -0.0894],
}
).write_parquet(naptan)
stations = _naptan_dlr_stations(naptan)
assert [station["name"] for station in stations] == [
"Greenwich DLR station",
"Shadwell DLR station",
]
shadwell = next(
station for station in stations if station["name"].startswith("Shadwell")
)
assert shadwell["lat"] == (51.51156 + 51.511693) / 2
assert shadwell["place_type"] == "station"
assert shadwell["travel_destination"] is True

View file

@ -56,6 +56,7 @@ NR_AUTH_URL = "https://opendata.nationalrail.co.uk/authenticate"
NR_TIMETABLE_URL = "https://opendata.nationalrail.co.uk/api/staticfeeds/3.0/timetable" NR_TIMETABLE_URL = "https://opendata.nationalrail.co.uk/api/staticfeeds/3.0/timetable"
USER_AGENT = "property-map-pipeline/1.0 (https://github.com)" USER_AGENT = "property-map-pipeline/1.0 (https://github.com)"
TRANSXCHANGE2GTFS_PACKAGE = "transxchange2gtfs@1.12.0"
def _download_http( def _download_http(
@ -473,10 +474,50 @@ def convert_tfl_to_gtfs(raw_dir: Path, output_dir: Path) -> Path:
download_naptan() download_naptan()
print("Converting TfL TransXChange → GTFS...") print("Converting TfL TransXChange → GTFS...")
# The shim patches known packaging/runtime issues in the pinned npm package
# before loading its CLI from npx's temporary install.
shim_path = Path(__file__).with_name("transxchange2gtfs_shim.js")
subprocess.run( subprocess.run(
["npx", "--yes", "transxchange2gtfs", str(txc_path), str(dest)], [
"npx",
"--yes",
"--package",
TRANSXCHANGE2GTFS_PACKAGE,
"sh",
"-c",
"\n".join(
[
'bin="$(command -v transxchange2gtfs)"',
'script="$(readlink -f "$bin")"',
'pkg_dir="$(dirname "$(dirname "$script")")"',
'shim="$1"',
"shift",
'exec node "$shim" "$pkg_dir" "$@"',
]
),
"transxchange2gtfs",
str(shim_path.resolve()),
str(txc_path.resolve()),
str(dest.resolve()),
],
check=True, check=True,
) )
required_files = {
"agency.txt",
"calendar.txt",
"calendar_dates.txt",
"routes.txt",
"stop_times.txt",
"stops.txt",
"trips.txt",
}
if not dest.exists() or not zipfile.is_zipfile(dest):
raise RuntimeError(f"transxchange2gtfs did not create a valid GTFS zip: {dest}")
with zipfile.ZipFile(dest) as z:
missing = required_files - set(z.namelist())
if missing:
missing_str = ", ".join(sorted(missing))
raise RuntimeError(f"TfL GTFS zip is missing required files: {missing_str}")
size_mb = dest.stat().st_size / (1024 * 1024) size_mb = dest.stat().st_size / (1024 * 1024)
print(f" Saved to {dest} ({size_mb:.1f} MB)") print(f" Saved to {dest} ({size_mb:.1f} MB)")
return dest return dest

View file

@ -0,0 +1,76 @@
#!/usr/bin/env node
"use strict";
const fs = require("fs");
const path = require("path");
const { createRequire } = require("module");
const [pkgDirArg, ...converterArgs] = process.argv.slice(2);
if (!pkgDirArg || converterArgs.length < 2) {
console.error(
"Usage: transxchange2gtfs_shim.js <package-dir> <input...> <output>",
);
process.exit(2);
}
const pkgDir = path.resolve(pkgDirArg);
function replaceOnce(relativePath, before, after) {
const file = path.join(pkgDir, relativePath);
const original = fs.readFileSync(file, "utf8");
if (original.includes(before)) {
fs.writeFileSync(file, original.replace(before, after));
} else if (original.includes(after)) {
return;
} else {
throw new Error(`Could not patch ${relativePath}: expected text not found`);
}
}
// The published 1.12.0 package has a few compatibility issues with current
// TfL TransXChange exports:
// - the bin script points at dist/src/cli.js, but the package ships dist/cli.js
// - the compiled date-holidays import expects a synthetic default export
// - some TfL journeys reference timing links without matching route-link geometry
//
// GTFS shapes are optional for R5 routing. Clear shape references and omit
// shapes.txt so missing route geometry does not drop otherwise usable trips.
function patchPackage() {
replaceOnce(
"dist/transxchange/TransXChangeJourneyStream.js",
"distanceSoFarM += routeLink.Distance;",
"distanceSoFarM += routeLink ? routeLink.Distance : 0;",
);
replaceOnce(
"dist/gtfs/TripsStream.js",
"(0, crypto_1.createHash)('md5').update(JSON.stringify({ routeId: journey.route, routeLinkSeq: journey.routeLinkIds })).digest(\"hex\"));",
"\"\");",
);
replaceOnce(
"dist/gtfs/StopTimesStream.js",
"stop.shapeDistTraveled, stop.exactTime ? \"1\" : \"0\");",
"\"\", stop.exactTime ? \"1\" : \"0\");",
);
replaceOnce(
"dist/Container.js",
"\"stops.txt\": transxchange.pipe(new StopsStream_1.StopsStream(naptanIndex)),\n \"shapes.txt\": journeyStream.pipe(new ShapesStream_1.ShapesStream())",
"\"stops.txt\": transxchange.pipe(new StopsStream_1.StopsStream(naptanIndex))",
);
replaceOnce(
"dist/Container.js",
"\"routes.txt\": transxchange.pipe(new RoutesStream_1.RoutesStream()),\n \"transfers.txt\": transxchange.pipe(new TransfersStream_1.TransfersStream(naptanIndex, locationIndex)),\n \"stops.txt\": transxchange.pipe(new StopsStream_1.StopsStream(naptanIndex))",
"\"routes.txt\": transxchange.pipe(new RoutesStream_1.RoutesStream()),\n \"stops.txt\": transxchange.pipe(new StopsStream_1.StopsStream(naptanIndex))",
);
}
patchPackage();
const pkgRequire = createRequire(path.join(pkgDir, "package.json"));
const Holidays = pkgRequire("date-holidays");
if (!Holidays.default) {
Holidays.default = Holidays;
}
process.argv = [process.argv[0], "transxchange2gtfs", ...converterArgs];
require(path.join(pkgDir, "dist", "cli.js"));

View file

@ -7,6 +7,15 @@ from pipeline.utils.postcode_mapping import build_postcode_mapping
MIN_FLOOR_AREA_M2 = 10 MIN_FLOOR_AREA_M2 = 10
_IOD_PERCENTILE_COLUMNS = [
"Education, Skills and Training Score",
"Income Score (rate)",
"Employment Score (rate)",
"Health Deprivation and Disability Score",
"Indoors Sub-domain Score",
"Outdoors Sub-domain Score",
]
_AREA_COLUMNS = [ _AREA_COLUMNS = [
"Postcode", "Postcode",
@ -51,6 +60,14 @@ _AREA_COLUMNS = [
"Number of parks within 1km", "Number of parks within 1km",
"Distance to nearest train or tube station (km)", "Distance to nearest train or tube station (km)",
"Distance to nearest park (km)", "Distance to nearest park (km)",
"Distance to nearest grocery store (km)",
"Distance to nearest tube station (km)",
"Distance to nearest rail station (km)",
"Distance to nearest Waitrose (km)",
"Distance to nearest Tesco (km)",
"Distance to nearest cafe (km)",
"Distance to nearest pub (km)",
"Distance to nearest restaurant (km)",
# Environment # Environment
"Noise (dB)", "Noise (dB)",
"Max available download speed (Mbps)", "Max available download speed (Mbps)",
@ -76,6 +93,34 @@ _AREA_COLUMNS = [
] ]
def _is_dynamic_poi_metric_column(column: str) -> bool:
return (
column.startswith("Distance to nearest ")
and column.endswith(" POI (km)")
) or (
column.startswith("Number of ")
and (column.endswith(" POIs within 2km") or column.endswith(" POIs within 5km"))
)
def _less_deprived_percentile_expr(column: str) -> pl.Expr:
"""Convert an IoD deprivation score to a 0-100 less-deprived percentile."""
non_null_count = pl.col(column).count()
descending_rank = pl.col(column).rank("average", descending=True)
return (
pl.when(pl.col(column).is_null())
.then(None)
.when(pl.col(column) == pl.col(column).min())
.then(100.0)
.when(pl.col(column) == pl.col(column).max())
.then(0.0)
.when(non_null_count > 1)
.then(((descending_rank - 1) / (non_null_count - 1) * 100).round(1))
.otherwise(100.0)
.alias(column)
)
def _build( def _build(
epc_pp_path: Path, epc_pp_path: Path,
arcgis_path: Path, arcgis_path: Path,
@ -134,20 +179,11 @@ def _build(
) )
wide = wide.join(arcgis, on="postcode", how="left") wide = wide.join(arcgis, on="postcode", how="left")
iod = pl.scan_parquet(iod_path) iod = pl.scan_parquet(iod_path).with_columns(
*(_less_deprived_percentile_expr(c) for c in _IOD_PERCENTILE_COLUMNS)
)
wide = wide.join(iod, left_on="lsoa21", right_on="LSOA code (2021)", how="left") wide = wide.join(iod, left_on="lsoa21", right_on="LSOA code (2021)", how="left")
# Invert deprivation scores so that higher values = less deprived (better)
iod_score_cols = [
"Education, Skills and Training Score",
"Income Score (rate)",
"Employment Score (rate)",
"Health Deprivation and Disability Score",
"Indoors Sub-domain Score",
"Outdoors Sub-domain Score",
]
wide = wide.with_columns(*(pl.col(c).max() - pl.col(c) for c in iod_score_cols))
ethnicity = pl.scan_parquet(ethnicity_path) ethnicity = pl.scan_parquet(ethnicity_path)
wide = wide.join( wide = wide.join(
ethnicity, ethnicity,
@ -351,6 +387,14 @@ def _build(
"parks_1km": "Number of parks within 1km", "parks_1km": "Number of parks within 1km",
"train_tube_nearest_km": "Distance to nearest train or tube station (km)", "train_tube_nearest_km": "Distance to nearest train or tube station (km)",
"parks_nearest_km": "Distance to nearest park (km)", "parks_nearest_km": "Distance to nearest park (km)",
"grocery_store_nearest_km": "Distance to nearest grocery store (km)",
"tube_station_nearest_km": "Distance to nearest tube station (km)",
"rail_station_nearest_km": "Distance to nearest rail station (km)",
"waitrose_nearest_km": "Distance to nearest Waitrose (km)",
"tesco_nearest_km": "Distance to nearest Tesco (km)",
"cafe_nearest_km": "Distance to nearest cafe (km)",
"pub_nearest_km": "Distance to nearest pub (km)",
"restaurant_nearest_km": "Distance to nearest restaurant (km)",
"latest_price": "Last known price", "latest_price": "Last known price",
"number_habitable_rooms": "Number of bedrooms & living rooms", "number_habitable_rooms": "Number of bedrooms & living rooms",
"noise_lden_db": "Noise (dB)", "noise_lden_db": "Noise (dB)",
@ -381,10 +425,14 @@ def _build(
# Split into postcode-level and property-level dataframes # Split into postcode-level and property-level dataframes
area_cols = [c for c in _AREA_COLUMNS if c in df.columns] area_cols = [c for c in _AREA_COLUMNS if c in df.columns]
area_cols.extend(
c for c in df.columns if _is_dynamic_poi_metric_column(c) and c not in area_cols
)
area_col_set = set(area_cols)
postcode_df = df.select(area_cols).group_by("Postcode").first() postcode_df = df.select(area_cols).group_by("Postcode").first()
print(f"Postcode rows: {postcode_df.height} (unique postcodes)") print(f"Postcode rows: {postcode_df.height} (unique postcodes)")
property_cols = [c for c in df.columns if c not in _AREA_COLUMNS or c == "Postcode"] property_cols = [c for c in df.columns if c not in area_col_set or c == "Postcode"]
properties_df = df.select(property_cols) properties_df = df.select(property_cols)
print(f"Property rows: {properties_df.height}") print(f"Property rows: {properties_df.height}")

View file

@ -1,6 +1,8 @@
"""Compute POI proximity counts and distances per postcode from ArcGIS + filtered POIs.""" """Compute POI proximity counts and distances per postcode from ArcGIS + filtered POIs."""
import argparse import argparse
import re
import unicodedata
from pathlib import Path from pathlib import Path
import polars as pl import polars as pl
@ -15,9 +17,25 @@ POI_GROUPS_2KM = {
"groceries": ["Greengrocer", "Supermarket", "Convenience Store"], "groceries": ["Greengrocer", "Supermarket", "Convenience Store"],
} }
# Groups for which to compute distance to nearest POI (from filtered POIs) # Groups for which to compute distance to nearest POI (from filtered POIs).
# Keep `train_tube` for the existing backend feature; the individual POI
# distance filters below power the frontend dropdown.
DISTANCE_GROUPS = { DISTANCE_GROUPS = {
"train_tube": ["Tube station", "Rail station"], "train_tube": ["Tube station", "Rail station"],
"grocery_store": [
"Greengrocer",
"Supermarket",
"Convenience Store",
"Waitrose",
"Tesco",
],
"tube_station": ["Tube station"],
"rail_station": ["Rail station"],
"waitrose": ["Waitrose"],
"tesco": ["Tesco"],
"cafe": ["Café"],
"pub": ["Pub"],
"restaurant": ["Restaurant"],
} }
# OS Open Greenspace function types used for park counts and distance calculation. # OS Open Greenspace function types used for park counts and distance calculation.
@ -27,6 +45,69 @@ GREENSPACE_PARK_FUNCTIONS = {
"parks": ["Public Park Or Garden", "Playing Field", "Play Space"], "parks": ["Public Park Or Garden", "Playing Field", "Play Space"],
} }
GROCERY_DYNAMIC_FILTER_MIN_POIS = 100
DYNAMIC_FILTER_ALL_GROUPS = {"Public Transport", "Leisure"}
DYNAMIC_FILTER_COUNT_THRESHOLD_GROUPS = {"Groceries"}
def _poi_category_slug(category: str) -> str:
ascii_text = (
unicodedata.normalize("NFKD", category)
.encode("ascii", "ignore")
.decode("ascii")
.lower()
)
slug = re.sub(r"[^a-z0-9]+", "_", ascii_text).strip("_")
return slug or "poi"
def _build_poi_category_groups(
pois: pl.DataFrame,
) -> tuple[dict[str, list[str]], dict[str, str]]:
"""Build one proximity group for each POI category selected for filters."""
if "group" not in pois.columns:
raise ValueError("POI dataframe must include a 'group' column")
categories = (
pois.group_by("group", "category")
.len()
.filter(
pl.col("group").is_in(list(DYNAMIC_FILTER_ALL_GROUPS))
| (
pl.col("group").is_in(list(DYNAMIC_FILTER_COUNT_THRESHOLD_GROUPS))
& (pl.col("len") > GROCERY_DYNAMIC_FILTER_MIN_POIS)
)
)
.select("category")
.sort("category")
.to_series()
.to_list()
)
used_slugs: dict[str, int] = {}
groups: dict[str, list[str]] = {}
display_names: dict[str, str] = {}
for category in categories:
if not isinstance(category, str) or not category:
continue
base_slug = f"poi_{_poi_category_slug(category)}"
slug_count = used_slugs.get(base_slug, 0)
used_slugs[base_slug] = slug_count + 1
group_key = base_slug if slug_count == 0 else f"{base_slug}_{slug_count + 1}"
groups[group_key] = [category]
display_names[group_key] = category
return groups, display_names
def _dynamic_poi_metric_renames(display_names: dict[str, str]) -> dict[str, str]:
renames: dict[str, str] = {}
for group_key, category in display_names.items():
renames[f"{group_key}_nearest_km"] = f"Distance to nearest {category} POI (km)"
renames[f"{group_key}_2km"] = f"Number of {category} POIs within 2km"
renames[f"{group_key}_5km"] = f"Number of {category} POIs within 5km"
return renames
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -56,12 +137,35 @@ def main():
) )
pois = pl.read_parquet(args.pois) pois = pl.read_parquet(args.pois)
poi_category_groups, poi_display_names = _build_poi_category_groups(pois)
# Count amenity POIs within 2km # Count amenity POIs within 2km
counts_2km = count_pois_per_postcode( counts_2km = count_pois_per_postcode(
postcodes, pois, groups=POI_GROUPS_2KM, radius_km=2 postcodes, pois, groups=POI_GROUPS_2KM, radius_km=2
) )
# Dynamic POI filters: nearest distance plus counts within 2km and 5km for
# the selected public transport, grocery, and leisure categories.
dynamic_counts_2km = count_pois_per_postcode(
postcodes, pois, groups=poi_category_groups, radius_km=2
)
dynamic_counts_5km = count_pois_per_postcode(
postcodes, pois, groups=poi_category_groups, radius_km=5
)
dynamic_distances = min_distance_per_postcode(
postcodes, pois, groups=poi_category_groups
)
dynamic_renames = _dynamic_poi_metric_renames(poi_display_names)
dynamic_counts_2km = dynamic_counts_2km.rename(
{k: v for k, v in dynamic_renames.items() if k in dynamic_counts_2km.columns}
)
dynamic_counts_5km = dynamic_counts_5km.rename(
{k: v for k, v in dynamic_renames.items() if k in dynamic_counts_5km.columns}
)
dynamic_distances = dynamic_distances.rename(
{k: v for k, v in dynamic_renames.items() if k in dynamic_distances.columns}
)
# Distance to nearest train/tube station (from filtered POIs) # Distance to nearest train/tube station (from filtered POIs)
distances = min_distance_per_postcode(postcodes, pois, groups=DISTANCE_GROUPS) distances = min_distance_per_postcode(postcodes, pois, groups=DISTANCE_GROUPS)
@ -77,6 +181,9 @@ def main():
# Join all results on postcode # Join all results on postcode
result = ( result = (
counts_2km.join(distances, on="postcode") counts_2km.join(distances, on="postcode")
.join(dynamic_counts_2km, on="postcode")
.join(dynamic_counts_5km, on="postcode")
.join(dynamic_distances, on="postcode")
.join(park_counts_1km, on="postcode") .join(park_counts_1km, on="postcode")
.join(park_distances, on="postcode") .join(park_distances, on="postcode")
) )

View file

@ -0,0 +1,33 @@
import polars as pl
from pipeline.transform.merge import (
_is_dynamic_poi_metric_column,
_less_deprived_percentile_expr,
)
def test_less_deprived_percentile_expr_preserves_direction_and_nulls() -> None:
df = pl.DataFrame({"Income Score (rate)": [1.0, 2.0, 3.0, None]})
result = df.lazy().with_columns(
_less_deprived_percentile_expr("Income Score (rate)")
).collect()
assert result["Income Score (rate)"].to_list() == [100.0, 50.0, 0.0, None]
def test_less_deprived_percentile_expr_uses_exact_scale_endpoints() -> None:
df = pl.DataFrame({"Income Score (rate)": [1.0, 1.0, 2.0, 3.0, 3.0]})
result = df.lazy().with_columns(
_less_deprived_percentile_expr("Income Score (rate)")
).collect()
assert result["Income Score (rate)"].to_list() == [100.0, 100.0, 50.0, 0.0, 0.0]
def test_dynamic_poi_metric_columns_are_area_level() -> None:
assert _is_dynamic_poi_metric_column("Distance to nearest Cafe POI (km)")
assert _is_dynamic_poi_metric_column("Number of Cafe POIs within 2km")
assert _is_dynamic_poi_metric_column("Number of Cafe POIs within 5km")
assert not _is_dynamic_poi_metric_column("Number of restaurants within 2km")

View file

@ -0,0 +1,41 @@
import polars as pl
from pipeline.transform.poi_proximity import _build_poi_category_groups
def test_dynamic_poi_groups_include_requested_categories_only() -> None:
pois = pl.DataFrame(
{
"group": (
["Public Transport"] * 2
+ ["Leisure"] * 2
+ ["Groceries"] * 101
+ ["Groceries"] * 100
+ ["Education"] * 200
+ ["Health"] * 200
),
"category": (
["Rail station", "Bus stop"]
+ ["Café", "Restaurant"]
+ ["Tesco"] * 101
+ ["Waitrose"] * 100
+ ["School"] * 200
+ ["Pharmacy"] * 200
),
"lat": [51.5] * 605,
"lng": [-0.1] * 605,
}
)
groups, display_names = _build_poi_category_groups(pois)
assert set(display_names.values()) == {
"Bus stop",
"Café",
"Rail station",
"Restaurant",
"Tesco",
}
assert "poi_waitrose" not in groups
assert "poi_school" not in groups
assert "poi_pharmacy" not in groups

View file

@ -1128,12 +1128,18 @@ GROCERY_FASCIA_ICON_NAMES: dict[str, str] = {
def normalize_grocery_retailer(retailer: str | None) -> str: def normalize_grocery_retailer(retailer: str | None) -> str:
if retailer is None: if retailer is None:
return "" return ""
return GROCERY_RETAILER_DISPLAY_NAMES.get(retailer, retailer) display_name = GROCERY_RETAILER_DISPLAY_NAMES.get(retailer)
if display_name is None:
raise ValueError(f"Missing grocery retailer display name for {retailer!r}")
return display_name
def normalize_grocery_icon_category(fascia: str | None, retailer: str | None) -> str: def normalize_grocery_icon_category(fascia: str | None, retailer: str | None) -> str:
if fascia: if fascia:
return GROCERY_FASCIA_ICON_NAMES.get(fascia, normalize_grocery_retailer(fascia)) icon_name = GROCERY_FASCIA_ICON_NAMES.get(fascia)
if icon_name is None:
raise ValueError(f"Missing grocery fascia icon name for {fascia!r}")
return icon_name
return normalize_grocery_retailer(retailer) return normalize_grocery_retailer(retailer)

View file

@ -2,9 +2,12 @@
import numpy as np import numpy as np
import polars as pl import polars as pl
from scipy.spatial import cKDTree
from .haversine import haversine_km from .haversine import haversine_km
EARTH_RADIUS_KM = 6371.0088
def _build_poi_grid( def _build_poi_grid(
pois: pl.DataFrame, grid_size: float = 0.05 pois: pl.DataFrame, grid_size: float = 0.05
@ -49,6 +52,21 @@ def _get_nearby_indices(
return np.concatenate(nearby_indices) return np.concatenate(nearby_indices)
def _project_lat_lng_km(
lats: np.ndarray, lngs: np.ndarray, origin_lat: float
) -> np.ndarray:
"""Project WGS84 coordinates to local km coordinates for nearest-neighbour lookup."""
lat_rad = np.radians(lats)
lng_rad = np.radians(lngs)
origin_lat_rad = np.radians(origin_lat)
return np.column_stack(
(
EARTH_RADIUS_KM * lng_rad * np.cos(origin_lat_rad),
EARTH_RADIUS_KM * lat_rad,
)
)
def count_pois_per_postcode( def count_pois_per_postcode(
postcodes_df: pl.DataFrame, postcodes_df: pl.DataFrame,
pois: pl.DataFrame, pois: pl.DataFrame,
@ -136,7 +154,7 @@ def min_distance_per_postcode(
) -> pl.DataFrame: ) -> pl.DataFrame:
""" """
For each postcode, compute the distance (km) to the closest POI per group. For each postcode, compute the distance (km) to the closest POI per group.
Returns NaN where no POI of that group exists within the grid search range (~5.5km). Returns NaN where no POI of that group exists.
""" """
print("Computing minimum POI distances per postcode...") print("Computing minimum POI distances per postcode...")
@ -144,51 +162,84 @@ def min_distance_per_postcode(
n_pois = len(pois) n_pois = len(pois)
print(f" {n_postcodes:,} postcodes, {n_pois:,} POIs") print(f" {n_postcodes:,} postcodes, {n_pois:,} POIs")
grid_size = 0.05
print(" Building POI spatial grid...")
poi_lats, poi_lngs, poi_cats, poi_grid = _build_poi_grid(pois, grid_size)
print(f" POI grid has {len(poi_grid):,} occupied cells")
category_masks = {}
for group, categories in groups.items():
mask = np.isin(poi_cats, categories)
category_masks[group] = mask
print(f" {group}: {mask.sum():,} POIs")
pc_lats = postcodes_df["lat"].to_numpy() pc_lats = postcodes_df["lat"].to_numpy()
pc_lons = postcodes_df["lon"].to_numpy() pc_lons = postcodes_df["lon"].to_numpy()
pc_codes = postcodes_df["postcode"].to_list() pc_codes = postcodes_df["postcode"].to_list()
valid_pc_mask = np.isfinite(pc_lats) & np.isfinite(pc_lons)
valid_pc_indices = np.flatnonzero(valid_pc_mask)
result_min_dist = { result_min_dist = {
group: np.full(n_postcodes, np.nan, dtype=np.float32) for group in groups group: np.full(n_postcodes, np.nan, dtype=np.float32) for group in groups
} }
batch_size = 50000 if n_pois == 0 or len(valid_pc_indices) == 0:
n_batches = (n_postcodes + batch_size - 1) // batch_size print(" No valid postcode/POI coordinates; returning NaN distances")
print(f" Processing {n_postcodes:,} postcodes in {n_batches} batches...") return pl.DataFrame(
{
"postcode": pc_codes,
**{
f"{group}_nearest_km": values
for group, values in result_min_dist.items()
},
}
)
for batch_idx in range(n_batches): poi_lats = pois["lat"].to_numpy()
start_idx = batch_idx * batch_size poi_lngs = pois["lng"].to_numpy()
end_idx = min(start_idx + batch_size, n_postcodes) poi_cats = pois["category"].to_numpy()
valid_poi_mask = np.isfinite(poi_lats) & np.isfinite(poi_lngs)
origin_lat = float(np.nanmean(pc_lats[valid_pc_mask]))
query_xy = _project_lat_lng_km(
pc_lats[valid_pc_indices], pc_lons[valid_pc_indices], origin_lat
)
if batch_idx % 5 == 0: batch_size = 200_000
print( n_batches = (len(valid_pc_indices) + batch_size - 1) // batch_size
f" Batch {batch_idx + 1}/{n_batches}: postcodes {start_idx:,} - {end_idx:,}"
)
for i in range(start_idx, end_idx): for group, categories in groups.items():
nearby = _get_nearby_indices(pc_lats[i], pc_lons[i], poi_grid, grid_size) group_indices = np.flatnonzero(valid_poi_mask & np.isin(poi_cats, categories))
if nearby is None: print(f" {group}: {len(group_indices):,} POIs")
continue if len(group_indices) == 0:
continue
distances = haversine_km( poi_xy = _project_lat_lng_km(
poi_lats[nearby], poi_lngs[nearby], pc_lats[i], pc_lons[i] poi_lats[group_indices], poi_lngs[group_indices], origin_lat
) )
tree = cKDTree(poi_xy)
k = min(8, len(group_indices))
for group, cat_mask in category_masks.items(): for batch_idx in range(n_batches):
group_mask = cat_mask[nearby] start_idx = batch_idx * batch_size
if group_mask.any(): end_idx = min(start_idx + batch_size, len(valid_pc_indices))
result_min_dist[group][i] = distances[group_mask].min() batch_pc_indices = valid_pc_indices[start_idx:end_idx]
batch_xy = query_xy[start_idx:end_idx]
if batch_idx == 0 or (batch_idx + 1) % 5 == 0:
print(
f" Batch {batch_idx + 1}/{n_batches}: postcodes {start_idx:,} - {end_idx:,}"
)
_, nearest = tree.query(batch_xy, k=k)
nearest = np.asarray(nearest)
if k == 1:
candidate_indices = group_indices[nearest]
distances = haversine_km(
poi_lats[candidate_indices],
poi_lngs[candidate_indices],
pc_lats[batch_pc_indices],
pc_lons[batch_pc_indices],
)
else:
candidate_indices = group_indices[nearest]
distances = haversine_km(
poi_lats[candidate_indices],
poi_lngs[candidate_indices],
pc_lats[batch_pc_indices, None],
pc_lons[batch_pc_indices, None],
).min(axis=1)
result_min_dist[group][batch_pc_indices] = distances.astype(np.float32)
result_data = {"postcode": pc_codes} result_data = {"postcode": pc_codes}
for group in groups: for group in groups:

View file

@ -113,9 +113,9 @@ def test_min_distance_finds_nearest(postcodes, pois):
# Restaurant is co-located — distance ~0 # Restaurant is co-located — distance ~0
assert ec1a["restaurants_nearest_km"][0] < 0.01 assert ec1a["restaurants_nearest_km"][0] < 0.01
# Far-away postcode should have NaN (no POIs within grid range) # Far-away postcode should still get the global nearest distance.
zz99 = result.filter(pl.col("postcode") == "ZZ99 9ZZ") zz99 = result.filter(pl.col("postcode") == "ZZ99 9ZZ")
assert np.isnan(zz99["train_tube_nearest_km"][0]) assert zz99["train_tube_nearest_km"][0] > 300
def test_min_distance_no_pois_returns_nan(postcodes): def test_min_distance_no_pois_returns_nan(postcodes):

View file

@ -111,20 +111,23 @@ fi
# R5 writes .mapdb temp files next to OSM/GTFS files during network construction. # R5 writes .mapdb temp files next to OSM/GTFS files during network construction.
# Copy source data to a writable build dir to avoid polluting the originals. # Copy source data to a writable build dir to avoid polluting the originals.
mkdir -p "$NETWORK_DIR" mkdir -p "$NETWORK_DIR"
OSM_PBF="property-data/england-latest.osm.pbf"
TRANSIT_SRC="property-data/transit" TRANSIT_SRC="property-data/transit"
NETWORK_DATA_DIR="$TRANSIT_SRC" NETWORK_DATA_DIR="$NETWORK_DIR/build"
if [ ! -f "$NETWORK_DIR/network.dat" ]; then if [ ! -f "$NETWORK_DIR/network.dat" ]; then
BUILD_DIR="$NETWORK_DIR/build" BUILD_DIR="$NETWORK_DIR/build"
echo "--- No cached network — copying transit data to build dir ---" echo "--- No cached network — copying transit data to build dir ---"
mkdir -p "$BUILD_DIR" mkdir -p "$BUILD_DIR"
if ! cp "$TRANSIT_SRC"/raw/*.osm.pbf "$BUILD_DIR/" 2>/dev/null; then if [ ! -f "$OSM_PBF" ]; then
echo "Warning: no .osm.pbf files found in $TRANSIT_SRC/raw/" echo "Error: OSM PBF not found at $OSM_PBF"
echo "Download it from https://download.geofabrik.de/europe/united-kingdom/england-latest.osm.pbf"
exit 1
fi fi
cp "$OSM_PBF" "$BUILD_DIR/"
if ! cp "$TRANSIT_SRC"/*.zip "$BUILD_DIR/" 2>/dev/null; then if ! cp "$TRANSIT_SRC"/*.zip "$BUILD_DIR/" 2>/dev/null; then
echo "Warning: no .zip files found in $TRANSIT_SRC/" echo "Warning: no GTFS .zip files found in $TRANSIT_SRC/ — transit routing will be unavailable"
fi fi
NETWORK_DATA_DIR="$BUILD_DIR"
fi fi
# --- Step 5: Run batch --- # --- Step 5: Run batch ---

View file

@ -1,5 +1,5 @@
use crate::consts::NAN_U16; use crate::consts::NAN_U16;
use crate::data::QuantRef; use crate::data::{PostcodePoiMetrics, QuantRef};
/// Optional per-enum-value distribution tracking for a single feature. /// Optional per-enum-value distribution tracking for a single feature.
/// Counts how many rows have each enum value (by raw u16 index). /// Counts how many rows have each enum value (by raw u16 index).
@ -21,6 +21,69 @@ pub struct Aggregator {
pub enum_dist: Option<EnumDist>, pub enum_dist: Option<EnumDist>,
} }
/// Accumulator for postcode-level POI metrics stored outside `feature_data`.
/// Only constructed when a request selects POI metric fields.
pub struct PoiAggregator {
pub mins: Box<[f32]>,
pub maxs: Box<[f32]>,
pub sums: Box<[f64]>,
pub counts: Box<[u32]>,
}
impl PoiAggregator {
pub fn new(num_features: usize) -> Self {
Self {
mins: vec![f32::INFINITY; num_features].into_boxed_slice(),
maxs: vec![f32::NEG_INFINITY; num_features].into_boxed_slice(),
sums: vec![0.0f64; num_features].into_boxed_slice(),
counts: vec![0u32; num_features].into_boxed_slice(),
}
}
#[inline]
pub fn add_row_selective(
&mut self,
poi_metrics: &PostcodePoiMetrics,
row: usize,
indices: &[usize],
) {
let Some(metric_row) = poi_metrics.metric_row_for_property(row) else {
return;
};
for &metric_idx in indices {
let raw = poi_metrics.raw_for_metric_row(metric_row, metric_idx);
if raw == NAN_U16 {
continue;
}
let value = poi_metrics.decode_raw(metric_idx, raw);
if value < self.mins[metric_idx] {
self.mins[metric_idx] = value;
}
if value > self.maxs[metric_idx] {
self.maxs[metric_idx] = value;
}
self.sums[metric_idx] += value as f64;
self.counts[metric_idx] += 1;
}
}
pub fn merge(&mut self, other: &PoiAggregator) {
for i in 0..self.counts.len() {
if other.counts[i] == 0 {
continue;
}
if other.mins[i] < self.mins[i] {
self.mins[i] = other.mins[i];
}
if other.maxs[i] > self.maxs[i] {
self.maxs[i] = other.maxs[i];
}
self.sums[i] += other.sums[i];
self.counts[i] += other.counts[i];
}
}
}
/// Configuration for enum distribution tracking, passed to Aggregator::new. /// Configuration for enum distribution tracking, passed to Aggregator::new.
/// (feature_index, number_of_enum_values) /// (feature_index, number_of_enum_values)
pub type EnumDistConfig = Option<(usize, usize)>; pub type EnumDistConfig = Option<(usize, usize)>;

View file

@ -0,0 +1,807 @@
use std::sync::LazyLock;
use std::time::{SystemTime, UNIX_EPOCH};
use anyhow::{anyhow, Context};
use serde_json::Value;
use tokio::sync::Mutex;
use tracing::warn;
use crate::auth::PocketBaseUser;
use crate::pocketbase::get_superuser_token;
use crate::pocketbase_locks::acquire_pocketbase_lock;
use crate::routes::pricing::{count_licensed_users, price_for_count};
use crate::state::AppState;
pub const CHECKOUT_CURRENCY: &str = "gbp";
const CHECKOUT_SESSION_TTL_SECS: u64 = 31 * 60;
const CHECKOUT_PRODUCT_NAME: &str = "Perfect Postcodes Lifetime License";
const CHECKOUT_COLLECTION: &str = "checkout_sessions";
const CHECKOUT_PRICING_LOCK_NAME: &str = "checkout:pricing";
const CHECKOUT_PRICING_LOCK_TTL_SECS: u64 = 5 * 60;
const REFERRAL_DISCOUNT_PERCENT: u64 = 30;
static CHECKOUT_RESERVATION_LOCK: LazyLock<Mutex<()>> = LazyLock::new(|| Mutex::new(()));
pub enum CheckoutStart {
Free,
Stripe { url: String },
}
pub enum CheckoutCompletion {
Grant(VerifiedCheckout),
AlreadyHandled,
Rejected(String),
}
pub struct VerifiedCheckout {
pub reservation_id: String,
pub user_id: String,
pub paid_amount_pence: u64,
pub referral_invite_id: String,
}
#[derive(Debug)]
struct PendingCheckout {
id: String,
user_id: String,
stripe_session_id: String,
checkout_url: String,
amount_pence: u64,
expected_total_pence: u64,
currency: String,
referral_invite_id: String,
status: String,
}
pub fn now_unix_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
pub async fn start_license_checkout(
state: &AppState,
user: &PocketBaseUser,
success_url: &str,
cancel_url: &str,
discount_coupon_id: Option<&str>,
referral_invite_id: Option<&str>,
) -> anyhow::Result<CheckoutStart> {
let _guard = CHECKOUT_RESERVATION_LOCK.lock().await;
let pricing_lock = acquire_pocketbase_lock(
state,
CHECKOUT_PRICING_LOCK_NAME,
CHECKOUT_PRICING_LOCK_TTL_SECS,
)
.await?;
let result = start_license_checkout_locked(
state,
user,
success_url,
cancel_url,
discount_coupon_id,
referral_invite_id,
)
.await;
if let Err(err) = pricing_lock.release().await {
warn!("Failed to release checkout pricing lock: {err}");
}
result
}
async fn start_license_checkout_locked(
state: &AppState,
user: &PocketBaseUser,
success_url: &str,
cancel_url: &str,
discount_coupon_id: Option<&str>,
referral_invite_id: Option<&str>,
) -> anyhow::Result<CheckoutStart> {
let now = now_unix_secs();
expire_stale_pending_checkouts(state, now).await?;
if let Some(existing) = find_active_checkout_for_user(
state,
&user.id,
discount_coupon_id.unwrap_or_default(),
referral_invite_id.unwrap_or_default(),
now,
)
.await?
{
if !existing.checkout_url.is_empty() {
return Ok(CheckoutStart::Stripe {
url: existing.checkout_url,
});
}
if let Err(err) = mark_checkout_status(state, &existing.id, "failed").await {
warn!(
reservation_id = %existing.id,
"Failed to fail incomplete checkout reservation: {err}"
);
}
}
let licensed_count = count_licensed_users(state).await?;
let pending_count = count_active_pending_checkouts(state, now).await?;
let price_pence = price_for_count(licensed_count + pending_count);
if price_pence == 0 {
grant_license(state, &user.id).await?;
return Ok(CheckoutStart::Free);
}
let expires_at_unix = now + CHECKOUT_SESSION_TTL_SECS;
let expected_total_pence = expected_total_for_checkout(price_pence, discount_coupon_id);
let reservation_id = create_pending_checkout(
state,
PendingCheckoutInput {
user_id: &user.id,
amount_pence: price_pence,
expected_total_pence,
currency: CHECKOUT_CURRENCY,
discount_coupon_id: discount_coupon_id.unwrap_or_default(),
referral_invite_id: referral_invite_id.unwrap_or_default(),
expires_at_unix,
},
)
.await?;
let stripe_result = create_stripe_session(
state,
user,
&reservation_id,
price_pence,
success_url,
cancel_url,
expires_at_unix,
discount_coupon_id,
)
.await;
let (stripe_session_id, url) = match stripe_result {
Ok(session) => session,
Err(err) => {
if let Err(mark_err) = mark_checkout_status(state, &reservation_id, "failed").await {
warn!(
reservation_id,
"Failed to mark checkout reservation failed: {mark_err}"
);
}
return Err(err);
}
};
if let Err(err) = attach_stripe_session(state, &reservation_id, &stripe_session_id, &url).await
{
if let Err(mark_err) = mark_checkout_status(state, &reservation_id, "failed").await {
warn!(
reservation_id,
"Failed to mark checkout reservation failed: {mark_err}"
);
}
return Err(err);
}
Ok(CheckoutStart::Stripe { url })
}
pub async fn verify_checkout_completion(
state: &AppState,
session: &Value,
) -> anyhow::Result<CheckoutCompletion> {
let session_id = match session["id"].as_str() {
Some(id) if is_safe_stripe_session_id(id) => id,
_ => {
return Ok(CheckoutCompletion::Rejected(
"missing or invalid session id".into(),
))
}
};
let checkout = match find_checkout_by_stripe_session(state, session_id).await? {
Some(checkout) => checkout,
None => {
return Ok(CheckoutCompletion::Rejected(
"checkout session has no reservation".into(),
))
}
};
if checkout.status == "completed" {
return Ok(CheckoutCompletion::AlreadyHandled);
}
if checkout.status != "pending" && checkout.status != "expired" {
return Ok(CheckoutCompletion::Rejected(format!(
"checkout reservation is {}",
checkout.status
)));
}
if checkout.stripe_session_id != session_id {
mark_checkout_status(state, &checkout.id, "invalid").await?;
return Ok(CheckoutCompletion::Rejected(
"checkout reservation session id mismatch".into(),
));
}
let client_reference_id = session["client_reference_id"].as_str().unwrap_or_default();
if client_reference_id != checkout.user_id {
mark_checkout_status(state, &checkout.id, "invalid").await?;
return Ok(CheckoutCompletion::Rejected(
"checkout client_reference_id mismatch".into(),
));
}
let payment_status = session["payment_status"].as_str().unwrap_or_default();
if payment_status != "paid" {
return Ok(CheckoutCompletion::Rejected(format!(
"checkout payment_status is {payment_status}"
)));
}
let currency = session["currency"]
.as_str()
.unwrap_or_default()
.to_ascii_lowercase();
if currency != checkout.currency {
mark_checkout_status(state, &checkout.id, "invalid").await?;
return Ok(CheckoutCompletion::Rejected(
"checkout currency mismatch".into(),
));
}
let amount_subtotal = match number_field(session, "amount_subtotal") {
Some(amount) => amount,
None => {
mark_checkout_status(state, &checkout.id, "invalid").await?;
return Ok(CheckoutCompletion::Rejected(
"checkout amount_subtotal missing".into(),
));
}
};
if amount_subtotal != checkout.amount_pence {
mark_checkout_status(state, &checkout.id, "invalid").await?;
return Ok(CheckoutCompletion::Rejected(
"checkout amount_subtotal mismatch".into(),
));
}
let amount_total = match number_field(session, "amount_total") {
Some(amount) => amount,
None => {
mark_checkout_status(state, &checkout.id, "invalid").await?;
return Ok(CheckoutCompletion::Rejected(
"checkout amount_total missing".into(),
));
}
};
if amount_total != checkout.expected_total_pence {
mark_checkout_status(state, &checkout.id, "invalid").await?;
return Ok(CheckoutCompletion::Rejected(
"checkout amount_total mismatch".into(),
));
}
Ok(CheckoutCompletion::Grant(VerifiedCheckout {
reservation_id: checkout.id,
user_id: checkout.user_id,
paid_amount_pence: amount_total,
referral_invite_id: checkout.referral_invite_id,
}))
}
pub async fn mark_checkout_completed(
state: &AppState,
reservation_id: &str,
paid_amount_pence: u64,
) -> anyhow::Result<()> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"status": "completed",
"paid_amount_pence": paid_amount_pence,
"completed_at_unix": now_unix_secs().to_string(),
}))
.send()
.await?;
ensure_success(resp)
.await
.context("PocketBase checkout completion update failed")
}
pub async fn grant_license(state: &AppState, user_id: &str) -> anyhow::Result<()> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({ "subscription": "licensed" }))
.send()
.await?;
ensure_success(resp)
.await
.context("PocketBase license update failed")?;
state.token_cache.invalidate_by_user_id(user_id);
Ok(())
}
pub async fn mark_referral_invite_used(
state: &AppState,
invite_id: &str,
user_id: &str,
) -> anyhow::Result<()> {
if invite_id.is_empty() {
return Ok(());
}
if !is_safe_pocketbase_id(invite_id) || !is_safe_pocketbase_id(user_id) {
return Err(anyhow!("invalid PocketBase id"));
}
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let existing_used_by = fetch_invite_used_by(state, pb_url, &token, invite_id).await?;
if existing_used_by == user_id {
return Ok(());
}
if !existing_used_by.is_empty() {
return Err(anyhow!("referral invite already used by another account"));
}
let url = format!("{pb_url}/api/collections/invites/records/{invite_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"used_by_id": user_id,
"used_at": now_unix_secs().to_string(),
}))
.send()
.await?;
ensure_success(resp)
.await
.context("PocketBase invite usage update failed")
}
async fn fetch_invite_used_by(
state: &AppState,
pb_url: &str,
token: &str,
invite_id: &str,
) -> anyhow::Result<String> {
let url = format!("{pb_url}/api/collections/invites/records/{invite_id}");
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
ensure_success_ref(&resp).await?;
let body: Value = resp.json().await?;
Ok(body["used_by_id"].as_str().unwrap_or_default().to_string())
}
pub async fn active_referral_checkout_user(
state: &AppState,
invite_id: &str,
) -> anyhow::Result<Option<String>> {
if !is_safe_pocketbase_id(invite_id) {
return Err(anyhow!("invalid PocketBase invite id"));
}
let now = now_unix_secs();
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let filter = format!(
"status=\"pending\" && expires_at_unix>={now} && referral_invite_id=\"{}\"",
invite_id
);
let url = format!(
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
urlencoding::encode(&filter)
);
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
ensure_success_ref(&resp).await?;
let body: Value = resp.json().await?;
Ok(body["items"]
.as_array()
.and_then(|items| items.first())
.and_then(|item| item["user"].as_str())
.map(str::to_string))
}
async fn count_active_pending_checkouts(state: &AppState, now: u64) -> anyhow::Result<u64> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let filter = format!("status=\"pending\" && expires_at_unix>={now}");
let url = format!(
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
urlencoding::encode(&filter)
);
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
ensure_success_ref(&resp).await?;
let body: Value = resp.json().await?;
Ok(body["totalItems"].as_u64().unwrap_or(0))
}
async fn find_active_checkout_for_user(
state: &AppState,
user_id: &str,
discount_coupon_id: &str,
referral_invite_id: &str,
now: u64,
) -> anyhow::Result<Option<PendingCheckout>> {
if !is_safe_pocketbase_id(user_id) {
return Err(anyhow!("invalid PocketBase user id"));
}
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let filter = format!(
"status=\"pending\" && expires_at_unix>={now} && user=\"{}\" && discount_coupon_id=\"{}\" && referral_invite_id=\"{}\"",
user_id, discount_coupon_id, referral_invite_id
);
let url = format!(
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
urlencoding::encode(&filter)
);
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
ensure_success_ref(&resp).await?;
let body: Value = resp.json().await?;
let item = body["items"]
.as_array()
.and_then(|items| items.first())
.cloned();
item.map(parse_pending_checkout).transpose()
}
async fn expire_stale_pending_checkouts(state: &AppState, now: u64) -> anyhow::Result<()> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let filter = format!("status=\"pending\" && expires_at_unix<{now}");
let url = format!(
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=50",
urlencoding::encode(&filter)
);
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
ensure_success_ref(&resp).await?;
let body: Value = resp.json().await?;
let Some(items) = body["items"].as_array() else {
return Ok(());
};
for id in items.iter().filter_map(|item| item["id"].as_str()) {
if let Err(err) = mark_checkout_status(state, id, "expired").await {
warn!(
reservation_id = id,
"Failed to expire checkout reservation: {err}"
);
}
}
Ok(())
}
struct PendingCheckoutInput<'a> {
user_id: &'a str,
amount_pence: u64,
expected_total_pence: u64,
currency: &'a str,
discount_coupon_id: &'a str,
referral_invite_id: &'a str,
expires_at_unix: u64,
}
async fn create_pending_checkout(
state: &AppState,
input: PendingCheckoutInput<'_>,
) -> anyhow::Result<String> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records");
let resp = state
.http_client
.post(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"user": input.user_id,
"stripe_session_id": "",
"checkout_url": "",
"amount_pence": input.amount_pence,
"expected_total_pence": input.expected_total_pence,
"currency": input.currency,
"discount_coupon_id": input.discount_coupon_id,
"referral_invite_id": input.referral_invite_id,
"status": "pending",
"expires_at_unix": input.expires_at_unix,
"paid_amount_pence": 0,
"completed_at_unix": "",
}))
.send()
.await?;
ensure_success_ref(&resp).await?;
let body: Value = resp.json().await?;
body["id"]
.as_str()
.map(str::to_string)
.ok_or_else(|| anyhow!("PocketBase checkout reservation missing id"))
}
#[allow(clippy::too_many_arguments)]
async fn create_stripe_session(
state: &AppState,
user: &PocketBaseUser,
reservation_id: &str,
price_pence: u64,
success_url: &str,
cancel_url: &str,
expires_at_unix: u64,
discount_coupon_id: Option<&str>,
) -> anyhow::Result<(String, String)> {
let mut form_params = vec![
("mode", "payment".to_string()),
("payment_method_types[0]", "card".to_string()),
(
"line_items[0][price_data][unit_amount]",
price_pence.to_string(),
),
(
"line_items[0][price_data][currency]",
CHECKOUT_CURRENCY.to_string(),
),
(
"line_items[0][price_data][product_data][name]",
CHECKOUT_PRODUCT_NAME.to_string(),
),
("line_items[0][quantity]", "1".to_string()),
("success_url", success_url.to_string()),
("cancel_url", cancel_url.to_string()),
("expires_at", expires_at_unix.to_string()),
("client_reference_id", user.id.clone()),
("customer_email", user.email.clone()),
("metadata[pending_checkout_id]", reservation_id.to_string()),
("metadata[expected_amount_pence]", price_pence.to_string()),
(
"metadata[expected_total_pence]",
expected_total_for_checkout(price_pence, discount_coupon_id).to_string(),
),
("metadata[expected_currency]", CHECKOUT_CURRENCY.to_string()),
];
if let Some(coupon_id) = discount_coupon_id.filter(|id| !id.is_empty()) {
form_params.push(("discounts[0][coupon]", coupon_id.to_string()));
form_params.push(("metadata[discount_coupon_id]", coupon_id.to_string()));
}
let resp = state
.http_client
.post("https://api.stripe.com/v1/checkout/sessions")
.basic_auth(&state.stripe_secret_key, None::<&str>)
.form(&form_params)
.send()
.await
.context("Stripe checkout request failed")?;
ensure_success_ref(&resp)
.await
.context("Stripe checkout failed")?;
let body: Value = resp
.json()
.await
.context("Failed to parse Stripe response")?;
let session_id = body["id"]
.as_str()
.filter(|id| is_safe_stripe_session_id(id))
.map(str::to_string)
.ok_or_else(|| anyhow!("Stripe session missing valid id"))?;
let url = body["url"]
.as_str()
.map(str::to_string)
.filter(|url| !url.is_empty())
.ok_or_else(|| anyhow!("Stripe session missing URL"))?;
Ok((session_id, url))
}
async fn attach_stripe_session(
state: &AppState,
reservation_id: &str,
stripe_session_id: &str,
checkout_url: &str,
) -> anyhow::Result<()> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"stripe_session_id": stripe_session_id,
"checkout_url": checkout_url,
}))
.send()
.await?;
ensure_success(resp)
.await
.context("PocketBase checkout session attach failed")
}
async fn mark_checkout_status(
state: &AppState,
reservation_id: &str,
status: &str,
) -> anyhow::Result<()> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({ "status": status }))
.send()
.await?;
ensure_success(resp)
.await
.with_context(|| format!("PocketBase checkout status update failed for {reservation_id}"))
}
async fn find_checkout_by_stripe_session(
state: &AppState,
stripe_session_id: &str,
) -> anyhow::Result<Option<PendingCheckout>> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let filter = format!("stripe_session_id=\"{}\"", stripe_session_id);
let url = format!(
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
urlencoding::encode(&filter)
);
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
ensure_success_ref(&resp).await?;
let body: Value = resp.json().await?;
let item = body["items"]
.as_array()
.and_then(|items| items.first())
.cloned();
item.map(parse_pending_checkout).transpose()
}
fn parse_pending_checkout(item: Value) -> anyhow::Result<PendingCheckout> {
Ok(PendingCheckout {
id: item["id"]
.as_str()
.ok_or_else(|| anyhow!("checkout reservation missing id"))?
.to_string(),
user_id: item["user"]
.as_str()
.ok_or_else(|| anyhow!("checkout reservation missing user"))?
.to_string(),
stripe_session_id: item["stripe_session_id"]
.as_str()
.unwrap_or_default()
.to_string(),
checkout_url: item["checkout_url"]
.as_str()
.unwrap_or_default()
.to_string(),
amount_pence: number_field(&item, "amount_pence")
.ok_or_else(|| anyhow!("checkout reservation missing amount_pence"))?,
expected_total_pence: number_field(&item, "expected_total_pence")
.ok_or_else(|| anyhow!("checkout reservation missing expected_total_pence"))?,
currency: item["currency"]
.as_str()
.unwrap_or_default()
.to_ascii_lowercase(),
referral_invite_id: item["referral_invite_id"]
.as_str()
.unwrap_or_default()
.to_string(),
status: item["status"].as_str().unwrap_or_default().to_string(),
})
}
fn expected_total_for_checkout(amount_pence: u64, discount_coupon_id: Option<&str>) -> u64 {
if discount_coupon_id.is_some_and(|id| !id.is_empty()) {
return ((amount_pence * (100 - REFERRAL_DISCOUNT_PERCENT)) / 100).max(1);
}
amount_pence
}
fn number_field(value: &Value, field: &str) -> Option<u64> {
value[field].as_u64().or_else(|| {
value[field]
.as_f64()
.filter(|n| n.is_finite() && *n >= 0.0 && n.fract() == 0.0)
.map(|n| n as u64)
})
}
fn is_safe_stripe_session_id(id: &str) -> bool {
!id.is_empty()
&& id.len() <= 128
&& id
.bytes()
.all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'-')
}
fn is_safe_pocketbase_id(id: &str) -> bool {
!id.is_empty() && id.len() <= 32 && id.bytes().all(|b| b.is_ascii_alphanumeric())
}
async fn ensure_success(resp: reqwest::Response) -> anyhow::Result<()> {
if resp.status().is_success() {
return Ok(());
}
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
Err(anyhow!("upstream returned {status}: {text}"))
}
async fn ensure_success_ref(resp: &reqwest::Response) -> anyhow::Result<()> {
if resp.status().is_success() {
return Ok(());
}
Err(anyhow!("upstream returned {}", resp.status()))
}

View file

@ -97,7 +97,7 @@ fn build_search_text(name: &str, place_type: &str) -> String {
} }
if place_type == "station" { if place_type == "station" {
let suffix_aliases: [(&str, &[&str]); 5] = [ let suffix_aliases: [(&str, &[&str]); 6] = [
( (
" tube station", " tube station",
&[" underground station", " station", " tube", " underground"], &[" underground station", " station", " tube", " underground"],
@ -118,6 +118,7 @@ fn build_search_text(name: &str, place_type: &str) -> String {
" elizabeth line station", " elizabeth line station",
&[" station", " elizabeth line", " crossrail station"], &[" station", " elizabeth line", " crossrail station"],
), ),
(" dlr station", &[" station", " dlr"]),
]; ];
for (suffix, replacements) in suffix_aliases { for (suffix, replacements) in suffix_aliases {
@ -139,10 +140,15 @@ fn extract_str_col(df: &DataFrame, name: &str) -> anyhow::Result<Vec<String>> {
let string_column = column let string_column = column
.str() .str()
.with_context(|| format!("Column '{name}' is not a string column"))?; .with_context(|| format!("Column '{name}' is not a string column"))?;
Ok(string_column string_column
.into_iter() .into_iter()
.map(|value| value.unwrap_or("").to_string()) .enumerate()
.collect()) .map(|(row, value)| {
value
.map(ToString::to_string)
.with_context(|| format!("Column '{name}' has null at row {row}"))
})
.collect()
} }
fn extract_f32_col(df: &DataFrame, name: &str) -> anyhow::Result<Vec<f32>> { fn extract_f32_col(df: &DataFrame, name: &str) -> anyhow::Result<Vec<f32>> {
@ -155,33 +161,37 @@ fn extract_f32_col(df: &DataFrame, name: &str) -> anyhow::Result<Vec<f32>> {
let float_column = cast let float_column = cast
.f32() .f32()
.with_context(|| format!("Column '{name}' is not a float32 column"))?; .with_context(|| format!("Column '{name}' is not a float32 column"))?;
Ok(float_column float_column
.into_iter() .into_iter()
.map(|value| value.unwrap_or(0.0)) .enumerate()
.collect()) .map(|(row, value)| value.with_context(|| format!("Column '{name}' has null at row {row}")))
.collect()
} }
fn extract_bool_col_or_default( fn extract_bool_col(df: &DataFrame, name: &str) -> anyhow::Result<Vec<bool>> {
df: &DataFrame, let column = df
name: &str, .column(name)
default_value: bool, .with_context(|| format!("Missing column '{name}' in places data"))?;
) -> anyhow::Result<Vec<bool>> {
let Ok(column) = df.column(name) else {
return Ok(vec![default_value; df.height()]);
};
let bool_column = column let bool_column = column
.bool() .bool()
.with_context(|| format!("Column '{name}' is not a boolean column"))?; .with_context(|| format!("Column '{name}' is not a boolean column"))?;
Ok(bool_column bool_column
.into_iter() .into_iter()
.map(|value| value.unwrap_or(default_value)) .enumerate()
.collect()) .map(|(row, value)| value.with_context(|| format!("Column '{name}' has null at row {row}")))
.collect()
} }
impl PlaceData { impl PlaceData {
pub fn load(parquet_path: &Path) -> anyhow::Result<Self> { pub fn load(parquet_path: &Path) -> anyhow::Result<Self> {
super::run_polars_io(|| Self::load_inner(parquet_path))
}
fn load_inner(parquet_path: &Path) -> anyhow::Result<Self> {
info!("Loading place data from {:?}...", parquet_path); info!("Loading place data from {:?}...", parquet_path);
let parquet_path = PlRefPath::try_from_path(parquet_path)
.context("Failed to normalize places parquet path")?;
let df = LazyFrame::scan_parquet(parquet_path, Default::default()) let df = LazyFrame::scan_parquet(parquet_path, Default::default())
.context("Failed to scan places parquet")? .context("Failed to scan places parquet")?
.collect() .collect()
@ -210,7 +220,7 @@ impl PlaceData {
let type_rank_vec: Vec<u8> = place_type_raw.iter().map(|pt| type_rank(pt)).collect(); let type_rank_vec: Vec<u8> = place_type_raw.iter().map(|pt| type_rank(pt)).collect();
let place_type = InternedColumn::build(&place_type_raw); let place_type = InternedColumn::build(&place_type_raw);
let travel_destination = if df.column("travel_destination").is_ok() { let travel_destination = if df.column("travel_destination").is_ok() {
extract_bool_col_or_default(&df, "travel_destination", true)? extract_bool_col(&df, "travel_destination")?
} else { } else {
place_type_raw place_type_raw
.iter() .iter()
@ -296,6 +306,7 @@ mod tests {
assert!(build_search_text("King's Cross tube station", "station") assert!(build_search_text("King's Cross tube station", "station")
.contains("kings cross underground")); .contains("kings cross underground"));
assert!(build_search_text("St Albans", "city").contains("saint albans")); assert!(build_search_text("St Albans", "city").contains("saint albans"));
assert!(build_search_text("Shadwell DLR station", "station").contains("shadwell station"));
} }
#[test] #[test]

View file

@ -5,6 +5,7 @@ use anyhow::{bail, Context};
use polars::frame::DataFrame; use polars::frame::DataFrame;
use polars::lazy::frame::LazyFrame; use polars::lazy::frame::LazyFrame;
use polars::prelude::*; use polars::prelude::*;
use rustc_hash::FxHashSet;
use serde::Serialize; use serde::Serialize;
use tracing::info; use tracing::info;
@ -17,6 +18,94 @@ pub struct POICategoryGroup {
pub categories: Vec<String>, pub categories: Vec<String>,
} }
const GROCERY_DASHBOARD_CATEGORIES: &[&str] = &[
"Supermarket",
"Convenience Store",
"Bakery",
"Greengrocer",
"Aldi",
"Amazon",
"Asda",
"Booths",
"Budgens",
"Centra",
"Co-op",
"COOK",
"Costco",
"Dunnes Stores",
"Farmfoods",
"Heron Foods",
"Iceland",
"Lidl",
"Makro",
"M&S",
"Morrisons",
"Planet Organic",
"Sainsbury's",
"Spar",
"Tesco",
"The Food Warehouse",
"Waitrose",
"Whole Foods Market",
];
const DASHBOARD_POI_GROUPS: &[(&str, &[&str])] = &[
(
"Public Transport",
&[
"Rail station",
"Tube station",
"Bus station",
"Bus stop",
"Airport",
],
),
("Groceries", GROCERY_DASHBOARD_CATEGORIES),
("Food & Drink", &["Café", "Restaurant", "Pub", "Fast Food"]),
("Green Space", &["Park", "Playground"]),
("Education", &["School"]),
(
"Health",
&["GP Surgery", "Pharmacy", "Dentist", "Hospital & Clinic"],
),
(
"Leisure",
&[
"Gym & Fitness",
"Sports Centre",
"Cinema",
"Theatre",
"Library",
],
),
(
"Practical",
&["Post Office", "Bank", "EV Charging", "Fuel Station"],
),
];
fn add_category_filter_index(
category_values: &[String],
category: &str,
selected: &mut FxHashSet<u16>,
) {
if let Some(pos) = category_values.iter().position(|value| value == category) {
selected.insert(pos as u16);
}
}
pub fn resolve_poi_category_filter(category_values: &[String], categories: &str) -> FxHashSet<u16> {
let mut selected = FxHashSet::default();
for part in categories.split(',') {
let category = part.trim();
if category.is_empty() {
continue;
}
add_category_filter_index(category_values, category, &mut selected);
}
selected
}
pub struct POIData { pub struct POIData {
/// Contiguous buffer holding all POI ID strings end-to-end. /// Contiguous buffer holding all POI ID strings end-to-end.
id_buffer: String, id_buffer: String,
@ -53,13 +142,18 @@ fn extract_str_col(df: &DataFrame, name: &str) -> anyhow::Result<Vec<String>> {
let string_column = column let string_column = column
.str() .str()
.with_context(|| format!("Column '{name}' is not a string column"))?; .with_context(|| format!("Column '{name}' is not a string column"))?;
Ok(string_column string_column
.into_iter() .into_iter()
.map(|value| value.unwrap_or("").to_string()) .enumerate()
.collect()) .map(|(row, value)| {
value
.map(ToString::to_string)
.with_context(|| format!("Column '{name}' has null at row {row}"))
})
.collect()
} }
fn extract_f32_col(df: &DataFrame, name: &str, default: f32) -> anyhow::Result<Vec<f32>> { fn extract_f32_col(df: &DataFrame, name: &str) -> anyhow::Result<Vec<f32>> {
let column = df let column = df
.column(name) .column(name)
.with_context(|| format!("Missing column '{name}' in POI data"))?; .with_context(|| format!("Missing column '{name}' in POI data"))?;
@ -69,16 +163,23 @@ fn extract_f32_col(df: &DataFrame, name: &str, default: f32) -> anyhow::Result<V
let float_column = cast let float_column = cast
.f32() .f32()
.with_context(|| format!("Column '{name}' is not a float32 column"))?; .with_context(|| format!("Column '{name}' is not a float32 column"))?;
Ok(float_column float_column
.into_iter() .into_iter()
.map(|value| value.unwrap_or(default)) .enumerate()
.collect()) .map(|(row, value)| value.with_context(|| format!("Column '{name}' has null at row {row}")))
.collect()
} }
impl POIData { impl POIData {
pub fn load(parquet_path: &Path) -> anyhow::Result<Self> { pub fn load(parquet_path: &Path) -> anyhow::Result<Self> {
super::run_polars_io(|| Self::load_inner(parquet_path))
}
fn load_inner(parquet_path: &Path) -> anyhow::Result<Self> {
info!("Loading POI data from {:?}...", parquet_path); info!("Loading POI data from {:?}...", parquet_path);
let parquet_path = PlRefPath::try_from_path(parquet_path)
.context("Failed to normalize POI parquet path")?;
let df = LazyFrame::scan_parquet(parquet_path, Default::default()) let df = LazyFrame::scan_parquet(parquet_path, Default::default())
.context("Failed to scan POI parquet")? .context("Failed to scan POI parquet")?
.collect() .collect()
@ -91,18 +192,10 @@ impl POIData {
let name = extract_str_col(&df, "name")?; let name = extract_str_col(&df, "name")?;
let category_raw = extract_str_col(&df, "category")?; let category_raw = extract_str_col(&df, "category")?;
let group_raw = extract_str_col(&df, "group")?; let group_raw = extract_str_col(&df, "group")?;
let lat = extract_f32_col(&df, "lat", 0.0)?; let lat = extract_f32_col(&df, "lat")?;
let lng = extract_f32_col(&df, "lng", 0.0)?; let lng = extract_f32_col(&df, "lng")?;
let emoji_raw = extract_str_col(&df, "emoji")?; let emoji_raw = extract_str_col(&df, "emoji")?;
let icon_category_raw = if df let icon_category_raw = extract_str_col(&df, "icon_category")?;
.get_column_names()
.iter()
.any(|name| name.as_str() == "icon_category")
{
extract_str_col(&df, "icon_category")?
} else {
category_raw.clone()
};
// Pack POI IDs into a contiguous buffer // Pack POI IDs into a contiguous buffer
let total_id_bytes: usize = id_raw.iter().map(|s| s.len()).sum(); let total_id_bytes: usize = id_raw.iter().map(|s| s.len()).sum();
@ -152,7 +245,7 @@ impl POIData {
}) })
} }
/// Build category groups from the loaded POI data, validated against POI_GROUP_ORDER. /// Build dashboard category groups from every category present in the loaded POI data.
pub fn category_groups(&self) -> anyhow::Result<Vec<POICategoryGroup>> { pub fn category_groups(&self) -> anyhow::Result<Vec<POICategoryGroup>> {
let mut group_cats: HashMap<String, HashSet<String>> = HashMap::new(); let mut group_cats: HashMap<String, HashSet<String>> = HashMap::new();
let num_pois = self.category.indices.len(); let num_pois = self.category.indices.len();
@ -174,18 +267,78 @@ impl POIData {
); );
} }
POI_GROUP_ORDER let preferred_order: HashMap<&str, HashMap<&str, usize>> = DASHBOARD_POI_GROUPS
.iter() .iter()
.map(|group_name| { .map(|(group, categories)| {
let name = group_name.to_string(); (
let mut categories: Vec<String> = group_cats *group,
.remove(&name) categories
.context("POI group validated but missing from map")? .iter()
.into_iter() .enumerate()
.collect(); .map(|(idx, category)| (*category, idx))
categories.sort(); .collect(),
Ok(POICategoryGroup { name, categories }) )
}) })
.collect() .collect();
let groups: Vec<POICategoryGroup> = POI_GROUP_ORDER
.iter()
.filter_map(|group_name| {
let mut categories: Vec<String> = group_cats
.get(*group_name)
.map(|categories| categories.iter().cloned().collect())
.unwrap_or_default();
if categories.is_empty() {
return None;
}
let group_order = preferred_order.get(*group_name);
categories.sort_by(|a, b| {
let a_order = group_order.and_then(|order| order.get(a.as_str())).copied();
let b_order = group_order.and_then(|order| order.get(b.as_str())).copied();
match (a_order, b_order) {
(Some(left), Some(right)) => left.cmp(&right),
(Some(_), None) => std::cmp::Ordering::Less,
(None, Some(_)) => std::cmp::Ordering::Greater,
(None, None) => a.cmp(b),
}
});
Some(POICategoryGroup {
name: (*group_name).to_string(),
categories,
})
})
.collect();
Ok(groups)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn category_filter_matches_exact_present_categories() {
let values = vec![
"Supermarket".to_string(),
"Tesco".to_string(),
"Aldi".to_string(),
"Rail station".to_string(),
];
let selected = resolve_poi_category_filter(&values, "Supermarket,Rail station");
assert!(selected.contains(&0));
assert!(selected.contains(&3));
assert_eq!(selected.len(), 2);
}
#[test]
fn unknown_category_filter_matches_nothing() {
let values = vec!["Supermarket".to_string()];
let selected = resolve_poi_category_filter(&values, "Unknown");
assert!(selected.is_empty());
} }
} }

View file

@ -195,33 +195,38 @@ impl PostcodeData {
// Extract all outer rings from the geometry // Extract all outer rings from the geometry
let rings: Vec<Vec<[f32; 2]>> = match feature.geometry { let rings: Vec<Vec<[f32; 2]>> = match feature.geometry {
Geometry::Polygon { coordinates } => coordinates Geometry::Polygon { coordinates } => {
.first() let ring = coordinates.first().with_context(|| {
.map(|ring| { format!("Postcode '{postcode}' polygon has no outer ring")
vec![ring })?;
.iter() vec![ring
.map(|[lon, lat]| [*lon as f32, *lat as f32]) .iter()
.collect()] .map(|[lon, lat]| [*lon as f32, *lat as f32])
}) .collect()]
.unwrap_or_default(), }
Geometry::MultiPolygon { coordinates } => coordinates Geometry::MultiPolygon { coordinates } => coordinates
.iter() .iter()
.filter_map(|poly| { .enumerate()
poly.first().map(|ring| { .map(|(idx, poly)| {
ring.iter() let ring = poly.first().with_context(|| {
.map(|[lon, lat]| [*lon as f32, *lat as f32]) format!(
.collect() "Postcode '{postcode}' multipolygon part {idx} has no outer ring"
}) )
})?;
Ok(ring
.iter()
.map(|[lon, lat]| [*lon as f32, *lat as f32])
.collect())
}) })
.collect(), .collect::<anyhow::Result<Vec<_>>>()?,
}; };
// Compute centroid across all vertices from all rings // Compute centroid across all vertices from all rings
let total_vertices: usize = rings.iter().map(|ring| ring.len()).sum(); let total_vertices: usize = rings.iter().map(|ring| ring.len()).sum();
let centroid = if total_vertices == 0 { if total_vertices == 0 {
tracing::warn!(postcode = %postcode, "Postcode polygon has zero vertices, defaulting centroid to (0,0)"); anyhow::bail!("Postcode '{postcode}' polygon has zero vertices");
(0.0, 0.0) }
} else { let centroid = {
let mut sum_lat: f32 = 0.0; let mut sum_lat: f32 = 0.0;
let mut sum_lon: f32 = 0.0; let mut sum_lon: f32 = 0.0;
for ring in &rings { for ring in &rings {

View file

@ -14,6 +14,7 @@ const ADDRESS_SEARCH_CANDIDATE_LIMIT: usize = 50_000;
const ADDRESS_SEARCH_MAX_POSTINGS_PER_TOKEN: usize = 250_000; const ADDRESS_SEARCH_MAX_POSTINGS_PER_TOKEN: usize = 250_000;
const ADDRESS_SEARCH_PREFIX_MIN_LEN: usize = 4; const ADDRESS_SEARCH_PREFIX_MIN_LEN: usize = 4;
const ADDRESS_SEARCH_PREFIX_MAX_LEN: usize = 8; const ADDRESS_SEARCH_PREFIX_MAX_LEN: usize = 8;
const NO_POI_METRIC_ROW: u32 = u32::MAX;
fn is_numeric_dtype(dtype: &DataType) -> bool { fn is_numeric_dtype(dtype: &DataType) -> bool {
matches!( matches!(
@ -495,6 +496,187 @@ impl QuantRef<'_> {
} }
} }
pub struct PostcodePoiMetrics {
pub feature_names: Vec<String>,
pub name_to_index: FxHashMap<String, usize>,
/// Metric-major storage: columns[metric_idx][postcode_metric_idx].
pub columns: Vec<Vec<u16>>,
pub feature_stats: Vec<FeatureStats>,
/// Per-property row lookup into the postcode metric table.
row_to_metric_idx: Vec<u32>,
dequant_a: Vec<f32>,
quant_min: Vec<f32>,
quant_range: Vec<f32>,
}
impl PostcodePoiMetrics {
fn empty(row_count: usize) -> Self {
Self {
feature_names: Vec::new(),
name_to_index: FxHashMap::default(),
columns: Vec::new(),
feature_stats: Vec::new(),
row_to_metric_idx: vec![NO_POI_METRIC_ROW; row_count],
dequant_a: Vec::new(),
quant_min: Vec::new(),
quant_range: Vec::new(),
}
}
fn from_postcode_df(df: &DataFrame, feature_names: Vec<String>) -> anyhow::Result<Self> {
if feature_names.is_empty() {
return Ok(Self::empty(0));
}
tracing::info!(
metrics = feature_names.len(),
postcodes = df.height(),
"Building postcode POI metric side table"
);
let col_major: Vec<Vec<f32>> = feature_names
.par_iter()
.map(|name| {
let column = df
.column(name.as_str())
.with_context(|| format!("Missing POI metric column '{name}'"))?;
column_to_f32_vec(column)
})
.collect::<anyhow::Result<Vec<_>>>()?;
let feature_stats: Vec<FeatureStats> = col_major
.par_iter()
.enumerate()
.map(|(metric_idx, vals)| {
let name = feature_names[metric_idx].as_str();
let bounds = features::bounds_for(name)
.with_context(|| format!("No bounds config for POI metric '{name}'"))?;
Ok(compute_feature_stats(
vals,
&bounds,
features::has_integer_bins(name),
))
})
.collect::<anyhow::Result<Vec<_>>>()?;
let mut quant_min = Vec::with_capacity(feature_names.len());
let mut quant_range = Vec::with_capacity(feature_names.len());
for (metric_idx, stats) in feature_stats.iter().enumerate() {
let (min, max) = match features::bounds_for(feature_names[metric_idx].as_str()) {
Some(Bounds::Fixed { min, max }) => (min, max),
_ => (stats.histogram.min, stats.histogram.max),
};
quant_min.push(min);
quant_range.push(if max > min { max - min } else { 0.0 });
}
let dequant_a: Vec<f32> = quant_range
.iter()
.map(|&range| {
if range > 0.0 {
range / QUANT_SCALE
} else {
0.0
}
})
.collect();
let columns: Vec<Vec<u16>> = col_major
.par_iter()
.enumerate()
.map(|(metric_idx, vals)| {
let range = quant_range[metric_idx];
let min = quant_min[metric_idx];
vals.iter()
.map(|&value| {
if !value.is_finite() {
NAN_U16
} else if range > 0.0 {
let normalized = (value - min) / range;
(normalized * QUANT_SCALE).round().clamp(0.0, QUANT_SCALE) as u16
} else {
0
}
})
.collect()
})
.collect();
let name_to_index = feature_names
.iter()
.enumerate()
.map(|(idx, name)| (name.clone(), idx))
.collect();
Ok(Self {
feature_names,
name_to_index,
columns,
feature_stats,
row_to_metric_idx: Vec::new(),
dequant_a,
quant_min,
quant_range,
})
}
fn set_row_mapping(&mut self, row_to_metric_idx: Vec<u32>) {
self.row_to_metric_idx = row_to_metric_idx;
}
pub fn is_empty(&self) -> bool {
self.feature_names.is_empty()
}
pub fn num_features(&self) -> usize {
self.feature_names.len()
}
pub fn quant_ref(&self) -> QuantRef<'_> {
QuantRef {
dequant_a: &self.dequant_a,
quant_min: &self.quant_min,
quant_range: &self.quant_range,
num_numeric: self.feature_names.len(),
}
}
#[inline]
pub fn metric_row_for_property(&self, row: usize) -> Option<usize> {
self.row_to_metric_idx
.get(row)
.copied()
.filter(|&idx| idx != NO_POI_METRIC_ROW)
.map(|idx| idx as usize)
}
#[inline]
pub fn raw_for_metric_row(&self, metric_row: usize, metric_idx: usize) -> u16 {
self.columns[metric_idx][metric_row]
}
#[inline]
pub fn raw_for_property_row(&self, row: usize, metric_idx: usize) -> u16 {
let Some(metric_row) = self.metric_row_for_property(row) else {
return NAN_U16;
};
self.raw_for_metric_row(metric_row, metric_idx)
}
#[inline]
pub fn decode_raw(&self, metric_idx: usize, raw: u16) -> f32 {
if raw == NAN_U16 {
f32::NAN
} else {
raw as f32 * self.dequant_a[metric_idx] + self.quant_min[metric_idx]
}
}
#[inline]
pub fn get_for_property_row(&self, row: usize, metric_idx: usize) -> f32 {
self.decode_raw(metric_idx, self.raw_for_property_row(row, metric_idx))
}
}
pub struct PropertyData { pub struct PropertyData {
pub lat: Vec<f32>, pub lat: Vec<f32>,
pub lon: Vec<f32>, pub lon: Vec<f32>,
@ -514,6 +696,7 @@ pub struct PropertyData {
/// Per-feature: max - min (for encoding filter bounds). /// Per-feature: max - min (for encoding filter bounds).
quant_range: Vec<f32>, quant_range: Vec<f32>,
pub feature_stats: Vec<FeatureStats>, pub feature_stats: Vec<FeatureStats>,
pub poi_metrics: PostcodePoiMetrics,
/// Unquantized last sale price used by the price-history chart. /// Unquantized last sale price used by the price-history chart.
last_known_price_raw: Vec<f32>, last_known_price_raw: Vec<f32>,
/// Contiguous buffer holding all address strings end-to-end. /// Contiguous buffer holding all address strings end-to-end.
@ -1055,19 +1238,54 @@ pub fn precompute_h3(lat: &[f32], lon: &[f32]) -> anyhow::Result<Vec<u64>> {
impl PropertyData { impl PropertyData {
pub fn load(properties_path: &Path, postcode_features_path: &Path) -> anyhow::Result<Self> { pub fn load(properties_path: &Path, postcode_features_path: &Path) -> anyhow::Result<Self> {
super::run_polars_io(|| Self::load_inner(properties_path, postcode_features_path))
}
fn load_inner(properties_path: &Path, postcode_features_path: &Path) -> anyhow::Result<Self> {
// Load postcode.parquet // Load postcode.parquet
tracing::info!( tracing::info!(
"Loading postcode features from {:?}", "Loading postcode features from {:?}",
postcode_features_path postcode_features_path
); );
let postcode_features_path = PlRefPath::try_from_path(postcode_features_path)
.context("Failed to normalize postcode parquet path")?;
let postcode_df = LazyFrame::scan_parquet(postcode_features_path, Default::default()) let postcode_df = LazyFrame::scan_parquet(postcode_features_path, Default::default())
.context("Failed to scan postcode parquet")? .context("Failed to scan postcode parquet")?
.collect() .collect()
.context("Failed to read postcode parquet")?; .context("Failed to read postcode parquet")?;
tracing::info!(rows = postcode_df.height(), "Postcode features loaded"); tracing::info!(rows = postcode_df.height(), "Postcode features loaded");
let mut poi_metric_names: Vec<String> = postcode_df
.get_column_names()
.iter()
.map(|name| name.as_str())
.filter(|&name| features::is_dynamic_poi_feature(name))
.map(str::to_string)
.collect();
poi_metric_names.sort_by_key(|name| features::dynamic_poi_feature_sort_key(name));
let poi_metric_by_postcode: FxHashMap<String, u32> = if poi_metric_names.is_empty() {
FxHashMap::default()
} else {
let postcode_column = postcode_df
.column("Postcode")
.context("Postcode feature parquet missing 'Postcode' column")?
.str()
.context("'Postcode' column in postcode feature parquet is not a string")?;
postcode_column
.into_iter()
.enumerate()
.filter_map(|(idx, postcode)| {
postcode.map(|postcode| (postcode.to_string(), idx as u32))
})
.collect()
};
let mut poi_metrics = PostcodePoiMetrics::from_postcode_df(&postcode_df, poi_metric_names)?;
// Load properties.parquet and join with postcode data for lat/lon + area features // Load properties.parquet and join with postcode data for lat/lon + area features
tracing::info!("Loading properties from {:?}", properties_path); tracing::info!("Loading properties from {:?}", properties_path);
let properties_path = PlRefPath::try_from_path(properties_path)
.context("Failed to normalize properties parquet path")?;
let properties_lf = LazyFrame::scan_parquet(properties_path, Default::default()) let properties_lf = LazyFrame::scan_parquet(properties_path, Default::default())
.context("Failed to scan properties parquet")?; .context("Failed to scan properties parquet")?;
let combined = properties_lf let combined = properties_lf
@ -1082,14 +1300,20 @@ impl PropertyData {
let total_rows = combined.height(); let total_rows = combined.height();
tracing::info!(rows = total_rows, "Properties joined with postcodes"); tracing::info!(rows = total_rows, "Properties joined with postcodes");
// Get configured feature/enum names in config order // Get configured feature/enum names in config order. Dynamic POI
let numeric_names = features::all_numeric_feature_names(); // metrics live in a postcode-level side table so they do not widen the
// hot row-major property feature matrix.
let configured_numeric_names = features::all_numeric_feature_names();
let enum_names = features::all_enum_feature_names(); let enum_names = features::all_enum_feature_names();
let schema = combined.schema(); let schema = combined.schema();
let numeric_names: Vec<String> = configured_numeric_names
.iter()
.map(|name| (*name).to_string())
.collect();
for name in &numeric_names { for name in &numeric_names {
match schema.get(name) { match schema.get(name.as_str()) {
Some(dtype) if is_numeric_dtype(dtype) => {} Some(dtype) if is_numeric_dtype(dtype) => {}
Some(dtype) => bail!( Some(dtype) => bail!(
"Configured numeric feature '{}' has non-numeric type {:?}", "Configured numeric feature '{}' has non-numeric type {:?}",
@ -1120,8 +1344,8 @@ impl PropertyData {
// Combine numeric and enum feature names (numeric first, then enum) // Combine numeric and enum feature names (numeric first, then enum)
let feature_names: Vec<String> = numeric_names let feature_names: Vec<String> = numeric_names
.iter() .iter()
.chain(enum_names.iter())
.map(|name| name.to_string()) .map(|name| name.to_string())
.chain(enum_names.iter().map(|name| name.to_string()))
.collect(); .collect();
let num_features = feature_names.len(); let num_features = feature_names.len();
let num_numeric = numeric_names.len(); let num_numeric = numeric_names.len();
@ -1138,16 +1362,16 @@ impl PropertyData {
select_exprs.push(col("lon").cast(DataType::Float32)); select_exprs.push(col("lon").cast(DataType::Float32));
// Select numeric features as Float32 (datetime columns → fractional year) // Select numeric features as Float32 (datetime columns → fractional year)
for &name in &numeric_names { for name in &numeric_names {
if is_datetime_dtype(schema.get(name).unwrap()) { if is_datetime_dtype(schema.get(name.as_str()).unwrap()) {
select_exprs.push( select_exprs.push(
(col(name).dt().year().cast(DataType::Float32) (col(name.as_str()).dt().year().cast(DataType::Float32)
+ (col(name).dt().month().cast(DataType::Float32) - lit(1.0f32)) + (col(name.as_str()).dt().month().cast(DataType::Float32) - lit(1.0f32))
/ lit(12.0f32)) / lit(12.0f32))
.alias(name), .alias(name.as_str()),
); );
} else { } else {
select_exprs.push(col(name).cast(DataType::Float32)); select_exprs.push(col(name.as_str()).cast(DataType::Float32));
} }
} }
@ -1233,7 +1457,7 @@ impl PropertyData {
.par_iter() .par_iter()
.map(|name| { .map(|name| {
let column = df let column = df
.column(name) .column(name.as_str())
.with_context(|| format!("Missing feature column '{name}'"))?; .with_context(|| format!("Missing feature column '{name}'"))?;
column_to_f32_vec(column) column_to_f32_vec(column)
}) })
@ -1244,10 +1468,10 @@ impl PropertyData {
.par_iter() .par_iter()
.enumerate() .enumerate()
.map(|(feat_index, vals)| { .map(|(feat_index, vals)| {
let name = numeric_names[feat_index]; let name = numeric_names[feat_index].as_str();
let bounds = features::bounds_for(name) let bounds = features::bounds_for(name)
.with_context(|| format!("No bounds config for feature '{}'", name))?; .with_context(|| format!("No bounds config for feature '{}'", name))?;
let stats = compute_feature_stats(vals, bounds, features::has_integer_bins(name)); let stats = compute_feature_stats(vals, &bounds, features::has_integer_bins(name));
tracing::debug!( tracing::debug!(
feature = %name, feature = %name,
slider_min = format_args!("{:.2}", stats.slider_min), slider_min = format_args!("{:.2}", stats.slider_min),
@ -1268,8 +1492,8 @@ impl PropertyData {
let mut quant_min = Vec::with_capacity(num_features); let mut quant_min = Vec::with_capacity(num_features);
let mut quant_range = Vec::with_capacity(num_features); let mut quant_range = Vec::with_capacity(num_features);
for (feat_idx, stats) in numeric_feature_stats.iter().enumerate() { for (feat_idx, stats) in numeric_feature_stats.iter().enumerate() {
let (min, max) = match features::bounds_for(numeric_names[feat_idx]) { let (min, max) = match features::bounds_for(numeric_names[feat_idx].as_str()) {
Some(Bounds::Fixed { min, max }) => (*min, *max), Some(Bounds::Fixed { min, max }) => (min, max),
_ => (stats.histogram.min, stats.histogram.max), _ => (stats.histogram.min, stats.histogram.max),
}; };
quant_min.push(min); quant_min.push(min);
@ -1284,10 +1508,15 @@ impl PropertyData {
let string_column = column let string_column = column
.str() .str()
.with_context(|| format!("Column '{name}' is not a string column"))?; .with_context(|| format!("Column '{name}' is not a string column"))?;
Ok(string_column string_column
.into_iter() .into_iter()
.map(|value| value.unwrap_or("").to_string()) .enumerate()
.collect()) .map(|(row, value)| {
value
.map(ToString::to_string)
.with_context(|| format!("Required column '{name}' has null at row {row}"))
})
.collect()
}; };
let address_raw = extract_string_col(&df, "Address per Property Register")?; let address_raw = extract_string_col(&df, "Address per Property Register")?;
@ -1325,18 +1554,18 @@ impl PropertyData {
// enum_col_major: Vec<(values_list, encoded_as_f32)> // enum_col_major: Vec<(values_list, encoded_as_f32)>
let enum_col_major: Vec<(Vec<String>, Vec<f32>)> = enum_names let enum_col_major: Vec<(Vec<String>, Vec<f32>)> = enum_names
.par_iter() .par_iter()
.filter_map(|&name| { .map(|&name| -> anyhow::Result<(Vec<String>, Vec<f32>)> {
let column_data = df.column(name).ok()?; let column_data = df
let string_column = column_data.str().ok()?; .column(name)
.with_context(|| format!("Required enum column '{name}' not found"))?;
let string_column = column_data
.str()
.with_context(|| format!("Enum column '{name}' is not a string column"))?;
let unique_set: std::collections::HashSet<String> = string_column let unique_set: std::collections::HashSet<String> = string_column
.into_iter() .into_iter()
.filter_map(|value| { .filter_map(|value| {
let text = value.unwrap_or(""); let text = value?.trim();
if text.is_empty() { (!text.is_empty()).then(|| text.to_string())
None
} else {
Some(text.to_string())
}
}) })
.collect(); .collect();
@ -1373,20 +1602,22 @@ impl PropertyData {
let encoded: Vec<f32> = string_column let encoded: Vec<f32> = string_column
.into_iter() .into_iter()
.map(|value| { .enumerate()
let text = value.unwrap_or(""); .map(|(row, value)| {
if text.is_empty() { let Some(text) = value.map(str::trim).filter(|text| !text.is_empty())
f32::NAN else {
} else { return Ok(f32::NAN);
*value_to_idx.get(text).unwrap_or(&f32::NAN) };
} value_to_idx.get(text).copied().with_context(|| {
format!("Enum column '{name}' has unknown value '{text}' at row {row}")
})
}) })
.collect(); .collect::<anyhow::Result<Vec<_>>>()?;
tracing::debug!(column = %name, unique_values = unique.len(), "Enum feature encoded as f32"); tracing::debug!(column = %name, unique_values = unique.len(), "Enum feature encoded as f32");
Some((unique, encoded)) Ok((unique, encoded))
}) })
.collect(); .collect::<anyhow::Result<Vec<_>>>()?;
// Extract is_approx_build_date: 0.0 = exact, anything else (1.0/NaN) = approximate // Extract is_approx_build_date: 0.0 = exact, anything else (1.0/NaN) = approximate
let is_approx_build_date_raw: Vec<bool> = if has_approx_col { let is_approx_build_date_raw: Vec<bool> = if has_approx_col {
@ -1487,13 +1718,13 @@ impl PropertyData {
.collect(); .collect();
let last_known_price_raw: Vec<f32> = numeric_names let last_known_price_raw: Vec<f32> = numeric_names
.iter() .iter()
.position(|&name| name == "Last known price") .position(|name| name == "Last known price")
.map(|price_idx| { .map(|price_idx| {
perm.iter() perm.iter()
.map(|&perm_index| numeric_col_major[price_idx][perm_index as usize]) .map(|&perm_index| numeric_col_major[price_idx][perm_index as usize])
.collect() .collect()
}) })
.unwrap_or_else(|| vec![f32::NAN; row_count]); .context("Required numeric column 'Last known price' not configured")?;
// Build contiguous address buffer and address search index (permuted) // Build contiguous address buffer and address search index (permuted)
tracing::info!("Building interned strings"); tracing::info!("Building interned strings");
@ -1561,6 +1792,20 @@ impl PropertyData {
} }
let postcode_interner = postcode_rodeo.into_reader(); let postcode_interner = postcode_rodeo.into_reader();
let row_to_poi_metric_idx: Vec<u32> = if poi_metrics.is_empty() {
vec![NO_POI_METRIC_ROW; row_count]
} else {
perm.iter()
.map(|&old_row| {
poi_metric_by_postcode
.get(postcode_raw[old_row as usize].as_str())
.copied()
.unwrap_or(NO_POI_METRIC_ROW)
})
.collect()
};
poi_metrics.set_row_mapping(row_to_poi_metric_idx);
// Pack is_approx_build_date into a bitvec (8 bools per byte) // Pack is_approx_build_date into a bitvec (8 bools per byte)
let num_bytes = row_count.div_ceil(8); let num_bytes = row_count.div_ceil(8);
let mut approx_build_date_bits = vec![0u8; num_bytes]; let mut approx_build_date_bits = vec![0u8; num_bytes];
@ -1697,6 +1942,7 @@ impl PropertyData {
quant_min, quant_min,
quant_range, quant_range,
feature_stats, feature_stats,
poi_metrics,
last_known_price_raw, last_known_price_raw,
address_buffer, address_buffer,
address_offsets, address_offsets,

View file

@ -5,6 +5,7 @@ use std::sync::Arc;
use anyhow::Context; use anyhow::Context;
use parking_lot::Mutex; use parking_lot::Mutex;
use polars::lazy::frame::LazyFrame; use polars::lazy::frame::LazyFrame;
use polars::prelude::PlRefPath;
use rustc_hash::{FxHashMap, FxHashSet}; use rustc_hash::{FxHashMap, FxHashSet};
use tracing::info; use tracing::info;
@ -155,15 +156,23 @@ impl TravelTimeStore {
/// Returns a cached or freshly-loaded postcode → travel_minutes mapping. /// Returns a cached or freshly-loaded postcode → travel_minutes mapping.
pub fn get(&self, mode: &str, slug: &str) -> anyhow::Result<TravelData> { pub fn get(&self, mode: &str, slug: &str) -> anyhow::Result<TravelData> {
let key = (mode.to_string(), slug.to_string()); let key = (mode.to_string(), slug.to_string());
if let Some(data) = self.get_cached(&key) {
// Check cache first return Ok(data);
{
let mut cache = self.cache.lock();
if let Some(data) = cache.get(&key) {
return Ok(data);
}
} }
super::run_polars_io(|| self.load_uncached(key))
}
fn get_cached(&self, key: &(String, String)) -> Option<TravelData> {
let mut cache = self.cache.lock();
cache.get(key)
}
fn load_uncached(&self, key: (String, String)) -> anyhow::Result<TravelData> {
if let Some(data) = self.get_cached(&key) {
return Ok(data);
}
let (mode, slug) = &key;
// Resolve slug to actual filename (may have numeric prefix). // Resolve slug to actual filename (may have numeric prefix).
// Reject unknown slugs rather than falling back to raw input to prevent path traversal. // Reject unknown slugs rather than falling back to raw input to prevent path traversal.
let file_stem = self let file_stem = self
@ -175,7 +184,9 @@ impl TravelTimeStore {
.join(mode) .join(mode)
.join(format!("{}.parquet", file_stem)); .join(format!("{}.parquet", file_stem));
let df = LazyFrame::scan_parquet(&path, Default::default()) let parquet_path = PlRefPath::try_from_path(&path)
.with_context(|| format!("Failed to normalize path: {}", path.display()))?;
let df = LazyFrame::scan_parquet(parquet_path, Default::default())
.with_context(|| format!("Failed to scan: {}", path.display()))? .with_context(|| format!("Failed to scan: {}", path.display()))?
.collect() .collect()
.with_context(|| format!("Failed to read: {}", path.display()))?; .with_context(|| format!("Failed to read: {}", path.display()))?;

View file

@ -1,6 +1,7 @@
//! Static feature configuration. Every numeric and enum column in wide.parquet //! Static feature configuration. Every numeric and enum column in wide.parquet
//! must be declared here. Unknown columns cause a startup panic. //! must be declared here. Unknown columns cause a startup panic.
#[derive(Clone, Copy)]
pub enum Bounds { pub enum Bounds {
/// Fixed min/max values for the slider /// Fixed min/max values for the slider
Fixed { min: f32, max: f32 }, Fixed { min: f32, max: f32 },
@ -61,6 +62,26 @@ pub struct FeatureGroup {
} }
pub static FEATURE_GROUPS: &[FeatureGroup] = &[ pub static FEATURE_GROUPS: &[FeatureGroup] = &[
FeatureGroup {
name: "Transport",
features: &[
Feature::Numeric(FeatureConfig {
name: "Distance to nearest train or tube station (km)",
bounds: Bounds::Percentile {
low: 2.0,
high: 98.0,
},
step: 0.1,
description: "Distance to the closest train or tube station",
detail: "Straight-line distance in kilometres from the postcode to the nearest rail station or Tube/metro/tram stop.",
source: "naptan",
prefix: "",
suffix: " km",
raw: false,
absolute: false,
}),
],
},
FeatureGroup { FeatureGroup {
name: "Properties", name: "Properties",
features: &[ features: &[
@ -78,6 +99,21 @@ pub static FEATURE_GROUPS: &[FeatureGroup] = &[
detail: "From HM Land Registry Price Paid data. Freehold means you own the building and the land it stands on. Leasehold means you own the building but not the land: you have a lease from the freeholder for a set number of years.", detail: "From HM Land Registry Price Paid data. Freehold means you own the building and the land it stands on. Leasehold means you own the building but not the land: you have a lease from the freeholder for a set number of years.",
source: "price-paid", source: "price-paid",
}), }),
Feature::Numeric(FeatureConfig {
name: "Estimated current price",
bounds: Bounds::Fixed {
min: 0.0,
max: 2_500_000.0,
},
step: 10000.0,
description: "Modelled estimate of the current property value",
detail: "Based on the last sale price, local repeat-sales price movement, and nearby recently sold properties. The repeat-sales index is tracked by postcode sector and property type, with smoothing and neighbour blending where data is sparse. Recent sales stay close to the recorded price; older sales depend more on the model.",
source: "price-paid",
prefix: "£",
suffix: "",
raw: false,
absolute: true,
}),
Feature::Numeric(FeatureConfig { Feature::Numeric(FeatureConfig {
name: "Last known price", name: "Last known price",
bounds: Bounds::Fixed { bounds: Bounds::Fixed {
@ -94,19 +130,19 @@ pub static FEATURE_GROUPS: &[FeatureGroup] = &[
absolute: true, absolute: true,
}), }),
Feature::Numeric(FeatureConfig { Feature::Numeric(FeatureConfig {
name: "Estimated current price", name: "Est. price per sqm",
bounds: Bounds::Fixed { bounds: Bounds::Percentile {
min: 0.0, low: 0.0,
max: 2_500_000.0, high: 98.0,
}, },
step: 10000.0, step: 100.0,
description: "Inflation-adjusted estimate of the current property value", description: "Estimated current price divided by total floor area",
detail: "Based on the last sale price, adjusted for local price changes over time using a repeat-sales index (tracked per postcode sector and property type). If post-sale improvements are detected from EPC records, a renovation premium is added. Recent sales will be close to the original price; older sales are adjusted more.", detail: "Calculated by dividing the modelled estimated current price by the total floor area from the EPC certificate. Provides a more up-to-date price-per-area comparison than the historical sale price per sqm.",
source: "price-paid", source: "price-paid",
prefix: "£", prefix: "£",
suffix: "", suffix: "",
raw: false, raw: false,
absolute: true, absolute: false,
}), }),
Feature::Numeric(FeatureConfig { Feature::Numeric(FeatureConfig {
name: "Price per sqm", name: "Price per sqm",
@ -123,21 +159,6 @@ pub static FEATURE_GROUPS: &[FeatureGroup] = &[
raw: false, raw: false,
absolute: false, absolute: false,
}), }),
Feature::Numeric(FeatureConfig {
name: "Est. price per sqm",
bounds: Bounds::Percentile {
low: 0.0,
high: 98.0,
},
step: 100.0,
description: "Estimated current price divided by total floor area",
detail: "Calculated by dividing the inflation-adjusted estimated current price (including any renovation premium) by the total floor area from the EPC certificate. Provides a more up-to-date price-per-area comparison than the historical sale price per sqm.",
source: "price-paid",
prefix: "£",
suffix: "",
raw: false,
absolute: false,
}),
Feature::Numeric(FeatureConfig { Feature::Numeric(FeatureConfig {
name: "Estimated monthly rent", name: "Estimated monthly rent",
bounds: Bounds::Percentile { low: 2.0, high: 98.0 }, bounds: Bounds::Percentile { low: 2.0, high: 98.0 },
@ -248,26 +269,6 @@ pub static FEATURE_GROUPS: &[FeatureGroup] = &[
}), }),
], ],
}, },
FeatureGroup {
name: "Transport",
features: &[
Feature::Numeric(FeatureConfig {
name: "Distance to nearest train or tube station (km)",
bounds: Bounds::Percentile {
low: 2.0,
high: 98.0,
},
step: 0.1,
description: "Distance to the closest train or tube station",
detail: "Straight-line distance in kilometres from the postcode to the nearest rail station or Tube/metro/tram stop.",
source: "naptan",
prefix: "",
suffix: " km",
raw: false,
absolute: false,
}),
],
},
FeatureGroup { FeatureGroup {
name: "Education", name: "Education",
features: &[ features: &[
@ -393,18 +394,18 @@ pub static FEATURE_GROUPS: &[FeatureGroup] = &[
}), }),
Feature::Numeric(FeatureConfig { Feature::Numeric(FeatureConfig {
name: "Education, Skills and Training Score", name: "Education, Skills and Training Score",
bounds: Bounds::Percentile { bounds: Bounds::Fixed {
low: 2.0, min: 0.0,
high: 98.0, max: 100.0,
}, },
step: 0.1, step: 1.0,
description: "Education quality score for the local area (higher = better)", description: "Education and skills deprivation percentile (higher = less deprived)",
detail: "From the English Indices of Deprivation (inverted so higher = better). Covers school attainment, entry to higher education, adult qualifications, and English language proficiency. Higher scores indicate less deprivation.", detail: "From the English Indices of Deprivation, converted to a national percentile where 0% is most deprived and 100% is least deprived. Covers school attainment, entry to higher education, adult qualifications, and English language proficiency.",
source: "iod", source: "iod",
prefix: "", prefix: "",
suffix: "", suffix: "%",
raw: false, raw: true,
absolute: false, absolute: true,
}), }),
], ],
}, },
@ -413,72 +414,78 @@ pub static FEATURE_GROUPS: &[FeatureGroup] = &[
features: &[ features: &[
Feature::Numeric(FeatureConfig { Feature::Numeric(FeatureConfig {
name: "Income Score", name: "Income Score",
bounds: Bounds::Fixed { min: 0.0, max: 1.0 }, bounds: Bounds::Fixed {
step: 0.01, min: 0.0,
description: "Income deprivation rate, inverted (higher = less deprived)", max: 100.0,
detail: "From the English Indices of Deprivation (inverted so higher = better). Higher values indicate less income deprivation. Based on Income Support, income-based Jobseeker's Allowance, income-based Employment and Support Allowance, Pension Credit, Working Tax Credit and Child Tax Credit, Universal Credit, and asylum seekers.", },
step: 1.0,
description: "Income deprivation percentile (higher = less deprived)",
detail: "From the English Indices of Deprivation, converted to a national percentile where 0% is most income deprived and 100% is least income deprived. Based on Income Support, income-based Jobseeker's Allowance, income-based Employment and Support Allowance, Pension Credit, Working Tax Credit and Child Tax Credit, Universal Credit, and asylum seekers.",
source: "iod", source: "iod",
prefix: "", prefix: "",
suffix: "", suffix: "%",
raw: false, raw: true,
absolute: false, absolute: true,
}), }),
Feature::Numeric(FeatureConfig { Feature::Numeric(FeatureConfig {
name: "Employment Score", name: "Employment Score",
bounds: Bounds::Fixed { min: 0.0, max: 1.0 }, bounds: Bounds::Fixed {
step: 0.01, min: 0.0,
description: "Employment deprivation rate, inverted (higher = less deprived)", max: 100.0,
detail: "From the English Indices of Deprivation (inverted so higher = better). Higher values indicate less employment deprivation. Based on claimants of Jobseeker's Allowance, Employment and Support Allowance, Incapacity Benefit, Severe Disablement Allowance, Carer's Allowance, and relevant Universal Credit claimants.", },
step: 1.0,
description: "Employment deprivation percentile (higher = less deprived)",
detail: "From the English Indices of Deprivation, converted to a national percentile where 0% is most employment deprived and 100% is least employment deprived. Based on claimants of Jobseeker's Allowance, Employment and Support Allowance, Incapacity Benefit, Severe Disablement Allowance, Carer's Allowance, and relevant Universal Credit claimants.",
source: "iod", source: "iod",
prefix: "", prefix: "",
suffix: "", suffix: "%",
raw: false, raw: true,
absolute: false, absolute: true,
}), }),
Feature::Numeric(FeatureConfig { Feature::Numeric(FeatureConfig {
name: "Health Deprivation and Disability Score", name: "Health Deprivation and Disability Score",
bounds: Bounds::Percentile { bounds: Bounds::Fixed {
low: 2.0, min: 0.0,
high: 98.0, max: 100.0,
}, },
step: 0.1, step: 1.0,
description: "Health and disability score (higher = better health outcomes)", description: "Health and disability deprivation percentile (higher = better outcomes)",
detail: "From the English Indices of Deprivation (inverted so higher = better). Higher scores indicate lower risk of premature death and better quality of life. Derived from years of potential life lost, comparative illness and disability ratio, acute morbidity, and mood and anxiety disorders.", detail: "From the English Indices of Deprivation, converted to a national percentile where 0% is most health deprived and 100% is least health deprived. Derived from years of potential life lost, comparative illness and disability ratio, acute morbidity, and mood and anxiety disorders.",
source: "iod", source: "iod",
prefix: "", prefix: "",
suffix: "", suffix: "%",
raw: false, raw: true,
absolute: false, absolute: true,
}), }),
Feature::Numeric(FeatureConfig { Feature::Numeric(FeatureConfig {
name: "Housing Conditions Score", name: "Housing Conditions Score",
bounds: Bounds::Percentile { bounds: Bounds::Fixed {
low: 2.0, min: 0.0,
high: 98.0, max: 100.0,
}, },
step: 0.1, step: 1.0,
description: "Housing quality and conditions (higher = better)", description: "Housing conditions percentile (higher = better conditions)",
detail: "From the English Indices of Deprivation, Living Environment domain (inverted so higher = better). Measures the quality of housing stock: central heating availability, housing condition, and Decent Homes standards. Higher scores indicate better housing conditions.", detail: "From the English Indices of Deprivation, Living Environment domain, converted to a national percentile where 0% is most deprived and 100% is least deprived. Measures the quality of housing stock: central heating availability, housing condition, and Decent Homes standards.",
source: "iod", source: "iod",
prefix: "", prefix: "",
suffix: "", suffix: "%",
raw: false, raw: true,
absolute: false, absolute: true,
}), }),
Feature::Numeric(FeatureConfig { Feature::Numeric(FeatureConfig {
name: "Air Quality and Road Safety Score", name: "Air Quality and Road Safety Score",
bounds: Bounds::Percentile { bounds: Bounds::Fixed {
low: 2.0, min: 0.0,
high: 98.0, max: 100.0,
}, },
step: 0.1, step: 1.0,
description: "Air quality and road safety (higher = better)", description: "Air quality and road safety percentile (higher = better conditions)",
detail: "From the English Indices of Deprivation, Living Environment domain (inverted so higher = better). Measures the outdoor living environment quality through air quality indicators and road traffic accident casualties involving pedestrians and cyclists. Higher scores indicate better outdoor environments.", detail: "From the English Indices of Deprivation, Living Environment domain, converted to a national percentile where 0% is most deprived and 100% is least deprived. Measures the outdoor living environment through air quality indicators and road traffic accident casualties involving pedestrians and cyclists.",
source: "iod", source: "iod",
prefix: "", prefix: "",
suffix: "", suffix: "%",
raw: false, raw: true,
absolute: false, absolute: true,
}), }),
], ],
}, },
@ -996,6 +1003,126 @@ pub static FEATURE_GROUPS: &[FeatureGroup] = &[
raw: false, raw: false,
absolute: false, absolute: false,
}), }),
Feature::Numeric(FeatureConfig {
name: "Distance to nearest grocery store (km)",
bounds: Bounds::Percentile {
low: 2.0,
high: 98.0,
},
step: 0.1,
description: "Distance to the closest grocery shop or supermarket",
detail: "Straight-line distance in kilometres from the postcode to the nearest grocery shop, supermarket, or convenience store. Uses OpenStreetMap POIs, with Waitrose and Tesco coverage from GEOLYTIX retail points.",
source: "osm-pois",
prefix: "",
suffix: " km",
raw: false,
absolute: false,
}),
Feature::Numeric(FeatureConfig {
name: "Distance to nearest tube station (km)",
bounds: Bounds::Percentile {
low: 2.0,
high: 98.0,
},
step: 0.1,
description: "Distance to the closest Tube, metro, tram, or DLR stop",
detail: "Straight-line distance in kilometres from the postcode to the nearest NaPTAN station classified as Tube, metro, tram, or DLR.",
source: "naptan",
prefix: "",
suffix: " km",
raw: false,
absolute: false,
}),
Feature::Numeric(FeatureConfig {
name: "Distance to nearest rail station (km)",
bounds: Bounds::Percentile {
low: 2.0,
high: 98.0,
},
step: 0.1,
description: "Distance to the closest National Rail station",
detail: "Straight-line distance in kilometres from the postcode to the nearest NaPTAN railway station.",
source: "naptan",
prefix: "",
suffix: " km",
raw: false,
absolute: false,
}),
Feature::Numeric(FeatureConfig {
name: "Distance to nearest Waitrose (km)",
bounds: Bounds::Percentile {
low: 2.0,
high: 98.0,
},
step: 0.1,
description: "Distance to the closest Waitrose store",
detail: "Straight-line distance in kilometres from the postcode to the nearest Waitrose or Little Waitrose store in the GEOLYTIX Grocery Retail Points dataset.",
source: "geolytix-retail-points",
prefix: "",
suffix: " km",
raw: false,
absolute: false,
}),
Feature::Numeric(FeatureConfig {
name: "Distance to nearest Tesco (km)",
bounds: Bounds::Percentile {
low: 2.0,
high: 98.0,
},
step: 0.1,
description: "Distance to the closest Tesco store",
detail: "Straight-line distance in kilometres from the postcode to the nearest Tesco store in the GEOLYTIX Grocery Retail Points dataset.",
source: "geolytix-retail-points",
prefix: "",
suffix: " km",
raw: false,
absolute: false,
}),
Feature::Numeric(FeatureConfig {
name: "Distance to nearest cafe (km)",
bounds: Bounds::Percentile {
low: 2.0,
high: 98.0,
},
step: 0.1,
description: "Distance to the closest cafe",
detail: "Straight-line distance in kilometres from the postcode to the nearest cafe, ice-cream shop, or internet cafe mapped in OpenStreetMap.",
source: "osm-pois",
prefix: "",
suffix: " km",
raw: false,
absolute: false,
}),
Feature::Numeric(FeatureConfig {
name: "Distance to nearest pub (km)",
bounds: Bounds::Percentile {
low: 2.0,
high: 98.0,
},
step: 0.1,
description: "Distance to the closest pub",
detail: "Straight-line distance in kilometres from the postcode to the nearest pub, social club, brewery, distillery, or winery mapped in OpenStreetMap.",
source: "osm-pois",
prefix: "",
suffix: " km",
raw: false,
absolute: false,
}),
Feature::Numeric(FeatureConfig {
name: "Distance to nearest restaurant (km)",
bounds: Bounds::Percentile {
low: 2.0,
high: 98.0,
},
step: 0.1,
description: "Distance to the closest restaurant",
detail: "Straight-line distance in kilometres from the postcode to the nearest restaurant or food court mapped in OpenStreetMap.",
source: "osm-pois",
prefix: "",
suffix: " km",
raw: false,
absolute: false,
}),
Feature::Numeric(FeatureConfig { Feature::Numeric(FeatureConfig {
name: "Number of parks within 1km", name: "Number of parks within 1km",
bounds: Bounds::Percentile { bounds: Bounds::Percentile {
@ -1105,20 +1232,76 @@ pub fn order_for(name: &str) -> Option<&'static [&'static str]> {
/// Whether this feature should use integer-width histogram bins. /// Whether this feature should use integer-width histogram bins.
pub fn has_integer_bins(name: &str) -> bool { pub fn has_integer_bins(name: &str) -> bool {
INTEGER_BIN_FEATURES.contains(&name) INTEGER_BIN_FEATURES.contains(&name) || dynamic_poi_count_radius(name).is_some()
} }
/// Look up the Bounds config for a numeric feature by name. /// Look up the Bounds config for a numeric feature by name.
pub fn bounds_for(name: &str) -> Option<&'static Bounds> { pub fn bounds_for(name: &str) -> Option<Bounds> {
if dynamic_poi_distance_category(name).is_some() {
return Some(Bounds::Percentile {
low: 2.0,
high: 98.0,
});
}
if dynamic_poi_count_radius(name).is_some() {
return Some(Bounds::Percentile {
low: 5.0,
high: 95.0,
});
}
FEATURE_GROUPS FEATURE_GROUPS
.iter() .iter()
.flat_map(|group| group.features.iter()) .flat_map(|group| group.features.iter())
.find_map(|feature| match feature { .find_map(|feature| match feature {
Feature::Numeric(c) if c.name == name => Some(&c.bounds), Feature::Numeric(c) if c.name == name => Some(c.bounds),
_ => None, _ => None,
}) })
} }
pub fn dynamic_poi_distance_category(name: &str) -> Option<&str> {
name.strip_prefix("Distance to nearest ")
.and_then(|rest| rest.strip_suffix(" POI (km)"))
.filter(|category| !category.is_empty())
}
pub fn dynamic_poi_count_radius(name: &str) -> Option<u8> {
let rest = name.strip_prefix("Number of ")?;
let (_category, suffix) = rest.rsplit_once(" POIs within ")?;
match suffix {
"2km" => Some(2),
"5km" => Some(5),
_ => None,
}
}
pub fn dynamic_poi_count_category(name: &str) -> Option<&str> {
let rest = name.strip_prefix("Number of ")?;
let (category, suffix) = rest.rsplit_once(" POIs within ")?;
matches!(suffix, "2km" | "5km")
.then_some(category)
.filter(|category| !category.is_empty())
}
pub fn is_dynamic_poi_feature(name: &str) -> bool {
dynamic_poi_distance_category(name).is_some() || dynamic_poi_count_category(name).is_some()
}
pub fn dynamic_poi_feature_sort_key(name: &str) -> (u8, String) {
if let Some(category) = dynamic_poi_distance_category(name) {
return (0, category.to_ascii_lowercase());
}
if let Some(category) = dynamic_poi_count_category(name) {
let metric_order = match dynamic_poi_count_radius(name) {
Some(2) => 1,
Some(5) => 2,
_ => 3,
};
return (metric_order, category.to_ascii_lowercase());
}
(9, name.to_ascii_lowercase())
}
/// Canonical display order for POI category groups. /// Canonical display order for POI category groups.
/// The server will panic at startup if the data contains groups not in this list or vice versa. /// The server will panic at startup if the data contains groups not in this list or vice versa.
pub const POI_GROUP_ORDER: &[&str] = &[ pub const POI_GROUP_ORDER: &[&str] = &[

View file

@ -2,6 +2,7 @@
mod aggregation; mod aggregation;
mod auth; mod auth;
mod checkout_sessions;
mod consts; mod consts;
mod data; mod data;
mod features; mod features;
@ -10,6 +11,7 @@ mod metrics;
mod og_middleware; mod og_middleware;
pub mod parsing; pub mod parsing;
mod pocketbase; mod pocketbase;
mod pocketbase_locks;
mod routes; mod routes;
mod state; mod state;
pub mod utils; pub mod utils;

View file

@ -4,8 +4,11 @@ mod filters;
mod h3; mod h3;
pub use bounds::{bounds_intersect, h3_cell_bounds, parse_bounds, require_bounds}; pub use bounds::{bounds_intersect, h3_cell_bounds, parse_bounds, require_bounds};
pub use fields::{parse_enum_dist, parse_field_indices, parse_field_set}; pub use fields::{
parse_enum_dist, parse_field_indices, parse_field_indices_with_poi, parse_field_set,
};
pub use filters::{ pub use filters::{
count_filter_impacts, parse_filters, row_passes_filters, ParsedEnumFilter, ParsedFilter, count_filter_impacts, parse_filters, parse_filters_with_poi, row_passes_filters,
row_passes_poi_filters, ParsedEnumFilter, ParsedFilter, ParsedPoiFilter,
}; };
pub use h3::{cell_for_row, cell_for_row_cached, needs_parent, validate_h3_resolution}; pub use h3::{cell_for_row, cell_for_row_cached, needs_parent, validate_h3_resolution};

View file

@ -31,6 +31,55 @@ pub fn parse_field_indices(
Ok(Some(indices)) Ok(Some(indices))
} }
pub struct ParsedFieldIndices {
/// None means no `fields` param was supplied, so normal aggregation keeps
/// its existing "all configured features" behavior.
pub normal: Option<Vec<usize>>,
pub poi: Vec<usize>,
}
/// Parse `?fields=` against both the row-major feature matrix and the
/// postcode-level POI side table.
pub fn parse_field_indices_with_poi(
fields: Option<&str>,
name_to_index: &FxHashMap<String, usize>,
poi_name_to_index: &FxHashMap<String, usize>,
) -> Result<ParsedFieldIndices, (StatusCode, String)> {
let Some(fields_str) = fields else {
return Ok(ParsedFieldIndices {
normal: None,
poi: Vec::new(),
});
};
if fields_str.is_empty() {
return Ok(ParsedFieldIndices {
normal: Some(Vec::new()),
poi: Vec::new(),
});
}
let mut normal = Vec::new();
let mut poi = Vec::new();
for name in fields_str.split(";;") {
let name = name.trim();
if name.is_empty() {
continue;
}
if let Some(&idx) = name_to_index.get(name) {
normal.push(idx);
} else if let Some(&idx) = poi_name_to_index.get(name) {
poi.push(idx);
} else {
return Err((StatusCode::BAD_REQUEST, format!("Unknown field: {}", name)));
}
}
Ok(ParsedFieldIndices {
normal: Some(normal),
poi,
})
}
/// Parse an optional `?enum_dist=` query param into (feature_index, num_values) for /// Parse an optional `?enum_dist=` query param into (feature_index, num_values) for
/// per-value distribution counting. Returns None if not requested. /// per-value distribution counting. Returns None if not requested.
/// Returns 400 if the feature name is unknown or not an enum feature. /// Returns 400 if the feature name is unknown or not an enum feature.
@ -73,3 +122,28 @@ pub fn parse_field_set(fields: Option<&str>) -> (bool, HashSet<String>) {
.unwrap_or_default(); .unwrap_or_default();
(fields_specified, field_set) (fields_specified, field_set)
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_field_indices_with_poi_splits_normal_and_side_fields() {
let normal: FxHashMap<String, usize> = [("Price".to_string(), 0), ("Area".to_string(), 1)]
.into_iter()
.collect();
let poi: FxHashMap<String, usize> = [("Distance to nearest cafe POI (km)".to_string(), 2)]
.into_iter()
.collect();
let parsed = parse_field_indices_with_poi(
Some("Price;;Distance to nearest cafe POI (km)"),
&normal,
&poi,
)
.unwrap();
assert_eq!(parsed.normal, Some(vec![0]));
assert_eq!(parsed.poi, vec![2]);
}
}

View file

@ -1,7 +1,7 @@
use rustc_hash::{FxHashMap, FxHashSet}; use rustc_hash::{FxHashMap, FxHashSet};
use crate::consts::NAN_U16; use crate::consts::NAN_U16;
use crate::data::QuantRef; use crate::data::{PostcodePoiMetrics, QuantRef};
/// Filter for numeric features: value must be in [min_u16, max_u16] range (quantized). /// Filter for numeric features: value must be in [min_u16, max_u16] range (quantized).
#[derive(Debug)] #[derive(Debug)]
@ -19,6 +19,20 @@ pub struct ParsedEnumFilter {
pub allowed: FxHashSet<u16>, pub allowed: FxHashSet<u16>,
} }
/// Filter for postcode-level POI metrics stored in the side table.
#[derive(Debug)]
pub struct ParsedPoiFilter {
pub metric_idx: usize,
pub min_u16: u16,
pub max_u16: u16,
}
pub type ParsedFiltersWithPoi = (
Vec<ParsedFilter>,
Vec<ParsedEnumFilter>,
Vec<ParsedPoiFilter>,
);
/// Parse `;;`-separated filter string into numeric and enum filters. /// Parse `;;`-separated filter string into numeric and enum filters.
/// Numeric format: `name:min:max` /// Numeric format: `name:min:max`
/// Enum format: `name:val1|val2|val3` (pipe-separated string values) /// Enum format: `name:val1|val2|val3` (pipe-separated string values)
@ -110,6 +124,101 @@ pub fn parse_filters(
Ok((numeric, enums)) Ok((numeric, enums))
} }
/// Parse filters while allowing dynamic POI metric names that live outside the
/// row-major property feature matrix.
pub fn parse_filters_with_poi(
filter_str: Option<&str>,
feature_name_to_index: &FxHashMap<String, usize>,
enum_values: &FxHashMap<usize, Vec<String>>,
quant: &QuantRef,
poi_name_to_index: &FxHashMap<String, usize>,
poi_quant: &QuantRef,
) -> Result<ParsedFiltersWithPoi, String> {
let mut numeric = Vec::new();
let mut enums = Vec::new();
let mut poi = Vec::new();
let input = match filter_str.filter(|text| !text.is_empty()) {
Some(text) => text,
None => return Ok((numeric, enums, poi)),
};
for entry in input.split(";;") {
let parts: Vec<&str> = entry.splitn(2, ':').collect();
if parts.len() != 2 {
return Err(format!("Malformed filter entry (missing ':'): '{entry}'"));
}
let name = parts[0].trim();
let rest = parts[1].trim();
if let Some(&feat_idx) = feature_name_to_index.get(name) {
if let Some(values) = enum_values.get(&feat_idx) {
let mut allowed: FxHashSet<u16> = FxHashSet::default();
for value in rest.split('|') {
let value = value.trim();
match values.iter().position(|existing| existing == value) {
Some(position) => {
allowed.insert(position as u16);
}
None => {
return Err(format!(
"Unknown value '{}' for enum feature '{}'. Valid values: {:?}",
value, name, values
));
}
}
}
enums.push(ParsedEnumFilter { feat_idx, allowed });
} else {
let (min, max) = parse_numeric_filter_bounds(name, rest, entry)?;
numeric.push(ParsedFilter {
feat_idx,
min_u16: quant.encode_min(feat_idx, min),
max_u16: quant.encode_max(feat_idx, max),
});
}
} else if let Some(&metric_idx) = poi_name_to_index.get(name) {
let (min, max) = parse_numeric_filter_bounds(name, rest, entry)?;
poi.push(ParsedPoiFilter {
metric_idx,
min_u16: poi_quant.encode_min(metric_idx, min),
max_u16: poi_quant.encode_max(metric_idx, max),
});
} else {
return Err(format!("Unknown feature in filter: '{name}'"));
}
}
numeric.sort_unstable_by_key(|f| f.max_u16.saturating_sub(f.min_u16));
enums.sort_unstable_by_key(|f| f.allowed.len());
poi.sort_unstable_by_key(|f| f.max_u16.saturating_sub(f.min_u16));
Ok((numeric, enums, poi))
}
fn parse_numeric_filter_bounds(name: &str, rest: &str, entry: &str) -> Result<(f32, f32), String> {
let num_parts: Vec<&str> = rest.splitn(2, ':').collect();
if num_parts.len() != 2 {
return Err(format!(
"Numeric filter '{name}' must have format 'name:min:max', got '{entry}'"
));
}
let min = num_parts[0]
.trim()
.parse::<f32>()
.map_err(|err| format!("Invalid min value in filter '{name}': {err}"))?;
let max = num_parts[1]
.trim()
.parse::<f32>()
.map_err(|err| format!("Invalid max value in filter '{name}': {err}"))?;
if min.is_finite() && max.is_finite() && min > max {
return Err(format!(
"Numeric filter '{name}' has inverted range: min ({min}) > max ({max})"
));
}
Ok((min, max))
}
/// Check if a row passes all filters. /// Check if a row passes all filters.
/// All features (numeric and enum) are stored in feature_data as quantized u16. /// All features (numeric and enum) are stored in feature_data as quantized u16.
pub fn row_passes_filters( pub fn row_passes_filters(
@ -130,6 +239,18 @@ pub fn row_passes_filters(
}) })
} }
#[inline]
pub fn row_passes_poi_filters(
row: usize,
filters: &[ParsedPoiFilter],
poi_metrics: &PostcodePoiMetrics,
) -> bool {
filters.iter().all(|filter| {
let raw = poi_metrics.raw_for_property_row(row, filter.metric_idx);
raw != NAN_U16 && raw >= filter.min_u16 && raw <= filter.max_u16
})
}
/// Single-pass marginal impact counting. /// Single-pass marginal impact counting.
/// ///
/// Returns `(total_passing, impacts)` where `impacts[i]` is how many MORE rows /// Returns `(total_passing, impacts)` where `impacts[i]` is how many MORE rows
@ -330,6 +451,35 @@ mod tests {
assert_eq!(enums[0].allowed.len(), 2); assert_eq!(enums[0].allowed.len(), 2);
} }
#[test]
fn parse_filters_with_poi_splits_side_table_filters() {
let tq = test_quant(3, 2);
let poi_tq = test_quant(2, 2);
let poi_map: FxHashMap<String, usize> = [
("Distance to nearest cafe POI (km)".into(), 0),
("Number of cafe POIs within 2km".into(), 1),
]
.into_iter()
.collect();
let (numeric, enums, poi) = parse_filters_with_poi(
Some("price:100:500;;rating:A;;Distance to nearest cafe POI (km):0:1.5"),
&feature_name_to_index(),
&enum_values(),
&tq.as_ref(),
&poi_map,
&poi_tq.as_ref(),
)
.unwrap();
assert_eq!(numeric.len(), 1);
assert_eq!(enums.len(), 1);
assert_eq!(poi.len(), 1);
assert_eq!(poi[0].metric_idx, 0);
assert_eq!(poi[0].min_u16, 0);
assert_eq!(poi[0].max_u16, 99);
}
#[test] #[test]
fn parse_filters_empty() { fn parse_filters_empty() {
let tq = test_quant(3, 2); let tq = test_quant(3, 2);

View file

@ -88,6 +88,8 @@ struct CreateCollection {
update_rule: Option<String>, update_rule: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
delete_rule: Option<String>, delete_rule: Option<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
indexes: Vec<String>,
} }
#[derive(Serialize)] #[derive(Serialize)]
@ -308,12 +310,13 @@ async fn ensure_user_fields(client: &Client, base_url: &str, token: &str) -> any
let has_ai_tokens_used = fields.iter().any(|f| f["name"] == "ai_tokens_used"); let has_ai_tokens_used = fields.iter().any(|f| f["name"] == "ai_tokens_used");
let has_ai_tokens_week = fields.iter().any(|f| f["name"] == "ai_tokens_week"); let has_ai_tokens_week = fields.iter().any(|f| f["name"] == "ai_tokens_week");
if has_is_admin let has_all_required_fields = has_is_admin
&& has_subscription && has_subscription
&& has_newsletter && has_newsletter
&& has_ai_tokens_used && has_ai_tokens_used
&& has_ai_tokens_week && has_ai_tokens_week;
{
if has_all_required_fields {
info!("PocketBase users collection already has all required fields"); info!("PocketBase users collection already has all required fields");
return Ok(()); return Ok(());
} }
@ -372,6 +375,52 @@ async fn ensure_user_fields(client: &Client, base_url: &str, token: &str) -> any
Ok(()) Ok(())
} }
/// Ensure clients can manage normal account data but cannot self-grant paid or
/// admin-only state. Superuser writes from the Rust API bypass these rules.
async fn ensure_user_auth_rules(
client: &Client,
base_url: &str,
token: &str,
) -> anyhow::Result<()> {
let url = format!("{base_url}/api/collections/users");
let self_only = "id = @request.auth.id";
let protected_fields_absent = concat!(
"@request.body.subscription:isset = false",
" && @request.body.is_admin:isset = false",
" && @request.body.ai_tokens_used:isset = false",
" && @request.body.ai_tokens_week:isset = false"
);
let protected_fields_unchanged = concat!(
"@request.body.subscription:changed = false",
" && @request.body.is_admin:changed = false",
" && @request.body.ai_tokens_used:changed = false",
" && @request.body.ai_tokens_week:changed = false"
);
let update_rule = format!("{self_only} && {protected_fields_unchanged}");
let resp = client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"listRule": self_only,
"viewRule": self_only,
"createRule": protected_fields_absent,
"updateRule": update_rule,
"deleteRule": self_only,
}))
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
anyhow::bail!("Failed to update users collection API rules ({status}): {text}");
}
info!("PocketBase users collection API rules hardened");
Ok(())
}
/// Ensure a collection has API rules allowing users to manage their own records. /// Ensure a collection has API rules allowing users to manage their own records.
async fn ensure_user_owned_rules( async fn ensure_user_owned_rules(
client: &Client, client: &Client,
@ -404,6 +453,263 @@ async fn ensure_user_owned_rules(
Ok(()) Ok(())
} }
/// Ensure a collection is accessible only via server-side superuser calls.
async fn ensure_server_only_rules(
client: &Client,
base_url: &str,
token: &str,
collection_name: &str,
) -> anyhow::Result<()> {
let url = format!("{base_url}/api/collections/{collection_name}");
let resp = client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"listRule": serde_json::Value::Null,
"viewRule": serde_json::Value::Null,
"createRule": serde_json::Value::Null,
"updateRule": serde_json::Value::Null,
"deleteRule": serde_json::Value::Null,
}))
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
anyhow::bail!("Failed to lock {collection_name} API rules ({status}): {text}");
}
info!("PocketBase collection '{collection_name}' locked to superuser access");
Ok(())
}
async fn ensure_checkout_sessions_fields(
client: &Client,
base_url: &str,
token: &str,
) -> anyhow::Result<()> {
let url = format!("{base_url}/api/collections/checkout_sessions");
let resp = client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
anyhow::bail!("Failed to fetch checkout_sessions collection ({status}): {text}");
}
let body: serde_json::Value = resp.json().await?;
let fields = body["fields"]
.as_array()
.ok_or_else(|| anyhow::anyhow!("checkout_sessions collection has no fields array"))?;
let users_id = find_users_collection_id(client, base_url, token).await?;
let mut new_fields = fields.clone();
let mut add_field = |name: &str, field: serde_json::Value| {
if !fields.iter().any(|f| f["name"] == name) {
new_fields.push(field);
}
};
add_field(
"user",
serde_json::json!({
"name": "user",
"type": "relation",
"required": true,
"maxSelect": 1,
"collectionId": users_id,
}),
);
add_field(
"stripe_session_id",
serde_json::json!({ "name": "stripe_session_id", "type": "text", "required": false }),
);
add_field(
"checkout_url",
serde_json::json!({ "name": "checkout_url", "type": "text", "required": false }),
);
add_field(
"amount_pence",
serde_json::json!({ "name": "amount_pence", "type": "number" }),
);
add_field(
"expected_total_pence",
serde_json::json!({ "name": "expected_total_pence", "type": "number" }),
);
add_field(
"currency",
serde_json::json!({ "name": "currency", "type": "text", "required": true }),
);
add_field(
"discount_coupon_id",
serde_json::json!({ "name": "discount_coupon_id", "type": "text", "required": false }),
);
add_field(
"referral_invite_id",
serde_json::json!({ "name": "referral_invite_id", "type": "text", "required": false }),
);
add_field(
"status",
serde_json::json!({ "name": "status", "type": "text", "required": true }),
);
add_field(
"expires_at_unix",
serde_json::json!({ "name": "expires_at_unix", "type": "number" }),
);
add_field(
"paid_amount_pence",
serde_json::json!({ "name": "paid_amount_pence", "type": "number" }),
);
add_field(
"completed_at_unix",
serde_json::json!({ "name": "completed_at_unix", "type": "text", "required": false }),
);
if new_fields.len() == fields.len() {
return Ok(());
}
let patch_resp = client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({ "fields": new_fields }))
.send()
.await?;
if !patch_resp.status().is_success() {
let status = patch_resp.status();
let text = patch_resp.text().await.unwrap_or_default();
anyhow::bail!("Failed to patch checkout_sessions fields ({status}): {text}");
}
info!("PocketBase checkout_sessions collection fields updated");
Ok(())
}
async fn ensure_checkout_locks_fields(
client: &Client,
base_url: &str,
token: &str,
) -> anyhow::Result<()> {
let url = format!("{base_url}/api/collections/checkout_locks");
let resp = client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
anyhow::bail!("Failed to fetch checkout_locks collection ({status}): {text}");
}
let body: serde_json::Value = resp.json().await?;
let fields = body["fields"]
.as_array()
.ok_or_else(|| anyhow::anyhow!("checkout_locks collection has no fields array"))?;
let mut new_fields = fields.clone();
let mut add_field = |name: &str, field: serde_json::Value| {
if !fields.iter().any(|f| f["name"] == name) {
new_fields.push(field);
}
};
add_field(
"name",
serde_json::json!({ "name": "name", "type": "text", "required": true }),
);
add_field(
"owner",
serde_json::json!({ "name": "owner", "type": "text", "required": true }),
);
add_field(
"expires_at_unix",
serde_json::json!({ "name": "expires_at_unix", "type": "number" }),
);
if new_fields.len() == fields.len() {
return Ok(());
}
let patch_resp = client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({ "fields": new_fields }))
.send()
.await?;
if !patch_resp.status().is_success() {
let status = patch_resp.status();
let text = patch_resp.text().await.unwrap_or_default();
anyhow::bail!("Failed to patch checkout_locks fields ({status}): {text}");
}
info!("PocketBase checkout_locks collection fields updated");
Ok(())
}
async fn ensure_collection_indexes(
client: &Client,
base_url: &str,
token: &str,
collection_name: &str,
required_indexes: &[(&str, &str)],
) -> anyhow::Result<()> {
let url = format!("{base_url}/api/collections/{collection_name}");
let resp = client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
anyhow::bail!("Failed to fetch {collection_name} collection ({status}): {text}");
}
let body: serde_json::Value = resp.json().await?;
let indexes = body["indexes"].as_array().cloned().unwrap_or_default();
let mut new_indexes = indexes.clone();
for (index_name, create_sql) in required_indexes {
let exists = indexes
.iter()
.filter_map(|idx| idx.as_str())
.any(|idx| idx.contains(index_name));
if !exists {
new_indexes.push(serde_json::Value::String((*create_sql).to_string()));
}
}
if new_indexes.len() == indexes.len() {
return Ok(());
}
let patch_resp = client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({ "indexes": new_indexes }))
.send()
.await?;
if !patch_resp.status().is_success() {
let status = patch_resp.status();
let text = patch_resp.text().await.unwrap_or_default();
anyhow::bail!("Failed to patch {collection_name} indexes ({status}): {text}");
}
info!("PocketBase collection '{collection_name}' indexes updated");
Ok(())
}
/// Ensure the `saved_searches` collection has API rules allowing users to manage their own records. /// Ensure the `saved_searches` collection has API rules allowing users to manage their own records.
async fn ensure_saved_searches_rules( async fn ensure_saved_searches_rules(
client: &Client, client: &Client,
@ -608,6 +914,7 @@ pub async fn ensure_collections(
let existing = list_collections(client, base_url, &token).await?; let existing = list_collections(client, base_url, &token).await?;
ensure_user_fields(client, base_url, &token).await?; ensure_user_fields(client, base_url, &token).await?;
ensure_user_auth_rules(client, base_url, &token).await?;
if !existing.iter().any(|n| n == "saved_searches") { if !existing.iter().any(|n| n == "saved_searches") {
let users_id = find_users_collection_id(client, base_url, &token).await?; let users_id = find_users_collection_id(client, base_url, &token).await?;
@ -633,6 +940,7 @@ pub async fn ensure_collections(
create_rule: user_only.clone(), create_rule: user_only.clone(),
update_rule: user_only.clone(), update_rule: user_only.clone(),
delete_rule: user_only, delete_rule: user_only,
indexes: Vec::new(),
}, },
) )
.await?; .await?;
@ -667,6 +975,7 @@ pub async fn ensure_collections(
create_rule: user_only.clone(), create_rule: user_only.clone(),
update_rule: user_only.clone(), update_rule: user_only.clone(),
delete_rule: user_only, delete_rule: user_only,
indexes: Vec::new(),
}, },
) )
.await?; .await?;
@ -698,6 +1007,7 @@ pub async fn ensure_collections(
create_rule: None, create_rule: None,
update_rule: None, update_rule: None,
delete_rule: None, delete_rule: None,
indexes: Vec::new(),
}, },
) )
.await?; .await?;
@ -705,6 +1015,86 @@ pub async fn ensure_collections(
ensure_autodate_fields(client, base_url, &token, "invites").await?; ensure_autodate_fields(client, base_url, &token, "invites").await?;
} }
if !existing.iter().any(|n| n == "checkout_sessions") {
let users_id = find_users_collection_id(client, base_url, &token).await?;
create_collection(
client,
base_url,
&token,
CreateCollection {
name: "checkout_sessions".to_string(),
r#type: "base".to_string(),
fields: vec![
Field::relation("user", &users_id),
Field::text("stripe_session_id", false),
Field::text("checkout_url", false),
Field::number("amount_pence"),
Field::number("expected_total_pence"),
Field::text("currency", true),
Field::text("discount_coupon_id", false),
Field::text("referral_invite_id", false),
Field::text("status", true),
Field::number("expires_at_unix"),
Field::number("paid_amount_pence"),
Field::text("completed_at_unix", false),
Field::autodate("created", true, false),
Field::autodate("updated", true, true),
],
list_rule: None,
view_rule: None,
create_rule: None,
update_rule: None,
delete_rule: None,
indexes: Vec::new(),
},
)
.await?;
} else {
ensure_server_only_rules(client, base_url, &token, "checkout_sessions").await?;
ensure_checkout_sessions_fields(client, base_url, &token).await?;
ensure_autodate_fields(client, base_url, &token, "checkout_sessions").await?;
}
let checkout_locks_name_index =
"CREATE UNIQUE INDEX idx_checkout_locks_name ON checkout_locks (name)";
if !existing.iter().any(|n| n == "checkout_locks") {
create_collection(
client,
base_url,
&token,
CreateCollection {
name: "checkout_locks".to_string(),
r#type: "base".to_string(),
fields: vec![
Field::text("name", true),
Field::text("owner", true),
Field::number("expires_at_unix"),
Field::autodate("created", true, false),
Field::autodate("updated", true, true),
],
list_rule: None,
view_rule: None,
create_rule: None,
update_rule: None,
delete_rule: None,
indexes: vec![checkout_locks_name_index.to_string()],
},
)
.await?;
} else {
ensure_server_only_rules(client, base_url, &token, "checkout_locks").await?;
ensure_checkout_locks_fields(client, base_url, &token).await?;
ensure_autodate_fields(client, base_url, &token, "checkout_locks").await?;
ensure_collection_indexes(
client,
base_url,
&token,
"checkout_locks",
&[("idx_checkout_locks_name", checkout_locks_name_index)],
)
.await?;
}
if !existing.iter().any(|n| n == "short_urls") { if !existing.iter().any(|n| n == "short_urls") {
create_collection( create_collection(
client, client,
@ -724,6 +1114,7 @@ pub async fn ensure_collections(
create_rule: None, create_rule: None,
update_rule: None, update_rule: None,
delete_rule: None, delete_rule: None,
indexes: Vec::new(),
}, },
) )
.await?; .await?;
@ -753,6 +1144,7 @@ pub async fn ensure_collections(
create_rule: None, create_rule: None,
update_rule: None, update_rule: None,
delete_rule: None, delete_rule: None,
indexes: Vec::new(),
}, },
) )
.await?; .await?;
@ -785,6 +1177,7 @@ pub async fn ensure_collections(
create_rule: None, create_rule: None,
update_rule: None, update_rule: None,
delete_rule: None, delete_rule: None,
indexes: Vec::new(),
}, },
) )
.await?; .await?;

View file

@ -0,0 +1,264 @@
use std::time::{Duration, Instant};
use anyhow::{anyhow, bail, Context};
use rand::RngExt;
use serde_json::Value;
use tokio::time::sleep;
use tracing::warn;
use crate::pocketbase::get_superuser_token;
use crate::state::AppState;
const LOCK_COLLECTION: &str = "checkout_locks";
const LOCK_ACQUIRE_TIMEOUT_SECS: u64 = 10;
const LOCK_RETRY_DELAY_MS: u64 = 100;
pub struct PocketBaseLock {
client: reqwest::Client,
pb_url: String,
token: String,
record_id: Option<String>,
name: String,
}
struct ExistingLock {
id: String,
expires_at_unix: u64,
}
pub async fn acquire_pocketbase_lock(
state: &AppState,
name: &str,
ttl_secs: u64,
) -> anyhow::Result<PocketBaseLock> {
validate_lock_name(name)?;
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/').to_string();
let owner = random_owner();
let deadline = Instant::now() + Duration::from_secs(LOCK_ACQUIRE_TIMEOUT_SECS);
loop {
let now = now_unix_secs();
if let Some(record_id) =
try_create_lock(state, &pb_url, &token, name, &owner, now + ttl_secs).await?
{
return Ok(PocketBaseLock {
client: state.http_client.clone(),
pb_url,
token,
record_id: Some(record_id),
name: name.to_string(),
});
}
if let Some(existing) = find_lock(state, &pb_url, &token, name).await? {
if existing.expires_at_unix <= now {
if let Err(err) = delete_lock_record(state, &pb_url, &token, &existing.id).await {
warn!(
lock_name = name,
lock_id = %existing.id,
"Failed to delete stale PocketBase lock: {err}"
);
}
continue;
}
}
if Instant::now() >= deadline {
bail!("Timed out acquiring PocketBase lock '{name}'");
}
sleep(Duration::from_millis(LOCK_RETRY_DELAY_MS)).await;
}
}
impl PocketBaseLock {
pub async fn release(mut self) -> anyhow::Result<()> {
let Some(record_id) = self.record_id.take() else {
return Ok(());
};
release_lock_record(&self.client, &self.pb_url, &self.token, &record_id)
.await
.with_context(|| format!("Failed to release PocketBase lock '{}'", self.name))
}
}
impl Drop for PocketBaseLock {
fn drop(&mut self) {
let Some(record_id) = self.record_id.take() else {
return;
};
let client = self.client.clone();
let pb_url = self.pb_url.clone();
let token = self.token.clone();
let name = self.name.clone();
tokio::spawn(async move {
if let Err(err) = release_lock_record(&client, &pb_url, &token, &record_id).await {
warn!(
lock_name = %name,
lock_id = %record_id,
"Failed to release PocketBase lock on drop: {err}"
);
}
});
}
}
async fn try_create_lock(
state: &AppState,
pb_url: &str,
token: &str,
name: &str,
owner: &str,
expires_at_unix: u64,
) -> anyhow::Result<Option<String>> {
let url = format!("{pb_url}/api/collections/{LOCK_COLLECTION}/records");
let resp = state
.http_client
.post(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"name": name,
"owner": owner,
"expires_at_unix": expires_at_unix,
}))
.send()
.await?;
if resp.status().is_success() {
let body: Value = resp.json().await?;
return body["id"]
.as_str()
.map(str::to_string)
.map(Some)
.ok_or_else(|| anyhow!("PocketBase lock record missing id"));
}
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
if status.is_client_error() {
return Ok(None);
}
Err(anyhow!("PocketBase lock create failed ({status}): {text}"))
}
async fn find_lock(
state: &AppState,
pb_url: &str,
token: &str,
name: &str,
) -> anyhow::Result<Option<ExistingLock>> {
let filter = format!("name=\"{}\"", name);
let url = format!(
"{pb_url}/api/collections/{LOCK_COLLECTION}/records?filter={}&perPage=1",
urlencoding::encode(&filter)
);
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
ensure_success_ref(&resp).await?;
let body: Value = resp.json().await?;
let Some(item) = body["items"].as_array().and_then(|items| items.first()) else {
return Ok(None);
};
let id = item["id"]
.as_str()
.ok_or_else(|| anyhow!("PocketBase lock missing id"))?
.to_string();
let expires_at_unix = number_field(item, "expires_at_unix").unwrap_or(0);
Ok(Some(ExistingLock {
id,
expires_at_unix,
}))
}
async fn delete_lock_record(
state: &AppState,
pb_url: &str,
token: &str,
record_id: &str,
) -> anyhow::Result<()> {
release_lock_record(&state.http_client, pb_url, token, record_id).await
}
async fn release_lock_record(
client: &reqwest::Client,
pb_url: &str,
token: &str,
record_id: &str,
) -> anyhow::Result<()> {
let url = format!("{pb_url}/api/collections/{LOCK_COLLECTION}/records/{record_id}");
let resp = client
.delete(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
if resp.status().is_success() || resp.status() == reqwest::StatusCode::NOT_FOUND {
return Ok(());
}
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
Err(anyhow!("PocketBase lock delete failed ({status}): {text}"))
}
fn validate_lock_name(name: &str) -> anyhow::Result<()> {
if name.is_empty() || name.len() > 80 {
bail!("invalid PocketBase lock name length");
}
if !name
.bytes()
.all(|b| b.is_ascii_alphanumeric() || b == b':' || b == b'_' || b == b'-')
{
bail!("invalid PocketBase lock name characters");
}
Ok(())
}
fn random_owner() -> String {
let mut rng = rand::rng();
(0..24)
.map(|_| {
let idx: u8 = rng.random_range(0..36);
if idx < 10 {
(b'0' + idx) as char
} else {
(b'a' + idx - 10) as char
}
})
.collect()
}
fn now_unix_secs() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
fn number_field(value: &Value, field: &str) -> Option<u64> {
value[field].as_u64().or_else(|| {
value[field]
.as_f64()
.filter(|n| n.is_finite() && *n >= 0.0 && n.fract() == 0.0)
.map(|n| n as u64)
})
}
async fn ensure_success_ref(resp: &reqwest::Response) -> anyhow::Result<()> {
if resp.status().is_success() {
return Ok(());
}
Err(anyhow!("upstream returned {}", resp.status()))
}

View file

@ -8,10 +8,8 @@ use serde::{Deserialize, Serialize};
use tracing::{info, warn}; use tracing::{info, warn};
use crate::auth::OptionalUser; use crate::auth::OptionalUser;
use crate::pocketbase::get_superuser_token; use crate::checkout_sessions::{start_license_checkout, CheckoutStart};
use crate::state::{AppState, SharedState}; use crate::state::SharedState;
use super::pricing::{count_licensed_users, price_for_count};
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct CheckoutRequest { pub struct CheckoutRequest {
@ -23,8 +21,8 @@ struct CheckoutResponse {
url: String, url: String,
} }
/// Create a Stripe Checkout session for the lifetime license (or grant for free if in free tier). /// Create a reserved Stripe Checkout session for the lifetime license.
/// Requires authentication. Optionally accepts a referral code to apply a coupon. /// Requires authentication. Referral discounts are issued via invite redemption.
pub async fn post_checkout( pub async fn post_checkout(
State(shared): State<Arc<SharedState>>, State(shared): State<Arc<SharedState>>,
Extension(user): Extension<OptionalUser>, Extension(user): Extension<OptionalUser>,
@ -36,147 +34,27 @@ pub async fn post_checkout(
None => return StatusCode::UNAUTHORIZED.into_response(), None => return StatusCode::UNAUTHORIZED.into_response(),
}; };
let count = match count_licensed_users(&state).await {
Ok(c) => c,
Err(err) => {
warn!("Failed to count licensed users at checkout: {err}");
return StatusCode::SERVICE_UNAVAILABLE.into_response();
}
};
let price_pence = price_for_count(count);
let public_url = &state.public_url; let public_url = &state.public_url;
let success_url = format!("{public_url}/pricing?license_success=1"); let success_url = format!("{public_url}/pricing?license_success=1");
// Free tier — grant license directly without Stripe
if price_pence == 0 {
if let Err(err) = grant_license(&state, &user.id).await {
warn!(user_id = %user.id, "Failed to grant free license: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
info!(user_id = %user.id, "Granted free early-bird license");
return Json(CheckoutResponse { url: success_url }).into_response();
}
// Paid tier — create Stripe checkout with dynamic price
let secret_key = &state.stripe_secret_key;
let cancel_url = format!("{public_url}/pricing"); let cancel_url = format!("{public_url}/pricing");
let mut form_params = vec![ if req.referral_code.is_some() {
("mode", "payment".to_string()), return (
( StatusCode::BAD_REQUEST,
"line_items[0][price_data][unit_amount]", "Referral codes must be redeemed from the invite link",
price_pence.to_string(), )
), .into_response();
("line_items[0][price_data][currency]", "gbp".to_string()),
(
"line_items[0][price_data][product_data][name]",
"Perfect Postcodes Lifetime License".to_string(),
),
("line_items[0][quantity]", "1".to_string()),
("success_url", success_url),
("cancel_url", cancel_url),
("client_reference_id", user.id.clone()),
("customer_email", user.email.clone()),
];
// If a referral code is provided and valid, look it up and apply the coupon
if let Some(ref code) = req.referral_code {
if validate_referral_invite(&state, code).await {
form_params.push((
"discounts[0][coupon]",
state.stripe_referral_coupon_id.clone(),
));
info!(code = %code, "Applying referral coupon to checkout");
} else {
warn!(code = %code, "Referral code validation failed, proceeding without discount");
}
} }
let res = state match start_license_checkout(&state, &user, &success_url, &cancel_url, None, None).await {
.http_client Ok(CheckoutStart::Free) => {
.post("https://api.stripe.com/v1/checkout/sessions") info!(user_id = %user.id, "Granted free early-bird license");
.basic_auth(secret_key, None::<&str>) Json(CheckoutResponse { url: success_url }).into_response()
.form(&form_params)
.send()
.await;
match res {
Ok(resp) if resp.status().is_success() => {
let body: serde_json::Value = match resp.json().await {
Ok(v) => v,
Err(err) => {
warn!("Failed to parse Stripe response: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
};
let url = body["url"].as_str().unwrap_or_default().to_string();
if url.is_empty() {
warn!("Stripe session missing URL");
return StatusCode::BAD_GATEWAY.into_response();
}
info!(user_id = %user.id, price_pence, "Created Stripe checkout session");
Json(CheckoutResponse { url }).into_response()
}
Ok(resp) => {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
warn!("Stripe checkout failed ({status}): {text}");
StatusCode::BAD_GATEWAY.into_response()
} }
Ok(CheckoutStart::Stripe { url }) => Json(CheckoutResponse { url }).into_response(),
Err(err) => { Err(err) => {
warn!("Stripe request error: {err}"); warn!(user_id = %user.id, "Failed to start checkout: {err:?}");
StatusCode::BAD_GATEWAY.into_response() StatusCode::BAD_GATEWAY.into_response()
} }
} }
} }
/// Grant a license by updating the user's subscription to "licensed" in PocketBase.
async fn grant_license(state: &AppState, user_id: &str) -> anyhow::Result<()> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({ "subscription": "licensed" }))
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
anyhow::bail!("PocketBase update failed ({status}): {text}");
}
state.token_cache.invalidate_by_user_id(user_id);
Ok(())
}
/// Check if a referral invite code exists and is unused.
async fn validate_referral_invite(state: &AppState, code: &str) -> bool {
// Only allow alphanumeric codes to prevent PocketBase filter injection
if code.is_empty() || code.len() > 20 || !code.bytes().all(|b| b.is_ascii_alphanumeric()) {
return false;
}
let pb_url = state.pocketbase_url.trim_end_matches('/');
let filter = format!(
"code=\"{}\" && invite_type=\"referral\" && used_by_id=\"\"",
code
);
let url = format!(
"{pb_url}/api/collections/invites/records?filter={}&perPage=1",
urlencoding::encode(&filter)
);
match state.http_client.get(&url).send().await {
Ok(resp) if resp.status().is_success() => {
let body: serde_json::Value = resp.json().await.unwrap_or_default();
body["totalItems"].as_u64().unwrap_or(0) > 0
}
_ => false,
}
}

View file

@ -1,6 +1,7 @@
use std::collections::hash_map::DefaultHasher; use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use axum::extract::{Query, State}; use axum::extract::{Query, State};
use axum::http::{header, HeaderMap, StatusCode}; use axum::http::{header, HeaderMap, StatusCode};
@ -13,14 +14,18 @@ use tracing::{info, warn};
use crate::auth::OptionalUser; use crate::auth::OptionalUser;
use crate::consts::NAN_U16; use crate::consts::NAN_U16;
use crate::data::QuantRef; use crate::data::{PostcodePoiMetrics, QuantRef};
use crate::features::INTEGER_BIN_FEATURES; use crate::features;
use crate::licensing::check_license_bounds; use crate::licensing::check_license_bounds;
use crate::parsing::{parse_field_indices, parse_filters, require_bounds, row_passes_filters}; use crate::parsing::{
parse_field_indices_with_poi, parse_filters_with_poi, require_bounds, row_passes_filters,
row_passes_poi_filters,
};
use crate::routes::{fetch_screenshot_bytes, FeatureInfo}; use crate::routes::{fetch_screenshot_bytes, FeatureInfo};
use crate::state::SharedState; use crate::state::SharedState;
const MAX_EXPORT_POSTCODES: usize = 250; const MAX_EXPORT_POSTCODES: usize = 250;
const EXPORT_SCREENSHOT_TIMEOUT_SECS: u64 = 12;
/// Height (in pixels) reserved for the screenshot row /// Height (in pixels) reserved for the screenshot row
const IMAGE_ROW_HEIGHT: f64 = 225.0; const IMAGE_ROW_HEIGHT: f64 = 225.0;
@ -41,11 +46,11 @@ struct PostcodeExportAgg {
} }
impl PostcodeExportAgg { impl PostcodeExportAgg {
fn new(num_features: usize) -> Self { fn new(total_features: usize) -> Self {
Self { Self {
count: 0, count: 0,
sums: vec![0.0; num_features], sums: vec![0.0; total_features],
finite_counts: vec![0; num_features], finite_counts: vec![0; total_features],
enum_freqs: FxHashMap::default(), enum_freqs: FxHashMap::default(),
} }
} }
@ -58,6 +63,7 @@ impl PostcodeExportAgg {
num_features: usize, num_features: usize,
enum_indices: &FxHashMap<usize, ()>, enum_indices: &FxHashMap<usize, ()>,
quant: &QuantRef, quant: &QuantRef,
poi_metrics: &PostcodePoiMetrics,
) { ) {
self.count += 1; self.count += 1;
let base = row * num_features; let base = row * num_features;
@ -79,6 +85,18 @@ impl PostcodeExportAgg {
self.finite_counts[feat_idx] += 1; self.finite_counts[feat_idx] += 1;
} }
} }
let poi_offset = num_features;
for metric_idx in 0..poi_metrics.num_features() {
let raw = poi_metrics.raw_for_property_row(row, metric_idx);
if raw == NAN_U16 {
continue;
}
let value = poi_metrics.decode_raw(metric_idx, raw);
let out_idx = poi_offset + metric_idx;
self.sums[out_idx] += value as f64;
self.finite_counts[out_idx] += 1;
}
} }
} }
@ -138,13 +156,17 @@ pub async fn get_export(
check_license_bounds(&user.0, (south, west, north, east), None)?; check_license_bounds(&user.0, (south, west, north, east), None)?;
let quant = state.data.quant_ref(); let quant = state.data.quant_ref();
let (parsed_filters, parsed_enum_filters) = parse_filters( let poi_quant = state.data.poi_metrics.quant_ref();
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
params.filters.as_deref(), params.filters.as_deref(),
&state.feature_name_to_index, &state.feature_name_to_index,
&state.data.enum_values, &state.data.enum_values,
&quant, &quant,
&state.data.poi_metrics.name_to_index,
&poi_quant,
) )
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let has_poi_filters = !parsed_poi_filters.is_empty();
let filters_str = params.filters; let filters_str = params.filters;
let fields_str = params.fields; let fields_str = params.fields;
@ -164,16 +186,28 @@ pub async fn get_export(
// Fetch screenshot (async, before spawn_blocking) // Fetch screenshot (async, before spawn_blocking)
let auth_header = headers.get(header::AUTHORIZATION); let auth_header = headers.get(header::AUTHORIZATION);
let screenshot_bytes = match fetch_screenshot_bytes(&state, &frontend_params, auth_header).await let screenshot_fetch = fetch_screenshot_bytes(&state, &frontend_params, auth_header);
let screenshot_bytes = match tokio::time::timeout(
Duration::from_secs(EXPORT_SCREENSHOT_TIMEOUT_SECS),
screenshot_fetch,
)
.await
{ {
Ok(bytes) => { Ok(Ok(bytes)) => {
info!(bytes = bytes.len(), "Fetched screenshot for export"); info!(bytes = bytes.len(), "Fetched screenshot for export");
Some(bytes) Some(bytes)
} }
Err(err) => { Ok(Err(err)) => {
warn!("Screenshot failed for export: {err}"); warn!("Screenshot failed for export: {err}");
None None
} }
Err(_) => {
warn!(
timeout_secs = EXPORT_SCREENSHOT_TIMEOUT_SECS,
"Screenshot timed out for export"
);
None
}
}; };
// Build feature name → description map from the precomputed features response // Build feature name → description map from the precomputed features response
@ -200,6 +234,9 @@ pub async fn get_export(
let feature_names = &state.data.feature_names; let feature_names = &state.data.feature_names;
let enum_values = &state.data.enum_values; let enum_values = &state.data.enum_values;
let postcode_data = &state.postcode_data; let postcode_data = &state.postcode_data;
let poi_metrics = &state.data.poi_metrics;
let poi_offset = num_features;
let total_export_features = num_features + poi_metrics.num_features();
// Build set of enum feature indices for quick lookup // Build set of enum feature indices for quick lookup
let enum_indices: FxHashMap<usize, ()> = enum_values.keys().map(|&idx| (idx, ())).collect(); let enum_indices: FxHashMap<usize, ()> = enum_values.keys().map(|&idx| (idx, ())).collect();
@ -219,6 +256,10 @@ pub async fn get_export(
) { ) {
return; return;
} }
if has_poi_filters && !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics)
{
return;
}
let postcode = state.data.postcode(row); let postcode = state.data.postcode(row);
if let Some(&pc_idx) = postcode_data.postcode_to_idx.get(postcode) { if let Some(&pc_idx) = postcode_data.postcode_to_idx.get(postcode) {
postcode_rows.entry(pc_idx).or_default().push(row); postcode_rows.entry(pc_idx).or_default().push(row);
@ -229,9 +270,16 @@ pub async fn get_export(
let mut postcode_aggs: Vec<(usize, PostcodeExportAgg)> = let mut postcode_aggs: Vec<(usize, PostcodeExportAgg)> =
Vec::with_capacity(postcode_rows.len()); Vec::with_capacity(postcode_rows.len());
for (pc_idx, rows) in postcode_rows { for (pc_idx, rows) in postcode_rows {
let mut agg = PostcodeExportAgg::new(num_features); let mut agg = PostcodeExportAgg::new(total_export_features);
for &row in &rows { for &row in &rows {
agg.add_row(feature_data, row, num_features, &enum_indices, &quant); agg.add_row(
feature_data,
row,
num_features,
&enum_indices,
&quant,
poi_metrics,
);
} }
if agg.count > 0 { if agg.count > 0 {
postcode_aggs.push((pc_idx, agg)); postcode_aggs.push((pc_idx, agg));
@ -265,14 +313,19 @@ pub async fn get_export(
// Determine column order: filter features first, then remaining // Determine column order: filter features first, then remaining
let filter_feature_names = extract_filter_feature_names(filters_str.as_deref()); let filter_feature_names = extract_filter_feature_names(filters_str.as_deref());
let field_indices = let field_indices = parse_field_indices_with_poi(
parse_field_indices(fields_str.as_deref(), &state.feature_name_to_index) fields_str.as_deref(),
.map_err(|err| err.1)?; &state.feature_name_to_index,
&state.data.poi_metrics.name_to_index,
)
.map_err(|err| err.1)?;
let all_feature_indices: Vec<usize> = if let Some(ref indices) = field_indices { let all_feature_indices: Vec<usize> = if let Some(ref indices) = field_indices.normal {
indices.clone() let mut selected = indices.clone();
selected.extend(field_indices.poi.iter().map(|idx| poi_offset + *idx));
selected
} else { } else {
let mut ordered = Vec::with_capacity(num_features); let mut ordered = Vec::with_capacity(total_export_features);
let mut used = FxHashSet::default(); let mut used = FxHashSet::default();
for name in &filter_feature_names { for name in &filter_feature_names {
@ -280,6 +333,11 @@ pub async fn get_export(
if used.insert(idx) { if used.insert(idx) {
ordered.push(idx); ordered.push(idx);
} }
} else if let Some(&idx) = state.data.poi_metrics.name_to_index.get(name.as_str()) {
let virtual_idx = poi_offset + idx;
if used.insert(virtual_idx) {
ordered.push(virtual_idx);
}
} }
} }
for idx in 0..num_features { for idx in 0..num_features {
@ -287,15 +345,42 @@ pub async fn get_export(
ordered.push(idx); ordered.push(idx);
} }
} }
for idx in 0..poi_metrics.num_features() {
let virtual_idx = poi_offset + idx;
if used.insert(virtual_idx) {
ordered.push(virtual_idx);
}
}
ordered ordered
}; };
// Filter-only feature indices for the Selected sheet // Filter-only feature indices for the Selected sheet
let filter_feature_indices: Vec<usize> = filter_feature_names let filter_feature_indices: Vec<usize> = filter_feature_names
.iter() .iter()
.filter_map(|name| state.feature_name_to_index.get(name.as_str()).copied()) .filter_map(|name| {
state
.feature_name_to_index
.get(name.as_str())
.copied()
.or_else(|| {
state
.data
.poi_metrics
.name_to_index
.get(name.as_str())
.map(|idx| poi_offset + *idx)
})
})
.collect(); .collect();
let feature_name_for_idx = |idx: usize| -> &str {
if idx < num_features {
&feature_names[idx]
} else {
&poi_metrics.feature_names[idx - poi_offset]
}
};
// Build feature unit map (feat_idx → (prefix, suffix)) for number formatting // Build feature unit map (feat_idx → (prefix, suffix)) for number formatting
let feature_units: FxHashMap<usize, (&str, &str)> = state let feature_units: FxHashMap<usize, (&str, &str)> = state
.features_response .features_response
@ -309,16 +394,25 @@ pub async fn get_export(
suffix, suffix,
.. ..
} => { } => {
let idx = state.feature_name_to_index.get(name.as_str())?; if let Some(&idx) = state.feature_name_to_index.get(name.as_str()) {
Some((*idx, (*prefix, *suffix))) Some((idx, (*prefix, *suffix)))
} else {
state
.data
.poi_metrics
.name_to_index
.get(name.as_str())
.map(|idx| (poi_offset + *idx, (*prefix, *suffix)))
}
} }
_ => None, _ => None,
}) })
.collect(); .collect();
let integer_feature_indices: FxHashSet<usize> = INTEGER_BIN_FEATURES let integer_feature_indices: FxHashSet<usize> = all_feature_indices
.iter() .iter()
.filter_map(|name| state.feature_name_to_index.get(*name).copied()) .copied()
.filter(|&idx| features::has_integer_bins(feature_name_for_idx(idx)))
.collect(); .collect();
// Build Excel number formats per feature index for unit display // Build Excel number formats per feature index for unit display
@ -435,7 +529,7 @@ pub async fn get_export(
.write_string_with_format( .write_string_with_format(
header_row, header_row,
col, col,
&feature_names[feat_idx], feature_name_for_idx(feat_idx),
&header_fmt, &header_fmt,
) )
.map_err(|e| format!("Failed to write header: {e}"))?; .map_err(|e| format!("Failed to write header: {e}"))?;
@ -453,7 +547,7 @@ pub async fn get_export(
for (col_offset, &feat_idx) in feat_indices.iter().enumerate() { for (col_offset, &feat_idx) in feat_indices.iter().enumerate() {
let col = (col_offset + 2) as u16; let col = (col_offset + 2) as u16;
let desc = feature_descriptions let desc = feature_descriptions
.get(&feature_names[feat_idx]) .get(feature_name_for_idx(feat_idx))
.map(String::as_str) .map(String::as_str)
.unwrap_or(""); .unwrap_or("");
sheet sheet
@ -477,7 +571,7 @@ pub async fn get_export(
for (col_offset, &feat_idx) in feat_indices.iter().enumerate() { for (col_offset, &feat_idx) in feat_indices.iter().enumerate() {
let col = (col_offset + 2) as u16; let col = (col_offset + 2) as u16;
if enum_indices.contains_key(&feat_idx) { if feat_idx < num_features && enum_indices.contains_key(&feat_idx) {
if let Some(freqs) = agg.enum_freqs.get(&feat_idx) { if let Some(freqs) = agg.enum_freqs.get(&feat_idx) {
if let Some((&mode_bits, _)) = if let Some((&mode_bits, _)) =
freqs.iter().max_by_key(|(_, &count)| count) freqs.iter().max_by_key(|(_, &count)| count)
@ -543,7 +637,7 @@ pub async fn get_export(
.map_err(|e| format!("Failed to set column width: {e}"))?; .map_err(|e| format!("Failed to set column width: {e}"))?;
for col_offset in 0..feat_indices.len() { for col_offset in 0..feat_indices.len() {
let col = (col_offset + 2) as u16; let col = (col_offset + 2) as u16;
let feat_name = &feature_names[feat_indices[col_offset]]; let feat_name = feature_name_for_idx(feat_indices[col_offset]);
let width = (feat_name.len() as f64 * 1.1).clamp(10.0, 30.0); let width = (feat_name.len() as f64 * 1.1).clamp(10.0, 30.0);
sheet sheet
.set_column_width(col, width) .set_column_width(col, width)

View file

@ -7,7 +7,7 @@ use serde::Serialize;
use tracing::info; use tracing::info;
use crate::data::{Histogram, PropertyData}; use crate::data::{Histogram, PropertyData};
use crate::features::{Feature, FEATURE_GROUPS}; use crate::features::{self, Feature, FEATURE_GROUPS};
use crate::state::SharedState; use crate::state::SharedState;
fn is_empty(val: &str) -> bool { fn is_empty(val: &str) -> bool {
@ -28,9 +28,9 @@ pub enum FeatureInfo {
max: f32, max: f32,
step: f32, step: f32,
histogram: Histogram, histogram: Histogram,
description: &'static str, description: String,
detail: &'static str, detail: String,
source: &'static str, source: String,
#[serde(skip_serializing_if = "is_empty")] #[serde(skip_serializing_if = "is_empty")]
prefix: &'static str, prefix: &'static str,
#[serde(skip_serializing_if = "is_empty")] #[serde(skip_serializing_if = "is_empty")]
@ -45,9 +45,9 @@ pub enum FeatureInfo {
name: String, name: String,
values: Vec<String>, values: Vec<String>,
counts: HashMap<String, u64>, counts: HashMap<String, u64>,
description: &'static str, description: String,
detail: &'static str, detail: String,
source: &'static str, source: String,
}, },
} }
@ -85,9 +85,9 @@ pub fn build_features_response(data: &PropertyData) -> FeaturesResponse {
max: stats.slider_max, max: stats.slider_max,
step: config.step, step: config.step,
histogram: stats.histogram.clone(), histogram: stats.histogram.clone(),
description: config.description, description: config.description.to_string(),
detail: config.detail, detail: config.detail.to_string(),
source: config.source, source: config.source.to_string(),
prefix: config.prefix, prefix: config.prefix,
suffix: config.suffix, suffix: config.suffix,
raw: config.raw, raw: config.raw,
@ -118,9 +118,9 @@ pub fn build_features_response(data: &PropertyData) -> FeaturesResponse {
name: config.name.to_string(), name: config.name.to_string(),
values: values.clone(), values: values.clone(),
counts, counts,
description: config.description, description: config.description.to_string(),
detail: config.detail, detail: config.detail.to_string(),
source: config.source, source: config.source.to_string(),
}); });
} }
} }
@ -136,6 +136,58 @@ pub fn build_features_response(data: &PropertyData) -> FeaturesResponse {
} }
} }
let mut dynamic_poi_features = Vec::new();
for (feat_idx, name) in data.poi_metrics.feature_names.iter().enumerate() {
if let Some(category) = features::dynamic_poi_distance_category(name) {
let stats = &data.poi_metrics.feature_stats[feat_idx];
dynamic_poi_features.push(FeatureInfo::Numeric {
name: name.clone(),
min: stats.slider_min,
max: stats.slider_max,
step: 0.1,
histogram: stats.histogram.clone(),
description: format!("Distance to the closest {category} POI"),
detail: format!(
"Straight-line distance in kilometres from the postcode to the nearest {category} point of interest in the POI dataset."
),
source: "osm-pois".to_string(),
prefix: "",
suffix: " km",
raw: false,
absolute: false,
});
} else if let Some(category) = features::dynamic_poi_count_category(name) {
let stats = &data.poi_metrics.feature_stats[feat_idx];
let radius = features::dynamic_poi_count_radius(name).unwrap_or(0);
dynamic_poi_features.push(FeatureInfo::Numeric {
name: name.clone(),
min: stats.slider_min,
max: stats.slider_max,
step: 1.0,
histogram: stats.histogram.clone(),
description: format!("Number of {category} POIs within {radius}km"),
detail: format!(
"Count of {category} points of interest within a {radius}km radius of the property's postcode centroid."
),
source: "osm-pois".to_string(),
prefix: "",
suffix: "",
raw: false,
absolute: false,
});
}
}
if !dynamic_poi_features.is_empty() {
dynamic_poi_features.sort_by_key(|feature| match feature {
FeatureInfo::Numeric { name, .. } => features::dynamic_poi_feature_sort_key(name),
FeatureInfo::Enum { name, .. } => features::dynamic_poi_feature_sort_key(name),
});
groups.push(FeatureGroupResponse {
name: "Nearby POIs".to_string(),
features: dynamic_poi_features,
});
}
FeaturesResponse { groups } FeaturesResponse { groups }
} }

View file

@ -9,7 +9,7 @@ use tracing::info;
use crate::consts::NAN_U16; use crate::consts::NAN_U16;
use crate::data::travel_time::TravelData; use crate::data::travel_time::TravelData;
use crate::parsing::{parse_filters, require_bounds}; use crate::parsing::{parse_filters_with_poi, require_bounds};
use crate::routes::travel_time::parse_optional_travel; use crate::routes::travel_time::parse_optional_travel;
use crate::state::SharedState; use crate::state::SharedState;
@ -36,18 +36,21 @@ pub async fn get_filter_counts(
require_bounds(params.bounds).map_err(IntoResponse::into_response)?; require_bounds(params.bounds).map_err(IntoResponse::into_response)?;
let quant = state.data.quant_ref(); let quant = state.data.quant_ref();
let (parsed_filters, parsed_enum_filters) = parse_filters( let poi_quant = state.data.poi_metrics.quant_ref();
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
params.filters.as_deref(), params.filters.as_deref(),
&state.feature_name_to_index, &state.feature_name_to_index,
&state.data.enum_values, &state.data.enum_values,
&quant, &quant,
&state.data.poi_metrics.name_to_index,
&poi_quant,
) )
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let travel_entries = parse_optional_travel(params.travel.as_deref()) let travel_entries = parse_optional_travel(params.travel.as_deref())
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let num_regular = parsed_filters.len() + parsed_enum_filters.len(); let num_regular = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
// Only travel entries with a filter range count as filters for impact tracking // Only travel entries with a filter range count as filters for impact tracking
let travel_filter_indices: Vec<usize> = travel_entries let travel_filter_indices: Vec<usize> = travel_entries
.iter() .iter()
@ -65,6 +68,7 @@ pub async fn get_filter_counts(
} }
let filters_str = params.filters; let filters_str = params.filters;
let has_poi_filters = !parsed_poi_filters.is_empty();
let response = tokio::task::spawn_blocking(move || -> Result<FilterCountsResponse, String> { let response = tokio::task::spawn_blocking(move || -> Result<FilterCountsResponse, String> {
let t0 = std::time::Instant::now(); let t0 = std::time::Instant::now();
@ -124,6 +128,23 @@ pub async fn get_filter_counts(
} }
} }
// Test travel time filters
if fail_count <= 1 && has_poi_filters {
for (i, f) in parsed_poi_filters.iter().enumerate() {
let raw = state
.data
.poi_metrics
.raw_for_property_row(row, f.metric_idx);
if raw == NAN_U16 || raw < f.min_u16 || raw > f.max_u16 {
fail_count += 1;
fail_index = parsed_filters.len() + parsed_enum_filters.len() + i;
if fail_count > 1 {
break;
}
}
}
}
// Test travel time filters // Test travel time filters
if fail_count <= 1 && has_travel { if fail_count <= 1 && has_travel {
let postcode = pc_interner.resolve(&pc_keys[row]); let postcode = pc_interner.resolve(&pc_keys[row]);
@ -169,8 +190,15 @@ pub async fn get_filter_counts(
let name = if i < parsed_filters.len() { let name = if i < parsed_filters.len() {
state.data.feature_names[parsed_filters[i].feat_idx].clone() state.data.feature_names[parsed_filters[i].feat_idx].clone()
} else if i < num_regular { } else if i < num_regular {
let ei = i - parsed_filters.len(); let enum_start = parsed_filters.len();
state.data.feature_names[parsed_enum_filters[ei].feat_idx].clone() let poi_start = enum_start + parsed_enum_filters.len();
if i < poi_start {
let ei = i - enum_start;
state.data.feature_names[parsed_enum_filters[ei].feat_idx].clone()
} else {
let pi = i - poi_start;
state.data.poi_metrics.feature_names[parsed_poi_filters[pi].metric_idx].clone()
}
} else { } else {
let slot = i - num_regular; let slot = i - num_regular;
let ti = travel_filter_indices[slot]; let ti = travel_filter_indices[slot];

View file

@ -13,8 +13,8 @@ use tracing::{info, warn};
use crate::auth::OptionalUser; use crate::auth::OptionalUser;
use crate::licensing::{check_license_bounds, resolve_share_code}; use crate::licensing::{check_license_bounds, resolve_share_code};
use crate::parsing::{ use crate::parsing::{
cell_for_row_cached, h3_cell_bounds, needs_parent, parse_field_set, parse_filters, cell_for_row_cached, h3_cell_bounds, needs_parent, parse_field_set, parse_filters_with_poi,
row_passes_filters, validate_h3_resolution, row_passes_filters, row_passes_poi_filters, validate_h3_resolution,
}; };
use crate::state::SharedState; use crate::state::SharedState;
@ -110,15 +110,19 @@ pub async fn get_hexagon_stats(
let h3_str = params.h3; let h3_str = params.h3;
let quant = state.data.quant_ref(); let quant = state.data.quant_ref();
let (parsed_filters, parsed_enum_filters) = parse_filters( let poi_quant = state.data.poi_metrics.quant_ref();
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
params.filters.as_deref(), params.filters.as_deref(),
&state.feature_name_to_index, &state.feature_name_to_index,
&state.data.enum_values, &state.data.enum_values,
&quant, &quant,
&state.data.poi_metrics.name_to_index,
&poi_quant,
) )
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let num_filters = parsed_filters.len() + parsed_enum_filters.len(); let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
let filters_str = params.filters; let filters_str = params.filters;
let has_poi_filters = !parsed_poi_filters.is_empty();
let (fields_specified, field_set) = parse_field_set(params.fields.as_deref()); let (fields_specified, field_set) = parse_field_set(params.fields.as_deref());
@ -161,6 +165,12 @@ pub async fn get_hexagon_stats(
feature_data, feature_data,
num_features, num_features,
) )
&& (!has_poi_filters
|| row_passes_poi_filters(
row,
&parsed_poi_filters,
&state.data.poi_metrics,
))
{ {
if has_travel { if has_travel {
let postcode = state.data.postcode(row); let postcode = state.data.postcode(row);
@ -233,7 +243,7 @@ pub async fn get_hexagon_stats(
let price_history = let price_history =
stats::extract_price_history(&matching_rows, &state.data, &state.feature_name_to_index); stats::extract_price_history(&matching_rows, &state.data, &state.feature_name_to_index);
let (numeric_features, enum_features_out) = stats::compute_feature_stats( let (mut numeric_features, enum_features_out) = stats::compute_feature_stats(
&matching_rows, &matching_rows,
&state.data, &state.data,
&state.data.feature_names, &state.data.feature_names,
@ -242,6 +252,12 @@ pub async fn get_hexagon_stats(
fields_specified, fields_specified,
&field_set, &field_set,
); );
numeric_features.extend(stats::compute_poi_feature_stats(
&matching_rows,
&state.data.poi_metrics,
fields_specified,
&field_set,
));
let elapsed = start_time.elapsed(); let elapsed = start_time.elapsed();
info!( info!(

View file

@ -11,14 +11,15 @@ use serde::{Deserialize, Serialize};
use serde_json::{Map, Value}; use serde_json::{Map, Value};
use tracing::info; use tracing::info;
use crate::aggregation::{Aggregator, EnumDistConfig}; use crate::aggregation::{Aggregator, EnumDistConfig, PoiAggregator};
use crate::auth::OptionalUser; use crate::auth::OptionalUser;
use crate::consts::MAX_CELLS_PER_REQUEST; use crate::consts::MAX_CELLS_PER_REQUEST;
use crate::data::travel_time::TravelData; use crate::data::travel_time::TravelData;
use crate::licensing::{check_license_bounds, resolve_share_code}; use crate::licensing::{check_license_bounds, resolve_share_code};
use crate::parsing::{ use crate::parsing::{
cell_for_row_cached, needs_parent, parse_enum_dist, parse_field_indices, parse_filters, cell_for_row_cached, needs_parent, parse_enum_dist, parse_field_indices_with_poi,
require_bounds, row_passes_filters, validate_h3_resolution, parse_filters_with_poi, require_bounds, row_passes_filters, row_passes_poi_filters,
validate_h3_resolution,
}; };
use crate::routes::travel_time::{parse_optional_travel, TravelTimeAgg}; use crate::routes::travel_time::{parse_optional_travel, TravelTimeAgg};
use crate::state::SharedState; use crate::state::SharedState;
@ -29,6 +30,7 @@ const PARALLEL_THRESHOLD: usize = 50_000;
/// Per-thread aggregation result: feature accumulators + travel time accumulators. /// Per-thread aggregation result: feature accumulators + travel time accumulators.
type ChunkResult = ( type ChunkResult = (
FxHashMap<u64, Aggregator>, FxHashMap<u64, Aggregator>,
FxHashMap<u64, PoiAggregator>,
Vec<FxHashMap<u64, TravelTimeAgg>>, Vec<FxHashMap<u64, TravelTimeAgg>>,
); );
@ -79,11 +81,14 @@ pub struct HexagonParams {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn build_feature_maps( fn build_feature_maps(
groups: &FxHashMap<u64, Aggregator>, groups: &FxHashMap<u64, Aggregator>,
poi_groups: &FxHashMap<u64, PoiAggregator>,
min_keys: &[String], min_keys: &[String],
max_keys: &[String], max_keys: &[String],
avg_keys: &[String], avg_keys: &[String],
num_features: usize, num_features: usize,
indices: Option<&[usize]>, indices: Option<&[usize]>,
poi_feature_names: &[String],
poi_indices: &[usize],
query_bounds: (f64, f64, f64, f64), query_bounds: (f64, f64, f64, f64),
resolution: h3o::Resolution, resolution: h3o::Resolution,
travel_aggs: &[FxHashMap<u64, TravelTimeAgg>], travel_aggs: &[FxHashMap<u64, TravelTimeAgg>],
@ -163,6 +168,25 @@ fn build_feature_maps(
} }
} }
if let Some(poi_aggregation) = poi_groups.get(&cell_id) {
for &metric_idx in poi_indices {
if poi_aggregation.counts[metric_idx] > 0 {
let avg = poi_aggregation.sums[metric_idx]
/ poi_aggregation.counts[metric_idx] as f64;
if let (Some(min_num), Some(max_num), Some(avg_num)) = (
serde_json::Number::from_f64(poi_aggregation.mins[metric_idx] as f64),
serde_json::Number::from_f64(poi_aggregation.maxs[metric_idx] as f64),
serde_json::Number::from_f64(avg),
) {
let name = &poi_feature_names[metric_idx];
map.insert(format!("min_{name}"), Value::Number(min_num));
map.insert(format!("max_{name}"), Value::Number(max_num));
map.insert(format!("avg_{name}"), Value::Number(avg_num));
}
}
}
}
// Add travel time aggregation fields (using pre-computed key strings) // Add travel time aggregation fields (using pre-computed key strings)
for (ti, agg_map) in travel_aggs.iter().enumerate() { for (ti, agg_map) in travel_aggs.iter().enumerate() {
if let Some(agg) = agg_map.get(&cell_id) { if let Some(agg) = agg_map.get(&cell_id) {
@ -209,18 +233,25 @@ pub async fn get_hexagons(
check_license_bounds(&user.0, (south, west, north, east), share_bounds)?; check_license_bounds(&user.0, (south, west, north, east), share_bounds)?;
let quant = state.data.quant_ref(); let quant = state.data.quant_ref();
let (parsed_filters, parsed_enum_filters) = parse_filters( let poi_quant = state.data.poi_metrics.quant_ref();
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
params.filters.as_deref(), params.filters.as_deref(),
&state.feature_name_to_index, &state.feature_name_to_index,
&state.data.enum_values, &state.data.enum_values,
&quant, &quant,
&state.data.poi_metrics.name_to_index,
&poi_quant,
) )
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let num_filters = parsed_filters.len() + parsed_enum_filters.len(); let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
let filters_str = params.filters; let filters_str = params.filters;
let field_indices = parse_field_indices(params.fields.as_deref(), &state.feature_name_to_index) let field_indices = parse_field_indices_with_poi(
.map_err(|err| (err.0, err.1).into_response())?; params.fields.as_deref(),
&state.feature_name_to_index,
&state.data.poi_metrics.name_to_index,
)
.map_err(|err| (err.0, err.1).into_response())?;
let travel_entries = parse_optional_travel(params.travel.as_deref()) let travel_entries = parse_optional_travel(params.travel.as_deref())
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
@ -269,6 +300,11 @@ pub async fn get_hexagons(
let min_keys = &state.min_keys; let min_keys = &state.min_keys;
let max_keys = &state.max_keys; let max_keys = &state.max_keys;
let avg_keys = &state.avg_keys; let avg_keys = &state.avg_keys;
let poi_metrics = &state.data.poi_metrics;
let poi_field_indices = field_indices.poi.as_slice();
let has_poi_fields = !poi_field_indices.is_empty();
let has_poi_filters = !parsed_poi_filters.is_empty();
let poi_num_features = poi_metrics.num_features();
let h3_res = h3o::Resolution::try_from(resolution) let h3_res = h3o::Resolution::try_from(resolution)
.map_err(|error| format!("Invalid H3 resolution {}: {}", resolution, error))?; .map_err(|error| format!("Invalid H3 resolution {}: {}", resolution, error))?;
@ -276,6 +312,7 @@ pub async fn get_hexagons(
let need_parent = needs_parent(resolution); let need_parent = needs_parent(resolution);
let mut groups: FxHashMap<u64, Aggregator> = FxHashMap::default(); let mut groups: FxHashMap<u64, Aggregator> = FxHashMap::default();
let mut poi_groups: FxHashMap<u64, PoiAggregator> = FxHashMap::default();
let mut travel_aggs: Vec<FxHashMap<u64, TravelTimeAgg>> = (0..travel_entries.len()) let mut travel_aggs: Vec<FxHashMap<u64, TravelTimeAgg>> = (0..travel_entries.len())
.map(|_| FxHashMap::default()) .map(|_| FxHashMap::default())
.collect(); .collect();
@ -296,6 +333,7 @@ pub async fn get_hexagons(
.par_chunks(chunk_size) .par_chunks(chunk_size)
.map(|chunk| { .map(|chunk| {
let mut local_groups: FxHashMap<u64, Aggregator> = FxHashMap::default(); let mut local_groups: FxHashMap<u64, Aggregator> = FxHashMap::default();
let mut local_poi_groups: FxHashMap<u64, PoiAggregator> = FxHashMap::default();
let mut local_travel_aggs: Vec<FxHashMap<u64, TravelTimeAgg>> = (0 let mut local_travel_aggs: Vec<FxHashMap<u64, TravelTimeAgg>> = (0
..travel_entries.len()) ..travel_entries.len())
.map(|_| FxHashMap::default()) .map(|_| FxHashMap::default())
@ -315,6 +353,11 @@ pub async fn get_hexagons(
) { ) {
continue; continue;
} }
if has_poi_filters
&& !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics)
{
continue;
}
if has_travel { if has_travel {
travel_minutes.clear(); travel_minutes.clear();
@ -352,7 +395,7 @@ pub async fn get_hexagons(
let agg = local_groups let agg = local_groups
.entry(cell_id) .entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features, enum_dist_config)); .or_insert_with(|| Aggregator::new(num_features, enum_dist_config));
if let Some(sel_indices) = field_indices.as_deref() { if let Some(sel_indices) = field_indices.normal.as_deref() {
agg.add_row_selective( agg.add_row_selective(
feature_data, feature_data,
row, row,
@ -364,6 +407,13 @@ pub async fn get_hexagons(
agg.add_row(feature_data, row, num_features, &quant); agg.add_row(feature_data, row, num_features, &quant);
} }
if has_poi_fields {
local_poi_groups
.entry(cell_id)
.or_insert_with(|| PoiAggregator::new(poi_num_features))
.add_row_selective(poi_metrics, row, poi_field_indices);
}
for (ti, minutes) in travel_minutes.iter().enumerate() { for (ti, minutes) in travel_minutes.iter().enumerate() {
if let Some(mins) = minutes { if let Some(mins) = minutes {
let tagg = local_travel_aggs[ti] let tagg = local_travel_aggs[ti]
@ -374,18 +424,24 @@ pub async fn get_hexagons(
} }
} }
(local_groups, local_travel_aggs) (local_groups, local_poi_groups, local_travel_aggs)
}) })
.collect(); .collect();
// Merge thread-local results into the main accumulators // Merge thread-local results into the main accumulators
for (local_groups, local_travel) in thread_results { for (local_groups, local_poi_groups, local_travel) in thread_results {
for (cell_id, local_agg) in local_groups { for (cell_id, local_agg) in local_groups {
groups groups
.entry(cell_id) .entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features, enum_dist_config)) .or_insert_with(|| Aggregator::new(num_features, enum_dist_config))
.merge(&local_agg); .merge(&local_agg);
} }
for (cell_id, local_agg) in local_poi_groups {
poi_groups
.entry(cell_id)
.or_insert_with(|| PoiAggregator::new(poi_num_features))
.merge(&local_agg);
}
for (ti, local_ta) in local_travel.into_iter().enumerate() { for (ti, local_ta) in local_travel.into_iter().enumerate() {
for (cell_id, local_tt) in local_ta { for (cell_id, local_tt) in local_ta {
travel_aggs[ti] travel_aggs[ti]
@ -414,6 +470,11 @@ pub async fn get_hexagons(
) { ) {
return; return;
} }
if has_poi_filters
&& !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics)
{
return;
}
if has_travel { if has_travel {
travel_minutes.clear(); travel_minutes.clear();
@ -444,7 +505,7 @@ pub async fn get_hexagons(
let aggregation = groups let aggregation = groups
.entry(cell_id) .entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features, enum_dist_config)); .or_insert_with(|| Aggregator::new(num_features, enum_dist_config));
if let Some(sel_indices) = field_indices.as_deref() { if let Some(sel_indices) = field_indices.normal.as_deref() {
aggregation.add_row_selective( aggregation.add_row_selective(
feature_data, feature_data,
row, row,
@ -456,6 +517,13 @@ pub async fn get_hexagons(
aggregation.add_row(feature_data, row, num_features, &quant); aggregation.add_row(feature_data, row, num_features, &quant);
} }
if has_poi_fields {
poi_groups
.entry(cell_id)
.or_insert_with(|| PoiAggregator::new(poi_num_features))
.add_row_selective(poi_metrics, row, poi_field_indices);
}
for (ti, minutes) in travel_minutes.iter().enumerate() { for (ti, minutes) in travel_minutes.iter().enumerate() {
if let Some(mins) = minutes { if let Some(mins) = minutes {
let agg = travel_aggs[ti] let agg = travel_aggs[ti]
@ -471,11 +539,14 @@ pub async fn get_hexagons(
let mut features = build_feature_maps( let mut features = build_feature_maps(
&groups, &groups,
&poi_groups,
min_keys, min_keys,
max_keys, max_keys,
avg_keys, avg_keys,
num_features, num_features,
field_indices.as_deref(), field_indices.normal.as_deref(),
&poi_metrics.feature_names,
poi_field_indices,
(south, west, north, east), (south, west, north, east),
h3_res, h3_res,
&travel_aggs, &travel_aggs,
@ -499,7 +570,11 @@ pub async fn get_hexagons(
bounds = format_args!("{:.4},{:.4},{:.4},{:.4}", south, west, north, east), bounds = format_args!("{:.4},{:.4},{:.4},{:.4}", south, west, north, east),
filters = num_filters, filters = num_filters,
filters_raw = filters_str.as_deref().unwrap_or("-"), filters_raw = filters_str.as_deref().unwrap_or("-"),
fields = field_indices.as_ref().map(|v| v.len() as i32).unwrap_or(-1), fields = field_indices
.normal
.as_ref()
.map(|v| (v.len() + poi_field_indices.len()) as i32)
.unwrap_or(-1),
travel_entries = travel_entries.len(), travel_entries = travel_entries.len(),
grid_ms = format_args!("{:.1}", t_grid.as_secs_f64() * 1000.0), grid_ms = format_args!("{:.1}", t_grid.as_secs_f64() * 1000.0),
agg_ms = format_args!("{:.1}", (t_agg - t_grid).as_secs_f64() * 1000.0), agg_ms = format_args!("{:.1}", (t_agg - t_grid).as_secs_f64() * 1000.0),

View file

@ -9,11 +9,16 @@ use serde::{Deserialize, Serialize};
use tracing::{info, warn}; use tracing::{info, warn};
use crate::auth::{OptionalUser, PocketBaseUser}; use crate::auth::{OptionalUser, PocketBaseUser};
use crate::checkout_sessions::{
active_referral_checkout_user, start_license_checkout, CheckoutStart,
};
use crate::pocketbase::get_superuser_token; use crate::pocketbase::get_superuser_token;
use crate::pocketbase_locks::acquire_pocketbase_lock;
use crate::state::{AppState, SharedState}; use crate::state::{AppState, SharedState};
static INVITE_REDEMPTIONS_IN_PROGRESS: LazyLock<Mutex<HashSet<String>>> = static INVITE_REDEMPTIONS_IN_PROGRESS: LazyLock<Mutex<HashSet<String>>> =
LazyLock::new(|| Mutex::new(HashSet::new())); LazyLock::new(|| Mutex::new(HashSet::new()));
const INVITE_REDEMPTION_LOCK_TTL_SECS: u64 = 5 * 60;
struct InviteRedemptionGuard { struct InviteRedemptionGuard {
code: String, code: String,
@ -103,7 +108,7 @@ fn validate_invite_code(code: &str) -> Result<(), &'static str> {
} }
fn generate_invite_code() -> String { fn generate_invite_code() -> String {
use rand::Rng; use rand::RngExt;
let mut rng = rand::rng(); let mut rng = rand::rng();
let chars: Vec<char> = (0..12) let chars: Vec<char> = (0..12)
.map(|_| { .map(|_| {
@ -246,74 +251,26 @@ async fn grant_license_for_invite(
async fn create_referral_checkout( async fn create_referral_checkout(
state: &AppState, state: &AppState,
user: &PocketBaseUser, user: &PocketBaseUser,
invite_id: &str,
) -> Result<String, Response> { ) -> Result<String, Response> {
let count = match super::pricing::count_licensed_users(state).await {
Ok(count) => count,
Err(err) => {
warn!("Failed to count licensed users for invite checkout: {err}");
return Err(StatusCode::SERVICE_UNAVAILABLE.into_response());
}
};
let price_pence = super::pricing::price_for_count(count);
let public_url = &state.public_url; let public_url = &state.public_url;
let success_url = format!("{public_url}/pricing?license_success=1"); let success_url = format!("{public_url}/pricing?license_success=1");
let cancel_url = format!("{public_url}/pricing"); let cancel_url = format!("{public_url}/pricing");
let form_params = vec![ match start_license_checkout(
("mode", "payment".to_string()), state,
( user,
"line_items[0][price_data][unit_amount]", &success_url,
price_pence.to_string(), &cancel_url,
), Some(&state.stripe_referral_coupon_id),
("line_items[0][price_data][currency]", "gbp".to_string()), Some(invite_id),
( )
"line_items[0][price_data][product_data][name]", .await
"Perfect Postcodes Lifetime License".to_string(), {
), Ok(CheckoutStart::Free) => Ok(success_url),
("line_items[0][quantity]", "1".to_string()), Ok(CheckoutStart::Stripe { url }) => Ok(url),
("success_url", success_url),
("cancel_url", cancel_url),
("client_reference_id", user.id.clone()),
("customer_email", user.email.clone()),
(
"discounts[0][coupon]",
state.stripe_referral_coupon_id.clone(),
),
];
let stripe_res = state
.http_client
.post("https://api.stripe.com/v1/checkout/sessions")
.basic_auth(&state.stripe_secret_key, None::<&str>)
.form(&form_params)
.send()
.await;
match stripe_res {
Ok(resp) if resp.status().is_success() => {
let stripe_body: serde_json::Value = match resp.json().await {
Ok(value) => value,
Err(err) => {
warn!("Failed to parse Stripe checkout response: {err}");
return Err(StatusCode::BAD_GATEWAY.into_response());
}
};
let checkout_url = stripe_body["url"].as_str().unwrap_or_default().to_string();
if checkout_url.is_empty() {
warn!("Stripe checkout response did not include a URL");
return Err(StatusCode::BAD_GATEWAY.into_response());
}
Ok(checkout_url)
}
Ok(resp) => {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
warn!("Failed to create Stripe checkout for referral invite ({status}): {text}");
Err(StatusCode::BAD_GATEWAY.into_response())
}
Err(err) => { Err(err) => {
warn!("Stripe request error for referral invite: {err}"); warn!("Failed to create reserved Stripe checkout for referral invite: {err:?}");
Err(StatusCode::BAD_GATEWAY.into_response()) Err(StatusCode::BAD_GATEWAY.into_response())
} }
} }
@ -541,6 +498,10 @@ pub async fn post_redeem_invite(
.into_response(); .into_response();
} }
if user.is_admin || user.subscription == "licensed" {
return (StatusCode::CONFLICT, "Account already has full access").into_response();
}
let pb_url = state.pocketbase_url.trim_end_matches('/'); let pb_url = state.pocketbase_url.trim_end_matches('/');
let token = match get_superuser_token(&state).await { let token = match get_superuser_token(&state).await {
@ -561,6 +522,19 @@ pub async fn post_redeem_invite(
.into_response() .into_response()
} }
}; };
let lock_name = format!("invite:{}", req.code);
let _distributed_redemption_guard =
match acquire_pocketbase_lock(&state, &lock_name, INVITE_REDEMPTION_LOCK_TTL_SECS).await {
Ok(guard) => guard,
Err(err) => {
warn!(code = %req.code, "Failed to acquire invite redemption lock: {err}");
return (
StatusCode::CONFLICT,
"Invite redemption is already in progress",
)
.into_response();
}
};
let invite = match lookup_unused_invite(&state, pb_url, &token, &req.code).await { let invite = match lookup_unused_invite(&state, pb_url, &token, &req.code).await {
Ok(Some(invite)) => invite, Ok(Some(invite)) => invite,
@ -591,11 +565,11 @@ pub async fn post_redeem_invite(
}; };
if invite_type == "admin" { if invite_type == "admin" {
if let Err(response) = grant_license_for_invite(&state, pb_url, &token, &user.id).await { if let Err(response) = mark_invite_used(&state, pb_url, &token, invite_id, &user.id).await {
return response; return response;
} }
if let Err(response) = mark_invite_used(&state, pb_url, &token, invite_id, &user.id).await { if let Err(response) = grant_license_for_invite(&state, pb_url, &token, &user.id).await {
return response; return response;
} }
@ -607,15 +581,26 @@ pub async fn post_redeem_invite(
.into_response(); .into_response();
} }
let checkout_url = match create_referral_checkout(&state, &user).await { match active_referral_checkout_user(&state, invite_id).await {
Ok(Some(active_user_id)) if active_user_id != user.id => {
return (
StatusCode::CONFLICT,
"Invite checkout is already in progress",
)
.into_response()
}
Ok(_) => {}
Err(err) => {
warn!(code = %req.code, "Failed to check active referral checkout: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
}
let checkout_url = match create_referral_checkout(&state, &user, invite_id).await {
Ok(url) => url, Ok(url) => url,
Err(response) => return response, Err(response) => return response,
}; };
if let Err(response) = mark_invite_used(&state, pb_url, &token, invite_id, &user.id).await {
return response;
}
info!(user_id = %user.id, code = %req.code, "Referral invite redeemed; checkout created"); info!(user_id = %user.id, code = %req.code, "Referral invite redeemed; checkout created");
Json(RedeemResponse { Json(RedeemResponse {
result: "checkout".to_string(), result: "checkout".to_string(),

View file

@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
use tracing::info; use tracing::info;
use crate::consts::MAX_POIS_PER_REQUEST; use crate::consts::MAX_POIS_PER_REQUEST;
use crate::data::POICategoryGroup; use crate::data::{resolve_poi_category_filter, POICategoryGroup};
use crate::parsing::require_bounds; use crate::parsing::require_bounds;
use crate::state::SharedState; use crate::state::SharedState;
@ -47,20 +47,7 @@ pub async fn get_pois(
.categories .categories
.as_deref() .as_deref()
.filter(|text| !text.is_empty()) .filter(|text| !text.is_empty())
.map(|text| { .map(|text| resolve_poi_category_filter(&state.poi_data.category.values, text));
text.split(',')
.filter_map(|part| {
let name = part.trim();
state
.poi_data
.category
.values
.iter()
.position(|v| v == name)
.map(|pos| pos as u16)
})
.collect()
});
let categories_raw = params.categories; let categories_raw = params.categories;
let num_categories = category_filter.as_ref().map(|cats| cats.len()).unwrap_or(0); let num_categories = category_filter.as_ref().map(|cats| cats.len()).unwrap_or(0);

View file

@ -10,7 +10,7 @@ use tracing::{info, warn};
use crate::auth::OptionalUser; use crate::auth::OptionalUser;
use crate::consts::{DEFAULT_PROPERTIES_LIMIT, MAX_PROPERTIES_LIMIT, POSTCODE_SEARCH_OFFSET}; use crate::consts::{DEFAULT_PROPERTIES_LIMIT, MAX_PROPERTIES_LIMIT, POSTCODE_SEARCH_OFFSET};
use crate::licensing::{check_license_point, resolve_share_code}; use crate::licensing::{check_license_point, resolve_share_code};
use crate::parsing::{parse_filters, row_passes_filters}; use crate::parsing::{parse_filters_with_poi, row_passes_filters, row_passes_poi_filters};
use crate::state::SharedState; use crate::state::SharedState;
use crate::utils::normalize_postcode; use crate::utils::normalize_postcode;
@ -62,15 +62,19 @@ pub async fn get_postcode_properties(
)?; )?;
let quant = state.data.quant_ref(); let quant = state.data.quant_ref();
let (parsed_filters, parsed_enum_filters) = parse_filters( let poi_quant = state.data.poi_metrics.quant_ref();
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
params.filters.as_deref(), params.filters.as_deref(),
&state.feature_name_to_index, &state.feature_name_to_index,
&state.data.enum_values, &state.data.enum_values,
&quant, &quant,
&state.data.poi_metrics.name_to_index,
&poi_quant,
) )
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let num_filters = parsed_filters.len() + parsed_enum_filters.len(); let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
let filters_str = params.filters; let filters_str = params.filters;
let has_poi_filters = !parsed_poi_filters.is_empty();
let travel_entries = parse_optional_travel(params.travel.as_deref()) let travel_entries = parse_optional_travel(params.travel.as_deref())
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
@ -111,6 +115,12 @@ pub async fn get_postcode_properties(
feature_data, feature_data,
num_features, num_features,
) )
&& (!has_poi_filters
|| row_passes_poi_filters(
row,
&parsed_poi_filters,
&state.data.poi_metrics,
))
{ {
if has_travel if has_travel
&& !row_passes_travel_filters( && !row_passes_travel_filters(

View file

@ -10,7 +10,9 @@ use tracing::{info, warn};
use crate::auth::OptionalUser; use crate::auth::OptionalUser;
use crate::consts::POSTCODE_SEARCH_OFFSET; use crate::consts::POSTCODE_SEARCH_OFFSET;
use crate::licensing::{check_license_point, resolve_share_code}; use crate::licensing::{check_license_point, resolve_share_code};
use crate::parsing::{parse_field_set, parse_filters, row_passes_filters}; use crate::parsing::{
parse_field_set, parse_filters_with_poi, row_passes_filters, row_passes_poi_filters,
};
use crate::state::SharedState; use crate::state::SharedState;
use crate::utils::normalize_postcode; use crate::utils::normalize_postcode;
@ -64,15 +66,19 @@ pub async fn get_postcode_stats(
)?; )?;
let quant = state.data.quant_ref(); let quant = state.data.quant_ref();
let (parsed_filters, parsed_enum_filters) = parse_filters( let poi_quant = state.data.poi_metrics.quant_ref();
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
params.filters.as_deref(), params.filters.as_deref(),
&state.feature_name_to_index, &state.feature_name_to_index,
&state.data.enum_values, &state.data.enum_values,
&quant, &quant,
&state.data.poi_metrics.name_to_index,
&poi_quant,
) )
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let num_filters = parsed_filters.len() + parsed_enum_filters.len(); let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
let filters_str = params.filters; let filters_str = params.filters;
let has_poi_filters = !parsed_poi_filters.is_empty();
let (fields_specified, field_set) = parse_field_set(params.fields.as_deref()); let (fields_specified, field_set) = parse_field_set(params.fields.as_deref());
let travel_entries = parse_optional_travel(params.travel.as_deref()) let travel_entries = parse_optional_travel(params.travel.as_deref())
@ -108,6 +114,12 @@ pub async fn get_postcode_stats(
feature_data, feature_data,
num_features, num_features,
) )
&& (!has_poi_filters
|| row_passes_poi_filters(
row,
&parsed_poi_filters,
&state.data.poi_metrics,
))
{ {
if has_travel if has_travel
&& !row_passes_travel_filters(row_postcode, &travel_entries, &travel_data) && !row_passes_travel_filters(row_postcode, &travel_entries, &travel_data)
@ -123,7 +135,7 @@ pub async fn get_postcode_stats(
let price_history = let price_history =
stats::extract_price_history(&matching_rows, &state.data, &state.feature_name_to_index); stats::extract_price_history(&matching_rows, &state.data, &state.feature_name_to_index);
let (numeric_features, enum_features_out) = stats::compute_feature_stats( let (mut numeric_features, enum_features_out) = stats::compute_feature_stats(
&matching_rows, &matching_rows,
&state.data, &state.data,
&state.data.feature_names, &state.data.feature_names,
@ -132,6 +144,12 @@ pub async fn get_postcode_stats(
fields_specified, fields_specified,
&field_set, &field_set,
); );
numeric_features.extend(stats::compute_poi_feature_stats(
&matching_rows,
&state.data.poi_metrics,
fields_specified,
&field_set,
));
let elapsed = start_time.elapsed(); let elapsed = start_time.elapsed();
info!( info!(

View file

@ -10,14 +10,14 @@ use serde::{Deserialize, Serialize};
use serde_json::{Map, Value}; use serde_json::{Map, Value};
use tracing::info; use tracing::info;
use crate::aggregation::{Aggregator, EnumDistConfig}; use crate::aggregation::{Aggregator, EnumDistConfig, PoiAggregator};
use crate::auth::OptionalUser; use crate::auth::OptionalUser;
use crate::consts::MAX_CELLS_PER_REQUEST; use crate::consts::MAX_CELLS_PER_REQUEST;
use crate::data::travel_time::TravelData; use crate::data::travel_time::TravelData;
use crate::licensing::{check_license_bounds, resolve_share_code}; use crate::licensing::{check_license_bounds, resolve_share_code};
use crate::parsing::{ use crate::parsing::{
bounds_intersect, parse_enum_dist, parse_field_indices, parse_filters, require_bounds, bounds_intersect, parse_enum_dist, parse_field_indices_with_poi, parse_filters_with_poi,
row_passes_filters, require_bounds, row_passes_filters, row_passes_poi_filters,
}; };
use crate::pocketbase::log_user_location; use crate::pocketbase::log_user_location;
use crate::routes::travel_time::{parse_optional_travel, TravelTimeAgg}; use crate::routes::travel_time::{parse_optional_travel, TravelTimeAgg};
@ -64,18 +64,25 @@ pub async fn get_postcodes(
check_license_bounds(&user.0, (south, west, north, east), share_bounds)?; check_license_bounds(&user.0, (south, west, north, east), share_bounds)?;
let quant = state.data.quant_ref(); let quant = state.data.quant_ref();
let (parsed_filters, parsed_enum_filters) = parse_filters( let poi_quant = state.data.poi_metrics.quant_ref();
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
params.filters.as_deref(), params.filters.as_deref(),
&state.feature_name_to_index, &state.feature_name_to_index,
&state.data.enum_values, &state.data.enum_values,
&quant, &quant,
&state.data.poi_metrics.name_to_index,
&poi_quant,
) )
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let num_filters = parsed_filters.len() + parsed_enum_filters.len(); let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
let filters_str = params.filters; let filters_str = params.filters;
let field_indices = parse_field_indices(params.fields.as_deref(), &state.feature_name_to_index) let field_indices = parse_field_indices_with_poi(
.map_err(|err| (err.0, err.1).into_response())?; params.fields.as_deref(),
&state.feature_name_to_index,
&state.data.poi_metrics.name_to_index,
)
.map_err(|err| (err.0, err.1).into_response())?;
let travel_entries = parse_optional_travel(params.travel.as_deref()) let travel_entries = parse_optional_travel(params.travel.as_deref())
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
@ -123,12 +130,18 @@ pub async fn get_postcodes(
let min_keys = &state.min_keys; let min_keys = &state.min_keys;
let max_keys = &state.max_keys; let max_keys = &state.max_keys;
let avg_keys = &state.avg_keys; let avg_keys = &state.avg_keys;
let poi_metrics = &state.data.poi_metrics;
let poi_field_indices = field_indices.poi.as_slice();
let has_poi_fields = !poi_field_indices.is_empty();
let has_poi_filters = !parsed_poi_filters.is_empty();
let poi_num_features = poi_metrics.num_features();
let has_selective = field_indices.is_some(); let has_selective = field_indices.normal.is_some();
let sel_indices = field_indices.as_deref().unwrap_or(&[]); let sel_indices = field_indices.normal.as_deref().unwrap_or(&[]);
// Single-pass: aggregate directly into postcode_aggs while iterating properties in bounds // Single-pass: aggregate directly into postcode_aggs while iterating properties in bounds
let mut postcode_aggs: FxHashMap<usize, Aggregator> = FxHashMap::default(); let mut postcode_aggs: FxHashMap<usize, Aggregator> = FxHashMap::default();
let mut poi_aggs: FxHashMap<usize, PoiAggregator> = FxHashMap::default();
state state
.grid .grid
@ -143,6 +156,10 @@ pub async fn get_postcodes(
) { ) {
return; return;
} }
if has_poi_filters && !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics)
{
return;
}
let postcode = state.data.postcode(row); let postcode = state.data.postcode(row);
if let Some(&pc_idx) = postcode_data.postcode_to_idx.get(postcode) { if let Some(&pc_idx) = postcode_data.postcode_to_idx.get(postcode) {
@ -154,6 +171,12 @@ pub async fn get_postcodes(
} else { } else {
agg.add_row(feature_data, row, num_features, &quant); agg.add_row(feature_data, row, num_features, &quant);
} }
if has_poi_fields {
poi_aggs
.entry(pc_idx)
.or_insert_with(|| PoiAggregator::new(poi_num_features))
.add_row_selective(poi_metrics, row, poi_field_indices);
}
} }
}); });
@ -250,11 +273,12 @@ pub async fn get_postcodes(
]), ]),
); );
let iter: Box<dyn Iterator<Item = usize>> = if let Some(idx) = field_indices.as_ref() { let iter: Box<dyn Iterator<Item = usize>> =
Box::new(idx.iter().copied()) if let Some(idx) = field_indices.normal.as_ref() {
} else { Box::new(idx.iter().copied())
Box::new(0..num_features) } else {
}; Box::new(0..num_features)
};
for feat_index in iter { for feat_index in iter {
if aggregation.feat_counts[feat_index] > 0 { if aggregation.feat_counts[feat_index] > 0 {
@ -272,6 +296,25 @@ pub async fn get_postcodes(
} }
} }
if let Some(poi_aggregation) = poi_aggs.get(&pc_idx) {
for &metric_idx in poi_field_indices {
if poi_aggregation.counts[metric_idx] > 0 {
let avg = poi_aggregation.sums[metric_idx]
/ poi_aggregation.counts[metric_idx] as f64;
if let (Some(min_num), Some(max_num), Some(avg_num)) = (
serde_json::Number::from_f64(poi_aggregation.mins[metric_idx] as f64),
serde_json::Number::from_f64(poi_aggregation.maxs[metric_idx] as f64),
serde_json::Number::from_f64(avg),
) {
let name = &poi_metrics.feature_names[metric_idx];
props.insert(format!("min_{name}"), Value::Number(min_num));
props.insert(format!("max_{name}"), Value::Number(max_num));
props.insert(format!("avg_{name}"), Value::Number(avg_num));
}
}
}
}
// Add travel time aggregation fields // Add travel time aggregation fields
if let Some(tt_aggs) = travel_aggs.get(&pc_idx) { if let Some(tt_aggs) = travel_aggs.get(&pc_idx) {
for (ti, agg) in tt_aggs.iter().enumerate() { for (ti, agg) in tt_aggs.iter().enumerate() {
@ -322,7 +365,11 @@ pub async fn get_postcodes(
bounds = format_args!("{:.6},{:.6},{:.6},{:.6}", south, west, north, east), bounds = format_args!("{:.6},{:.6},{:.6},{:.6}", south, west, north, east),
filters = num_filters, filters = num_filters,
filters_raw = filters_str.as_deref().unwrap_or("-"), filters_raw = filters_str.as_deref().unwrap_or("-"),
fields = field_indices.as_ref().map(|v| v.len() as i32).unwrap_or(-1), fields = field_indices
.normal
.as_ref()
.map(|v| (v.len() + poi_field_indices.len()) as i32)
.unwrap_or(-1),
travel_entries = travel_entries.len(), travel_entries = travel_entries.len(),
agg_ms = format_args!("{:.1}", t_agg.as_secs_f64() * 1000.0), agg_ms = format_args!("{:.1}", t_agg.as_secs_f64() * 1000.0),
json_ms = format_args!("{:.1}", (t_total - t_agg).as_secs_f64() * 1000.0), json_ms = format_args!("{:.1}", (t_total - t_agg).as_secs_f64() * 1000.0),

View file

@ -14,8 +14,8 @@ use crate::consts::{DEFAULT_PROPERTIES_LIMIT, MAX_PROPERTIES_LIMIT};
use crate::data::RenovationEvent; use crate::data::RenovationEvent;
use crate::licensing::{check_license_bounds, resolve_share_code}; use crate::licensing::{check_license_bounds, resolve_share_code};
use crate::parsing::{ use crate::parsing::{
cell_for_row_cached, h3_cell_bounds, needs_parent, parse_filters, row_passes_filters, cell_for_row_cached, h3_cell_bounds, needs_parent, parse_filters_with_poi, row_passes_filters,
validate_h3_resolution, row_passes_poi_filters, validate_h3_resolution,
}; };
use crate::state::{AppState, SharedState}; use crate::state::{AppState, SharedState};
@ -117,6 +117,12 @@ pub fn build_property(
features.insert(feat_name.clone(), value); features.insert(feat_name.clone(), value);
} }
} }
for (metric_idx, metric_name) in state.data.poi_metrics.feature_names.iter().enumerate() {
let value = state.data.poi_metrics.get_for_property_row(row, metric_idx);
if value.is_finite() {
features.insert(metric_name.clone(), value);
}
}
Property { Property {
address: non_empty_string(state.data.address(row)), address: non_empty_string(state.data.address(row)),
@ -199,15 +205,19 @@ pub async fn get_hexagon_properties(
let h3_str = params.h3; let h3_str = params.h3;
let quant = state.data.quant_ref(); let quant = state.data.quant_ref();
let (parsed_filters, parsed_enum_filters) = parse_filters( let poi_quant = state.data.poi_metrics.quant_ref();
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
params.filters.as_deref(), params.filters.as_deref(),
&state.feature_name_to_index, &state.feature_name_to_index,
&state.data.enum_values, &state.data.enum_values,
&quant, &quant,
&state.data.poi_metrics.name_to_index,
&poi_quant,
) )
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let num_filters = parsed_filters.len() + parsed_enum_filters.len(); let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
let filters_str = params.filters; let filters_str = params.filters;
let has_poi_filters = !parsed_poi_filters.is_empty();
let travel_entries = parse_optional_travel(params.travel.as_deref()) let travel_entries = parse_optional_travel(params.travel.as_deref())
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
@ -242,6 +252,12 @@ pub async fn get_hexagon_properties(
feature_data, feature_data,
num_features, num_features,
) )
&& (!has_poi_filters
|| row_passes_poi_filters(
row,
&parsed_poi_filters,
&state.data.poi_metrics,
))
{ {
if has_travel { if has_travel {
let postcode = state.data.postcode(row); let postcode = state.data.postcode(row);

View file

@ -4,7 +4,7 @@ use rustc_hash::FxHashMap;
use tracing::warn; use tracing::warn;
use crate::consts::MAX_PRICE_HISTORY_POINTS; use crate::consts::MAX_PRICE_HISTORY_POINTS;
use crate::data::{FeatureStats, PropertyData}; use crate::data::{FeatureStats, PostcodePoiMetrics, PropertyData};
use super::hexagon_stats::{EnumFeatureStats, HistogramStats, NumericFeatureStats, PricePoint}; use super::hexagon_stats::{EnumFeatureStats, HistogramStats, NumericFeatureStats, PricePoint};
@ -243,3 +243,80 @@ pub fn compute_feature_stats(
(numeric_features, enum_features_out) (numeric_features, enum_features_out)
} }
pub fn compute_poi_feature_stats(
matching_rows: &[usize],
poi_metrics: &PostcodePoiMetrics,
fields_specified: bool,
field_set: &HashSet<String>,
) -> Vec<NumericFeatureStats> {
let mut out = Vec::new();
for (metric_idx, name) in poi_metrics.feature_names.iter().enumerate() {
if fields_specified && !field_set.contains(name.as_str()) {
continue;
}
let global_hist = &poi_metrics.feature_stats[metric_idx].histogram;
let p1 = global_hist.p1;
let p99 = global_hist.p99;
let num_bins = global_hist.counts.len();
let middle_bins = num_bins.saturating_sub(2);
let middle_width = if middle_bins > 0 && p99 > p1 {
(p99 - p1) / middle_bins as f32
} else {
0.0
};
let mut count = 0usize;
let mut min_value = f32::INFINITY;
let mut max_value = f32::NEG_INFINITY;
let mut sum = 0.0f64;
let mut bins = vec![0u64; num_bins];
for &row in matching_rows {
let value = poi_metrics.get_for_property_row(row, metric_idx);
if !value.is_finite() {
continue;
}
count += 1;
if value < min_value {
min_value = value;
}
if value > max_value {
max_value = value;
}
sum += value as f64;
let bin = if value < p1 {
0
} else if value >= p99 {
num_bins - 1
} else if middle_width > 0.0 {
let middle_bin = ((value - p1) / middle_width) as usize;
(1 + middle_bin).min(num_bins - 2)
} else {
num_bins / 2
};
bins[bin] += 1;
}
if count > 0 {
out.push(NumericFeatureStats {
name: name.clone(),
count,
min: min_value as f64,
max: max_value as f64,
mean: sum / count as f64,
histogram: HistogramStats {
min: global_hist.min as f64,
max: global_hist.max as f64,
p1: p1 as f64,
p99: p99 as f64,
counts: bins,
},
});
}
}
out
}

View file

@ -1,78 +1,40 @@
use std::collections::VecDeque; use std::sync::Arc;
use std::sync::{Arc, LazyLock};
use axum::body::Bytes; use axum::body::Bytes;
use axum::extract::State; use axum::extract::State;
use axum::http::{HeaderMap, StatusCode}; use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response}; use axum::response::{IntoResponse, Response};
use hmac::{Hmac, Mac}; use hmac::{Hmac, KeyInit, Mac};
use parking_lot::Mutex;
use rustc_hash::FxHashSet;
use sha2::Sha256; use sha2::Sha256;
use tracing::{info, warn}; use tracing::{info, warn};
use crate::pocketbase::get_superuser_token; use crate::checkout_sessions::{
grant_license, mark_checkout_completed, mark_referral_invite_used, verify_checkout_completion,
CheckoutCompletion,
};
use crate::state::SharedState; use crate::state::SharedState;
type HmacSha256 = Hmac<Sha256>; type HmacSha256 = Hmac<Sha256>;
/// Process-local LRU of recently processed Stripe event IDs.
/// Stripe retries deliver the same event ID; we drop duplicates so we don't
/// re-run side effects (subscription writes, token cache invalidation, logs).
/// Capacity is intentionally generous: at typical webhook volumes this covers
/// far more than Stripe's retry window.
struct EventDedup {
seen: FxHashSet<String>,
queue: VecDeque<String>,
capacity: usize,
}
impl EventDedup {
fn new(capacity: usize) -> Self {
Self {
seen: FxHashSet::default(),
queue: VecDeque::with_capacity(capacity),
capacity,
}
}
/// Returns `true` if this event ID is new (and records it),
/// `false` if it was already seen recently.
fn check_and_insert(&mut self, id: &str) -> bool {
if self.seen.contains(id) {
return false;
}
self.seen.insert(id.to_string());
self.queue.push_back(id.to_string());
if self.queue.len() > self.capacity {
if let Some(old) = self.queue.pop_front() {
self.seen.remove(&old);
}
}
true
}
}
static EVENT_DEDUP: LazyLock<Mutex<EventDedup>> =
LazyLock::new(|| Mutex::new(EventDedup::new(1024)));
/// Verify Stripe webhook signature (v1 scheme). /// Verify Stripe webhook signature (v1 scheme).
fn verify_signature(payload: &[u8], sig_header: &str, secret: &str) -> bool { fn verify_signature(payload: &[u8], sig_header: &str, secret: &str) -> bool {
// Parse timestamp and signature from header: "t=TIMESTAMP,v1=SIGNATURE" // Parse timestamp and signature from header: "t=TIMESTAMP,v1=SIGNATURE"
let mut timestamp = None; let mut timestamp = None;
let mut signature = None; let mut signatures = Vec::new();
for part in sig_header.split(',') { for part in sig_header.split(',') {
if let Some(ts) = part.strip_prefix("t=") { if let Some(ts) = part.strip_prefix("t=") {
timestamp = Some(ts); timestamp = Some(ts);
} else if let Some(sig) = part.strip_prefix("v1=") { } else if let Some(sig) = part.strip_prefix("v1=") {
signature = Some(sig); signatures.push(sig);
} }
} }
let (ts, sig_hex) = match (timestamp, signature) { let Some(ts) = timestamp else {
(Some(t), Some(s)) => (t, s), return false;
_ => return false,
}; };
if signatures.is_empty() {
return false;
}
// Reject webhooks older than 5 minutes to prevent replay attacks // Reject webhooks older than 5 minutes to prevent replay attacks
if let Ok(ts_secs) = ts.parse::<i64>() { if let Ok(ts_secs) = ts.parse::<i64>() {
@ -87,20 +49,21 @@ fn verify_signature(payload: &[u8], sig_header: &str, secret: &str) -> bool {
return false; return false;
} }
// Compute expected signature: HMAC-SHA256(secret, "TIMESTAMP.PAYLOAD") let mut signed_payload = Vec::with_capacity(ts.len() + 1 + payload.len());
let signed_payload = format!("{ts}.{}", String::from_utf8_lossy(payload)); signed_payload.extend_from_slice(ts.as_bytes());
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) { signed_payload.push(b'.');
Ok(m) => m, signed_payload.extend_from_slice(payload);
Err(_) => return false,
};
mac.update(signed_payload.as_bytes());
// Decode the provided hex signature and verify with constant-time comparison signatures.into_iter().any(|sig_hex| {
let sig_bytes = match hex::decode(sig_hex) { let Ok(sig_bytes) = hex::decode(sig_hex) else {
Ok(bytes) => bytes, return false;
Err(_) => return false, };
}; let Ok(mut mac) = HmacSha256::new_from_slice(secret.as_bytes()) else {
mac.verify_slice(&sig_bytes).is_ok() return false;
};
mac.update(&signed_payload);
mac.verify_slice(&sig_bytes).is_ok()
})
} }
/// Handle Stripe webhook events. /// Handle Stripe webhook events.
@ -140,65 +103,64 @@ pub async fn post_stripe_webhook(
let event_type = event["type"].as_str().unwrap_or(""); let event_type = event["type"].as_str().unwrap_or("");
let event_id = event["id"].as_str().unwrap_or(""); let event_id = event["id"].as_str().unwrap_or("");
// Idempotency: drop replays/retries of an already-processed event.
// We always answer 200 so Stripe stops retrying.
if !event_id.is_empty() && !EVENT_DEDUP.lock().check_and_insert(event_id) {
info!(event_id, event_type, "Dropping duplicate Stripe webhook");
return StatusCode::OK.into_response();
}
info!(event_id, event_type, "Received Stripe webhook"); info!(event_id, event_type, "Received Stripe webhook");
if event_type == "checkout.session.completed" { if event_type == "checkout.session.completed" {
let user_id = event["data"]["object"]["client_reference_id"] let session = &event["data"]["object"];
.as_str() match verify_checkout_completion(&state, session).await {
.unwrap_or(""); Ok(CheckoutCompletion::Grant(checkout)) => {
if user_id.is_empty() { if let Err(err) = mark_referral_invite_used(
warn!("checkout.session.completed missing client_reference_id"); &state,
return StatusCode::OK.into_response(); &checkout.referral_invite_id,
} &checkout.user_id,
if !user_id.bytes().all(|b| b.is_ascii_alphanumeric()) || user_id.len() > 20 { )
warn!(user_id, "Invalid client_reference_id format in webhook"); .await
return StatusCode::BAD_REQUEST.into_response(); {
} warn!(
user_id = %checkout.user_id,
// Update user subscription to "licensed" via PocketBase superuser auth reservation_id = %checkout.reservation_id,
let token = match get_superuser_token(&state).await { referral_invite_id = %checkout.referral_invite_id,
Ok(t) => t, "Failed to mark referral invite used after Stripe checkout: {err:?}"
Err(err) => { );
warn!("Failed to auth as PocketBase superuser in webhook: {err}"); return StatusCode::INTERNAL_SERVER_ERROR.into_response();
return StatusCode::INTERNAL_SERVER_ERROR.into_response(); }
} if let Err(err) = grant_license(&state, &checkout.user_id).await {
}; warn!(
user_id = %checkout.user_id,
let pb_url = state.pocketbase_url.trim_end_matches('/'); reservation_id = %checkout.reservation_id,
let url = format!("{pb_url}/api/collections/users/records/{user_id}"); "Failed to grant license after Stripe checkout: {err:?}"
let res = state );
.http_client return StatusCode::INTERNAL_SERVER_ERROR.into_response();
.patch(&url) }
.header("Authorization", format!("Bearer {token}")) if let Err(err) = mark_checkout_completed(
.json(&serde_json::json!({ "subscription": "licensed" })) &state,
.send() &checkout.reservation_id,
.await; checkout.paid_amount_pence,
)
match res { .await
Ok(resp) if resp.status().is_success() => { {
state.token_cache.invalidate_by_user_id(user_id); warn!(
user_id = %checkout.user_id,
reservation_id = %checkout.reservation_id,
"Failed to mark checkout completed after license grant: {err:?}"
);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
info!( info!(
user_id, user_id = %checkout.user_id,
"User subscription updated to licensed via Stripe webhook" reservation_id = %checkout.reservation_id,
"User subscription updated to licensed via verified Stripe checkout"
); );
} }
Ok(resp) => { Ok(CheckoutCompletion::AlreadyHandled) => {
let status = resp.status(); info!("Stripe checkout session was already handled");
let text = resp.text().await.unwrap_or_default(); }
warn!( Ok(CheckoutCompletion::Rejected(reason)) => {
user_id, warn!("Rejecting Stripe checkout completion: {reason}");
"Failed to update user subscription ({status}): {text}"
);
} }
Err(err) => { Err(err) => {
warn!(user_id, "PocketBase request error in webhook: {err}"); warn!("Failed to verify Stripe checkout completion: {err:?}");
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
} }
} }
} }