perfect-postcode/pipeline/utils/test_poi_counts.py
2026-05-13 12:11:54 +01:00

162 lines
5 KiB
Python

import numpy as np
import polars as pl
import pytest
from pipeline.utils.poi_counts import count_pois_per_postcode, min_distance_per_postcode
POI_GROUPS = {
"restaurants": ["Restaurant", "Fast Food"],
"groceries": ["Supermarket"],
"parks": ["Park"],
"train_tube": ["Rail station", "Tube station"],
}
@pytest.fixture
def pois():
"""POIs clustered around two locations: central London and 10km away."""
return pl.DataFrame(
{
"lat": [51.5074, 51.5075, 51.5080, 51.5076, 51.5073, 51.60],
"lng": [-0.1278, -0.1280, -0.1275, -0.1279, -0.1277, -0.20],
"category": [
"Restaurant",
"Fast Food",
"Supermarket",
"Park",
"Rail station",
"Restaurant", # too far from any property
],
}
)
@pytest.fixture
def postcodes():
"""Two postcodes: one near central London, one far away."""
return pl.DataFrame(
{
"postcode": ["EC1A 1BB", "ZZ99 9ZZ"],
"lat": [51.5074, 55.0],
"lon": [-0.1278, -3.0],
}
)
def test_counts_pois_within_radius(postcodes, pois):
result = count_pois_per_postcode(postcodes, pois, groups=POI_GROUPS, radius_km=2.0)
expected_cols = {f"{g}_2km" for g in POI_GROUPS}
assert expected_cols.issubset(set(result.columns))
# Result must be aligned to postcodes (2 rows)
assert len(result) == 2
ec1a = result.filter(pl.col("postcode") == "EC1A 1BB")
assert ec1a["restaurants_2km"][0] == 2 # Restaurant + Fast Food
assert ec1a["groceries_2km"][0] == 1 # Supermarket
assert ec1a["parks_2km"][0] == 1 # Park
assert ec1a["train_tube_2km"][0] == 1 # Rail station
# Far-away postcode should have zero counts
zz99 = result.filter(pl.col("postcode") == "ZZ99 9ZZ")
for group in POI_GROUPS:
assert zz99[f"{group}_2km"][0] == 0
def test_no_pois_returns_zeros(postcodes):
empty_pois = pl.DataFrame(
{
"lat": pl.Series([], dtype=pl.Float64),
"lng": pl.Series([], dtype=pl.Float64),
"category": pl.Series([], dtype=pl.String),
}
)
result = count_pois_per_postcode(
postcodes, empty_pois, groups=POI_GROUPS, radius_km=2.0
)
for group in POI_GROUPS:
col = f"{group}_2km"
assert col in result.columns
assert result[col].to_list() == [0, 0]
def test_custom_radius(pois):
"""A tiny radius should exclude POIs that are even slightly away."""
postcodes = pl.DataFrame(
{
"postcode": ["EC1A 1BB"],
"lat": [51.5074],
"lon": [-0.1278],
}
)
# 0.01 km = 10m — only the POI at the exact same location should match
result = count_pois_per_postcode(postcodes, pois, groups=POI_GROUPS, radius_km=0.01)
# The Restaurant at (51.5074, -0.1278) is at distance 0
assert result["restaurants_0km"][0] >= 1
# POIs >100m away should not be counted
total = sum(result[f"{g}_0km"][0] for g in POI_GROUPS)
assert total <= 2 # at most the co-located POIs
def test_counts_pois_across_multiple_grid_cells_within_5km():
"""A POI around 4.8km away must not be dropped by grid candidate lookup."""
postcodes = pl.DataFrame(
{
"postcode": ["GRID 5KM"],
"lat": [51.5],
"lon": [0.049],
}
)
pois = pl.DataFrame(
{
"lat": [51.5, 51.5],
"lng": [0.1183, 0.1240],
"category": ["Park", "Park"],
}
)
result = count_pois_per_postcode(
postcodes,
pois,
groups={"parks": ["Park"]},
radius_km=5.0,
)
assert result["parks_5km"][0] == 1
def test_min_distance_finds_nearest(postcodes, pois):
"""min_distance_per_postcode returns distance to closest POI per group."""
result = min_distance_per_postcode(postcodes, pois, groups=POI_GROUPS)
assert len(result) == 2
ec1a = result.filter(pl.col("postcode") == "EC1A 1BB")
# Rail station is at (51.5073, -0.1277), postcode at (51.5074, -0.1278) — very close
assert ec1a["train_tube_nearest_km"][0] < 0.05 # within 50m
# Restaurant is co-located — distance ~0
assert ec1a["restaurants_nearest_km"][0] < 0.01
# Far-away postcode should still get the global nearest distance.
zz99 = result.filter(pl.col("postcode") == "ZZ99 9ZZ")
assert zz99["train_tube_nearest_km"][0] > 300
def test_min_distance_no_pois_returns_nan(postcodes):
"""With no POIs, all distances should be NaN."""
empty_pois = pl.DataFrame(
{
"lat": pl.Series([], dtype=pl.Float64),
"lng": pl.Series([], dtype=pl.Float64),
"category": pl.Series([], dtype=pl.String),
}
)
result = min_distance_per_postcode(
postcodes, empty_pois, groups={"train_tube": ["Rail station"]}
)
assert "train_tube_nearest_km" in result.columns
assert all(np.isnan(v) for v in result["train_tube_nearest_km"].to_list())