perfect-postcode/pipeline/journey_times/__main__.py
2026-02-08 10:21:37 +00:00

167 lines
5.4 KiB
Python

import argparse
import asyncio
import random
from datetime import date, timedelta
from pathlib import Path
import polars as pl
from tqdm import tqdm
from .config import (
DESTINATIONS,
MAX_CONCURRENT,
MAX_POSTCODES,
MAX_DISTANCE_KM,
)
from .models import JourneyResult
from .results import CheckpointSaver, results_to_dataframe, save_results
from .tfl_client import fetch_journey_times
from pipeline.utils import haversine_km_expr
def main():
parser = argparse.ArgumentParser(description="Fetch TfL journey times")
parser.add_argument(
"--destination",
required=True,
choices=list(DESTINATIONS.keys()),
help="Destination key",
)
parser.add_argument(
"--output-dir",
required=True,
type=Path,
help="Directory for output and checkpoint files",
)
parser.add_argument(
"--postcodes",
required=True,
type=Path,
help="ArcGIS postcode parquet file",
)
args = parser.parse_args()
destination = DESTINATIONS[args.destination]
output_dir = args.output_dir
# Calculate next Monday at 8am
today = date.today()
days_until_monday = (7 - today.weekday()) % 7 or 7
journey_date = today + timedelta(days=days_until_monday)
journey_time = "0845"
print(f"Destination: {destination.name}")
print(
f"Journey: {journey_date.strftime('%A %Y-%m-%d')} "
f"at {journey_time[:2]}:{journey_time[2:]}"
)
postcodes_df = pl.read_parquet(args.postcodes).select(
pl.col("pcds").alias("postcode"),
"lat",
"long",
)
print(f"Loaded {postcodes_df.height:,} postcodes")
# Filter to postcodes within range of destination
postcodes_df = postcodes_df.with_columns(
haversine_km_expr("lat", "long", destination.lat, destination.lon).alias(
"distance_km"
)
).filter(pl.col("distance_km") <= MAX_DISTANCE_KM)
print(f"Filtered to {postcodes_df.height:,} postcodes within {MAX_DISTANCE_KM}km")
postcode_data = list(
zip(
postcodes_df["postcode"].to_list(),
postcodes_df["lat"].to_list(),
postcodes_df["long"].to_list(),
)
)
if MAX_POSTCODES is not None and len(postcode_data) > MAX_POSTCODES:
postcode_data = random.sample(postcode_data, MAX_POSTCODES)
print(f"Randomly sampled {MAX_POSTCODES} postcodes")
checkpoint_saver = CheckpointSaver(
destination_name=destination.name,
output_dir=output_dir,
on_save=lambda path, count: print(
f"Checkpoint saved: {count:,} results to {path}"
),
)
# Resume from checkpoint if one exists
checkpoint_path = checkpoint_saver._checkpoint_path()
prior_results: list[JourneyResult] = []
if checkpoint_path.exists():
checkpoint_df = pl.read_parquet(checkpoint_path)
# Deduplicate checkpoint rows per postcode, preferring rows with data
checkpoint_df = checkpoint_df.sort(
"public_transport_quick_minutes", nulls_last=True
).unique(subset=["postcode"], keep="first")
completed_postcodes = set(checkpoint_df["postcode"].to_list())
prior_results = [
JourneyResult(
postcode=row["postcode"],
public_transport_easy_minutes=row["public_transport_easy_minutes"],
public_transport_quick_minutes=row["public_transport_quick_minutes"],
cycling_minutes=row["cycling_minutes"],
error=row["error"],
)
for row in checkpoint_df.iter_rows(named=True)
]
checkpoint_saver.results = prior_results
checkpoint_saver._last_save_count = len(prior_results)
postcode_data = [
(pc, lat, lon)
for pc, lat, lon in postcode_data
if pc not in completed_postcodes
]
print(
f"Resumed from checkpoint: {len(prior_results):,} already done, "
f"{len(postcode_data):,} remaining"
)
def on_result(result):
pbar.update(1)
checkpoint_saver.add_result(result)
with tqdm(total=len(postcode_data), desc="Fetching journeys") as pbar:
new_results = asyncio.run(
fetch_journey_times(
postcode_data,
destination,
journey_date.strftime("%Y%m%d"),
journey_time,
MAX_CONCURRENT,
progress_callback=on_result,
)
)
all_results = prior_results + new_results
results_df = results_to_dataframe(all_results)
all_postcodes = {r.postcode for r in all_results}
coords_df = postcodes_df.filter(pl.col("postcode").is_in(all_postcodes)).select(
["postcode", "lat", "long"]
)
results_df = coords_df.join(results_df, on="postcode", how="left")
results_df = results_df.with_columns(
pl.lit(destination.name).alias("destination"),
pl.lit(journey_date.strftime("%Y-%m-%d")).alias("journey_date"),
pl.lit(f"{journey_time[:2]}:{journey_time[2:]}").alias("journey_time"),
)
successful = results_df.filter(pl.col("cycling_minutes").is_not_null()).height
print(f"Completed: {successful}/{len(all_results)} successful")
parquet_path = save_results(results_df, destination.name, output_dir)
checkpoint_saver.cleanup_checkpoint()
print(f"Saved to {parquet_path}")
if __name__ == "__main__":
main()