Add radius, progress bar, and checkpointing
This commit is contained in:
parent
d227239651
commit
275e5afac6
3 changed files with 97 additions and 6 deletions
|
|
@ -5,8 +5,8 @@ from datetime import date, timedelta
|
||||||
import polars as pl
|
import polars as pl
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .config import DESTINATIONS, MAX_CONCURRENT, MAX_POSTCODES, OUTPUT_DIR
|
from .config import DESTINATIONS, MAX_CONCURRENT, MAX_POSTCODES, OUTPUT_DIR, MAX_DISTANCE_KM
|
||||||
from .results import results_to_dataframe, save_results
|
from .results import CheckpointSaver, results_to_dataframe, save_results
|
||||||
from .tfl_client import fetch_journey_times
|
from .tfl_client import fetch_journey_times
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -28,6 +28,35 @@ def main():
|
||||||
postcodes_df = pl.read_parquet(OUTPUT_DIR / "postcodes_h3.parquet")
|
postcodes_df = pl.read_parquet(OUTPUT_DIR / "postcodes_h3.parquet")
|
||||||
print(f"Loaded {postcodes_df.height:,} postcodes")
|
print(f"Loaded {postcodes_df.height:,} postcodes")
|
||||||
|
|
||||||
|
# Filter to postcodes within 150km of destination using Haversine formula
|
||||||
|
earth_radius_km = 6371
|
||||||
|
|
||||||
|
dest_lat_rad = destination.lat * 3.14159265359 / 180
|
||||||
|
dest_lon_rad = destination.lon * 3.14159265359 / 180
|
||||||
|
|
||||||
|
postcodes_df = postcodes_df.with_columns(
|
||||||
|
(
|
||||||
|
2
|
||||||
|
* earth_radius_km
|
||||||
|
* (
|
||||||
|
(
|
||||||
|
((pl.lit(dest_lat_rad) - pl.col("lat") * 3.14159265359 / 180) / 2).sin()
|
||||||
|
** 2
|
||||||
|
+ pl.lit(dest_lat_rad).cos()
|
||||||
|
* (pl.col("lat") * 3.14159265359 / 180).cos()
|
||||||
|
* (
|
||||||
|
(pl.lit(dest_lon_rad) - pl.col("long") * 3.14159265359 / 180) / 2
|
||||||
|
).sin()
|
||||||
|
** 2
|
||||||
|
)
|
||||||
|
.sqrt()
|
||||||
|
.arcsin()
|
||||||
|
)
|
||||||
|
).alias("distance_km")
|
||||||
|
).filter(pl.col("distance_km") <= MAX_DISTANCE_KM)
|
||||||
|
|
||||||
|
print(f"Filtered to {postcodes_df.height:,} postcodes within {MAX_DISTANCE_KM}km")
|
||||||
|
|
||||||
postcode_data = list(
|
postcode_data = list(
|
||||||
zip(
|
zip(
|
||||||
postcodes_df["postcode"].to_list(),
|
postcodes_df["postcode"].to_list(),
|
||||||
|
|
@ -40,6 +69,15 @@ def main():
|
||||||
postcode_data = random.sample(postcode_data, MAX_POSTCODES)
|
postcode_data = random.sample(postcode_data, MAX_POSTCODES)
|
||||||
print(f"Randomly sampled {MAX_POSTCODES} postcodes")
|
print(f"Randomly sampled {MAX_POSTCODES} postcodes")
|
||||||
|
|
||||||
|
checkpoint_saver = CheckpointSaver(
|
||||||
|
destination_name=destination.name,
|
||||||
|
on_save=lambda path, count: print(f"Checkpoint saved: {count:,} results to {path}"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_result(result):
|
||||||
|
pbar.update(1)
|
||||||
|
checkpoint_saver.add_result(result)
|
||||||
|
|
||||||
with tqdm(total=len(postcode_data), desc="Fetching journeys") as pbar:
|
with tqdm(total=len(postcode_data), desc="Fetching journeys") as pbar:
|
||||||
results = asyncio.run(
|
results = asyncio.run(
|
||||||
fetch_journey_times(
|
fetch_journey_times(
|
||||||
|
|
@ -48,7 +86,7 @@ def main():
|
||||||
journey_date.strftime("%Y%m%d"),
|
journey_date.strftime("%Y%m%d"),
|
||||||
journey_time,
|
journey_time,
|
||||||
MAX_CONCURRENT,
|
MAX_CONCURRENT,
|
||||||
progress_callback=lambda _: pbar.update(1),
|
progress_callback=on_result,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -70,6 +108,7 @@ def main():
|
||||||
print(f"Completed: {successful}/{len(results)} successful")
|
print(f"Completed: {successful}/{len(results)} successful")
|
||||||
|
|
||||||
parquet_path = save_results(results_df, destination.name)
|
parquet_path = save_results(results_df, destination.name)
|
||||||
|
checkpoint_saver.cleanup_checkpoint()
|
||||||
print(f"Saved to {parquet_path}")
|
print(f"Saved to {parquet_path}")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,11 @@ OUTPUT_DIR = DATA_DIR / "processed"
|
||||||
|
|
||||||
MAX_DELAY = 10
|
MAX_DELAY = 10
|
||||||
REQUESTS_PER_MIN = 500
|
REQUESTS_PER_MIN = 500
|
||||||
MAX_POSTCODES = 100 # Set to None to process all postcodes
|
MAX_POSTCODES = None
|
||||||
MAX_CONCURRENT = 5
|
MAX_CONCURRENT = 80
|
||||||
|
MAX_DISTANCE_KM = 110
|
||||||
|
CHECKPOINT_INTERVAL = 10000
|
||||||
|
|
||||||
|
|
||||||
DESTINATIONS = {
|
DESTINATIONS = {
|
||||||
"bank": Destination(51.5133, -0.0886, "Bank", "940GZZLUBNK"),
|
"bank": Destination(51.5133, -0.0886, "Bank", "940GZZLUBNK"),
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import polars as pl
|
import polars as pl
|
||||||
|
|
||||||
from .config import OUTPUT_DIR
|
from .config import CHECKPOINT_INTERVAL, OUTPUT_DIR
|
||||||
from .models import JourneyResult
|
from .models import JourneyResult
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -21,6 +22,54 @@ def results_to_dataframe(results: list[JourneyResult]) -> pl.DataFrame:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointSaver:
|
||||||
|
"""Collects results and saves checkpoints at regular intervals."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
destination_name: str,
|
||||||
|
output_dir: Path | None = None,
|
||||||
|
interval: int = CHECKPOINT_INTERVAL,
|
||||||
|
on_save: Callable[[Path, int], None] | None = None,
|
||||||
|
):
|
||||||
|
self.destination_name = destination_name
|
||||||
|
self.output_dir = output_dir or OUTPUT_DIR
|
||||||
|
self.interval = interval
|
||||||
|
self.on_save = on_save
|
||||||
|
self.results: list[JourneyResult] = []
|
||||||
|
self._last_save_count = 0
|
||||||
|
|
||||||
|
def add_result(self, result: JourneyResult) -> None:
|
||||||
|
"""Add a result and save checkpoint if interval is reached."""
|
||||||
|
self.results.append(result)
|
||||||
|
if len(self.results) - self._last_save_count >= self.interval:
|
||||||
|
self.save_checkpoint()
|
||||||
|
|
||||||
|
def save_checkpoint(self) -> Path:
|
||||||
|
"""Save current results to checkpoint file."""
|
||||||
|
df = results_to_dataframe(self.results)
|
||||||
|
path = self._checkpoint_path()
|
||||||
|
df.write_parquet(path)
|
||||||
|
self._last_save_count = len(self.results)
|
||||||
|
if self.on_save:
|
||||||
|
self.on_save(path, len(self.results))
|
||||||
|
return path
|
||||||
|
|
||||||
|
def _checkpoint_path(self) -> Path:
|
||||||
|
safe_name = self.destination_name.lower().replace(" ", "-")
|
||||||
|
return self.output_dir / f"journey_times_{safe_name}_checkpoint.parquet"
|
||||||
|
|
||||||
|
def get_results(self) -> list[JourneyResult]:
|
||||||
|
"""Return all collected results."""
|
||||||
|
return self.results
|
||||||
|
|
||||||
|
def cleanup_checkpoint(self) -> None:
|
||||||
|
"""Remove the checkpoint file after successful completion."""
|
||||||
|
path = self._checkpoint_path()
|
||||||
|
if path.exists():
|
||||||
|
path.unlink()
|
||||||
|
|
||||||
|
|
||||||
def save_results(
|
def save_results(
|
||||||
results: pl.DataFrame,
|
results: pl.DataFrame,
|
||||||
destination_name: str,
|
destination_name: str,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue