perfect-postcode/pipeline/download/transit_network.py

1153 lines
43 KiB
Python

"""Download and prepare transit network data for R5 routing.
Downloads:
- England OSM PBF from Geofabrik (~1.5GB)
- BODS GTFS from Bus Open Data Service (~1.5GB; all England bus/tram/ferry,
plus London Underground, DLR, London Tramlink and the IFS Cloud Cable Car)
- National Rail CIF timetable → converted to GTFS (requires credentials;
includes the Elizabeth line, TOC "XR")
Then processes for R5 compatibility:
- Cleans BODS GTFS (fixes stop_times >72h, feed_info year >2100)
- Converts high-frequency metro/tram services to frequency-based GTFS
- Converts National Rail CIF to GTFS via dtd2mysql (requires MariaDB Docker)
- Validates every produced GTFS zip (active calendar window, plausible UK
stop coordinates, non-empty routes/trips/stop_times)
Note: the legacy TfL TransXChange feed (tfl.gov.uk journey-planner-timetables)
was removed: that URL serves a 2010-10-28 snapshot whose calendars all expired
in 2010 and whose stops have empty/0,0 coordinates, so it contributed zero
service. BODS covers all TfL modes that feed nominally provided.
Requires: osmium-tool, Docker (for national rail)
Output directory: property-data/transit/
raw/england.osm.pbf + bods_gtfs.zip + national_rail_gtfs.zip
"""
import argparse
import csv
import datetime as dt
import io
import json
import os
import shutil
import statistics
import subprocess
import tempfile
import time
import urllib.parse
import urllib.request
import zipfile
from collections import defaultdict
from pathlib import Path
from tqdm import tqdm
from pipeline.local_temp import local_tmp_dir
ENGLAND_PBF_URL = (
"https://download.geofabrik.de/europe/united-kingdom/england-latest.osm.pbf"
)
# Bus Open Data Service — pre-converted GTFS covering all England bus/tram/ferry
BODS_GTFS_URL = "https://data.bus-data.dft.gov.uk/timetable/download/gtfs-file/all/"
# National Rail Open Data API
NR_AUTH_URL = "https://opendata.nationalrail.co.uk/authenticate"
NR_TIMETABLE_URL = "https://opendata.nationalrail.co.uk/api/staticfeeds/3.0/timetable"
USER_AGENT = "property-map-pipeline/1.0 (https://github.com)"
# GTFS validation: a feed must have service within this many days of the build
# date, and at least this fraction of stops must have plausible UK coordinates.
GTFS_CALENDAR_LOOKAHEAD_DAYS = 60
GTFS_MIN_VALID_STOP_FRACTION = 0.95
UK_LAT_RANGE = (49.0, 61.0)
UK_LON_RANGE = (-9.0, 2.5)
def _download_http(
url: str, dest: Path, *, desc: str, headers: dict | None = None
) -> None:
"""Stream-download a URL to a file with progress bar."""
dest.parent.mkdir(parents=True, exist_ok=True)
tmp = dest.with_suffix(dest.suffix + ".tmp")
req_headers = {"User-Agent": USER_AGENT}
if headers:
req_headers.update(headers)
req = urllib.request.Request(url, headers=req_headers)
with (
tqdm(unit="B", unit_scale=True, desc=desc) as bar,
urllib.request.urlopen(req) as resp,
open(tmp, "wb") as f,
):
length = resp.headers.get("Content-Length")
if length:
bar.total = int(length)
while chunk := resp.read(1 << 20):
f.write(chunk)
bar.update(len(chunk))
tmp.rename(dest)
print(f" Saved to {dest}")
def download_osm_pbf(output_dir: Path) -> Path:
"""Download England OSM PBF extract from Geofabrik."""
dest = output_dir / "england.osm.pbf"
if dest.exists():
print(f"OSM PBF already exists: {dest}")
return dest
print("Downloading England OSM PBF (~1.5 GB)...")
_download_http(ENGLAND_PBF_URL, dest, desc="england.osm.pbf")
return dest
def download_bods_gtfs(output_dir: Path) -> Path:
"""Download BODS GTFS (all England bus/tram/ferry timetables)."""
dest = output_dir / "bods_gtfs_raw.zip"
if dest.exists():
print(f"BODS GTFS already exists: {dest}")
return dest
print("Downloading BODS GTFS (~1.5 GB)...")
_download_http(BODS_GTFS_URL, dest, desc="bods_gtfs_raw.zip")
return dest
def _parse_csv_line(line: bytes | str) -> list[str]:
"""Parse a single GTFS CSV record."""
if isinstance(line, bytes):
line = line.decode("utf-8", errors="replace")
line = line.rstrip("\r\n")
if not line:
return []
return next(csv.reader([line]))
def _format_csv_row(parts: list[str]) -> bytes:
"""Serialize one GTFS CSV row with stable LF line endings."""
output = io.StringIO()
csv.writer(output, lineterminator="\n").writerow(parts)
return output.getvalue().encode("utf-8")
def _format_csv_rows(rows: list[list[str]]) -> str:
output = io.StringIO()
writer = csv.writer(output, lineterminator="\n")
writer.writerows(rows)
return output.getvalue()
def clean_gtfs(src: Path, dst: Path) -> None:
"""Fix R5-incompatible entries in GTFS.
- Removes stop_times with arrival/departure hour > 72
- Caps feed_info end_date year to 2099
"""
if dst.exists():
print(f"Cleaned GTFS already exists: {dst}")
return
print("Cleaning GTFS for R5 compatibility...")
with (
zipfile.ZipFile(src, "r") as zin,
zipfile.ZipFile(dst, "w", zipfile.ZIP_DEFLATED) as zout,
):
for info in zin.infolist():
if info.filename == "stop_times.txt":
dropped = 0
with zin.open(info) as f:
header = f.readline()
cols = _parse_csv_line(header)
arr_idx = (
cols.index("arrival_time") if "arrival_time" in cols else -1
)
dep_idx = (
cols.index("departure_time") if "departure_time" in cols else -1
)
tmp = tempfile.NamedTemporaryFile(
mode="wb",
delete=False,
suffix=".txt",
dir=local_tmp_dir(),
)
tmp.write(header)
for line in f:
parts = _parse_csv_line(line)
if not parts:
continue
skip = False
for idx in [arr_idx, dep_idx]:
if 0 <= idx < len(parts):
time_val = parts[idx].strip('"')
if ":" in time_val:
try:
hour = int(time_val.split(":")[0])
if hour > 72:
skip = True
break
except ValueError:
pass
if skip:
dropped += 1
else:
tmp.write(line)
tmp.close()
print(f" stop_times: dropped {dropped} rows with hours > 72")
zout.write(tmp.name, "stop_times.txt")
os.unlink(tmp.name)
elif info.filename == "feed_info.txt":
data = zin.read(info).decode("utf-8")
rows = list(csv.reader(io.StringIO(data)))
if not rows:
zout.writestr("feed_info.txt", data)
continue
feed_cols = rows[0]
fixed_rows = [feed_cols]
for parts in rows[1:]:
for i, col_name in enumerate(feed_cols):
if "end_date" in col_name.lower() and i < len(parts):
date_val = parts[i].strip('"')
if len(date_val) == 8:
year = int(date_val[:4])
if year > 2100:
parts[i] = "20991231"
print(
f" feed_info: capped end_date {date_val} → 20991231"
)
fixed_rows.append(parts)
zout.writestr("feed_info.txt", _format_csv_rows(fixed_rows))
else:
zout.writestr(info, zin.read(info))
print(f" Saved to {dst}")
def _parse_gtfs_time(time_str: str) -> int | None:
"""Parse HH:MM:SS to seconds since midnight. Returns None on failure."""
time_str = time_str.strip('"')
if ":" not in time_str:
return None
try:
h, m, s = time_str.split(":")
return int(h) * 3600 + int(m) * 60 + int(s)
except ValueError:
return None
def _secs_to_gtfs_time(s: int) -> str:
"""Convert seconds since midnight to HH:MM:SS."""
h = s // 3600
m = (s % 3600) // 60
sec = s % 60
return f"{h:02d}:{m:02d}:{sec:02d}"
def convert_high_freq_to_frequency_based(
src: Path, dst: Path, *, max_headway_minutes: int = 15
) -> None:
"""Convert high-frequency scheduled services to frequency-based GTFS entries.
Identifies metro (route_type=1) and tram (route_type=0) routes with regular
headways under max_headway_minutes, then creates frequencies.txt entries and
removes redundant trips. R5's RAPTOR produces smoother percentile results for
frequency-based services, matching the "just turn up" reality of high-frequency
metro/tram services.
"""
if dst.exists():
print(f"Frequency-converted GTFS already exists: {dst}")
return
print("Converting high-frequency services to frequency-based...")
max_headway_secs = max_headway_minutes * 60
with zipfile.ZipFile(src, "r") as zin:
# Step 1: Find metro/tram route IDs
target_route_ids: set[str] = set()
with zin.open("routes.txt") as f:
cols = _parse_csv_line(f.readline())
route_id_idx = cols.index("route_id")
rt_idx = cols.index("route_type")
for line in f:
parts = _parse_csv_line(line)
if not parts:
continue
route_type = parts[rt_idx].strip('"')
if route_type in ("0", "1"): # tram, metro/subway
target_route_ids.add(parts[route_id_idx].strip('"'))
if not target_route_ids:
print(" No metro/tram routes found, copying unchanged")
shutil.copy2(src, dst)
return
print(f" Found {len(target_route_ids)} metro/tram routes")
# Step 2: Map target trips to grouping keys
trip_group_key: dict[str, tuple[str, str, str]] = {}
with zin.open("trips.txt") as f:
cols = _parse_csv_line(f.readline())
trip_id_idx = cols.index("trip_id")
route_id_idx = cols.index("route_id")
dir_idx = cols.index("direction_id") if "direction_id" in cols else -1
service_idx = cols.index("service_id")
for line in f:
parts = _parse_csv_line(line)
if not parts:
continue
route_id = parts[route_id_idx].strip('"')
if route_id in target_route_ids:
trip_id = parts[trip_id_idx].strip('"')
direction = parts[dir_idx].strip('"') if dir_idx >= 0 else "0"
service_id = parts[service_idx].strip('"')
trip_group_key[trip_id] = (route_id, direction, service_id)
print(f" Found {len(trip_group_key)} trips on target routes")
# Step 3: Get first departure time and first stop for each target trip.
# GTFS only requires stop_sequence to be strictly increasing per trip; it
# is NOT required to start at 0 (1-based numbering is common, and BODS is
# consumed raw here without renumbering). So pick the row with the minimum
# stop_sequence per trip rather than keying off the literal "0".
trip_first_dep: dict[str, int] = {}
trip_first_stop: dict[str, str] = {}
trip_min_seq: dict[str, int] = {}
with zin.open("stop_times.txt") as f:
cols = _parse_csv_line(f.readline())
trip_id_idx = cols.index("trip_id")
dep_idx = cols.index("departure_time")
seq_idx = cols.index("stop_sequence")
stop_id_idx = cols.index("stop_id")
for line in f:
parts = _parse_csv_line(line)
if not parts:
continue
trip_id = parts[trip_id_idx].strip('"')
if trip_id not in trip_group_key:
continue
try:
seq = int(parts[seq_idx].strip('"'))
except ValueError:
continue
if trip_id in trip_min_seq and seq >= trip_min_seq[trip_id]:
continue
dep_secs = _parse_gtfs_time(parts[dep_idx])
if dep_secs is None:
continue
trip_min_seq[trip_id] = seq
trip_first_dep[trip_id] = dep_secs
trip_first_stop[trip_id] = parts[stop_id_idx].strip('"')
if trip_group_key and not trip_first_dep:
raise RuntimeError(
"convert_high_freq_to_frequency_based found no first stops for "
f"{len(trip_group_key)} target trips; stop_times.txt may be malformed "
"or stop_sequence parsing failed"
)
# Step 4: Group trips by (route, direction, service, first_stop) and compute headways
groups: dict[tuple[str, ...], list[tuple[str, int]]] = defaultdict(list)
for trip_id, dep_secs in trip_first_dep.items():
route_id, direction, service_id = trip_group_key[trip_id]
first_stop = trip_first_stop.get(trip_id, "")
key = (route_id, direction, service_id, first_stop)
groups[key].append((trip_id, dep_secs))
trips_to_remove: set[str] = set()
frequency_entries: list[tuple[str, int, int, int]] = []
groups_converted = 0
for _key, trips in groups.items():
if len(trips) < 4:
continue
trips.sort(key=lambda x: x[1])
headways = [trips[i + 1][1] - trips[i][1] for i in range(len(trips) - 1)]
headways = [h for h in headways if h > 0]
if len(headways) < 3:
continue
median_hw = statistics.median(headways)
if median_hw > max_headway_secs or median_hw < 30:
continue
mean_hw = statistics.mean(headways)
if mean_hw == 0:
continue
stdev_hw = statistics.stdev(headways) if len(headways) > 1 else 0
if stdev_hw / mean_hw > 0.5:
continue
# Convert: keep first trip as template, remove the rest
template_trip_id = trips[0][0]
start_secs = trips[0][1]
end_secs = trips[-1][1] + int(median_hw)
headway_rounded = max(60, round(median_hw / 60) * 60)
frequency_entries.append(
(template_trip_id, start_secs, end_secs, headway_rounded)
)
for trip_id, _ in trips[1:]:
trips_to_remove.add(trip_id)
groups_converted += 1
print(f" Converted {groups_converted} trip groups to frequency-based")
print(f" Removing {len(trips_to_remove)} redundant trips")
print(f" Created {len(frequency_entries)} frequency entries")
# Step 5: Write modified GTFS
with (
zipfile.ZipFile(src, "r") as zin,
zipfile.ZipFile(dst, "w", zipfile.ZIP_DEFLATED) as zout,
):
for info in zin.infolist():
if info.filename == "trips.txt":
with zin.open(info) as f:
header = f.readline()
cols = _parse_csv_line(header)
trip_id_idx = cols.index("trip_id")
tmp = tempfile.NamedTemporaryFile(
mode="wb",
delete=False,
suffix=".txt",
dir=local_tmp_dir(),
)
tmp.write(header)
for line in f:
parts = _parse_csv_line(line)
if not parts:
continue
if parts[trip_id_idx].strip('"') not in trips_to_remove:
tmp.write(line)
tmp.close()
zout.write(tmp.name, "trips.txt")
os.unlink(tmp.name)
elif info.filename == "stop_times.txt":
with zin.open(info) as f:
header = f.readline()
cols = _parse_csv_line(header)
trip_id_idx = cols.index("trip_id")
tmp = tempfile.NamedTemporaryFile(
mode="wb",
delete=False,
suffix=".txt",
dir=local_tmp_dir(),
)
tmp.write(header)
for line in f:
parts = _parse_csv_line(line)
if not parts:
continue
if parts[trip_id_idx].strip('"') not in trips_to_remove:
tmp.write(line)
tmp.close()
zout.write(tmp.name, "stop_times.txt")
os.unlink(tmp.name)
elif info.filename == "frequencies.txt":
pass # we'll write our own below
else:
zout.writestr(info, zin.read(info))
# Write frequencies.txt
freq_lines = ["trip_id,start_time,end_time,headway_secs,exact_times\n"]
for trip_id, start, end, headway in frequency_entries:
freq_lines.append(
f"{trip_id},{_secs_to_gtfs_time(start)},{_secs_to_gtfs_time(end)},{headway},0\n"
)
zout.writestr("frequencies.txt", "".join(freq_lines))
print(f" Saved to {dst}")
def _gtfs_has_data_row(z: zipfile.ZipFile, filename: str) -> bool:
"""True if a GTFS file has at least one non-empty data row after the header."""
with z.open(filename) as f:
f.readline() # header
for line in f:
if _parse_csv_line(line):
return True
return False
def _calendar_active_in_window(
z: zipfile.ZipFile, names: set[str], window_start: int, window_end: int
) -> bool:
"""True if calendar.txt/calendar_dates.txt have service in [start, end].
Dates are compared as YYYYMMDD integers. A calendar.txt row counts when its
date range overlaps the window AND at least one weekday flag is set; a
calendar_dates.txt row counts when it adds service (exception_type=1) on a
date inside the window.
"""
weekdays = (
"monday",
"tuesday",
"wednesday",
"thursday",
"friday",
"saturday",
"sunday",
)
if "calendar.txt" in names:
with z.open("calendar.txt") as f:
cols = _parse_csv_line(f.readline())
try:
start_idx = cols.index("start_date")
end_idx = cols.index("end_date")
except ValueError:
return False
day_idxs = [cols.index(d) for d in weekdays if d in cols]
for line in f:
parts = _parse_csv_line(line)
if not parts:
continue
try:
start = int(parts[start_idx].strip('"'))
end = int(parts[end_idx].strip('"'))
except (ValueError, IndexError):
continue
if start > window_end or end < window_start:
continue
if day_idxs and not any(
parts[i].strip('"') == "1" for i in day_idxs if i < len(parts)
):
continue
return True
if "calendar_dates.txt" in names:
with z.open("calendar_dates.txt") as f:
cols = _parse_csv_line(f.readline())
try:
date_idx = cols.index("date")
exc_idx = cols.index("exception_type")
except ValueError:
return False
for line in f:
parts = _parse_csv_line(line)
if not parts:
continue
try:
date = int(parts[date_idx].strip('"'))
except (ValueError, IndexError):
continue
if exc_idx < len(parts) and parts[exc_idx].strip('"') != "1":
continue
if window_start <= date <= window_end:
return True
return False
def validate_gtfs_feed(path: Path, feed_name: str, *, today: dt.date | None = None) -> None:
"""Sanity-check a produced/downloaded GTFS zip; raise RuntimeError if dead.
Guards against silently shipping a feed that contributes zero service (as
the old TfL dump did: 2010 calendars, empty/0,0 stop coordinates). Checks:
(a) calendar.txt/calendar_dates.txt have at least one service active
within [today, today + GTFS_CALENDAR_LOOKAHEAD_DAYS];
(b) stops.txt is non-empty and >= GTFS_MIN_VALID_STOP_FRACTION of stops
have plausible UK coordinates (lat 49-61, lon -9..2.5, not 0,0);
(c) routes.txt, trips.txt and stop_times.txt each have data rows.
"""
if today is None:
today = dt.date.today()
window_start = int(today.strftime("%Y%m%d"))
window_end = int(
(today + dt.timedelta(days=GTFS_CALENDAR_LOOKAHEAD_DAYS)).strftime("%Y%m%d")
)
def fail(reason: str) -> None:
raise RuntimeError(
f"GTFS validation failed for feed '{feed_name}' ({path}): {reason}"
)
print(f"Validating GTFS feed '{feed_name}'...")
if not path.exists() or not zipfile.is_zipfile(path):
fail("not a valid zip file")
with zipfile.ZipFile(path) as z:
names = set(z.namelist())
# (c) core files present and non-empty
for required in ("routes.txt", "trips.txt", "stop_times.txt", "stops.txt"):
if required not in names:
fail(f"missing {required}")
if not _gtfs_has_data_row(z, required):
fail(f"{required} has no data rows")
# (a) at least one service active in the routing window
if "calendar.txt" not in names and "calendar_dates.txt" not in names:
fail("has neither calendar.txt nor calendar_dates.txt")
if not _calendar_active_in_window(z, names, window_start, window_end):
fail(
f"no service active between {window_start} and {window_end}"
"the feed's calendars are stale/expired and it would contribute "
"zero service to routing"
)
# (b) stops have plausible UK coordinates
total_stops = 0
valid_stops = 0
with z.open("stops.txt") as f:
cols = _parse_csv_line(f.readline())
try:
lat_idx = cols.index("stop_lat")
lon_idx = cols.index("stop_lon")
except ValueError:
fail("stops.txt is missing stop_lat/stop_lon columns")
for line in f:
parts = _parse_csv_line(line)
if not parts:
continue
total_stops += 1
try:
lat = float(parts[lat_idx].strip('"'))
lon = float(parts[lon_idx].strip('"'))
except (ValueError, IndexError):
continue # empty/garbage coordinate → invalid
if lat == 0.0 and lon == 0.0:
continue
if (
UK_LAT_RANGE[0] <= lat <= UK_LAT_RANGE[1]
and UK_LON_RANGE[0] <= lon <= UK_LON_RANGE[1]
):
valid_stops += 1
if total_stops == 0:
fail("stops.txt has no stops")
fraction = valid_stops / total_stops
if fraction < GTFS_MIN_VALID_STOP_FRACTION:
fail(
f"only {valid_stops}/{total_stops} stops "
f"({fraction:.1%}) have plausible UK coordinates "
f"(lat {UK_LAT_RANGE[0]}-{UK_LAT_RANGE[1]}, "
f"lon {UK_LON_RANGE[0]}..{UK_LON_RANGE[1]}, non-null, not 0,0); "
f"need >= {GTFS_MIN_VALID_STOP_FRACTION:.0%}"
)
print(
f" OK: service active in window, {valid_stops}/{total_stops} stops "
f"({fraction:.1%}) with plausible UK coordinates"
)
def download_national_rail_cif(raw_dir: Path) -> Path | None:
"""Download National Rail CIF timetable (requires credentials)."""
dest = raw_dir / "national_rail_cif.zip"
if dest.exists():
print(f"National Rail CIF already exists: {dest}")
return dest
# Free National Rail Open Data account; env vars override the baked-in default.
email = os.environ.get("NATIONAL_RAIL_EMAIL", "schmelczerandras@gmail.com")
password = os.environ.get("NATIONAL_RAIL_PASSWORD", "z8^b!4GhCS8kj1Vp")
if not email or not password:
print(
"Warning: NATIONAL_RAIL_EMAIL/NATIONAL_RAIL_PASSWORD not set, skipping national rail"
)
return None
print("Authenticating with National Rail Open Data...")
auth_data = urllib.parse.urlencode(
{"username": email, "password": password}
).encode()
auth_req = urllib.request.Request(
NR_AUTH_URL,
data=auth_data,
headers={
"User-Agent": USER_AGENT,
"Content-Type": "application/x-www-form-urlencoded",
},
)
with urllib.request.urlopen(auth_req) as resp:
token_data = json.loads(resp.read())
token = token_data["token"]
print(" Authenticated successfully")
print("Downloading National Rail CIF timetable...")
_download_http(
NR_TIMETABLE_URL,
dest,
desc="national_rail_cif.zip",
headers={"X-Auth-Token": token},
)
return dest
def clean_national_rail_gtfs(src: Path, dst: Path) -> None:
"""Fix R5-incompatible entries in dtd2mysql-generated National Rail GTFS.
Fixes:
- Interior pass-through stops (pickup_type=1, drop_off_type=1) → normal stops.
R5 builds TripPatterns from the full stop sequence but may build shorter
TripSchedules when stops are non-boarding, causing ArrayIndexOutOfBoundsException.
- Removes stop_times referencing stops not in stops.txt.
- Removes trips with backwards travel times.
- Converts route_type=714 (rail replacement bus) to 3 (bus) for R5 compatibility.
- Removes non-standard links.txt file.
- Renumbers stop_sequence to 0-based (R5/BODS convention).
- Fixes bogus coordinates (lat < 0) on Irish CIE stations.
"""
if dst.exists():
print(f"Cleaned National Rail GTFS already exists: {dst}")
return
print("Cleaning National Rail GTFS for R5 compatibility...")
# First pass: collect valid stop IDs and find bad trips
stop_ids: set[str] = set()
bad_trip_ids: set[str] = set()
with zipfile.ZipFile(src, "r") as zin:
# Load valid stop IDs
with zin.open("stops.txt") as f:
cols = _parse_csv_line(f.readline())
stop_id_idx = cols.index("stop_id")
for line in f:
parts = _parse_csv_line(line)
if parts:
stop_ids.add(parts[stop_id_idx])
# Find trips with backwards travel times
with zin.open("stop_times.txt") as f:
st_cols = _parse_csv_line(f.readline())
trip_id_idx = st_cols.index("trip_id")
dep_idx = st_cols.index("departure_time")
prev_trip = ""
prev_dep_secs = -1
for line in f:
parts = _parse_csv_line(line)
if not parts:
continue
trip_id = parts[trip_id_idx].strip('"')
if trip_id != prev_trip:
prev_trip = trip_id
prev_dep_secs = -1
dep_str = parts[dep_idx].strip('"')
if ":" in dep_str:
try:
h, m, s = dep_str.split(":")
dep_secs = int(h) * 3600 + int(m) * 60 + int(s)
if dep_secs < prev_dep_secs:
bad_trip_ids.add(trip_id)
prev_dep_secs = dep_secs
except ValueError:
pass
print(f" Found {len(bad_trip_ids)} trips with backwards travel times")
# Second pass: write cleaned zip
passthrough_fixed = 0
orphan_stops_removed = 0
bad_trips_removed = 0
seqs_renumbered = 0
coords_fixed = 0
route_types_fixed = 0
with (
zipfile.ZipFile(src, "r") as zin,
zipfile.ZipFile(dst, "w", zipfile.ZIP_DEFLATED) as zout,
):
for info in zin.infolist():
# Skip non-standard links.txt
if info.filename == "links.txt":
continue
if info.filename == "stop_times.txt":
with zin.open(info) as f:
header = f.readline()
cols = _parse_csv_line(header)
trip_id_idx = cols.index("trip_id")
stop_id_idx = cols.index("stop_id")
seq_idx = cols.index("stop_sequence")
pickup_idx = (
cols.index("pickup_type") if "pickup_type" in cols else -1
)
dropoff_idx = (
cols.index("drop_off_type") if "drop_off_type" in cols else -1
)
tmp = tempfile.NamedTemporaryFile(
mode="wb",
delete=False,
suffix=".txt",
dir=local_tmp_dir(),
)
tmp.write(header)
prev_trip = ""
seq_counter = 0
for line in f:
parts = _parse_csv_line(line)
if not parts:
continue
trip_id = parts[trip_id_idx].strip('"')
stop_id = parts[stop_id_idx].strip('"')
# Skip trips with backwards times
if trip_id in bad_trip_ids:
bad_trips_removed += 1
continue
# Skip stop_times referencing missing stops
if stop_id not in stop_ids:
orphan_stops_removed += 1
continue
# Fix pass-through stops: set pickup/dropoff to 0 (normal)
if pickup_idx >= 0 and dropoff_idx >= 0:
pickup = parts[pickup_idx].strip('"')
dropoff = parts[dropoff_idx].strip('"')
if pickup == "1" and dropoff == "1":
parts[pickup_idx] = "0"
parts[dropoff_idx] = "0"
passthrough_fixed += 1
# Renumber stop_sequence to 0-based
if trip_id != prev_trip:
prev_trip = trip_id
seq_counter = 0
else:
seq_counter += 1
old_seq = parts[seq_idx].strip('"')
parts[seq_idx] = str(seq_counter)
if old_seq != str(seq_counter):
seqs_renumbered += 1
tmp.write(_format_csv_row(parts))
tmp.close()
zout.write(tmp.name, "stop_times.txt")
os.unlink(tmp.name)
elif info.filename == "stops.txt":
with zin.open(info) as f:
header = f.readline()
cols = _parse_csv_line(header)
lat_idx = cols.index("stop_lat")
lon_idx = cols.index("stop_lon")
tmp = tempfile.NamedTemporaryFile(
mode="wb",
delete=False,
suffix=".txt",
dir=local_tmp_dir(),
)
tmp.write(header)
for line in f:
parts = _parse_csv_line(line)
if not parts:
continue
try:
lat = float(parts[lat_idx])
# Fix bogus Irish CIE coordinates (South Atlantic)
if lat < 0:
# Set to a neutral UK coordinate that won't be routed to
parts[lat_idx] = "54.0"
parts[lon_idx] = "-2.0"
coords_fixed += 1
except ValueError:
pass
tmp.write(_format_csv_row(parts))
tmp.close()
zout.write(tmp.name, "stops.txt")
os.unlink(tmp.name)
elif info.filename == "routes.txt":
with zin.open(info) as f:
header = f.readline()
cols = _parse_csv_line(header)
rt_idx = cols.index("route_type")
tmp = tempfile.NamedTemporaryFile(
mode="wb",
delete=False,
suffix=".txt",
dir=local_tmp_dir(),
)
tmp.write(header)
for line in f:
parts = _parse_csv_line(line)
if not parts:
continue
if parts[rt_idx].strip('"') == "714":
parts[rt_idx] = "3"
route_types_fixed += 1
tmp.write(_format_csv_row(parts))
tmp.close()
zout.write(tmp.name, "routes.txt")
os.unlink(tmp.name)
elif info.filename == "trips.txt":
# Remove trips that have backwards travel times
with zin.open(info) as f:
header = f.readline()
cols = _parse_csv_line(header)
trip_id_idx = cols.index("trip_id")
tmp = tempfile.NamedTemporaryFile(
mode="wb",
delete=False,
suffix=".txt",
dir=local_tmp_dir(),
)
tmp.write(header)
for line in f:
parts = _parse_csv_line(line)
if not parts:
continue
if parts[trip_id_idx].strip('"') not in bad_trip_ids:
tmp.write(line)
tmp.close()
zout.write(tmp.name, "trips.txt")
os.unlink(tmp.name)
elif info.filename == "calendar.txt":
# Cap end_date year to 2099
with zin.open(info) as f:
header = f.readline()
cols = _parse_csv_line(header)
end_idx = cols.index("end_date")
tmp = tempfile.NamedTemporaryFile(
mode="wb",
delete=False,
suffix=".txt",
dir=local_tmp_dir(),
)
tmp.write(header)
for line in f:
parts = _parse_csv_line(line)
if not parts:
continue
date_val = parts[end_idx].strip('"')
if len(date_val) == 8:
try:
year = int(date_val[:4])
if year > 2099:
parts[end_idx] = "20991231"
except ValueError:
pass
tmp.write(_format_csv_row(parts))
tmp.close()
zout.write(tmp.name, "calendar.txt")
os.unlink(tmp.name)
else:
zout.writestr(info, zin.read(info))
print(f" Pass-through stops fixed: {passthrough_fixed}")
print(f" Orphan stop references removed: {orphan_stops_removed}")
print(f" Bad trip stop_times removed: {bad_trips_removed}")
print(f" Stop sequences renumbered: {seqs_renumbered}")
print(f" Bogus coordinates fixed: {coords_fixed}")
print(f" Route types 714→3 fixed: {route_types_fixed}")
print(f" Saved to {dst}")
def _docker_run_dtd2mysql(
network: str, db_container: str, volumes: list[str], args: list[str]
) -> None:
"""Run dtd2mysql in a Node.js container on the same Docker network as MariaDB."""
cmd = [
"docker",
"run",
"--rm",
"--network",
network,
"-e",
f"DATABASE_HOSTNAME={db_container}",
"-e",
"DATABASE_USERNAME=root",
"-e",
"DATABASE_PASSWORD=root",
"-e",
"DATABASE_NAME=dtd",
]
for v in volumes:
cmd.extend(["-v", v])
# Install zip (needed for --gtfs-zip) then run dtd2mysql
inner = (
"apt-get update -qq && apt-get install -y -qq zip > /dev/null 2>&1 && npx --yes dtd2mysql "
+ " ".join(args)
)
cmd.extend(["node:20", "bash", "-c", inner])
subprocess.run(cmd, check=True)
def convert_national_rail_to_gtfs(raw_dir: Path, output_dir: Path) -> Path:
"""Convert National Rail CIF to GTFS using dtd2mysql + MariaDB Docker.
Runs both MariaDB and dtd2mysql as Docker containers on a shared network,
since Docker port forwarding is not available in all environments.
Then cleans the output for R5 compatibility.
"""
dest = output_dir / "national_rail_gtfs.zip"
if dest.exists():
print(f"National Rail GTFS already exists: {dest}")
return dest
raw_dest = raw_dir / "national_rail_gtfs_raw.zip"
if not raw_dest.exists():
db_container = "propertymap-mariadb-temp"
network = "propertymap-dtd-net"
print("Creating Docker network and starting MariaDB...")
subprocess.run(["docker", "network", "create", network], capture_output=True)
subprocess.run(
[
"docker",
"run",
"-d",
"--name",
db_container,
"--network",
network,
"-e",
"MARIADB_ROOT_PASSWORD=root",
"-e",
"MARIADB_DATABASE=dtd",
"mariadb:latest",
],
check=True,
)
try:
# Wait for MariaDB to be ready
print(" Waiting for MariaDB to be ready...")
for attempt in range(30):
result = subprocess.run(
[
"docker",
"exec",
db_container,
"mariadb",
"-uroot",
"-proot",
"-e",
"SELECT 1",
],
capture_output=True,
)
if result.returncode == 0:
break
time.sleep(2)
else:
raise RuntimeError("MariaDB did not become ready in time")
raw_abs = str(raw_dir.resolve())
print("Importing CIF timetable into MariaDB...")
_docker_run_dtd2mysql(
network,
db_container,
volumes=[f"{raw_abs}:/data:ro"],
args=["--timetable", "/data/national_rail_cif.zip"],
)
print("Exporting GTFS from MariaDB...")
_docker_run_dtd2mysql(
network,
db_container,
volumes=[f"{raw_abs}:/output"],
args=["--gtfs-zip", "/output/national_rail_gtfs_raw.zip"],
)
finally:
print("Cleaning up Docker resources...")
subprocess.run(["docker", "stop", db_container], capture_output=True)
subprocess.run(["docker", "rm", db_container], capture_output=True)
subprocess.run(["docker", "network", "rm", network], capture_output=True)
# Clean the raw GTFS for R5 compatibility
clean_national_rail_gtfs(raw_dest, dest)
return dest
def main() -> None:
parser = argparse.ArgumentParser(
description="Download and prepare transit network data for R5 routing engine"
)
parser.add_argument(
"--output",
type=Path,
required=True,
help="Output directory for transit data",
)
args = parser.parse_args()
output_dir: Path = args.output
raw_dir = output_dir / "raw"
raw_dir.mkdir(parents=True, exist_ok=True)
# 1. Download, clean, and frequency-convert BODS GTFS. BODS covers all
# England bus/tram/ferry plus London Underground, DLR, London Tramlink and
# the IFS Cloud Cable Car, so no separate TfL feed is needed.
download_osm_pbf(raw_dir)
bods_raw = download_bods_gtfs(raw_dir)
bods_cleaned = raw_dir / "bods_gtfs_cleaned.zip"
clean_gtfs(bods_raw, bods_cleaned)
bods_final = output_dir / "bods_gtfs.zip"
convert_high_freq_to_frequency_based(bods_cleaned, bods_final)
validate_gtfs_feed(bods_final, "BODS GTFS")
# 2. National Rail CIF → GTFS. Heavy rail is mandatory: trains are how people
# reach the ~2,725 railway-station destinations, so a bus/metro-only network
# silently overstates every train commute. Missing credentials are a HARD
# error, so a rail-less network can never ship.
cif = download_national_rail_cif(raw_dir)
if cif is None:
raise RuntimeError(
"National Rail timetable was not downloaded — set "
"NATIONAL_RAIL_EMAIL / NATIONAL_RAIL_PASSWORD (register free at "
"https://opendata.nationalrail.co.uk/). National Rail heavy rail is "
"required; without it the transit network models every train journey "
"as bus-only and overstates commute times."
)
nr_final = convert_national_rail_to_gtfs(raw_dir, output_dir)
validate_gtfs_feed(nr_final, "National Rail GTFS")
# Summary
print()
print("Transit data ready for R5:")
for f in sorted(output_dir.iterdir()):
if f.is_dir() or f.name.startswith("."):
continue
size_mb = f.stat().st_size / (1024 * 1024)
print(f" {f.name}: {size_mb:.1f} MB")
print()
print("IMPORTANT: If you previously built a network from London-only data,")
print("delete the stale cache before running R5:")
print(" rm -f property-data/r5-network/network.dat")
if __name__ == "__main__":
main()