153 lines
4.8 KiB
Python
153 lines
4.8 KiB
Python
"""R5 travel time service — FastAPI wrapper around r5py.
|
|
|
|
Loads an OSM PBF + GTFS feeds at startup, then serves many-to-one
|
|
travel time queries via POST /travel-times.
|
|
"""
|
|
|
|
import datetime
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import geopandas as gpd
|
|
import pandas as pd
|
|
import r5py
|
|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
from shapely.geometry import Point
|
|
|
|
logger = logging.getLogger("r5-service")
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
app = FastAPI(title="R5 Travel Time Service")
|
|
|
|
# Global transport network — loaded once at startup
|
|
_transport_network: r5py.TransportNetwork | None = None
|
|
|
|
|
|
@app.on_event("startup")
|
|
def load_network() -> None:
|
|
global _transport_network
|
|
|
|
data_dir = Path(os.environ.get("TRANSIT_DATA_DIR", "/data/transit"))
|
|
|
|
osm_files = list(data_dir.glob("*.osm.pbf"))
|
|
if not osm_files:
|
|
raise RuntimeError(f"No .osm.pbf file found in {data_dir}")
|
|
osm_pbf = osm_files[0]
|
|
|
|
gtfs_files = list(data_dir.glob("*.zip"))
|
|
logger.info(
|
|
"Loading transport network: OSM=%s, GTFS=%s",
|
|
osm_pbf.name,
|
|
[f.name for f in gtfs_files],
|
|
)
|
|
|
|
_transport_network = r5py.TransportNetwork(
|
|
osm_pbf=osm_pbf,
|
|
gtfs=gtfs_files if gtfs_files else None,
|
|
)
|
|
logger.info("Transport network loaded successfully")
|
|
|
|
|
|
# ── Request / Response models ────────────────────────────────────────────────
|
|
|
|
# r5py 1.x uses transport_modes for direct modes and access_modes for
|
|
# first-mile walking/cycling to transit.
|
|
MODE_CONFIGS = {
|
|
"transit": {
|
|
"transport_modes": [r5py.TransportMode.TRANSIT],
|
|
"access_modes": [r5py.TransportMode.WALK],
|
|
},
|
|
"car": {
|
|
"transport_modes": [r5py.TransportMode.CAR],
|
|
},
|
|
"bicycle": {
|
|
"transport_modes": [r5py.TransportMode.BICYCLE],
|
|
},
|
|
}
|
|
|
|
|
|
class TravelTimeRequest(BaseModel):
|
|
origins: list[list[float]] # [[lat, lon], ...]
|
|
destination: list[float] # [lat, lon]
|
|
mode: str = "transit"
|
|
departure_time: str | None = None # ISO 8601, defaults to next weekday 8am
|
|
|
|
|
|
class TravelTimeResponse(BaseModel):
|
|
travel_times: list[float | None] # minutes per origin, null if unreachable
|
|
|
|
|
|
# ── Endpoints ────────────────────────────────────────────────────────────────
|
|
|
|
|
|
@app.get("/health")
|
|
def health() -> dict:
|
|
if _transport_network is None:
|
|
raise HTTPException(status_code=503, detail="Network not loaded")
|
|
return {"status": "ok"}
|
|
|
|
|
|
@app.post("/travel-times", response_model=TravelTimeResponse)
|
|
def compute_travel_times(req: TravelTimeRequest) -> TravelTimeResponse:
|
|
if _transport_network is None:
|
|
raise HTTPException(status_code=503, detail="Network not loaded")
|
|
|
|
if req.mode not in MODE_CONFIGS:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Invalid mode '{req.mode}'. Must be one of: {list(MODE_CONFIGS)}",
|
|
)
|
|
|
|
if not req.origins:
|
|
return TravelTimeResponse(travel_times=[])
|
|
|
|
# Parse departure time
|
|
if req.departure_time:
|
|
departure = datetime.datetime.fromisoformat(req.departure_time)
|
|
else:
|
|
# Default: next weekday at 8:00 AM
|
|
now = datetime.datetime.now()
|
|
departure = now.replace(hour=8, minute=0, second=0, microsecond=0)
|
|
# Advance to next weekday if weekend
|
|
while departure.weekday() >= 5:
|
|
departure += datetime.timedelta(days=1)
|
|
|
|
# Build origin GeoDataFrame (note: Point takes (lon, lat))
|
|
origin_points = [Point(lon, lat) for lat, lon in req.origins]
|
|
origins_gdf = gpd.GeoDataFrame(
|
|
{"id": range(len(origin_points))},
|
|
geometry=origin_points,
|
|
crs="EPSG:4326",
|
|
)
|
|
|
|
# Build destination GeoDataFrame
|
|
dest_lat, dest_lon = req.destination
|
|
dest_gdf = gpd.GeoDataFrame(
|
|
{"id": [0]},
|
|
geometry=[Point(dest_lon, dest_lat)],
|
|
crs="EPSG:4326",
|
|
)
|
|
|
|
mode_config = MODE_CONFIGS[req.mode]
|
|
|
|
# r5py 1.x: TravelTimeMatrix is instantiated directly and IS the result
|
|
result = r5py.TravelTimeMatrix(
|
|
_transport_network,
|
|
origins=origins_gdf,
|
|
destinations=dest_gdf,
|
|
departure=departure,
|
|
**mode_config,
|
|
)
|
|
|
|
# Build response: one travel time per origin
|
|
# r5py 1.x returns a GeoDataFrame with columns: from_id, to_id, travel_time
|
|
travel_times: list[float | None] = [None] * len(req.origins)
|
|
for _, row in result.iterrows():
|
|
origin_idx = int(row["from_id"])
|
|
tt = row["travel_time"]
|
|
if pd.notna(tt) and tt >= 0:
|
|
travel_times[origin_idx] = float(tt)
|
|
|
|
return TravelTimeResponse(travel_times=travel_times)
|