import argparse import asyncio import random from datetime import date, timedelta from pathlib import Path import polars as pl from tqdm import tqdm from .config import ( DESTINATIONS, MAX_CONCURRENT, MAX_POSTCODES, MAX_DISTANCE_KM, ) from .models import JourneyResult from .results import CheckpointSaver, results_to_dataframe, save_results from .tfl_client import fetch_journey_times from pipeline.utils import haversine_km_expr def main(): parser = argparse.ArgumentParser(description="Fetch TfL journey times") parser.add_argument( "--destination", required=True, choices=list(DESTINATIONS.keys()), help="Destination key", ) parser.add_argument( "--output-dir", required=True, type=Path, help="Directory for output and checkpoint files", ) parser.add_argument( "--postcodes", required=True, type=Path, help="ArcGIS postcode parquet file", ) args = parser.parse_args() destination = DESTINATIONS[args.destination] output_dir = args.output_dir # Calculate next Monday at 8am today = date.today() days_until_monday = (7 - today.weekday()) % 7 or 7 journey_date = today + timedelta(days=days_until_monday) journey_time = "0845" print(f"Destination: {destination.name}") print( f"Journey: {journey_date.strftime('%A %Y-%m-%d')} " f"at {journey_time[:2]}:{journey_time[2:]}" ) postcodes_df = pl.read_parquet(args.postcodes).select( pl.col("pcds").alias("postcode"), "lat", "long", ) print(f"Loaded {postcodes_df.height:,} postcodes") # Filter to postcodes within range of destination postcodes_df = postcodes_df.with_columns( haversine_km_expr("lat", "long", destination.lat, destination.lon).alias( "distance_km" ) ).filter(pl.col("distance_km") <= MAX_DISTANCE_KM) print(f"Filtered to {postcodes_df.height:,} postcodes within {MAX_DISTANCE_KM}km") postcode_data = list( zip( postcodes_df["postcode"].to_list(), postcodes_df["lat"].to_list(), postcodes_df["long"].to_list(), ) ) if MAX_POSTCODES is not None and len(postcode_data) > MAX_POSTCODES: postcode_data = random.sample(postcode_data, MAX_POSTCODES) print(f"Randomly sampled {MAX_POSTCODES} postcodes") checkpoint_saver = CheckpointSaver( destination_name=destination.name, output_dir=output_dir, on_save=lambda path, count: print( f"Checkpoint saved: {count:,} results to {path}" ), ) # Resume from checkpoint if one exists checkpoint_path = checkpoint_saver._checkpoint_path() prior_results: list[JourneyResult] = [] if checkpoint_path.exists(): checkpoint_df = pl.read_parquet(checkpoint_path) # Deduplicate checkpoint rows per postcode, preferring rows with data checkpoint_df = checkpoint_df.sort( "public_transport_quick_minutes", nulls_last=True ).unique(subset=["postcode"], keep="first") completed_postcodes = set(checkpoint_df["postcode"].to_list()) prior_results = [ JourneyResult( postcode=row["postcode"], public_transport_easy_minutes=row["public_transport_easy_minutes"], public_transport_quick_minutes=row["public_transport_quick_minutes"], cycling_minutes=row["cycling_minutes"], error=row["error"], ) for row in checkpoint_df.iter_rows(named=True) ] checkpoint_saver.results = prior_results checkpoint_saver._last_save_count = len(prior_results) postcode_data = [ (pc, lat, lon) for pc, lat, lon in postcode_data if pc not in completed_postcodes ] print( f"Resumed from checkpoint: {len(prior_results):,} already done, " f"{len(postcode_data):,} remaining" ) def on_result(result): pbar.update(1) checkpoint_saver.add_result(result) with tqdm(total=len(postcode_data), desc="Fetching journeys") as pbar: new_results = asyncio.run( fetch_journey_times( postcode_data, destination, journey_date.strftime("%Y%m%d"), journey_time, MAX_CONCURRENT, progress_callback=on_result, ) ) all_results = prior_results + new_results results_df = results_to_dataframe(all_results) all_postcodes = {r.postcode for r in all_results} coords_df = postcodes_df.filter(pl.col("postcode").is_in(all_postcodes)).select( ["postcode", "lat", "long"] ) results_df = coords_df.join(results_df, on="postcode", how="left") results_df = results_df.with_columns( pl.lit(destination.name).alias("destination"), pl.lit(journey_date.strftime("%Y-%m-%d")).alias("journey_date"), pl.lit(f"{journey_time[:2]}:{journey_time[2:]}").alias("journey_time"), ) successful = results_df.filter(pl.col("cycling_minutes").is_not_null()).height print(f"Completed: {successful}/{len(all_results)} successful") parquet_path = save_results(results_df, destination.name, output_dir) checkpoint_saver.cleanup_checkpoint() print(f"Saved to {parquet_path}") if __name__ == "__main__": main()