perfect-postcode/pipeline/processors/journey_times_aggregator.py

127 lines
3.8 KiB
Python

"""Aggregate journey times data by H3 hexagonal cells."""
from pathlib import Path
import polars as pl
from pipeline.config import AGGREGATES_DIR, H3_RESOLUTIONS, PROCESSED_DIR
JOURNEY_COLS = [
"public_transport_easy_minutes",
"public_transport_quick_minutes",
"cycling_minutes",
]
AGGREGATE_COLS = [
"median_pt_easy_minutes",
"median_pt_quick_minutes",
"median_cycling_minutes",
"median_journey_minutes",
]
def aggregate_journey_times(
journey_times_path: Path | None = None,
postcodes_h3_path: Path | None = None,
aggregates_dir: Path | None = None,
) -> list[Path]:
"""
Add journey times to existing H3 aggregate parquet files.
Joins journey_times_bank_checkpoint.parquet with postcodes_h3.parquet on postcode,
aggregates by H3 cell, then merges into existing res{N}.parquet files.
"""
journey_times_path = (
journey_times_path
or PROCESSED_DIR / "journey_times_bank_checkpoint.parquet"
)
postcodes_h3_path = postcodes_h3_path or PROCESSED_DIR / "postcodes_h3.parquet"
aggregates_dir = aggregates_dir or AGGREGATES_DIR
# Load journey times data
journey_df = pl.read_parquet(journey_times_path).select(
["postcode"] + JOURNEY_COLS
)
# Filter out rows where all journey time columns are null
journey_df = journey_df.filter(
pl.any_horizontal(pl.col(c).is_not_null() for c in JOURNEY_COLS)
)
if journey_df.height == 0:
print("No valid journey times found")
return []
# Load postcodes with H3 indices
postcodes_df = pl.read_parquet(postcodes_h3_path)
# Join on postcode to get H3 indices
joined_df = journey_df.join(postcodes_df, on="postcode", how="inner")
if joined_df.height == 0:
print("No matching postcodes found")
return []
print(f"Joined {joined_df.height} postcodes with journey times")
updated_paths = []
for resolution in H3_RESOLUTIONS:
h3_col = f"h3_res{resolution}"
parquet_path = aggregates_dir / f"res{resolution}.parquet"
if not parquet_path.exists():
print(f"Skipping resolution {resolution} - {parquet_path} not found")
continue
if h3_col not in joined_df.columns:
print(f"Skipping resolution {resolution} - column {h3_col} not found")
continue
# Aggregate journey times by H3 cell
journey_agg = (
joined_df.group_by(h3_col)
.agg(
pl.col("public_transport_easy_minutes")
.median()
.alias("median_pt_easy_minutes"),
pl.col("public_transport_quick_minutes")
.median()
.alias("median_pt_quick_minutes"),
pl.col("cycling_minutes")
.median()
.alias("median_cycling_minutes"),
pl.col("public_transport_quick_minutes")
.median()
.alias("median_journey_minutes"),
)
.rename({h3_col: "h3"})
)
# Load existing parquet
existing_df = pl.read_parquet(parquet_path)
# Drop existing journey time columns if present
existing_df = existing_df.drop(
[c for c in AGGREGATE_COLS if c in existing_df.columns]
)
# Left join journey times onto existing data
updated_df = existing_df.join(journey_agg, on="h3", how="left")
# Save back to parquet
updated_df.write_parquet(parquet_path)
updated_paths.append(parquet_path)
matched = updated_df.filter(
pl.col("median_journey_minutes").is_not_null()
).height
print(
f"Updated {parquet_path.name}: {matched} rows with journey times "
f"(out of {updated_df.height} total)"
)
return updated_paths
if __name__ == "__main__":
aggregate_journey_times()