import polars as pl from pipeline.transform.transform_poi import transform_grocery_retail_points def test_transform_grocery_retail_points_outputs_chain_categories(): raw = pl.DataFrame( { "id": [101, 102, 103], "retailer": ["Waitrose", "Sainsburys", "The Co-operative Group"], "fascia": ["Waitrose", "Sainsbury's Local", "Co-op Food"], "store_name": ["Waitrose Test", "Sainsbury''s Test", "Co-op Test"], "long_wgs": [-0.141, -0.142, -0.143], "lat_wgs": [51.515, 51.516, 51.517], } ) pois = transform_grocery_retail_points(raw) assert pois.select("id", "name", "category", "group", "emoji").to_dicts() == [ { "id": "glx-101", "name": "Waitrose Test", "category": "Waitrose", "group": "Groceries", "emoji": "🛒", }, { "id": "glx-102", "name": "Sainsbury's Test", "category": "Sainsbury's", "group": "Groceries", "emoji": "🛒", }, { "id": "glx-103", "name": "Co-op Test", "category": "Co-op", "group": "Groceries", "emoji": "🛒", }, ] def test_transform_grocery_retail_points_drops_invalid_rows(): raw = pl.DataFrame( { "id": [101, 102], "retailer": ["Waitrose", ""], "fascia": ["Waitrose", "Tesco"], "store_name": ["Waitrose Test", "Tesco Test"], "long_wgs": [-0.141, -0.142], "lat_wgs": [51.515, 51.516], } ) pois = transform_grocery_retail_points(raw) assert pois["category"].to_list() == ["Waitrose"]