import polars as pl from pipeline.transform.poi_proximity import ( GREENSPACE_PARK_FUNCTIONS, POI_GROUPS_2KM, _build_poi_category_groups, _dynamic_poi_metric_renames, _greenspace_count_frame, _groceries_categories, ) from pipeline.utils.poi_counts import count_pois_per_postcode def test_groceries_2km_counts_geolytix_brand_categories() -> None: """The static groceries 2km count must include GEOLYTIX brand POIs. GEOLYTIX stores the brand (e.g. "Tesco") in `category` with group "Groceries" and never emits the literal "Supermarket"; matching only the OSM strings counts the supermarket but drops the brand store. """ postcodes = pl.DataFrame( { "postcode": ["SW1A 1AA"], "lat": [51.5010], "lon": [-0.1416], } ) pois = pl.DataFrame( { "category": ["Tesco", "Supermarket"], "group": ["Groceries", "Groceries"], "lat": [51.5011, 51.5012], "lng": [-0.1417, -0.1418], } ) groups_2km = {**POI_GROUPS_2KM, "groceries": _groceries_categories(pois)} result = count_pois_per_postcode(postcodes, pois, groups=groups_2km, radius_km=2) # Both the GEOLYTIX brand ("Tesco") and the OSM "Supermarket" must count. # Pre-fix the static list was ["Greengrocer", "Supermarket", "Convenience # Store"], so "Tesco" was dropped and this was 1. assert result["groceries_2km"][0] == 2 def test_dynamic_poi_groups_include_requested_categories_only() -> None: pois = pl.DataFrame( { "group": ( ["Public Transport"] * 2 + ["Leisure"] * 2 + ["Groceries"] * 101 + ["Groceries"] * 100 + ["Leisure"] * 10 + ["Education"] * 200 + ["Health"] * 200 ), "category": ( ["Rail station", "Bus stop"] + ["Café", "Restaurant"] + ["Tesco"] * 101 + ["Waitrose"] * 100 + ["Park"] * 10 + ["School"] * 200 + ["Pharmacy"] * 200 ), "lat": [51.5] * 615, "lng": [-0.1] * 615, } ) groups, display_names = _build_poi_category_groups(pois) assert set(display_names.values()) == { "Bus stop", "Café", "Pharmacy", "Rail station", "Restaurant", "Tesco", } assert "poi_waitrose" not in groups assert "poi_park" not in groups assert "poi_school" not in groups def test_dynamic_poi_metric_renames_support_park_count_options() -> None: assert _dynamic_poi_metric_renames({"parks": "Park"}) == { "parks_nearest_km": "Distance to nearest amenity (Park) (km)", "parks_2km": "Number of amenities (Park) within 2km", "parks_5km": "Number of amenities (Park) within 5km", } def test_groceries_categories_exclude_speciality_food_retail() -> None: """The static groceries metric must not count bakeries/butchers/delis/ off-licences (speciality retail, ~a third of the group), while keeping Supermarket, Convenience Store, Greengrocer and GEOLYTIX brands.""" pois = pl.DataFrame( { "category": [ "Tesco", "Supermarket", "Convenience Store", "Greengrocer", "Bakery", "Butcher & Fishmonger", "Deli & Specialty", "Off-Licence", "Café", ], "group": ["Groceries"] * 8 + ["Leisure"], "lat": [51.5] * 9, "lng": [-0.1] * 9, } ) assert _groceries_categories(pois) == [ "Convenience Store", "Greengrocer", "Supermarket", "Tesco", ] def test_park_group_excludes_playgrounds_and_play_space() -> None: # "Play Space" (playgrounds) must not count as a Park; Public Park Or # Garden and Playing Field (open recreation grounds) are in scope. assert GREENSPACE_PARK_FUNCTIONS == { "parks": ["Public Park Or Garden", "Playing Field"] } def test_greenspace_count_frame_collapses_to_one_row_per_site() -> None: # Three gates of one park (with a site centroid), one gate of another park # without a centroid, and one centroid-fallback row with a null site_id. greenspace = pl.DataFrame( { "lat": [51.50, 51.51, 51.52, 53.0, 54.0], "lng": [-0.10, -0.11, -0.12, -2.0, -3.0], "category": ["Public Park Or Garden"] * 3 + ["Playing Field", "Public Park Or Garden"], "site_id": ["site-a", "site-a", "site-a", "site-b", None], "site_lat": [51.505, 51.505, 51.505, None, None], "site_lng": [-0.105, -0.105, -0.105, None, None], } ) result = _greenspace_count_frame(greenspace).sort("lat") # One row per site (site-a collapses 3 → 1), null-site rows preserved. assert result.height == 3 site_a = result.filter(pl.col("site_id") == "site-a") # The representative point is the site centroid… assert site_a["lat"].to_list() == [51.505] assert site_a["lng"].to_list() == [-0.105] # …or the first access point when no centroid is available. site_b = result.filter(pl.col("site_id") == "site-b") assert site_b["lat"].to_list() == [53.0] def test_greenspace_count_frame_passes_legacy_parquet_through() -> None: # The shipped parquet predates the site_id column; counting must not crash # (it keeps the old access-point grain until regenerated). legacy = pl.DataFrame( { "lat": [51.50, 51.51], "lng": [-0.10, -0.11], "category": ["Public Park Or Garden", "Play Space"], } ) assert _greenspace_count_frame(legacy).equals(legacy)