perfect-postcode/pipeline/transform/test_transform_poi.py
2026-06-02 13:46:18 +01:00

294 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import polars as pl
from pipeline.transform.transform_poi import (
_load_ofsted_ratings,
_school_icon_category_expr,
transform_grocery_retail_points,
)
def test_transform_grocery_retail_points_outputs_chain_categories():
raw = pl.DataFrame(
{
"id": [101, 102, 103],
"retailer": ["Waitrose", "Sainsburys", "The Co-operative Group"],
"fascia": ["Waitrose", "Sainsbury's Local", "Co-op Food"],
"store_name": ["Waitrose Test", "Sainsbury''s Test", "Co-op Test"],
"long_wgs": [-0.141, -0.142, -0.143],
"lat_wgs": [51.515, 51.516, 51.517],
}
)
pois = transform_grocery_retail_points(raw, min_chain_locations=1)
assert pois.select(
"id", "name", "category", "icon_category", "group", "emoji"
).to_dicts() == [
{
"id": "glx-101",
"name": "Waitrose Test",
"category": "Waitrose",
"icon_category": "Waitrose",
"group": "Groceries",
"emoji": "🛒",
},
{
"id": "glx-102",
"name": "Sainsbury's Test",
"category": "Sainsbury's",
"icon_category": "Sainsbury's Local",
"group": "Groceries",
"emoji": "🛒",
},
{
"id": "glx-103",
"name": "Co-op Test",
"category": "Co-op",
"icon_category": "Co-op",
"group": "Groceries",
"emoji": "🛒",
},
]
def test_transform_grocery_retail_points_keeps_fascia_icon_category():
raw = pl.DataFrame(
{
"id": [101, 102, 103, 104],
"retailer": ["Tesco", "Iceland", "Waitrose", "Morrisons"],
"fascia": [
"Tesco Express Esso",
"The Food Warehouse",
"Little Waitrose Shell",
"Morrisons Daily",
],
"store_name": [
"Tesco Test Express",
"Iceland Test Food Warehouse",
"Little Waitrose Test",
"Morrisons Daily Test",
],
"long_wgs": [-0.141, -0.142, -0.143, -0.144],
"lat_wgs": [51.515, 51.516, 51.517, 51.518],
}
)
pois = transform_grocery_retail_points(raw, min_chain_locations=1)
assert pois.select("category", "icon_category").to_dicts() == [
{"category": "Tesco", "icon_category": "Tesco Express"},
{"category": "Iceland", "icon_category": "The Food Warehouse"},
{"category": "Waitrose", "icon_category": "Little Waitrose"},
{"category": "Morrisons", "icon_category": "Morrisons Daily"},
]
def test_transform_grocery_retail_points_merges_cooperative_societies():
raw = pl.DataFrame(
{
"id": [101, 102, 103],
"retailer": [
"Central England Co-operative",
"Lincolnshire Co-operative",
"The Southern Co-operative",
],
"fascia": [
"Central England Co-operative",
"The Co-operative Food",
None,
],
"store_name": [
"Central Co-op Test",
"Lincolnshire Co-op Test",
"Southern Co-op Test",
],
"long_wgs": [-0.141, -0.142, -0.143],
"lat_wgs": [51.515, 51.516, 51.517],
}
)
pois = transform_grocery_retail_points(raw, min_chain_locations=1)
assert pois.select("category", "icon_category").to_dicts() == [
{"category": "Co-op", "icon_category": "Co-op"},
{"category": "Co-op", "icon_category": "Co-op"},
{"category": "Co-op", "icon_category": "Co-op"},
]
def test_transform_grocery_retail_points_pools_small_coop_societies_before_cutoff():
# Each Co-op society has <5 in-England stores; only after normalising to the
# shared "Co-op" brand do they clear MIN_GROCERY_CHAIN_LOCATIONS together.
societies = [
"Central England Co-operative",
"Lincolnshire Co-operative",
"The Southern Co-operative",
"Midcounties Co-operative",
"Heart of England Co-operative",
]
raw = pl.DataFrame(
{
"id": list(range(1, len(societies) + 1)),
"retailer": societies,
"fascia": ["The Co-operative Food"] * len(societies),
"store_name": [f"Co-op Test {i}" for i in range(1, len(societies) + 1)],
"long_wgs": [-0.141] * len(societies),
"lat_wgs": [51.515] * len(societies),
}
)
pois = transform_grocery_retail_points(raw)
assert pois.height == len(societies)
assert pois["category"].unique().to_list() == ["Co-op"]
def test_transform_grocery_retail_points_accepts_base_fascias():
raw = pl.DataFrame(
{
"id": [101, 102, 103, 104],
"retailer": ["Aldi", "Asda", "Booths", "Whole Foods Market"],
"fascia": ["Aldi", "Asda Superstore", "Booths", "Whole Foods Market"],
"store_name": [
"Aldi Test",
"Asda Test Superstore",
"Booths Test",
"Whole Foods Test",
],
"long_wgs": [-0.141, -0.142, -0.143, -0.144],
"lat_wgs": [51.515, 51.516, 51.517, 51.518],
}
)
pois = transform_grocery_retail_points(raw, min_chain_locations=1)
assert pois.select("category", "icon_category").to_dicts() == [
{"category": "Aldi", "icon_category": "Aldi"},
{"category": "Asda", "icon_category": "Asda Superstore"},
{"category": "Booths", "icon_category": "Booths"},
{"category": "Whole Foods Market", "icon_category": "Whole Foods Market"},
]
def test_transform_grocery_retail_points_drops_invalid_rows():
raw = pl.DataFrame(
{
"id": [101, 102],
"retailer": ["Waitrose", ""],
"fascia": ["Waitrose", "Tesco"],
"store_name": ["Waitrose Test", "Tesco Test"],
"long_wgs": [-0.141, -0.142],
"lat_wgs": [51.515, 51.516],
}
)
pois = transform_grocery_retail_points(raw, min_chain_locations=1)
assert pois["category"].to_list() == ["Waitrose"]
def test_transform_grocery_retail_points_includes_unmapped_chains_with_five_locations():
raw = pl.DataFrame(
{
"id": list(range(1, 10)),
"retailer": ["Tian Tian"] * 5 + ["Corner Shop"] * 4,
"fascia": ["Tian Tian Market"] * 5 + ["Corner Shop"] * 4,
"store_name": [f"Store {i}" for i in range(1, 10)],
"long_wgs": [-0.1] * 9,
"lat_wgs": [51.5] * 9,
}
)
pois = transform_grocery_retail_points(raw)
assert pois.select("id", "category", "icon_category").to_dicts() == [
{"id": "glx-1", "category": "Tian Tian", "icon_category": "Tian Tian"},
{"id": "glx-2", "category": "Tian Tian", "icon_category": "Tian Tian"},
{"id": "glx-3", "category": "Tian Tian", "icon_category": "Tian Tian"},
{"id": "glx-4", "category": "Tian Tian", "icon_category": "Tian Tian"},
{"id": "glx-5", "category": "Tian Tian", "icon_category": "Tian Tian"},
]
def test_load_ofsted_ratings_falls_back_to_ungraded_outcome(tmp_path):
# URNs 1-4: graded results map straight through. URNs 5-6: no usable graded
# grade (null/"Not judged") but a good/outstanding ungraded outcome, incl.
# the "(Concerns)"/"(Improving)" suffixes. URN 7: genuinely "Not judged".
# URN 8: a real grade 3 must NOT be overridden by an ungraded outcome.
ofsted_path = tmp_path / "ofsted.parquet"
pl.DataFrame(
{
"URN": [1, 2, 3, 4, 5, 6, 7, 8],
"Latest OEIF overall effectiveness": [
"1",
"2",
"3",
"4",
None,
"Not judged",
"Not judged",
"3",
],
"Ungraded inspection overall outcome": [
None,
None,
None,
None,
"School remains Outstanding",
"School remains Good (Concerns)",
None,
"School remains Outstanding",
],
}
).write_parquet(ofsted_path)
ratings = (
_load_ofsted_ratings(ofsted_path)
.collect()
.sort("urn")
.to_dicts()
)
assert ratings == [
{"urn": 1, "ofsted_rating": "Outstanding"},
{"urn": 2, "ofsted_rating": "Good"},
{"urn": 3, "ofsted_rating": "Requires improvement"},
{"urn": 4, "ofsted_rating": "Inadequate"},
{"urn": 5, "ofsted_rating": "Outstanding"},
{"urn": 6, "ofsted_rating": "Good"},
{"urn": 7, "ofsted_rating": "Not judged"},
{"urn": 8, "ofsted_rating": "Requires improvement"},
]
def test_school_icon_category_handles_one_sided_age_ranges():
# gias._format_age_range emits "up to {high}", "{low}+" and "{low}{high}".
# All three (plus null) must classify, not fall through to "School".
df = pl.DataFrame(
{
"phase": [None, None, None, None, None],
"type_group": [None, None, None, None, None],
# "up to 5" -> nursery; "16+" -> sixth form; "318" -> all-through;
# "411" -> primary; null age_range with null phase -> "School".
"age_range": ["up to 5", "16+", "318", "411", None],
},
# Production reads these from a scanned parquet as String; an all-null
# Python list would otherwise infer the Null dtype and break .str ops.
schema_overrides={
"phase": pl.String,
"type_group": pl.String,
"age_range": pl.String,
},
)
categories = df.select(
_school_icon_category_expr().alias("category")
)["category"].to_list()
assert categories == [
"Nursery school",
"Sixth form",
"All-through school",
"Primary school",
"School",
]