diff --git a/pipeline/journey_times/__init__.py b/pipeline/journey_times/__init__.py index 678959a..799c14b 100644 --- a/pipeline/journey_times/__init__.py +++ b/pipeline/journey_times/__init__.py @@ -1,12 +1,10 @@ """Journey times calculation module for TfL transit data.""" from .config import ( - DATA_DIR, DESTINATIONS, MAX_CONCURRENT, MAX_DELAY, MAX_POSTCODES, - OUTPUT_DIR, REQUESTS_PER_MIN, ) from .models import Destination, JourneyResult @@ -14,8 +12,6 @@ from .results import results_to_dataframe, save_results from .tfl_client import fetch_journey_times __all__ = [ - "DATA_DIR", - "OUTPUT_DIR", "MAX_DELAY", "REQUESTS_PER_MIN", "MAX_POSTCODES", diff --git a/pipeline/journey_times/__main__.py b/pipeline/journey_times/__main__.py index 6297f09..a623c04 100644 --- a/pipeline/journey_times/__main__.py +++ b/pipeline/journey_times/__main__.py @@ -1,6 +1,8 @@ +import argparse import asyncio import random from datetime import date, timedelta +from pathlib import Path import polars as pl from tqdm import tqdm @@ -9,7 +11,6 @@ from .config import ( DESTINATIONS, MAX_CONCURRENT, MAX_POSTCODES, - OUTPUT_DIR, MAX_DISTANCE_KM, ) from .models import JourneyResult @@ -19,7 +20,29 @@ from pipeline.utils import haversine_km_expr def main(): - destination = DESTINATIONS["bank"] + parser = argparse.ArgumentParser(description="Fetch TfL journey times") + parser.add_argument( + "--destination", + required=True, + choices=list(DESTINATIONS.keys()), + help="Destination key", + ) + parser.add_argument( + "--output-dir", + required=True, + type=Path, + help="Directory for output and checkpoint files", + ) + parser.add_argument( + "--postcodes", + required=True, + type=Path, + help="ArcGIS postcode parquet file", + ) + args = parser.parse_args() + + destination = DESTINATIONS[args.destination] + output_dir = args.output_dir # Calculate next Monday at 8am today = date.today() @@ -33,7 +56,11 @@ def main(): f"at {journey_time[:2]}:{journey_time[2:]}" ) - postcodes_df = pl.read_parquet(OUTPUT_DIR / "postcodes_h3.parquet") + postcodes_df = pl.read_parquet(args.postcodes).select( + pl.col("pcds").alias("postcode"), + "lat", + "long", + ) print(f"Loaded {postcodes_df.height:,} postcodes") # Filter to postcodes within range of destination @@ -59,13 +86,12 @@ def main(): checkpoint_saver = CheckpointSaver( destination_name=destination.name, + output_dir=output_dir, on_save=lambda path, count: print( f"Checkpoint saved: {count:,} results to {path}" ), ) - # 25556/76273 - # Resume from checkpoint if one exists checkpoint_path = checkpoint_saver._checkpoint_path() prior_results: list[JourneyResult] = [] @@ -133,7 +159,7 @@ def main(): successful = results_df.filter(pl.col("cycling_minutes").is_not_null()).height print(f"Completed: {successful}/{len(all_results)} successful") - parquet_path = save_results(results_df, destination.name) + parquet_path = save_results(results_df, destination.name, output_dir) checkpoint_saver.cleanup_checkpoint() print(f"Saved to {parquet_path}") diff --git a/pipeline/journey_times/config.py b/pipeline/journey_times/config.py index f86a422..ed36d58 100644 --- a/pipeline/journey_times/config.py +++ b/pipeline/journey_times/config.py @@ -2,7 +2,6 @@ from .models import Destination - MAX_DELAY = 10 REQUESTS_PER_MIN = 500 MAX_POSTCODES = None @@ -20,4 +19,5 @@ DESTINATIONS = { ), "paddington": Destination(51.5154, -0.1755, "Paddington", "940GZZLUPAC"), "victoria": Destination(51.4965, -0.1447, "Victoria", "940GZZLUVIC"), + "fitzrovia": Destination(51.5165, -0.1310, "Fitzrovia", "940GZZLUTCR"), } diff --git a/pipeline/journey_times/data.py b/pipeline/journey_times/data.py deleted file mode 100644 index e69de29..0000000 diff --git a/pipeline/journey_times/results.py b/pipeline/journey_times/results.py index 2845fb2..293e28f 100644 --- a/pipeline/journey_times/results.py +++ b/pipeline/journey_times/results.py @@ -3,7 +3,7 @@ from typing import Callable import polars as pl -from .config import CHECKPOINT_INTERVAL, OUTPUT_DIR +from .config import CHECKPOINT_INTERVAL from .models import JourneyResult @@ -28,12 +28,12 @@ class CheckpointSaver: def __init__( self, destination_name: str, - output_dir: Path | None = None, + output_dir: Path, interval: int = CHECKPOINT_INTERVAL, on_save: Callable[[Path, int], None] | None = None, ): self.destination_name = destination_name - self.output_dir = output_dir or OUTPUT_DIR + self.output_dir = output_dir self.interval = interval self.on_save = on_save self.results: list[JourneyResult] = [] @@ -73,11 +73,8 @@ class CheckpointSaver: def save_results( results: pl.DataFrame, destination_name: str, - output_dir: Path | None = None, + output_dir: Path, ) -> Path: - if output_dir is None: - output_dir = OUTPUT_DIR - safe_name = destination_name.lower().replace(" ", "-") parquet_path = output_dir / f"journey_times_{safe_name}.parquet" results.write_parquet(parquet_path)