This commit is contained in:
Andras Schmelczer 2026-03-15 21:22:28 +00:00
parent 479ef92236
commit c38d654ac7
44 changed files with 2526 additions and 701 deletions

View file

@ -56,7 +56,9 @@ NR_TIMETABLE_URL = "https://opendata.nationalrail.co.uk/api/staticfeeds/3.0/time
USER_AGENT = "property-map-pipeline/1.0 (https://github.com)"
def _download_http(url: str, dest: Path, *, desc: str, headers: dict | None = None) -> None:
def _download_http(
url: str, dest: Path, *, desc: str, headers: dict | None = None
) -> None:
"""Stream-download a URL to a file with progress bar."""
dest.parent.mkdir(parents=True, exist_ok=True)
tmp = dest.with_suffix(dest.suffix + ".tmp")
@ -117,9 +119,10 @@ def clean_gtfs(src: Path, dst: Path) -> None:
return
print("Cleaning GTFS for R5 compatibility...")
with zipfile.ZipFile(src, "r") as zin, zipfile.ZipFile(
dst, "w", zipfile.ZIP_DEFLATED
) as zout:
with (
zipfile.ZipFile(src, "r") as zin,
zipfile.ZipFile(dst, "w", zipfile.ZIP_DEFLATED) as zout,
):
for info in zin.infolist():
if info.filename == "stop_times.txt":
dropped = 0
@ -127,7 +130,9 @@ def clean_gtfs(src: Path, dst: Path) -> None:
header = f.readline()
header_str = header.decode("utf-8").strip()
cols = header_str.split(",")
arr_idx = cols.index("arrival_time") if "arrival_time" in cols else -1
arr_idx = (
cols.index("arrival_time") if "arrival_time" in cols else -1
)
dep_idx = (
cols.index("departure_time") if "departure_time" in cols else -1
)
@ -179,7 +184,9 @@ def clean_gtfs(src: Path, dst: Path) -> None:
year = int(date_val[:4])
if year > 2100:
parts[i] = "20991231"
print(f" feed_info: capped end_date {date_val} → 20991231")
print(
f" feed_info: capped end_date {date_val} → 20991231"
)
fixed_lines.append(",".join(parts))
zout.writestr("feed_info.txt", "\n".join(fixed_lines) + "\n")
else:
@ -334,7 +341,9 @@ def convert_high_freq_to_frequency_based(
end_secs = trips[-1][1] + int(median_hw)
headway_rounded = max(60, round(median_hw / 60) * 60)
frequency_entries.append((template_trip_id, start_secs, end_secs, headway_rounded))
frequency_entries.append(
(template_trip_id, start_secs, end_secs, headway_rounded)
)
for trip_id, _ in trips[1:]:
trips_to_remove.add(trip_id)
groups_converted += 1
@ -344,9 +353,10 @@ def convert_high_freq_to_frequency_based(
print(f" Created {len(frequency_entries)} frequency entries")
# Step 5: Write modified GTFS
with zipfile.ZipFile(src, "r") as zin, zipfile.ZipFile(
dst, "w", zipfile.ZIP_DEFLATED
) as zout:
with (
zipfile.ZipFile(src, "r") as zin,
zipfile.ZipFile(dst, "w", zipfile.ZIP_DEFLATED) as zout,
):
for info in zin.infolist():
if info.filename == "trips.txt":
with zin.open(info) as f:
@ -466,15 +476,22 @@ def download_national_rail_cif(raw_dir: Path) -> Path | None:
email = os.environ.get("NATIONAL_RAIL_EMAIL")
password = os.environ.get("NATIONAL_RAIL_PASSWORD")
if not email or not password:
print("Warning: NATIONAL_RAIL_EMAIL/NATIONAL_RAIL_PASSWORD not set, skipping national rail")
print(
"Warning: NATIONAL_RAIL_EMAIL/NATIONAL_RAIL_PASSWORD not set, skipping national rail"
)
return None
print("Authenticating with National Rail Open Data...")
auth_data = urllib.parse.urlencode({"username": email, "password": password}).encode()
auth_data = urllib.parse.urlencode(
{"username": email, "password": password}
).encode()
auth_req = urllib.request.Request(
NR_AUTH_URL,
data=auth_data,
headers={"User-Agent": USER_AGENT, "Content-Type": "application/x-www-form-urlencoded"},
headers={
"User-Agent": USER_AGENT,
"Content-Type": "application/x-www-form-urlencoded",
},
)
with urllib.request.urlopen(auth_req) as resp:
token_data = json.loads(resp.read())
@ -565,9 +582,10 @@ def clean_national_rail_gtfs(src: Path, dst: Path) -> None:
coords_fixed = 0
route_types_fixed = 0
with zipfile.ZipFile(src, "r") as zin, zipfile.ZipFile(
dst, "w", zipfile.ZIP_DEFLATED
) as zout:
with (
zipfile.ZipFile(src, "r") as zin,
zipfile.ZipFile(dst, "w", zipfile.ZIP_DEFLATED) as zout,
):
for info in zin.infolist():
# Skip non-standard links.txt
if info.filename == "links.txt":
@ -581,8 +599,12 @@ def clean_national_rail_gtfs(src: Path, dst: Path) -> None:
trip_id_idx = cols.index("trip_id")
stop_id_idx = cols.index("stop_id")
seq_idx = cols.index("stop_sequence")
pickup_idx = cols.index("pickup_type") if "pickup_type" in cols else -1
dropoff_idx = cols.index("drop_off_type") if "drop_off_type" in cols else -1
pickup_idx = (
cols.index("pickup_type") if "pickup_type" in cols else -1
)
dropoff_idx = (
cols.index("drop_off_type") if "drop_off_type" in cols else -1
)
tmp = tempfile.NamedTemporaryFile(
mode="wb", delete=False, suffix=".txt"
@ -769,16 +791,27 @@ def _docker_run_dtd2mysql(
) -> None:
"""Run dtd2mysql in a Node.js container on the same Docker network as MariaDB."""
cmd = [
"docker", "run", "--rm", "--network", network,
"-e", f"DATABASE_HOSTNAME={db_container}",
"-e", "DATABASE_USERNAME=root",
"-e", "DATABASE_PASSWORD=root",
"-e", "DATABASE_NAME=dtd",
"docker",
"run",
"--rm",
"--network",
network,
"-e",
f"DATABASE_HOSTNAME={db_container}",
"-e",
"DATABASE_USERNAME=root",
"-e",
"DATABASE_PASSWORD=root",
"-e",
"DATABASE_NAME=dtd",
]
for v in volumes:
cmd.extend(["-v", v])
# Install zip (needed for --gtfs-zip) then run dtd2mysql
inner = "apt-get update -qq && apt-get install -y -qq zip > /dev/null 2>&1 && npx --yes dtd2mysql " + " ".join(args)
inner = (
"apt-get update -qq && apt-get install -y -qq zip > /dev/null 2>&1 && npx --yes dtd2mysql "
+ " ".join(args)
)
cmd.extend(["node:20", "bash", "-c", inner])
subprocess.run(cmd, check=True)
@ -805,11 +838,17 @@ def convert_national_rail_to_gtfs(raw_dir: Path, output_dir: Path) -> Path:
subprocess.run(["docker", "network", "create", network], capture_output=True)
subprocess.run(
[
"docker", "run", "-d",
"--name", db_container,
"--network", network,
"-e", "MARIADB_ROOT_PASSWORD=root",
"-e", "MARIADB_DATABASE=dtd",
"docker",
"run",
"-d",
"--name",
db_container,
"--network",
network,
"-e",
"MARIADB_ROOT_PASSWORD=root",
"-e",
"MARIADB_DATABASE=dtd",
"mariadb:latest",
],
check=True,
@ -820,7 +859,16 @@ def convert_national_rail_to_gtfs(raw_dir: Path, output_dir: Path) -> Path:
print(" Waiting for MariaDB to be ready...")
for attempt in range(30):
result = subprocess.run(
["docker", "exec", db_container, "mariadb", "-uroot", "-proot", "-e", "SELECT 1"],
[
"docker",
"exec",
db_container,
"mariadb",
"-uroot",
"-proot",
"-e",
"SELECT 1",
],
capture_output=True,
)
if result.returncode == 0:
@ -833,14 +881,16 @@ def convert_national_rail_to_gtfs(raw_dir: Path, output_dir: Path) -> Path:
print("Importing CIF timetable into MariaDB...")
_docker_run_dtd2mysql(
network, db_container,
network,
db_container,
volumes=[f"{raw_abs}:/data:ro"],
args=["--timetable", "/data/national_rail_cif.zip"],
)
print("Exporting GTFS from MariaDB...")
_docker_run_dtd2mysql(
network, db_container,
network,
db_container,
volumes=[f"{raw_abs}:/output"],
args=["--gtfs-zip", "/output/national_rail_gtfs_raw.zip"],
)