254 lines
7.9 KiB
Python
254 lines
7.9 KiB
Python
import asyncio
|
|
import os
|
|
from typing import Literal
|
|
import warnings
|
|
from collections.abc import Callable
|
|
from http import HTTPStatus
|
|
|
|
import httpx
|
|
|
|
from .config import MAX_DELAY
|
|
from .models import Destination, JourneyResult
|
|
from .rate_limiter import RateLimiter
|
|
|
|
|
|
BASE_URL = "https://api.tfl.gov.uk"
|
|
|
|
|
|
async def fetch_journey_for_mode(
|
|
client: httpx.AsyncClient,
|
|
rate_limiter: RateLimiter,
|
|
from_location: str,
|
|
to_location: str,
|
|
journey_date: str,
|
|
journey_time: str,
|
|
journey_type: Literal["quick"] | Literal["easy"] | Literal["cycle"],
|
|
retry_count: int = 5,
|
|
) -> int | None:
|
|
"""Fetch journey time for a specific mode with rate limiting."""
|
|
backoff = 1.0
|
|
for attempt in range(retry_count):
|
|
try:
|
|
await rate_limiter.acquire()
|
|
|
|
journey_preference = {
|
|
"quick": "LeastTime",
|
|
"easy": "LeastInterchange",
|
|
"cycle": None,
|
|
}[journey_type]
|
|
|
|
cycle_preference = {
|
|
"quick": None,
|
|
"easy": None,
|
|
"cycle": "AllTheWay",
|
|
}[journey_type]
|
|
|
|
# curl -s "https://api.tfl.gov.uk/Journey/Meta/Modes" | jq '.[].modeName'
|
|
mode = {
|
|
"quick": [
|
|
"bus",
|
|
"overground",
|
|
"national-rail",
|
|
"international-rail",
|
|
"elizabeth-line",
|
|
"tube",
|
|
"coach",
|
|
"dlr",
|
|
"cable-car",
|
|
"replacement-bus",
|
|
"tram",
|
|
"river-bus",
|
|
"walking",
|
|
"cycle",
|
|
],
|
|
"easy": [
|
|
"bus",
|
|
"overground",
|
|
"national-rail",
|
|
"international-rail",
|
|
"elizabeth-line",
|
|
"replacement-bus",
|
|
"tube",
|
|
"coach",
|
|
"dlr",
|
|
"cable-car",
|
|
"tram",
|
|
"river-bus",
|
|
],
|
|
"cycle": ["cycle"],
|
|
}[journey_type]
|
|
|
|
params: dict = {
|
|
"date": journey_date,
|
|
"time": journey_time,
|
|
"nationalSearch": "true",
|
|
"timeIs": "Arriving",
|
|
"cyclePreference": cycle_preference,
|
|
"bikeProficiency": "Fast",
|
|
"walkingOptimization": str(journey_type == "quick").lower(),
|
|
"mode": ",".join(mode),
|
|
}
|
|
if journey_preference:
|
|
params["journeyPreference"] = journey_preference
|
|
|
|
url = f"/Journey/JourneyResults/{from_location}/to/{to_location}"
|
|
response = await client.get(url, params=params)
|
|
|
|
if response.status_code == HTTPStatus.OK:
|
|
data = response.json()
|
|
journeys = data.get("journeys", [])
|
|
if journeys:
|
|
durations = [
|
|
j["duration"] for j in journeys if j.get("duration") is not None
|
|
]
|
|
if durations:
|
|
return min(durations)
|
|
return None
|
|
elif response.status_code in (
|
|
HTTPStatus.TOO_MANY_REQUESTS,
|
|
HTTPStatus.INTERNAL_SERVER_ERROR,
|
|
HTTPStatus.BAD_GATEWAY,
|
|
HTTPStatus.SERVICE_UNAVAILABLE,
|
|
HTTPStatus.GATEWAY_TIMEOUT,
|
|
):
|
|
warnings.warn(
|
|
f"HTTP {response.status_code} for {journey_type} from {from_location}, "
|
|
f"retrying in {backoff:.1f}s (attempt {attempt + 1}/{retry_count})",
|
|
stacklevel=2,
|
|
)
|
|
await asyncio.sleep(backoff)
|
|
backoff = min(backoff * 2, MAX_DELAY)
|
|
continue
|
|
else:
|
|
return None
|
|
except Exception as e:
|
|
warnings.warn(
|
|
f"Network error for {journey_type} from {from_location}: {e}, "
|
|
f"retrying in {backoff:.1f}s (attempt {attempt + 1}/{retry_count})",
|
|
stacklevel=2,
|
|
)
|
|
await asyncio.sleep(backoff)
|
|
backoff = min(backoff * 2, MAX_DELAY)
|
|
continue
|
|
warnings.warn(
|
|
f"Failed to fetch {journey_type} from {from_location} after {retry_count} attempts",
|
|
stacklevel=2,
|
|
)
|
|
return None
|
|
|
|
|
|
async def fetch_all_modes(
|
|
client: httpx.AsyncClient,
|
|
rate_limiter: RateLimiter,
|
|
postcode: str,
|
|
lat: float,
|
|
lon: float,
|
|
to_location: str,
|
|
journey_date: str,
|
|
journey_time: str,
|
|
semaphore: asyncio.Semaphore,
|
|
) -> JourneyResult:
|
|
"""Fetch journey times for all transport modes using coordinates."""
|
|
async with semaphore:
|
|
try:
|
|
from_location = f"{lat},{lon}"
|
|
|
|
easy = await fetch_journey_for_mode(
|
|
client,
|
|
rate_limiter,
|
|
from_location,
|
|
to_location,
|
|
journey_date,
|
|
journey_time,
|
|
journey_type="easy",
|
|
)
|
|
quick = await fetch_journey_for_mode(
|
|
client,
|
|
rate_limiter,
|
|
from_location,
|
|
to_location,
|
|
journey_date,
|
|
journey_time,
|
|
journey_type="quick",
|
|
)
|
|
cycling = await fetch_journey_for_mode(
|
|
client,
|
|
rate_limiter,
|
|
from_location,
|
|
to_location,
|
|
journey_date,
|
|
journey_time,
|
|
journey_type="cycle",
|
|
)
|
|
|
|
return JourneyResult(
|
|
postcode=postcode,
|
|
public_transport_easy_minutes=easy,
|
|
public_transport_quick_minutes=quick,
|
|
cycling_minutes=cycling,
|
|
)
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
return JourneyResult(postcode=postcode, error=str(e))
|
|
|
|
|
|
async def fetch_journey_times(
|
|
postcode_data: list[tuple[str, float, float]],
|
|
dest: Destination,
|
|
journey_date: str,
|
|
journey_time: str,
|
|
max_concurrent: int = 2,
|
|
progress_callback: Callable[[JourneyResult], None] | None = None,
|
|
) -> list[JourneyResult]:
|
|
"""Fetch journey times for all postcodes with rate limiting.
|
|
|
|
Args:
|
|
postcode_data: List of (postcode, lat, lon) tuples
|
|
dest: Destination for journey planning
|
|
journey_date: Date in YYYYMMDD format
|
|
journey_time: Time in HHMM format
|
|
max_concurrent: Maximum concurrent API requests
|
|
progress_callback: Optional callback called with each result
|
|
|
|
Returns:
|
|
List of JourneyResult objects in the same order as postcode_data
|
|
"""
|
|
semaphore = asyncio.Semaphore(max_concurrent)
|
|
to_location = dest.to_tfl_location()
|
|
rate_limiter = RateLimiter()
|
|
|
|
# TFL API authentication via app_key query parameter
|
|
tfl_token = os.environ.get("TFL_TOKEN")
|
|
if not tfl_token:
|
|
raise RuntimeError("TFL_TOKEN environment variable not set")
|
|
params = {"app_key": tfl_token}
|
|
|
|
async with httpx.AsyncClient(
|
|
base_url=BASE_URL,
|
|
params=params,
|
|
timeout=httpx.Timeout(30),
|
|
) as client:
|
|
tasks = [
|
|
fetch_all_modes(
|
|
client,
|
|
rate_limiter,
|
|
pc,
|
|
lat,
|
|
lon,
|
|
to_location,
|
|
journey_date,
|
|
journey_time,
|
|
semaphore,
|
|
)
|
|
for pc, lat, lon in postcode_data
|
|
]
|
|
|
|
results = []
|
|
for coro in asyncio.as_completed(tasks):
|
|
result = await coro
|
|
results.append(result)
|
|
if progress_callback:
|
|
progress_callback(result)
|
|
|
|
postcode_to_result = {r.postcode: r for r in results}
|
|
return [postcode_to_result[pc] for pc, _, _ in postcode_data]
|