85 lines
2.7 KiB
Python
85 lines
2.7 KiB
Python
from pathlib import Path
|
|
from typing import Callable
|
|
|
|
import polars as pl
|
|
|
|
from .config import CHECKPOINT_INTERVAL, OUTPUT_DIR
|
|
from .models import JourneyResult
|
|
|
|
|
|
def results_to_dataframe(results: list[JourneyResult]) -> pl.DataFrame:
|
|
return pl.DataFrame(
|
|
[
|
|
{
|
|
"postcode": r.postcode,
|
|
"public_transport_easy_minutes": r.public_transport_easy_minutes,
|
|
"public_transport_quick_minutes": r.public_transport_quick_minutes,
|
|
"cycling_minutes": r.cycling_minutes,
|
|
"error": r.error,
|
|
}
|
|
for r in results
|
|
]
|
|
)
|
|
|
|
|
|
class CheckpointSaver:
|
|
"""Collects results and saves checkpoints at regular intervals."""
|
|
|
|
def __init__(
|
|
self,
|
|
destination_name: str,
|
|
output_dir: Path | None = None,
|
|
interval: int = CHECKPOINT_INTERVAL,
|
|
on_save: Callable[[Path, int], None] | None = None,
|
|
):
|
|
self.destination_name = destination_name
|
|
self.output_dir = output_dir or OUTPUT_DIR
|
|
self.interval = interval
|
|
self.on_save = on_save
|
|
self.results: list[JourneyResult] = []
|
|
self._last_save_count = 0
|
|
|
|
def add_result(self, result: JourneyResult) -> None:
|
|
"""Add a result and save checkpoint if interval is reached."""
|
|
self.results.append(result)
|
|
if len(self.results) - self._last_save_count >= self.interval:
|
|
self.save_checkpoint()
|
|
|
|
def save_checkpoint(self) -> Path:
|
|
"""Save current results to checkpoint file."""
|
|
df = results_to_dataframe(self.results)
|
|
path = self._checkpoint_path()
|
|
df.write_parquet(path)
|
|
self._last_save_count = len(self.results)
|
|
if self.on_save:
|
|
self.on_save(path, len(self.results))
|
|
return path
|
|
|
|
def _checkpoint_path(self) -> Path:
|
|
safe_name = self.destination_name.lower().replace(" ", "-")
|
|
return self.output_dir / f"journey_times_{safe_name}_checkpoint.parquet"
|
|
|
|
def get_results(self) -> list[JourneyResult]:
|
|
"""Return all collected results."""
|
|
return self.results
|
|
|
|
def cleanup_checkpoint(self) -> None:
|
|
"""Remove the checkpoint file after successful completion."""
|
|
path = self._checkpoint_path()
|
|
if path.exists():
|
|
path.unlink()
|
|
|
|
|
|
def save_results(
|
|
results: pl.DataFrame,
|
|
destination_name: str,
|
|
output_dir: Path | None = None,
|
|
) -> Path:
|
|
if output_dir is None:
|
|
output_dir = OUTPUT_DIR
|
|
|
|
safe_name = destination_name.lower().replace(" ", "-")
|
|
parquet_path = output_dir / f"journey_times_{safe_name}.parquet"
|
|
results.write_parquet(parquet_path)
|
|
|
|
return parquet_path
|