Integrate journey times into taskfile

This commit is contained in:
Andras Schmelczer 2026-02-03 19:22:31 +00:00
parent 25865acd44
commit 0242722268
5 changed files with 37 additions and 18 deletions

View file

@ -1,12 +1,10 @@
"""Journey times calculation module for TfL transit data."""
from .config import (
DATA_DIR,
DESTINATIONS,
MAX_CONCURRENT,
MAX_DELAY,
MAX_POSTCODES,
OUTPUT_DIR,
REQUESTS_PER_MIN,
)
from .models import Destination, JourneyResult
@ -14,8 +12,6 @@ from .results import results_to_dataframe, save_results
from .tfl_client import fetch_journey_times
__all__ = [
"DATA_DIR",
"OUTPUT_DIR",
"MAX_DELAY",
"REQUESTS_PER_MIN",
"MAX_POSTCODES",

View file

@ -1,6 +1,8 @@
import argparse
import asyncio
import random
from datetime import date, timedelta
from pathlib import Path
import polars as pl
from tqdm import tqdm
@ -9,7 +11,6 @@ from .config import (
DESTINATIONS,
MAX_CONCURRENT,
MAX_POSTCODES,
OUTPUT_DIR,
MAX_DISTANCE_KM,
)
from .models import JourneyResult
@ -19,7 +20,29 @@ from pipeline.utils import haversine_km_expr
def main():
destination = DESTINATIONS["bank"]
parser = argparse.ArgumentParser(description="Fetch TfL journey times")
parser.add_argument(
"--destination",
required=True,
choices=list(DESTINATIONS.keys()),
help="Destination key",
)
parser.add_argument(
"--output-dir",
required=True,
type=Path,
help="Directory for output and checkpoint files",
)
parser.add_argument(
"--postcodes",
required=True,
type=Path,
help="ArcGIS postcode parquet file",
)
args = parser.parse_args()
destination = DESTINATIONS[args.destination]
output_dir = args.output_dir
# Calculate next Monday at 8am
today = date.today()
@ -33,7 +56,11 @@ def main():
f"at {journey_time[:2]}:{journey_time[2:]}"
)
postcodes_df = pl.read_parquet(OUTPUT_DIR / "postcodes_h3.parquet")
postcodes_df = pl.read_parquet(args.postcodes).select(
pl.col("pcds").alias("postcode"),
"lat",
"long",
)
print(f"Loaded {postcodes_df.height:,} postcodes")
# Filter to postcodes within range of destination
@ -59,13 +86,12 @@ def main():
checkpoint_saver = CheckpointSaver(
destination_name=destination.name,
output_dir=output_dir,
on_save=lambda path, count: print(
f"Checkpoint saved: {count:,} results to {path}"
),
)
# 25556/76273
# Resume from checkpoint if one exists
checkpoint_path = checkpoint_saver._checkpoint_path()
prior_results: list[JourneyResult] = []
@ -133,7 +159,7 @@ def main():
successful = results_df.filter(pl.col("cycling_minutes").is_not_null()).height
print(f"Completed: {successful}/{len(all_results)} successful")
parquet_path = save_results(results_df, destination.name)
parquet_path = save_results(results_df, destination.name, output_dir)
checkpoint_saver.cleanup_checkpoint()
print(f"Saved to {parquet_path}")

View file

@ -2,7 +2,6 @@
from .models import Destination
MAX_DELAY = 10
REQUESTS_PER_MIN = 500
MAX_POSTCODES = None
@ -20,4 +19,5 @@ DESTINATIONS = {
),
"paddington": Destination(51.5154, -0.1755, "Paddington", "940GZZLUPAC"),
"victoria": Destination(51.4965, -0.1447, "Victoria", "940GZZLUVIC"),
"fitzrovia": Destination(51.5165, -0.1310, "Fitzrovia", "940GZZLUTCR"),
}

View file

@ -3,7 +3,7 @@ from typing import Callable
import polars as pl
from .config import CHECKPOINT_INTERVAL, OUTPUT_DIR
from .config import CHECKPOINT_INTERVAL
from .models import JourneyResult
@ -28,12 +28,12 @@ class CheckpointSaver:
def __init__(
self,
destination_name: str,
output_dir: Path | None = None,
output_dir: Path,
interval: int = CHECKPOINT_INTERVAL,
on_save: Callable[[Path, int], None] | None = None,
):
self.destination_name = destination_name
self.output_dir = output_dir or OUTPUT_DIR
self.output_dir = output_dir
self.interval = interval
self.on_save = on_save
self.results: list[JourneyResult] = []
@ -73,11 +73,8 @@ class CheckpointSaver:
def save_results(
results: pl.DataFrame,
destination_name: str,
output_dir: Path | None = None,
output_dir: Path,
) -> Path:
if output_dir is None:
output_dir = OUTPUT_DIR
safe_name = destination_name.lower().replace(" ", "-")
parquet_path = output_dir / f"journey_times_{safe_name}.parquet"
results.write_parquet(parquet_path)