Add radius, progress bar, and checkpointing
This commit is contained in:
parent
d227239651
commit
275e5afac6
3 changed files with 97 additions and 6 deletions
|
|
@ -5,8 +5,8 @@ from datetime import date, timedelta
|
|||
import polars as pl
|
||||
from tqdm import tqdm
|
||||
|
||||
from .config import DESTINATIONS, MAX_CONCURRENT, MAX_POSTCODES, OUTPUT_DIR
|
||||
from .results import results_to_dataframe, save_results
|
||||
from .config import DESTINATIONS, MAX_CONCURRENT, MAX_POSTCODES, OUTPUT_DIR, MAX_DISTANCE_KM
|
||||
from .results import CheckpointSaver, results_to_dataframe, save_results
|
||||
from .tfl_client import fetch_journey_times
|
||||
|
||||
|
||||
|
|
@ -28,6 +28,35 @@ def main():
|
|||
postcodes_df = pl.read_parquet(OUTPUT_DIR / "postcodes_h3.parquet")
|
||||
print(f"Loaded {postcodes_df.height:,} postcodes")
|
||||
|
||||
# Filter to postcodes within 150km of destination using Haversine formula
|
||||
earth_radius_km = 6371
|
||||
|
||||
dest_lat_rad = destination.lat * 3.14159265359 / 180
|
||||
dest_lon_rad = destination.lon * 3.14159265359 / 180
|
||||
|
||||
postcodes_df = postcodes_df.with_columns(
|
||||
(
|
||||
2
|
||||
* earth_radius_km
|
||||
* (
|
||||
(
|
||||
((pl.lit(dest_lat_rad) - pl.col("lat") * 3.14159265359 / 180) / 2).sin()
|
||||
** 2
|
||||
+ pl.lit(dest_lat_rad).cos()
|
||||
* (pl.col("lat") * 3.14159265359 / 180).cos()
|
||||
* (
|
||||
(pl.lit(dest_lon_rad) - pl.col("long") * 3.14159265359 / 180) / 2
|
||||
).sin()
|
||||
** 2
|
||||
)
|
||||
.sqrt()
|
||||
.arcsin()
|
||||
)
|
||||
).alias("distance_km")
|
||||
).filter(pl.col("distance_km") <= MAX_DISTANCE_KM)
|
||||
|
||||
print(f"Filtered to {postcodes_df.height:,} postcodes within {MAX_DISTANCE_KM}km")
|
||||
|
||||
postcode_data = list(
|
||||
zip(
|
||||
postcodes_df["postcode"].to_list(),
|
||||
|
|
@ -40,6 +69,15 @@ def main():
|
|||
postcode_data = random.sample(postcode_data, MAX_POSTCODES)
|
||||
print(f"Randomly sampled {MAX_POSTCODES} postcodes")
|
||||
|
||||
checkpoint_saver = CheckpointSaver(
|
||||
destination_name=destination.name,
|
||||
on_save=lambda path, count: print(f"Checkpoint saved: {count:,} results to {path}"),
|
||||
)
|
||||
|
||||
def on_result(result):
|
||||
pbar.update(1)
|
||||
checkpoint_saver.add_result(result)
|
||||
|
||||
with tqdm(total=len(postcode_data), desc="Fetching journeys") as pbar:
|
||||
results = asyncio.run(
|
||||
fetch_journey_times(
|
||||
|
|
@ -48,7 +86,7 @@ def main():
|
|||
journey_date.strftime("%Y%m%d"),
|
||||
journey_time,
|
||||
MAX_CONCURRENT,
|
||||
progress_callback=lambda _: pbar.update(1),
|
||||
progress_callback=on_result,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -70,6 +108,7 @@ def main():
|
|||
print(f"Completed: {successful}/{len(results)} successful")
|
||||
|
||||
parquet_path = save_results(results_df, destination.name)
|
||||
checkpoint_saver.cleanup_checkpoint()
|
||||
print(f"Saved to {parquet_path}")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,8 +9,11 @@ OUTPUT_DIR = DATA_DIR / "processed"
|
|||
|
||||
MAX_DELAY = 10
|
||||
REQUESTS_PER_MIN = 500
|
||||
MAX_POSTCODES = 100 # Set to None to process all postcodes
|
||||
MAX_CONCURRENT = 5
|
||||
MAX_POSTCODES = None
|
||||
MAX_CONCURRENT = 80
|
||||
MAX_DISTANCE_KM = 110
|
||||
CHECKPOINT_INTERVAL = 10000
|
||||
|
||||
|
||||
DESTINATIONS = {
|
||||
"bank": Destination(51.5133, -0.0886, "Bank", "940GZZLUBNK"),
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import polars as pl
|
||||
|
||||
from .config import OUTPUT_DIR
|
||||
from .config import CHECKPOINT_INTERVAL, OUTPUT_DIR
|
||||
from .models import JourneyResult
|
||||
|
||||
|
||||
|
|
@ -21,6 +22,54 @@ def results_to_dataframe(results: list[JourneyResult]) -> pl.DataFrame:
|
|||
)
|
||||
|
||||
|
||||
class CheckpointSaver:
|
||||
"""Collects results and saves checkpoints at regular intervals."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
destination_name: str,
|
||||
output_dir: Path | None = None,
|
||||
interval: int = CHECKPOINT_INTERVAL,
|
||||
on_save: Callable[[Path, int], None] | None = None,
|
||||
):
|
||||
self.destination_name = destination_name
|
||||
self.output_dir = output_dir or OUTPUT_DIR
|
||||
self.interval = interval
|
||||
self.on_save = on_save
|
||||
self.results: list[JourneyResult] = []
|
||||
self._last_save_count = 0
|
||||
|
||||
def add_result(self, result: JourneyResult) -> None:
|
||||
"""Add a result and save checkpoint if interval is reached."""
|
||||
self.results.append(result)
|
||||
if len(self.results) - self._last_save_count >= self.interval:
|
||||
self.save_checkpoint()
|
||||
|
||||
def save_checkpoint(self) -> Path:
|
||||
"""Save current results to checkpoint file."""
|
||||
df = results_to_dataframe(self.results)
|
||||
path = self._checkpoint_path()
|
||||
df.write_parquet(path)
|
||||
self._last_save_count = len(self.results)
|
||||
if self.on_save:
|
||||
self.on_save(path, len(self.results))
|
||||
return path
|
||||
|
||||
def _checkpoint_path(self) -> Path:
|
||||
safe_name = self.destination_name.lower().replace(" ", "-")
|
||||
return self.output_dir / f"journey_times_{safe_name}_checkpoint.parquet"
|
||||
|
||||
def get_results(self) -> list[JourneyResult]:
|
||||
"""Return all collected results."""
|
||||
return self.results
|
||||
|
||||
def cleanup_checkpoint(self) -> None:
|
||||
"""Remove the checkpoint file after successful completion."""
|
||||
path = self._checkpoint_path()
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
|
||||
|
||||
def save_results(
|
||||
results: pl.DataFrame,
|
||||
destination_name: str,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue