Add radius, progress bar, and checkpointing

This commit is contained in:
Andras Schmelczer 2026-01-28 20:38:59 +00:00
parent d227239651
commit 275e5afac6
3 changed files with 97 additions and 6 deletions

View file

@ -5,8 +5,8 @@ from datetime import date, timedelta
import polars as pl
from tqdm import tqdm
from .config import DESTINATIONS, MAX_CONCURRENT, MAX_POSTCODES, OUTPUT_DIR
from .results import results_to_dataframe, save_results
from .config import DESTINATIONS, MAX_CONCURRENT, MAX_POSTCODES, OUTPUT_DIR, MAX_DISTANCE_KM
from .results import CheckpointSaver, results_to_dataframe, save_results
from .tfl_client import fetch_journey_times
@ -28,6 +28,35 @@ def main():
postcodes_df = pl.read_parquet(OUTPUT_DIR / "postcodes_h3.parquet")
print(f"Loaded {postcodes_df.height:,} postcodes")
# Filter to postcodes within 150km of destination using Haversine formula
earth_radius_km = 6371
dest_lat_rad = destination.lat * 3.14159265359 / 180
dest_lon_rad = destination.lon * 3.14159265359 / 180
postcodes_df = postcodes_df.with_columns(
(
2
* earth_radius_km
* (
(
((pl.lit(dest_lat_rad) - pl.col("lat") * 3.14159265359 / 180) / 2).sin()
** 2
+ pl.lit(dest_lat_rad).cos()
* (pl.col("lat") * 3.14159265359 / 180).cos()
* (
(pl.lit(dest_lon_rad) - pl.col("long") * 3.14159265359 / 180) / 2
).sin()
** 2
)
.sqrt()
.arcsin()
)
).alias("distance_km")
).filter(pl.col("distance_km") <= MAX_DISTANCE_KM)
print(f"Filtered to {postcodes_df.height:,} postcodes within {MAX_DISTANCE_KM}km")
postcode_data = list(
zip(
postcodes_df["postcode"].to_list(),
@ -40,6 +69,15 @@ def main():
postcode_data = random.sample(postcode_data, MAX_POSTCODES)
print(f"Randomly sampled {MAX_POSTCODES} postcodes")
checkpoint_saver = CheckpointSaver(
destination_name=destination.name,
on_save=lambda path, count: print(f"Checkpoint saved: {count:,} results to {path}"),
)
def on_result(result):
pbar.update(1)
checkpoint_saver.add_result(result)
with tqdm(total=len(postcode_data), desc="Fetching journeys") as pbar:
results = asyncio.run(
fetch_journey_times(
@ -48,7 +86,7 @@ def main():
journey_date.strftime("%Y%m%d"),
journey_time,
MAX_CONCURRENT,
progress_callback=lambda _: pbar.update(1),
progress_callback=on_result,
)
)
@ -70,6 +108,7 @@ def main():
print(f"Completed: {successful}/{len(results)} successful")
parquet_path = save_results(results_df, destination.name)
checkpoint_saver.cleanup_checkpoint()
print(f"Saved to {parquet_path}")

View file

@ -9,8 +9,11 @@ OUTPUT_DIR = DATA_DIR / "processed"
MAX_DELAY = 10
REQUESTS_PER_MIN = 500
MAX_POSTCODES = 100 # Set to None to process all postcodes
MAX_CONCURRENT = 5
MAX_POSTCODES = None
MAX_CONCURRENT = 80
MAX_DISTANCE_KM = 110
CHECKPOINT_INTERVAL = 10000
DESTINATIONS = {
"bank": Destination(51.5133, -0.0886, "Bank", "940GZZLUBNK"),

View file

@ -1,8 +1,9 @@
from pathlib import Path
from typing import Callable
import polars as pl
from .config import OUTPUT_DIR
from .config import CHECKPOINT_INTERVAL, OUTPUT_DIR
from .models import JourneyResult
@ -21,6 +22,54 @@ def results_to_dataframe(results: list[JourneyResult]) -> pl.DataFrame:
)
class CheckpointSaver:
"""Collects results and saves checkpoints at regular intervals."""
def __init__(
self,
destination_name: str,
output_dir: Path | None = None,
interval: int = CHECKPOINT_INTERVAL,
on_save: Callable[[Path, int], None] | None = None,
):
self.destination_name = destination_name
self.output_dir = output_dir or OUTPUT_DIR
self.interval = interval
self.on_save = on_save
self.results: list[JourneyResult] = []
self._last_save_count = 0
def add_result(self, result: JourneyResult) -> None:
"""Add a result and save checkpoint if interval is reached."""
self.results.append(result)
if len(self.results) - self._last_save_count >= self.interval:
self.save_checkpoint()
def save_checkpoint(self) -> Path:
"""Save current results to checkpoint file."""
df = results_to_dataframe(self.results)
path = self._checkpoint_path()
df.write_parquet(path)
self._last_save_count = len(self.results)
if self.on_save:
self.on_save(path, len(self.results))
return path
def _checkpoint_path(self) -> Path:
safe_name = self.destination_name.lower().replace(" ", "-")
return self.output_dir / f"journey_times_{safe_name}_checkpoint.parquet"
def get_results(self) -> list[JourneyResult]:
"""Return all collected results."""
return self.results
def cleanup_checkpoint(self) -> None:
"""Remove the checkpoint file after successful completion."""
path = self._checkpoint_path()
if path.exists():
path.unlink()
def save_results(
results: pl.DataFrame,
destination_name: str,