perfect-postcode/r5-service/main.py
2026-02-15 22:39:53 +00:00

153 lines
4.8 KiB
Python

"""R5 travel time service — FastAPI wrapper around r5py.
Loads an OSM PBF + GTFS feeds at startup, then serves many-to-one
travel time queries via POST /travel-times.
"""
import datetime
import logging
import os
from pathlib import Path
import geopandas as gpd
import pandas as pd
import r5py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from shapely.geometry import Point
logger = logging.getLogger("r5-service")
logging.basicConfig(level=logging.INFO)
app = FastAPI(title="R5 Travel Time Service")
# Global transport network — loaded once at startup
_transport_network: r5py.TransportNetwork | None = None
@app.on_event("startup")
def load_network() -> None:
global _transport_network
data_dir = Path(os.environ.get("TRANSIT_DATA_DIR", "/data/transit"))
osm_files = list(data_dir.glob("*.osm.pbf"))
if not osm_files:
raise RuntimeError(f"No .osm.pbf file found in {data_dir}")
osm_pbf = osm_files[0]
gtfs_files = list(data_dir.glob("*.zip"))
logger.info(
"Loading transport network: OSM=%s, GTFS=%s",
osm_pbf.name,
[f.name for f in gtfs_files],
)
_transport_network = r5py.TransportNetwork(
osm_pbf=osm_pbf,
gtfs=gtfs_files if gtfs_files else None,
)
logger.info("Transport network loaded successfully")
# ── Request / Response models ────────────────────────────────────────────────
# r5py 1.x uses transport_modes for direct modes and access_modes for
# first-mile walking/cycling to transit.
MODE_CONFIGS = {
"transit": {
"transport_modes": [r5py.TransportMode.TRANSIT],
"access_modes": [r5py.TransportMode.WALK],
},
"car": {
"transport_modes": [r5py.TransportMode.CAR],
},
"bicycle": {
"transport_modes": [r5py.TransportMode.BICYCLE],
},
}
class TravelTimeRequest(BaseModel):
origins: list[list[float]] # [[lat, lon], ...]
destination: list[float] # [lat, lon]
mode: str = "transit"
departure_time: str | None = None # ISO 8601, defaults to next weekday 8am
class TravelTimeResponse(BaseModel):
travel_times: list[float | None] # minutes per origin, null if unreachable
# ── Endpoints ────────────────────────────────────────────────────────────────
@app.get("/health")
def health() -> dict:
if _transport_network is None:
raise HTTPException(status_code=503, detail="Network not loaded")
return {"status": "ok"}
@app.post("/travel-times", response_model=TravelTimeResponse)
def compute_travel_times(req: TravelTimeRequest) -> TravelTimeResponse:
if _transport_network is None:
raise HTTPException(status_code=503, detail="Network not loaded")
if req.mode not in MODE_CONFIGS:
raise HTTPException(
status_code=400,
detail=f"Invalid mode '{req.mode}'. Must be one of: {list(MODE_CONFIGS)}",
)
if not req.origins:
return TravelTimeResponse(travel_times=[])
# Parse departure time
if req.departure_time:
departure = datetime.datetime.fromisoformat(req.departure_time)
else:
# Default: next weekday at 8:00 AM
now = datetime.datetime.now()
departure = now.replace(hour=8, minute=0, second=0, microsecond=0)
# Advance to next weekday if weekend
while departure.weekday() >= 5:
departure += datetime.timedelta(days=1)
# Build origin GeoDataFrame (note: Point takes (lon, lat))
origin_points = [Point(lon, lat) for lat, lon in req.origins]
origins_gdf = gpd.GeoDataFrame(
{"id": range(len(origin_points))},
geometry=origin_points,
crs="EPSG:4326",
)
# Build destination GeoDataFrame
dest_lat, dest_lon = req.destination
dest_gdf = gpd.GeoDataFrame(
{"id": [0]},
geometry=[Point(dest_lon, dest_lat)],
crs="EPSG:4326",
)
mode_config = MODE_CONFIGS[req.mode]
# r5py 1.x: TravelTimeMatrix is instantiated directly and IS the result
result = r5py.TravelTimeMatrix(
_transport_network,
origins=origins_gdf,
destinations=dest_gdf,
departure=departure,
**mode_config,
)
# Build response: one travel time per origin
# r5py 1.x returns a GeoDataFrame with columns: from_id, to_id, travel_time
travel_times: list[float | None] = [None] * len(req.origins)
for _, row in result.iterrows():
origin_idx = int(row["from_id"])
tt = row["travel_time"]
if pd.notna(tt) and tt >= 0:
travel_times[origin_idx] = float(tt)
return TravelTimeResponse(travel_times=travel_times)