diff --git a/pipeline/transform/transform_poi.py b/pipeline/transform/transform_poi.py index d44b877..8482142 100644 --- a/pipeline/transform/transform_poi.py +++ b/pipeline/transform/transform_poi.py @@ -624,13 +624,20 @@ def transform(input_path: Path, naptan_path: Path | None = None) -> pl.LazyFrame if missing_emojis: raise ValueError(f"Empty emojis for: {missing_emojis}") + # Derive group from the first component of the raw category key, title-cased + group_mapping = { + k: k.split("/")[0].replace("_", " ").title() for k in CATEGORY_MAP + } + lf = lf.with_columns( + pl.col("category").replace_strict(group_mapping).alias("group"), pl.col("category").replace_strict(name_mapping).alias("category"), pl.col("category").replace_strict(emoji_mapping).alias("emoji"), ) naptan = pl.scan_parquet(naptan_path).with_columns( pl.col("category").replace_strict(NAPTAN_EMOJIS).alias("emoji"), + pl.lit("Public Transport").alias("group"), ) return pl.concat([lf, naptan], how="diagonal_relaxed")