294 lines
9.9 KiB
Python
294 lines
9.9 KiB
Python
import polars as pl
|
||
|
||
from pipeline.transform.transform_poi import (
|
||
_load_ofsted_ratings,
|
||
_school_icon_category_expr,
|
||
transform_grocery_retail_points,
|
||
)
|
||
|
||
|
||
def test_transform_grocery_retail_points_outputs_chain_categories():
|
||
raw = pl.DataFrame(
|
||
{
|
||
"id": [101, 102, 103],
|
||
"retailer": ["Waitrose", "Sainsburys", "The Co-operative Group"],
|
||
"fascia": ["Waitrose", "Sainsbury's Local", "Co-op Food"],
|
||
"store_name": ["Waitrose Test", "Sainsbury''s Test", "Co-op Test"],
|
||
"long_wgs": [-0.141, -0.142, -0.143],
|
||
"lat_wgs": [51.515, 51.516, 51.517],
|
||
}
|
||
)
|
||
|
||
pois = transform_grocery_retail_points(raw, min_chain_locations=1)
|
||
|
||
assert pois.select(
|
||
"id", "name", "category", "icon_category", "group", "emoji"
|
||
).to_dicts() == [
|
||
{
|
||
"id": "glx-101",
|
||
"name": "Waitrose Test",
|
||
"category": "Waitrose",
|
||
"icon_category": "Waitrose",
|
||
"group": "Groceries",
|
||
"emoji": "🛒",
|
||
},
|
||
{
|
||
"id": "glx-102",
|
||
"name": "Sainsbury's Test",
|
||
"category": "Sainsbury's",
|
||
"icon_category": "Sainsbury's Local",
|
||
"group": "Groceries",
|
||
"emoji": "🛒",
|
||
},
|
||
{
|
||
"id": "glx-103",
|
||
"name": "Co-op Test",
|
||
"category": "Co-op",
|
||
"icon_category": "Co-op",
|
||
"group": "Groceries",
|
||
"emoji": "🛒",
|
||
},
|
||
]
|
||
|
||
|
||
def test_transform_grocery_retail_points_keeps_fascia_icon_category():
|
||
raw = pl.DataFrame(
|
||
{
|
||
"id": [101, 102, 103, 104],
|
||
"retailer": ["Tesco", "Iceland", "Waitrose", "Morrisons"],
|
||
"fascia": [
|
||
"Tesco Express Esso",
|
||
"The Food Warehouse",
|
||
"Little Waitrose Shell",
|
||
"Morrisons Daily",
|
||
],
|
||
"store_name": [
|
||
"Tesco Test Express",
|
||
"Iceland Test Food Warehouse",
|
||
"Little Waitrose Test",
|
||
"Morrisons Daily Test",
|
||
],
|
||
"long_wgs": [-0.141, -0.142, -0.143, -0.144],
|
||
"lat_wgs": [51.515, 51.516, 51.517, 51.518],
|
||
}
|
||
)
|
||
|
||
pois = transform_grocery_retail_points(raw, min_chain_locations=1)
|
||
|
||
assert pois.select("category", "icon_category").to_dicts() == [
|
||
{"category": "Tesco", "icon_category": "Tesco Express"},
|
||
{"category": "Iceland", "icon_category": "The Food Warehouse"},
|
||
{"category": "Waitrose", "icon_category": "Little Waitrose"},
|
||
{"category": "Morrisons", "icon_category": "Morrisons Daily"},
|
||
]
|
||
|
||
|
||
def test_transform_grocery_retail_points_merges_cooperative_societies():
|
||
raw = pl.DataFrame(
|
||
{
|
||
"id": [101, 102, 103],
|
||
"retailer": [
|
||
"Central England Co-operative",
|
||
"Lincolnshire Co-operative",
|
||
"The Southern Co-operative",
|
||
],
|
||
"fascia": [
|
||
"Central England Co-operative",
|
||
"The Co-operative Food",
|
||
None,
|
||
],
|
||
"store_name": [
|
||
"Central Co-op Test",
|
||
"Lincolnshire Co-op Test",
|
||
"Southern Co-op Test",
|
||
],
|
||
"long_wgs": [-0.141, -0.142, -0.143],
|
||
"lat_wgs": [51.515, 51.516, 51.517],
|
||
}
|
||
)
|
||
|
||
pois = transform_grocery_retail_points(raw, min_chain_locations=1)
|
||
|
||
assert pois.select("category", "icon_category").to_dicts() == [
|
||
{"category": "Co-op", "icon_category": "Co-op"},
|
||
{"category": "Co-op", "icon_category": "Co-op"},
|
||
{"category": "Co-op", "icon_category": "Co-op"},
|
||
]
|
||
|
||
|
||
def test_transform_grocery_retail_points_pools_small_coop_societies_before_cutoff():
|
||
# Each Co-op society has <5 in-England stores; only after normalising to the
|
||
# shared "Co-op" brand do they clear MIN_GROCERY_CHAIN_LOCATIONS together.
|
||
societies = [
|
||
"Central England Co-operative",
|
||
"Lincolnshire Co-operative",
|
||
"The Southern Co-operative",
|
||
"Midcounties Co-operative",
|
||
"Heart of England Co-operative",
|
||
]
|
||
raw = pl.DataFrame(
|
||
{
|
||
"id": list(range(1, len(societies) + 1)),
|
||
"retailer": societies,
|
||
"fascia": ["The Co-operative Food"] * len(societies),
|
||
"store_name": [f"Co-op Test {i}" for i in range(1, len(societies) + 1)],
|
||
"long_wgs": [-0.141] * len(societies),
|
||
"lat_wgs": [51.515] * len(societies),
|
||
}
|
||
)
|
||
|
||
pois = transform_grocery_retail_points(raw)
|
||
|
||
assert pois.height == len(societies)
|
||
assert pois["category"].unique().to_list() == ["Co-op"]
|
||
|
||
|
||
def test_transform_grocery_retail_points_accepts_base_fascias():
|
||
raw = pl.DataFrame(
|
||
{
|
||
"id": [101, 102, 103, 104],
|
||
"retailer": ["Aldi", "Asda", "Booths", "Whole Foods Market"],
|
||
"fascia": ["Aldi", "Asda Superstore", "Booths", "Whole Foods Market"],
|
||
"store_name": [
|
||
"Aldi Test",
|
||
"Asda Test Superstore",
|
||
"Booths Test",
|
||
"Whole Foods Test",
|
||
],
|
||
"long_wgs": [-0.141, -0.142, -0.143, -0.144],
|
||
"lat_wgs": [51.515, 51.516, 51.517, 51.518],
|
||
}
|
||
)
|
||
|
||
pois = transform_grocery_retail_points(raw, min_chain_locations=1)
|
||
|
||
assert pois.select("category", "icon_category").to_dicts() == [
|
||
{"category": "Aldi", "icon_category": "Aldi"},
|
||
{"category": "Asda", "icon_category": "Asda Superstore"},
|
||
{"category": "Booths", "icon_category": "Booths"},
|
||
{"category": "Whole Foods Market", "icon_category": "Whole Foods Market"},
|
||
]
|
||
|
||
|
||
def test_transform_grocery_retail_points_drops_invalid_rows():
|
||
raw = pl.DataFrame(
|
||
{
|
||
"id": [101, 102],
|
||
"retailer": ["Waitrose", ""],
|
||
"fascia": ["Waitrose", "Tesco"],
|
||
"store_name": ["Waitrose Test", "Tesco Test"],
|
||
"long_wgs": [-0.141, -0.142],
|
||
"lat_wgs": [51.515, 51.516],
|
||
}
|
||
)
|
||
|
||
pois = transform_grocery_retail_points(raw, min_chain_locations=1)
|
||
|
||
assert pois["category"].to_list() == ["Waitrose"]
|
||
|
||
|
||
def test_transform_grocery_retail_points_includes_unmapped_chains_with_five_locations():
|
||
raw = pl.DataFrame(
|
||
{
|
||
"id": list(range(1, 10)),
|
||
"retailer": ["Tian Tian"] * 5 + ["Corner Shop"] * 4,
|
||
"fascia": ["Tian Tian Market"] * 5 + ["Corner Shop"] * 4,
|
||
"store_name": [f"Store {i}" for i in range(1, 10)],
|
||
"long_wgs": [-0.1] * 9,
|
||
"lat_wgs": [51.5] * 9,
|
||
}
|
||
)
|
||
|
||
pois = transform_grocery_retail_points(raw)
|
||
|
||
assert pois.select("id", "category", "icon_category").to_dicts() == [
|
||
{"id": "glx-1", "category": "Tian Tian", "icon_category": "Tian Tian"},
|
||
{"id": "glx-2", "category": "Tian Tian", "icon_category": "Tian Tian"},
|
||
{"id": "glx-3", "category": "Tian Tian", "icon_category": "Tian Tian"},
|
||
{"id": "glx-4", "category": "Tian Tian", "icon_category": "Tian Tian"},
|
||
{"id": "glx-5", "category": "Tian Tian", "icon_category": "Tian Tian"},
|
||
]
|
||
|
||
|
||
def test_load_ofsted_ratings_falls_back_to_ungraded_outcome(tmp_path):
|
||
# URNs 1-4: graded results map straight through. URNs 5-6: no usable graded
|
||
# grade (null/"Not judged") but a good/outstanding ungraded outcome, incl.
|
||
# the "(Concerns)"/"(Improving)" suffixes. URN 7: genuinely "Not judged".
|
||
# URN 8: a real grade 3 must NOT be overridden by an ungraded outcome.
|
||
ofsted_path = tmp_path / "ofsted.parquet"
|
||
pl.DataFrame(
|
||
{
|
||
"URN": [1, 2, 3, 4, 5, 6, 7, 8],
|
||
"Latest OEIF overall effectiveness": [
|
||
"1",
|
||
"2",
|
||
"3",
|
||
"4",
|
||
None,
|
||
"Not judged",
|
||
"Not judged",
|
||
"3",
|
||
],
|
||
"Ungraded inspection overall outcome": [
|
||
None,
|
||
None,
|
||
None,
|
||
None,
|
||
"School remains Outstanding",
|
||
"School remains Good (Concerns)",
|
||
None,
|
||
"School remains Outstanding",
|
||
],
|
||
}
|
||
).write_parquet(ofsted_path)
|
||
|
||
ratings = (
|
||
_load_ofsted_ratings(ofsted_path)
|
||
.collect()
|
||
.sort("urn")
|
||
.to_dicts()
|
||
)
|
||
|
||
assert ratings == [
|
||
{"urn": 1, "ofsted_rating": "Outstanding"},
|
||
{"urn": 2, "ofsted_rating": "Good"},
|
||
{"urn": 3, "ofsted_rating": "Requires improvement"},
|
||
{"urn": 4, "ofsted_rating": "Inadequate"},
|
||
{"urn": 5, "ofsted_rating": "Outstanding"},
|
||
{"urn": 6, "ofsted_rating": "Good"},
|
||
{"urn": 7, "ofsted_rating": "Not judged"},
|
||
{"urn": 8, "ofsted_rating": "Requires improvement"},
|
||
]
|
||
|
||
|
||
def test_school_icon_category_handles_one_sided_age_ranges():
|
||
# gias._format_age_range emits "up to {high}", "{low}+" and "{low}–{high}".
|
||
# All three (plus null) must classify, not fall through to "School".
|
||
df = pl.DataFrame(
|
||
{
|
||
"phase": [None, None, None, None, None],
|
||
"type_group": [None, None, None, None, None],
|
||
# "up to 5" -> nursery; "16+" -> sixth form; "3–18" -> all-through;
|
||
# "4–11" -> primary; null age_range with null phase -> "School".
|
||
"age_range": ["up to 5", "16+", "3–18", "4–11", None],
|
||
},
|
||
# Production reads these from a scanned parquet as String; an all-null
|
||
# Python list would otherwise infer the Null dtype and break .str ops.
|
||
schema_overrides={
|
||
"phase": pl.String,
|
||
"type_group": pl.String,
|
||
"age_range": pl.String,
|
||
},
|
||
)
|
||
|
||
categories = df.select(
|
||
_school_icon_category_expr().alias("category")
|
||
)["category"].to_list()
|
||
|
||
assert categories == [
|
||
"Nursery school",
|
||
"Sixth form",
|
||
"All-through school",
|
||
"Primary school",
|
||
"School",
|
||
]
|