perfect-postcode/pipeline/journey_times/tfl_client.py

248 lines
8.2 KiB
Python

import asyncio
from typing import Literal
import warnings
from collections.abc import Callable
from http import HTTPStatus
from httpx import Timeout
from journey_client import Client
from journey_client.api.journey import (
journey_journey_results_by_path_from_path_to_query_via_query_national_search_query_date_qu as journey_api,
)
from journey_client.models import (
JourneyJourneyResultsByPathFromPathToQueryViaQueryNationalSearchQueryDateQuTimeIs as TimeIs,
)
from journey_client.models import (
JourneyJourneyResultsByPathFromPathToQueryViaQueryNationalSearchQueryDateQuJourneyPreference as JourneyPreference,
)
from journey_client.models import (
JourneyJourneyResultsByPathFromPathToQueryViaQueryNationalSearchQueryDateQuCyclePreference as CyclePreference,
)
from journey_client.models import (
JourneyJourneyResultsByPathFromPathToQueryViaQueryNationalSearchQueryDateQuBikeProficiency as BikeProficiency,
)
from journey_client.types import Unset
from .config import MAX_DELAY
from .models import Destination, JourneyResult
from .rate_limiter import RateLimiter
async def fetch_journey_for_mode(
client: Client,
rate_limiter: RateLimiter,
from_location: str,
to_location: str,
journey_date: str,
journey_time: str,
journey_type: Literal["quick"] | Literal["easy"] | Literal["cycle"],
retry_count: int = 5,
) -> int | None:
"""Fetch journey time for a specific mode with rate limiting."""
backoff = 1.0
for attempt in range(retry_count):
try:
await rate_limiter.acquire()
cycle_preference = {
"quick": CyclePreference.TAKEONTRANSPORT,
"easy": CyclePreference.NONE,
"cycle": CyclePreference.ALLTHEWAY,
}[journey_type]
# options: public-bus,overground,train,tube,coach,dlr,cablecar,tram,river,walking,cycle
mode = {
"quick": [
"public-bus",
"overground",
"train",
"tube",
"coach",
"dlr",
"cablecar",
"tram",
"river",
"walking",
"cycle",
],
"easy": [
"public-bus",
"overground",
"train",
"tube",
"coach",
"dlr",
"cablecar",
"tram",
"river",
],
"cycle": ["cycle"],
}[journey_type]
response = await journey_api.asyncio_detailed(
from_=from_location,
to=to_location,
client=client,
date=journey_date,
time=journey_time,
national_search=True,
time_is=TimeIs.ARRIVING,
journey_preference=JourneyPreference.LEASTINTERCHANGE
if journey_type == "easy"
else JourneyPreference.LEASTINTERCHANGE,
cycle_preference=cycle_preference,
bike_proficiency=BikeProficiency.FAST,
walking_optimization=journey_type == "quick",
mode=mode,
)
if response.status_code == HTTPStatus.OK and response.parsed:
journeys = response.parsed.journeys
if not isinstance(journeys, Unset) and journeys:
durations = [
j.duration
for j in journeys
if not isinstance(j.duration, Unset)
]
if durations:
return min(durations)
return None
elif response.status_code in (
HTTPStatus.TOO_MANY_REQUESTS,
HTTPStatus.INTERNAL_SERVER_ERROR,
HTTPStatus.BAD_GATEWAY,
HTTPStatus.SERVICE_UNAVAILABLE,
HTTPStatus.GATEWAY_TIMEOUT,
):
warnings.warn(
f"HTTP {response.status_code.value} for {journey_type} from {from_location}, "
f"retrying in {backoff:.1f}s (attempt {attempt + 1}/{retry_count})",
stacklevel=2,
)
await asyncio.sleep(backoff)
backoff = min(backoff * 2, MAX_DELAY)
continue
else:
return None
except Exception as e:
warnings.warn(
f"Network error for {journey_type} from {from_location}: {e}, "
f"retrying in {backoff:.1f}s (attempt {attempt + 1}/{retry_count})",
stacklevel=2,
)
await asyncio.sleep(backoff)
backoff = min(backoff * 2, MAX_DELAY)
continue
warnings.warn(
f"Failed to fetch {journey_type} from {from_location} after {retry_count} attempts",
stacklevel=2,
)
return None
async def fetch_all_modes(
client: Client,
rate_limiter: RateLimiter,
postcode: str,
lat: float,
lon: float,
to_location: str,
journey_date: str,
journey_time: str,
semaphore: asyncio.Semaphore,
) -> JourneyResult:
"""Fetch journey times for all transport modes using coordinates."""
async with semaphore:
try:
from_location = f"{lat},{lon}"
easy = await fetch_journey_for_mode(
client,
rate_limiter,
from_location,
to_location,
journey_date,
journey_time,
journey_type="easy",
)
quick = await fetch_journey_for_mode(
client,
rate_limiter,
from_location,
to_location,
journey_date,
journey_time,
journey_type="quick",
)
cycling = await fetch_journey_for_mode(
client,
rate_limiter,
from_location,
to_location,
journey_date,
journey_time,
journey_type="cycle",
)
return JourneyResult(
postcode=postcode,
public_transport_easy_minutes=easy,
public_transport_quick_minutes=quick,
cycling_minutes=cycling,
)
except Exception as e:
print(f"Error: {e}")
return JourneyResult(postcode=postcode, error=str(e))
async def fetch_journey_times(
postcode_data: list[tuple[str, float, float]],
dest: Destination,
journey_date: str,
journey_time: str,
max_concurrent: int = 2,
progress_callback: Callable[[JourneyResult], None] | None = None,
) -> list[JourneyResult]:
"""Fetch journey times for all postcodes with rate limiting.
Args:
postcode_data: List of (postcode, lat, lon) tuples
dest: Destination for journey planning
journey_date: Date in YYYYMMDD format
journey_time: Time in HHMM format
max_concurrent: Maximum concurrent API requests
progress_callback: Optional callback called with each result
Returns:
List of JourneyResult objects in the same order as postcode_data
"""
semaphore = asyncio.Semaphore(max_concurrent)
to_location = dest.to_tfl_location()
rate_limiter = RateLimiter()
client = Client(base_url="https://api.tfl.gov.uk").with_timeout(Timeout(30))
async with client as client:
tasks = [
fetch_all_modes(
client,
rate_limiter,
pc,
lat,
lon,
to_location,
journey_date,
journey_time,
semaphore,
)
for pc, lat, lon in postcode_data
]
results = []
for coro in asyncio.as_completed(tasks):
result = await coro
results.append(result)
if progress_callback:
progress_callback(result)
postcode_to_result = {r.postcode: r for r in results}
return [postcode_to_result[pc] for pc, _, _ in postcode_data]