diff --git a/pipeline/journey_times/__main__.py b/pipeline/journey_times/__main__.py index e5f4e97..0951ae4 100644 --- a/pipeline/journey_times/__main__.py +++ b/pipeline/journey_times/__main__.py @@ -5,8 +5,8 @@ from datetime import date, timedelta import polars as pl from tqdm import tqdm -from .config import DESTINATIONS, MAX_CONCURRENT, MAX_POSTCODES, OUTPUT_DIR -from .results import results_to_dataframe, save_results +from .config import DESTINATIONS, MAX_CONCURRENT, MAX_POSTCODES, OUTPUT_DIR, MAX_DISTANCE_KM +from .results import CheckpointSaver, results_to_dataframe, save_results from .tfl_client import fetch_journey_times @@ -28,6 +28,35 @@ def main(): postcodes_df = pl.read_parquet(OUTPUT_DIR / "postcodes_h3.parquet") print(f"Loaded {postcodes_df.height:,} postcodes") + # Filter to postcodes within 150km of destination using Haversine formula + earth_radius_km = 6371 + + dest_lat_rad = destination.lat * 3.14159265359 / 180 + dest_lon_rad = destination.lon * 3.14159265359 / 180 + + postcodes_df = postcodes_df.with_columns( + ( + 2 + * earth_radius_km + * ( + ( + ((pl.lit(dest_lat_rad) - pl.col("lat") * 3.14159265359 / 180) / 2).sin() + ** 2 + + pl.lit(dest_lat_rad).cos() + * (pl.col("lat") * 3.14159265359 / 180).cos() + * ( + (pl.lit(dest_lon_rad) - pl.col("long") * 3.14159265359 / 180) / 2 + ).sin() + ** 2 + ) + .sqrt() + .arcsin() + ) + ).alias("distance_km") + ).filter(pl.col("distance_km") <= MAX_DISTANCE_KM) + + print(f"Filtered to {postcodes_df.height:,} postcodes within {MAX_DISTANCE_KM}km") + postcode_data = list( zip( postcodes_df["postcode"].to_list(), @@ -40,6 +69,15 @@ def main(): postcode_data = random.sample(postcode_data, MAX_POSTCODES) print(f"Randomly sampled {MAX_POSTCODES} postcodes") + checkpoint_saver = CheckpointSaver( + destination_name=destination.name, + on_save=lambda path, count: print(f"Checkpoint saved: {count:,} results to {path}"), + ) + + def on_result(result): + pbar.update(1) + checkpoint_saver.add_result(result) + with tqdm(total=len(postcode_data), desc="Fetching journeys") as pbar: results = asyncio.run( fetch_journey_times( @@ -48,7 +86,7 @@ def main(): journey_date.strftime("%Y%m%d"), journey_time, MAX_CONCURRENT, - progress_callback=lambda _: pbar.update(1), + progress_callback=on_result, ) ) @@ -70,6 +108,7 @@ def main(): print(f"Completed: {successful}/{len(results)} successful") parquet_path = save_results(results_df, destination.name) + checkpoint_saver.cleanup_checkpoint() print(f"Saved to {parquet_path}") diff --git a/pipeline/journey_times/config.py b/pipeline/journey_times/config.py index 4741653..f2401cc 100644 --- a/pipeline/journey_times/config.py +++ b/pipeline/journey_times/config.py @@ -9,8 +9,11 @@ OUTPUT_DIR = DATA_DIR / "processed" MAX_DELAY = 10 REQUESTS_PER_MIN = 500 -MAX_POSTCODES = 100 # Set to None to process all postcodes -MAX_CONCURRENT = 5 +MAX_POSTCODES = None +MAX_CONCURRENT = 80 +MAX_DISTANCE_KM = 110 +CHECKPOINT_INTERVAL = 10000 + DESTINATIONS = { "bank": Destination(51.5133, -0.0886, "Bank", "940GZZLUBNK"), diff --git a/pipeline/journey_times/results.py b/pipeline/journey_times/results.py index 4d4d54f..2845fb2 100644 --- a/pipeline/journey_times/results.py +++ b/pipeline/journey_times/results.py @@ -1,8 +1,9 @@ from pathlib import Path +from typing import Callable import polars as pl -from .config import OUTPUT_DIR +from .config import CHECKPOINT_INTERVAL, OUTPUT_DIR from .models import JourneyResult @@ -21,6 +22,54 @@ def results_to_dataframe(results: list[JourneyResult]) -> pl.DataFrame: ) +class CheckpointSaver: + """Collects results and saves checkpoints at regular intervals.""" + + def __init__( + self, + destination_name: str, + output_dir: Path | None = None, + interval: int = CHECKPOINT_INTERVAL, + on_save: Callable[[Path, int], None] | None = None, + ): + self.destination_name = destination_name + self.output_dir = output_dir or OUTPUT_DIR + self.interval = interval + self.on_save = on_save + self.results: list[JourneyResult] = [] + self._last_save_count = 0 + + def add_result(self, result: JourneyResult) -> None: + """Add a result and save checkpoint if interval is reached.""" + self.results.append(result) + if len(self.results) - self._last_save_count >= self.interval: + self.save_checkpoint() + + def save_checkpoint(self) -> Path: + """Save current results to checkpoint file.""" + df = results_to_dataframe(self.results) + path = self._checkpoint_path() + df.write_parquet(path) + self._last_save_count = len(self.results) + if self.on_save: + self.on_save(path, len(self.results)) + return path + + def _checkpoint_path(self) -> Path: + safe_name = self.destination_name.lower().replace(" ", "-") + return self.output_dir / f"journey_times_{safe_name}_checkpoint.parquet" + + def get_results(self) -> list[JourneyResult]: + """Return all collected results.""" + return self.results + + def cleanup_checkpoint(self) -> None: + """Remove the checkpoint file after successful completion.""" + path = self._checkpoint_path() + if path.exists(): + path.unlink() + + def save_results( results: pl.DataFrame, destination_name: str,