This commit is contained in:
Andras Schmelczer 2026-02-10 22:04:55 +00:00
parent 1397b6afd5
commit d39d1b15fd
3 changed files with 175 additions and 0 deletions

17
r5-service/Dockerfile Normal file
View file

@ -0,0 +1,17 @@
FROM python:3.12-slim
# r5py needs a JVM to run the R5 routing engine
RUN apt-get update && \
apt-get install -y --no-install-recommends openjdk-21-jre-headless curl libexpat1 libgdal-dev && \
rm -rf /var/lib/apt/lists/*
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY main.py .
EXPOSE 8003
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8003"]

153
r5-service/main.py Normal file
View file

@ -0,0 +1,153 @@
"""R5 travel time service — FastAPI wrapper around r5py.
Loads an OSM PBF + GTFS feeds at startup, then serves many-to-one
travel time queries via POST /travel-times.
"""
import datetime
import logging
import os
from pathlib import Path
import geopandas as gpd
import pandas as pd
import r5py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from shapely.geometry import Point
logger = logging.getLogger("r5-service")
logging.basicConfig(level=logging.INFO)
app = FastAPI(title="R5 Travel Time Service")
# Global transport network — loaded once at startup
_transport_network: r5py.TransportNetwork | None = None
@app.on_event("startup")
def load_network() -> None:
global _transport_network
data_dir = Path(os.environ.get("TRANSIT_DATA_DIR", "/data/transit"))
osm_files = list(data_dir.glob("*.osm.pbf"))
if not osm_files:
raise RuntimeError(f"No .osm.pbf file found in {data_dir}")
osm_pbf = osm_files[0]
gtfs_files = list(data_dir.glob("*.zip"))
logger.info(
"Loading transport network: OSM=%s, GTFS=%s",
osm_pbf.name,
[f.name for f in gtfs_files],
)
_transport_network = r5py.TransportNetwork(
osm_pbf=osm_pbf,
gtfs=gtfs_files if gtfs_files else None,
)
logger.info("Transport network loaded successfully")
# ── Request / Response models ────────────────────────────────────────────────
# r5py 1.x uses transport_modes for direct modes and access_modes for
# first-mile walking/cycling to transit.
MODE_CONFIGS = {
"transit": {
"transport_modes": [r5py.TransportMode.TRANSIT],
"access_modes": [r5py.TransportMode.WALK],
},
"car": {
"transport_modes": [r5py.TransportMode.CAR],
},
"bicycle": {
"transport_modes": [r5py.TransportMode.BICYCLE],
},
}
class TravelTimeRequest(BaseModel):
origins: list[list[float]] # [[lat, lon], ...]
destination: list[float] # [lat, lon]
mode: str = "transit"
departure_time: str | None = None # ISO 8601, defaults to next weekday 8am
class TravelTimeResponse(BaseModel):
travel_times: list[float | None] # minutes per origin, null if unreachable
# ── Endpoints ────────────────────────────────────────────────────────────────
@app.get("/health")
def health() -> dict:
if _transport_network is None:
raise HTTPException(status_code=503, detail="Network not loaded")
return {"status": "ok"}
@app.post("/travel-times", response_model=TravelTimeResponse)
def compute_travel_times(req: TravelTimeRequest) -> TravelTimeResponse:
if _transport_network is None:
raise HTTPException(status_code=503, detail="Network not loaded")
if req.mode not in MODE_CONFIGS:
raise HTTPException(
status_code=400,
detail=f"Invalid mode '{req.mode}'. Must be one of: {list(MODE_CONFIGS)}",
)
if not req.origins:
return TravelTimeResponse(travel_times=[])
# Parse departure time
if req.departure_time:
departure = datetime.datetime.fromisoformat(req.departure_time)
else:
# Default: next weekday at 8:00 AM
now = datetime.datetime.now()
departure = now.replace(hour=8, minute=0, second=0, microsecond=0)
# Advance to next weekday if weekend
while departure.weekday() >= 5:
departure += datetime.timedelta(days=1)
# Build origin GeoDataFrame (note: Point takes (lon, lat))
origin_points = [Point(lon, lat) for lat, lon in req.origins]
origins_gdf = gpd.GeoDataFrame(
{"id": range(len(origin_points))},
geometry=origin_points,
crs="EPSG:4326",
)
# Build destination GeoDataFrame
dest_lat, dest_lon = req.destination
dest_gdf = gpd.GeoDataFrame(
{"id": [0]},
geometry=[Point(dest_lon, dest_lat)],
crs="EPSG:4326",
)
mode_config = MODE_CONFIGS[req.mode]
# r5py 1.x: TravelTimeMatrix is instantiated directly and IS the result
result = r5py.TravelTimeMatrix(
_transport_network,
origins=origins_gdf,
destinations=dest_gdf,
departure=departure,
**mode_config,
)
# Build response: one travel time per origin
# r5py 1.x returns a GeoDataFrame with columns: from_id, to_id, travel_time
travel_times: list[float | None] = [None] * len(req.origins)
for _, row in result.iterrows():
origin_idx = int(row["from_id"])
tt = row["travel_time"]
if pd.notna(tt) and tt >= 0:
travel_times[origin_idx] = float(tt)
return TravelTimeResponse(travel_times=travel_times)

View file

@ -0,0 +1,5 @@
r5py>=0.8
fastapi>=0.115
uvicorn>=0.34
geopandas>=1.0
shapely>=2.0