perfect-postcode/pipeline/journey_times/results.py

82 lines
2.5 KiB
Python

from pathlib import Path
from typing import Callable
import polars as pl
from .config import CHECKPOINT_INTERVAL
from .models import JourneyResult
def results_to_dataframe(results: list[JourneyResult]) -> pl.DataFrame:
return pl.DataFrame(
[
{
"postcode": r.postcode,
"public_transport_easy_minutes": r.public_transport_easy_minutes,
"public_transport_quick_minutes": r.public_transport_quick_minutes,
"cycling_minutes": r.cycling_minutes,
"error": r.error,
}
for r in results
]
)
class CheckpointSaver:
"""Collects results and saves checkpoints at regular intervals."""
def __init__(
self,
destination_name: str,
output_dir: Path,
interval: int = CHECKPOINT_INTERVAL,
on_save: Callable[[Path, int], None] | None = None,
):
self.destination_name = destination_name
self.output_dir = output_dir
self.interval = interval
self.on_save = on_save
self.results: list[JourneyResult] = []
self._last_save_count = 0
def add_result(self, result: JourneyResult) -> None:
"""Add a result and save checkpoint if interval is reached."""
self.results.append(result)
if len(self.results) - self._last_save_count >= self.interval:
self.save_checkpoint()
def save_checkpoint(self) -> Path:
"""Save current results to checkpoint file."""
df = results_to_dataframe(self.results)
path = self._checkpoint_path()
df.write_parquet(path)
self._last_save_count = len(self.results)
if self.on_save:
self.on_save(path, len(self.results))
return path
def _checkpoint_path(self) -> Path:
safe_name = self.destination_name.lower().replace(" ", "-")
return self.output_dir / f"journey_times_{safe_name}_checkpoint.parquet"
def get_results(self) -> list[JourneyResult]:
"""Return all collected results."""
return self.results
def cleanup_checkpoint(self) -> None:
"""Remove the checkpoint file after successful completion."""
path = self._checkpoint_path()
if path.exists():
path.unlink()
def save_results(
results: pl.DataFrame,
destination_name: str,
output_dir: Path,
) -> Path:
safe_name = destination_name.lower().replace(" ", "-")
parquet_path = output_dir / f"journey_times_{safe_name}.parquet"
results.write_parquet(parquet_path)
return parquet_path