568 lines
20 KiB
Python
568 lines
20 KiB
Python
import json
|
||
|
||
import polars as pl
|
||
|
||
from pipeline.transform.transform_poi import (
|
||
_load_ofsted_ratings,
|
||
_school_icon_category_expr,
|
||
osm_groceries_colocated_with_geolytix,
|
||
transform,
|
||
transform_grocery_retail_points,
|
||
)
|
||
|
||
|
||
def test_osm_groceries_colocated_with_geolytix_drops_only_brand_matched_duplicates():
|
||
# GEOLYTIX is authoritative for its chains. An OSM grocery that sits on top
|
||
# of a GEOLYTIX store AND carries its brand is the same physical store and
|
||
# must be dropped; an independent shop at the same spot, and a same-brand
|
||
# store far from any GEOLYTIX point, must be kept.
|
||
geolytix = pl.DataFrame(
|
||
{
|
||
"category": ["Tesco"],
|
||
"lat": [51.5000],
|
||
"lng": [-0.1000],
|
||
}
|
||
)
|
||
osm = pl.DataFrame(
|
||
{
|
||
"id": ["dup-brand", "independent", "far-brand"],
|
||
"name": ["Tesco Express", "Bob's Corner Shop", "Tesco Express"],
|
||
# ~1 m, ~2 m, and ~55 km from the GEOLYTIX Tesco.
|
||
"lat": [51.50001, 51.50002, 52.0],
|
||
"lng": [-0.10001, -0.1000, -1.0],
|
||
}
|
||
)
|
||
|
||
drop_ids = osm_groceries_colocated_with_geolytix(osm, geolytix, radius_m=50.0)
|
||
|
||
assert drop_ids == ["dup-brand"]
|
||
|
||
|
||
def test_osm_groceries_colocated_with_geolytix_dedupes_cooperative_spelling():
|
||
# GEOLYTIX brand "Co-op" tokenises to "coop"; OSM commonly spells it
|
||
# "The Co-operative Food" -> "cooperative". The alias folds them so the
|
||
# genuine duplicate is still dropped.
|
||
geolytix = pl.DataFrame({"category": ["Co-op"], "lat": [53.0], "lng": [-1.5]})
|
||
osm = pl.DataFrame(
|
||
{
|
||
"id": ["coop-dup"],
|
||
"name": ["The Co-operative Food"],
|
||
"lat": [53.00001],
|
||
"lng": [-1.5],
|
||
}
|
||
)
|
||
assert osm_groceries_colocated_with_geolytix(osm, geolytix, radius_m=50.0) == [
|
||
"coop-dup"
|
||
]
|
||
|
||
|
||
def test_osm_groceries_colocated_with_geolytix_handles_empty_inputs():
|
||
geolytix = pl.DataFrame({"category": ["Tesco"], "lat": [51.5], "lng": [-0.1]})
|
||
empty = pl.DataFrame(
|
||
schema={"id": pl.Utf8, "name": pl.Utf8, "lat": pl.Float64, "lng": pl.Float64}
|
||
)
|
||
assert osm_groceries_colocated_with_geolytix(empty, geolytix) == []
|
||
osm = pl.DataFrame(
|
||
{"id": ["x"], "name": ["Tesco Express"], "lat": [51.5], "lng": [-0.1]}
|
||
)
|
||
empty_glx = pl.DataFrame(
|
||
schema={"category": pl.Utf8, "lat": pl.Float64, "lng": pl.Float64}
|
||
)
|
||
assert osm_groceries_colocated_with_geolytix(osm, empty_glx) == []
|
||
|
||
|
||
def _write_boundary(tmp_path):
|
||
"""A FeatureCollection whose single feature covers the London-area test
|
||
coords used by the transform() fixtures, so in_england_mask keeps them."""
|
||
boundary_path = tmp_path / "england.geojson"
|
||
coords = [[-1.0, 51.0], [1.0, 51.0], [1.0, 52.0], [-1.0, 52.0], [-1.0, 51.0]]
|
||
boundary_path.write_text(
|
||
json.dumps(
|
||
{
|
||
"type": "FeatureCollection",
|
||
"features": [
|
||
{
|
||
"type": "Feature",
|
||
"properties": {},
|
||
"geometry": {"type": "Polygon", "coordinates": [coords]},
|
||
}
|
||
],
|
||
}
|
||
)
|
||
)
|
||
return boundary_path
|
||
|
||
|
||
def _write_transform_inputs(tmp_path, raw_pois: pl.DataFrame):
|
||
"""Materialise the parquet inputs transform() requires around a given raw
|
||
OSM POIs frame. NaPTAN / grocery / GIAS / Ofsted are minimal but valid."""
|
||
input_path = tmp_path / "pois.parquet"
|
||
raw_pois.write_parquet(input_path)
|
||
|
||
naptan_path = tmp_path / "naptan.parquet"
|
||
pl.DataFrame(
|
||
{
|
||
"id": ["naptan-1"],
|
||
"name": ["Test Rail Station"],
|
||
"category": ["Rail station"],
|
||
"lat": [51.51],
|
||
"lng": [-0.13],
|
||
}
|
||
).write_parquet(naptan_path)
|
||
|
||
grocery_path = tmp_path / "grocery.parquet"
|
||
pl.DataFrame(
|
||
{
|
||
"id": list(range(1, 6)),
|
||
"retailer": ["Tesco"] * 5,
|
||
"fascia": ["Tesco"] * 5,
|
||
"store_name": [f"Tesco Test {i}" for i in range(1, 6)],
|
||
"long_wgs": [-0.14] * 5,
|
||
"lat_wgs": [51.52] * 5,
|
||
}
|
||
).write_parquet(grocery_path)
|
||
|
||
gias_path = tmp_path / "gias.parquet"
|
||
pl.DataFrame(
|
||
{
|
||
"urn": [1001],
|
||
"name": ["Test Primary School"],
|
||
"phase": ["Primary"],
|
||
"type": ["Community school"],
|
||
"type_group": ["Local authority maintained schools"],
|
||
"age_range": ["4–11"],
|
||
"gender": ["Mixed"],
|
||
"religious_character": [None],
|
||
"admissions_policy": ["Comprehensive"],
|
||
"nursery_provision": ["No"],
|
||
"sixth_form": ["No"],
|
||
"capacity": [200],
|
||
"pupils": [180],
|
||
"fsm_percent": [12.5],
|
||
"trust": [None],
|
||
"address": ["1 Test Street"],
|
||
"postcode": ["E1 1AA"],
|
||
"local_authority": ["Test LA"],
|
||
"website": [None],
|
||
"telephone": ["02012345678"],
|
||
"head_name": ["Jane Doe"],
|
||
"lat": [51.53],
|
||
"lng": [-0.12],
|
||
}
|
||
).write_parquet(gias_path)
|
||
|
||
ofsted_path = tmp_path / "ofsted.parquet"
|
||
pl.DataFrame(
|
||
{
|
||
"URN": [1001],
|
||
"Latest OEIF overall effectiveness": ["2"],
|
||
"Ungraded inspection overall outcome": [None],
|
||
}
|
||
).write_parquet(ofsted_path)
|
||
|
||
boundary_path = _write_boundary(tmp_path)
|
||
return {
|
||
"input_path": input_path,
|
||
"naptan_path": naptan_path,
|
||
"boundary_path": boundary_path,
|
||
"grocery_retail_points_path": grocery_path,
|
||
"gias_path": gias_path,
|
||
"ofsted_path": ofsted_path,
|
||
}
|
||
|
||
|
||
def test_transform_grocery_retail_points_outputs_chain_categories():
|
||
raw = pl.DataFrame(
|
||
{
|
||
"id": [101, 102, 103],
|
||
"retailer": ["Waitrose", "Sainsburys", "The Co-operative Group"],
|
||
"fascia": ["Waitrose", "Sainsbury's Local", "Co-op Food"],
|
||
"store_name": ["Waitrose Test", "Sainsbury''s Test", "Co-op Test"],
|
||
"long_wgs": [-0.141, -0.142, -0.143],
|
||
"lat_wgs": [51.515, 51.516, 51.517],
|
||
}
|
||
)
|
||
|
||
pois = transform_grocery_retail_points(raw, min_chain_locations=1)
|
||
|
||
assert pois.select(
|
||
"id", "name", "category", "icon_category", "group", "emoji"
|
||
).to_dicts() == [
|
||
{
|
||
"id": "glx-101",
|
||
"name": "Waitrose Test",
|
||
"category": "Waitrose",
|
||
"icon_category": "Waitrose",
|
||
"group": "Groceries",
|
||
"emoji": "🛒",
|
||
},
|
||
{
|
||
"id": "glx-102",
|
||
"name": "Sainsbury's Test",
|
||
"category": "Sainsbury's",
|
||
"icon_category": "Sainsbury's Local",
|
||
"group": "Groceries",
|
||
"emoji": "🛒",
|
||
},
|
||
{
|
||
"id": "glx-103",
|
||
"name": "Co-op Test",
|
||
"category": "Co-op",
|
||
"icon_category": "Co-op",
|
||
"group": "Groceries",
|
||
"emoji": "🛒",
|
||
},
|
||
]
|
||
|
||
|
||
def test_transform_grocery_retail_points_keeps_fascia_icon_category():
|
||
raw = pl.DataFrame(
|
||
{
|
||
"id": [101, 102, 103, 104],
|
||
"retailer": ["Tesco", "Iceland", "Waitrose", "Morrisons"],
|
||
"fascia": [
|
||
"Tesco Express Esso",
|
||
"The Food Warehouse",
|
||
"Little Waitrose Shell",
|
||
"Morrisons Daily",
|
||
],
|
||
"store_name": [
|
||
"Tesco Test Express",
|
||
"Iceland Test Food Warehouse",
|
||
"Little Waitrose Test",
|
||
"Morrisons Daily Test",
|
||
],
|
||
"long_wgs": [-0.141, -0.142, -0.143, -0.144],
|
||
"lat_wgs": [51.515, 51.516, 51.517, 51.518],
|
||
}
|
||
)
|
||
|
||
pois = transform_grocery_retail_points(raw, min_chain_locations=1)
|
||
|
||
assert pois.select("category", "icon_category").to_dicts() == [
|
||
{"category": "Tesco", "icon_category": "Tesco Express"},
|
||
{"category": "Iceland", "icon_category": "The Food Warehouse"},
|
||
{"category": "Waitrose", "icon_category": "Little Waitrose"},
|
||
{"category": "Morrisons", "icon_category": "Morrisons Daily"},
|
||
]
|
||
|
||
|
||
def test_transform_grocery_retail_points_merges_cooperative_societies():
|
||
raw = pl.DataFrame(
|
||
{
|
||
"id": [101, 102, 103],
|
||
"retailer": [
|
||
"Central England Co-operative",
|
||
"Lincolnshire Co-operative",
|
||
"The Southern Co-operative",
|
||
],
|
||
"fascia": [
|
||
"Central England Co-operative",
|
||
"The Co-operative Food",
|
||
None,
|
||
],
|
||
"store_name": [
|
||
"Central Co-op Test",
|
||
"Lincolnshire Co-op Test",
|
||
"Southern Co-op Test",
|
||
],
|
||
"long_wgs": [-0.141, -0.142, -0.143],
|
||
"lat_wgs": [51.515, 51.516, 51.517],
|
||
}
|
||
)
|
||
|
||
pois = transform_grocery_retail_points(raw, min_chain_locations=1)
|
||
|
||
assert pois.select("category", "icon_category").to_dicts() == [
|
||
{"category": "Co-op", "icon_category": "Co-op"},
|
||
{"category": "Co-op", "icon_category": "Co-op"},
|
||
{"category": "Co-op", "icon_category": "Co-op"},
|
||
]
|
||
|
||
|
||
def test_transform_grocery_retail_points_pools_small_coop_societies_before_cutoff():
|
||
# Each Co-op society has <5 in-England stores; only after normalising to the
|
||
# shared "Co-op" brand do they clear MIN_GROCERY_CHAIN_LOCATIONS together.
|
||
societies = [
|
||
"Central England Co-operative",
|
||
"Lincolnshire Co-operative",
|
||
"The Southern Co-operative",
|
||
"Midcounties Co-operative",
|
||
"Heart of England Co-operative",
|
||
]
|
||
raw = pl.DataFrame(
|
||
{
|
||
"id": list(range(1, len(societies) + 1)),
|
||
"retailer": societies,
|
||
"fascia": ["The Co-operative Food"] * len(societies),
|
||
"store_name": [f"Co-op Test {i}" for i in range(1, len(societies) + 1)],
|
||
"long_wgs": [-0.141] * len(societies),
|
||
"lat_wgs": [51.515] * len(societies),
|
||
}
|
||
)
|
||
|
||
pois = transform_grocery_retail_points(raw)
|
||
|
||
assert pois.height == len(societies)
|
||
assert pois["category"].unique().to_list() == ["Co-op"]
|
||
|
||
|
||
def test_transform_grocery_retail_points_accepts_base_fascias():
|
||
raw = pl.DataFrame(
|
||
{
|
||
"id": [101, 102, 103, 104],
|
||
"retailer": ["Aldi", "Asda", "Booths", "Whole Foods Market"],
|
||
"fascia": ["Aldi", "Asda Superstore", "Booths", "Whole Foods Market"],
|
||
"store_name": [
|
||
"Aldi Test",
|
||
"Asda Test Superstore",
|
||
"Booths Test",
|
||
"Whole Foods Test",
|
||
],
|
||
"long_wgs": [-0.141, -0.142, -0.143, -0.144],
|
||
"lat_wgs": [51.515, 51.516, 51.517, 51.518],
|
||
}
|
||
)
|
||
|
||
pois = transform_grocery_retail_points(raw, min_chain_locations=1)
|
||
|
||
assert pois.select("category", "icon_category").to_dicts() == [
|
||
{"category": "Aldi", "icon_category": "Aldi"},
|
||
{"category": "Asda", "icon_category": "Asda Superstore"},
|
||
{"category": "Booths", "icon_category": "Booths"},
|
||
{"category": "Whole Foods Market", "icon_category": "Whole Foods Market"},
|
||
]
|
||
|
||
|
||
def test_transform_grocery_retail_points_drops_invalid_rows():
|
||
raw = pl.DataFrame(
|
||
{
|
||
"id": [101, 102],
|
||
"retailer": ["Waitrose", ""],
|
||
"fascia": ["Waitrose", "Tesco"],
|
||
"store_name": ["Waitrose Test", "Tesco Test"],
|
||
"long_wgs": [-0.141, -0.142],
|
||
"lat_wgs": [51.515, 51.516],
|
||
}
|
||
)
|
||
|
||
pois = transform_grocery_retail_points(raw, min_chain_locations=1)
|
||
|
||
assert pois["category"].to_list() == ["Waitrose"]
|
||
|
||
|
||
def test_transform_grocery_retail_points_includes_unmapped_chains_with_five_locations():
|
||
raw = pl.DataFrame(
|
||
{
|
||
"id": list(range(1, 10)),
|
||
"retailer": ["Tian Tian"] * 5 + ["Corner Shop"] * 4,
|
||
"fascia": ["Tian Tian Market"] * 5 + ["Corner Shop"] * 4,
|
||
"store_name": [f"Store {i}" for i in range(1, 10)],
|
||
"long_wgs": [-0.1] * 9,
|
||
"lat_wgs": [51.5] * 9,
|
||
}
|
||
)
|
||
|
||
pois = transform_grocery_retail_points(raw)
|
||
|
||
assert pois.select("id", "category", "icon_category").to_dicts() == [
|
||
{"id": "glx-1", "category": "Tian Tian", "icon_category": "Tian Tian"},
|
||
{"id": "glx-2", "category": "Tian Tian", "icon_category": "Tian Tian"},
|
||
{"id": "glx-3", "category": "Tian Tian", "icon_category": "Tian Tian"},
|
||
{"id": "glx-4", "category": "Tian Tian", "icon_category": "Tian Tian"},
|
||
{"id": "glx-5", "category": "Tian Tian", "icon_category": "Tian Tian"},
|
||
]
|
||
|
||
|
||
def test_load_ofsted_ratings_falls_back_to_ungraded_outcome(tmp_path):
|
||
# URNs 1-4: graded results map straight through. URNs 5-6: no usable graded
|
||
# grade (null/"Not judged") but a good/outstanding ungraded outcome, incl.
|
||
# the "(Concerns)"/"(Improving)" suffixes. URN 7: genuinely "Not judged".
|
||
# URN 8: a real grade 3 must NOT be overridden by an ungraded outcome.
|
||
ofsted_path = tmp_path / "ofsted.parquet"
|
||
pl.DataFrame(
|
||
{
|
||
"URN": [1, 2, 3, 4, 5, 6, 7, 8],
|
||
"Latest OEIF overall effectiveness": [
|
||
"1",
|
||
"2",
|
||
"3",
|
||
"4",
|
||
None,
|
||
"Not judged",
|
||
"Not judged",
|
||
"3",
|
||
],
|
||
"Ungraded inspection overall outcome": [
|
||
None,
|
||
None,
|
||
None,
|
||
None,
|
||
"School remains Outstanding",
|
||
"School remains Good (Concerns)",
|
||
None,
|
||
"School remains Outstanding",
|
||
],
|
||
}
|
||
).write_parquet(ofsted_path)
|
||
|
||
ratings = _load_ofsted_ratings(ofsted_path).collect().sort("urn").to_dicts()
|
||
|
||
assert ratings == [
|
||
{"urn": 1, "ofsted_rating": "Outstanding"},
|
||
{"urn": 2, "ofsted_rating": "Good"},
|
||
{"urn": 3, "ofsted_rating": "Requires improvement"},
|
||
{"urn": 4, "ofsted_rating": "Inadequate"},
|
||
{"urn": 5, "ofsted_rating": "Outstanding"},
|
||
{"urn": 6, "ofsted_rating": "Good"},
|
||
{"urn": 7, "ofsted_rating": "Not judged"},
|
||
{"urn": 8, "ofsted_rating": "Requires improvement"},
|
||
]
|
||
|
||
|
||
def test_school_icon_category_handles_one_sided_age_ranges():
|
||
# gias._format_age_range emits "up to {high}", "{low}+" and "{low}–{high}".
|
||
# All three (plus null) must classify, not fall through to "School".
|
||
df = pl.DataFrame(
|
||
{
|
||
"phase": [None, None, None, None, None],
|
||
"type_group": [None, None, None, None, None],
|
||
# "up to 5" -> nursery; "16+" -> sixth form; "3–18" -> all-through;
|
||
# "4–11" -> primary; null age_range with null phase -> "School".
|
||
"age_range": ["up to 5", "16+", "3–18", "4–11", None],
|
||
},
|
||
# Production reads these from a scanned parquet as String; an all-null
|
||
# Python list would otherwise infer the Null dtype and break .str ops.
|
||
schema_overrides={
|
||
"phase": pl.String,
|
||
"type_group": pl.String,
|
||
"age_range": pl.String,
|
||
},
|
||
)
|
||
|
||
categories = df.select(_school_icon_category_expr().alias("category"))[
|
||
"category"
|
||
].to_list()
|
||
|
||
assert categories == [
|
||
"Nursery school",
|
||
"Sixth form",
|
||
"All-through school",
|
||
"Primary school",
|
||
"School",
|
||
]
|
||
|
||
|
||
def test_transform_dedupes_multi_tag_pois(tmp_path):
|
||
# One OSM object can carry several tag keys that map to the SAME friendly
|
||
# category, so pois.py emits one raw row per key with the SAME id.
|
||
# "amenity/pharmacy" and "shop/chemist" both map to "Pharmacy".
|
||
raw = pl.DataFrame(
|
||
{
|
||
"id": ["n42", "n42"],
|
||
"name": ["Boots", "Boots"],
|
||
"category": ["amenity/pharmacy", "shop/chemist"],
|
||
"lat": [51.50, 51.50],
|
||
"lng": [-0.10, -0.10],
|
||
}
|
||
)
|
||
inputs = _write_transform_inputs(tmp_path, raw)
|
||
|
||
out = transform(**inputs).collect()
|
||
|
||
# No (id, category) pair appears more than once.
|
||
assert out.group_by("id", "category").len()["len"].max() == 1
|
||
# The single physical pharmacy is present exactly once.
|
||
pharmacies = out.filter(
|
||
(pl.col("id") == "n42") & (pl.col("category") == "Pharmacy")
|
||
)
|
||
assert pharmacies.height == 1
|
||
|
||
|
||
def test_osm_supermarkets_dropped(tmp_path):
|
||
# GEOLYTIX is authoritative for supermarkets; an OSM "shop/supermarket" row
|
||
# must not flow through as a second Groceries/Supermarket pin. A
|
||
# complementary grocery category (Convenience Store) must still survive.
|
||
raw = pl.DataFrame(
|
||
{
|
||
"id": ["n1", "n2"],
|
||
"name": ["Some Supermarket", "Corner Shop"],
|
||
"category": ["shop/supermarket", "shop/convenience"],
|
||
"lat": [51.50, 51.51],
|
||
"lng": [-0.10, -0.11],
|
||
}
|
||
)
|
||
inputs = _write_transform_inputs(tmp_path, raw)
|
||
|
||
out = transform(**inputs).collect()
|
||
|
||
osm_supermarkets = out.filter(
|
||
(pl.col("group") == "Groceries") & (pl.col("category") == "Supermarket")
|
||
)
|
||
assert osm_supermarkets.height == 0
|
||
# Complementary OSM grocery category survives.
|
||
convenience = out.filter(pl.col("category") == "Convenience Store")
|
||
assert convenience.height == 1
|
||
|
||
|
||
def test_transform_grocery_dedup_drops_only_grocery_aspect(tmp_path):
|
||
# The _write_transform_inputs fixture seeds 5 GEOLYTIX "Tesco" points at
|
||
# (51.52, -0.14). An OSM object colocated there carrying "Tesco" in its name
|
||
# is the same physical store, so its Convenience Store (Groceries) row is a
|
||
# duplicate and must be dropped — but its NON-grocery aspect (a Post Office
|
||
# sharing the same OSM id) must survive. An independent shop away from the
|
||
# GEOLYTIX point keeps its grocery row.
|
||
raw = pl.DataFrame(
|
||
{
|
||
"id": ["n1", "n1", "n2"],
|
||
"name": ["Tesco Express", "Tesco Express", "Corner Shop"],
|
||
"category": [
|
||
"shop/convenience",
|
||
"amenity/post_office",
|
||
"shop/convenience",
|
||
],
|
||
"lat": [51.52, 51.52, 51.40],
|
||
"lng": [-0.14, -0.14, -0.05],
|
||
}
|
||
)
|
||
inputs = _write_transform_inputs(tmp_path, raw)
|
||
|
||
out = transform(**inputs).collect()
|
||
|
||
# The colocated, brand-matched grocery row is dropped.
|
||
n1_grocery = out.filter((pl.col("id") == "n1") & (pl.col("group") == "Groceries"))
|
||
assert n1_grocery.height == 0
|
||
# Its non-grocery aspect (Post Office) survives.
|
||
n1_post_office = out.filter(
|
||
(pl.col("id") == "n1") & (pl.col("category") == "Post Office")
|
||
)
|
||
assert n1_post_office.height == 1
|
||
# The independent corner shop (no brand, far away) keeps its grocery row.
|
||
n2_grocery = out.filter(
|
||
(pl.col("id") == "n2") & (pl.col("category") == "Convenience Store")
|
||
)
|
||
assert n2_grocery.height == 1
|
||
|
||
|
||
def test_transform_output_unique_per_id_category(tmp_path):
|
||
# Soundness: the full transform() output has at most one row per
|
||
# (id, category) overall, across every source.
|
||
raw = pl.DataFrame(
|
||
{
|
||
"id": ["n42", "n42", "n7", "n8"],
|
||
"name": ["Boots", "Boots", "St Mary's", "St Mary's"],
|
||
"category": [
|
||
"amenity/pharmacy",
|
||
"shop/chemist",
|
||
"amenity/place_of_worship",
|
||
"building/church",
|
||
],
|
||
"lat": [51.50, 51.50, 51.55, 51.55],
|
||
"lng": [-0.10, -0.10, -0.15, -0.15],
|
||
}
|
||
)
|
||
inputs = _write_transform_inputs(tmp_path, raw)
|
||
|
||
out = transform(**inputs).collect()
|
||
|
||
assert out.group_by("id", "category").len()["len"].max() == 1
|