Codex changes

This commit is contained in:
Andras Schmelczer 2026-05-04 16:19:09 +01:00
parent 0bae902e08
commit d4dde21ad2
46 changed files with 4953 additions and 966 deletions

View file

@ -22,6 +22,92 @@ STOP_TYPES = {
}
OUTPUT_COLUMNS = ["id", "name", "category", "lat", "lng"]
def canonical_station_name_expr(name_col: str = "name") -> pl.Expr:
"""Normalize station names so entrances/transport-mode variants collapse."""
expr = pl.col(name_col).str.to_lowercase()
expr = expr.str.replace_all(r"\([^)]*\)", " ")
expr = expr.str.replace_all(r"['`]", "")
expr = expr.str.replace_all(r"&", " and ")
expr = expr.str.replace_all(r"[^a-z0-9]+", " ")
expr = expr.str.replace_all(r"\s+", " ").str.strip_chars()
expr = expr.str.replace_all(
r"\s+(underground|tube|dlr|metro|rail|railway)\s+station$", ""
)
expr = expr.str.replace_all(r"\s+tram\s+stop$", "")
expr = expr.str.replace_all(r"\s+(station|stop)$", "")
return expr.str.strip_chars()
def _has_locality() -> pl.Expr:
return pl.col("locality").is_not_null() & (pl.col("locality") != "")
def _deduplicate_tube_partition(
df: pl.DataFrame, group_cols: list[str]
) -> pl.DataFrame:
if len(df) == 0:
return pl.DataFrame(
{
"id": pl.Series([], dtype=pl.String),
"name": pl.Series([], dtype=pl.String),
"category": pl.Series([], dtype=pl.String),
"lat": pl.Series([], dtype=pl.Float64),
"lng": pl.Series([], dtype=pl.Float64),
}
)
name_len = pl.col("name").str.len_chars()
return (
df.group_by(group_cols)
.agg(
pl.col("id").sort_by(name_len).first(),
pl.col("name").sort_by(name_len).first(),
pl.col("category").first(),
pl.col("lat").mean(),
pl.col("lng").mean(),
)
.select(OUTPUT_COLUMNS)
)
def deduplicate_naptan(df: pl.DataFrame) -> pl.DataFrame:
"""Deduplicate NaPTAN stops, with stricter station-level merging for Tube POIs."""
has_loc = df.filter(_has_locality())
no_loc = df.filter(~_has_locality())
cols_with_locality = [*OUTPUT_COLUMNS, "locality"]
# First pass: one record per exact stop name/category/locality.
deduped_has_loc = (
has_loc.group_by("name", "category", "locality")
.agg(
pl.col("id").first(),
pl.col("lat").mean(),
pl.col("lng").mean(),
)
.select(cols_with_locality)
)
df = pl.concat([deduped_has_loc, no_loc.select(cols_with_locality)])
tube = df.filter(pl.col("category") == "Tube station").with_columns(
canonical_station_name_expr().alias("_station_key")
)
other = df.filter(pl.col("category") != "Tube station")
tube_with_loc = tube.filter(_has_locality())
tube_no_loc = tube.filter(~_has_locality())
deduped_tube = pl.concat(
[
_deduplicate_tube_partition(tube_with_loc, ["_station_key", "locality"]),
_deduplicate_tube_partition(tube_no_loc, ["_station_key"]),
]
)
return pl.concat([other.select(OUTPUT_COLUMNS), deduped_tube])
def download_naptan(output: Path) -> None:
output.parent.mkdir(parents=True, exist_ok=True)
@ -50,24 +136,12 @@ def download_naptan(output: Path) -> None:
)
before = len(df)
df = deduplicate_naptan(df)
# Deduplicate: one record per name+category+locality
# (merges entrances, bus stop pairs on opposite sides of the road, etc.)
has_loc = df.filter(
pl.col("locality").is_not_null() & (pl.col("locality") != "")
print(
f"Deduplicated {before:,}{len(df):,} stops "
"(by name+category+locality; tube stations by normalized station name)"
)
no_loc = df.filter(
pl.col("locality").is_null() | (pl.col("locality") == "")
)
cols = ["id", "name", "category", "lat", "lng"]
deduped = has_loc.group_by("name", "category", "locality").agg(
pl.col("id").first(),
pl.col("lat").mean(),
pl.col("lng").mean(),
)
df = pl.concat([deduped.select(cols), no_loc.select(cols)])
print(f"Deduplicated {before:,}{len(df):,} stops (by name+category+locality)")
df.write_parquet(output)
size_mb = output.stat().st_size / (1024 * 1024)

View file

@ -0,0 +1,71 @@
import polars as pl
import pytest
from pipeline.download.naptan import canonical_station_name_expr, deduplicate_naptan
def test_canonical_station_name_expr_normalizes_transport_suffixes():
df = pl.DataFrame(
{
"name": [
"Bank",
"Bank Underground Station",
"Bank DLR Station",
"Pleasure Beach (Blackpool Tramway)",
"Earl's Court Tube Station",
]
}
)
result = df.select(canonical_station_name_expr().alias("key"))["key"].to_list()
assert result == [
"bank",
"bank",
"bank",
"pleasure beach",
"earls court",
]
def test_deduplicate_naptan_merges_tube_station_variants_by_locality():
df = pl.DataFrame(
{
"id": ["bank", "bank-lu", "bank-dlr", "other-bank"],
"name": [
"Bank",
"Bank Underground Station",
"Bank DLR Station",
"Bank Underground Station",
],
"category": ["Tube station"] * 4,
"lat": [51.5129, 51.5134, 51.5132, 55.0140],
"lng": [-0.0889, -0.0890, -0.0885, -1.6781],
"locality": ["LOC1", "LOC1", "LOC1", "LOC2"],
}
)
result = deduplicate_naptan(df).sort("lat")
assert len(result) == 2
assert result["name"].to_list() == ["Bank", "Bank Underground Station"]
assert result["lat"].to_list()[0] == pytest.approx(
(51.5129 + 51.5134 + 51.5132) / 3
)
def test_deduplicate_naptan_does_not_merge_missing_locality_bus_stops():
df = pl.DataFrame(
{
"id": ["a", "b"],
"name": ["High Street", "High Street"],
"category": ["Bus stop", "Bus stop"],
"lat": [51.5, 52.5],
"lng": [-0.1, -1.1],
"locality": [None, None],
}
)
result = deduplicate_naptan(df)
assert len(result) == 2

View file

@ -19,6 +19,8 @@ Output directory: property-data/transit/
"""
import argparse
import csv
import io
import json
import os
import shutil
@ -108,6 +110,30 @@ def download_bods_gtfs(output_dir: Path) -> Path:
return dest
def _parse_csv_line(line: bytes | str) -> list[str]:
"""Parse a single GTFS CSV record."""
if isinstance(line, bytes):
line = line.decode("utf-8", errors="replace")
line = line.rstrip("\r\n")
if not line:
return []
return next(csv.reader([line]))
def _format_csv_row(parts: list[str]) -> bytes:
"""Serialize one GTFS CSV row with stable LF line endings."""
output = io.StringIO()
csv.writer(output, lineterminator="\n").writerow(parts)
return output.getvalue().encode("utf-8")
def _format_csv_rows(rows: list[list[str]]) -> str:
output = io.StringIO()
writer = csv.writer(output, lineterminator="\n")
writer.writerows(rows)
return output.getvalue()
def clean_gtfs(src: Path, dst: Path) -> None:
"""Fix R5-incompatible entries in GTFS.
@ -128,8 +154,7 @@ def clean_gtfs(src: Path, dst: Path) -> None:
dropped = 0
with zin.open(info) as f:
header = f.readline()
header_str = header.decode("utf-8").strip()
cols = header_str.split(",")
cols = _parse_csv_line(header)
arr_idx = (
cols.index("arrival_time") if "arrival_time" in cols else -1
)
@ -143,10 +168,9 @@ def clean_gtfs(src: Path, dst: Path) -> None:
tmp.write(header)
for line in f:
line_str = line.decode("utf-8", errors="replace").strip()
if not line_str:
parts = _parse_csv_line(line)
if not parts:
continue
parts = line_str.split(",")
skip = False
for idx in [arr_idx, dep_idx]:
if 0 <= idx < len(parts):
@ -171,12 +195,13 @@ def clean_gtfs(src: Path, dst: Path) -> None:
elif info.filename == "feed_info.txt":
data = zin.read(info).decode("utf-8")
lines = data.strip().split("\n")
header_line = lines[0]
feed_cols = header_line.split(",")
fixed_lines = [header_line]
for line in lines[1:]:
parts = line.split(",")
rows = list(csv.reader(io.StringIO(data)))
if not rows:
zout.writestr("feed_info.txt", data)
continue
feed_cols = rows[0]
fixed_rows = [feed_cols]
for parts in rows[1:]:
for i, col_name in enumerate(feed_cols):
if "end_date" in col_name.lower() and i < len(parts):
date_val = parts[i].strip('"')
@ -187,8 +212,8 @@ def clean_gtfs(src: Path, dst: Path) -> None:
print(
f" feed_info: capped end_date {date_val} → 20991231"
)
fixed_lines.append(",".join(parts))
zout.writestr("feed_info.txt", "\n".join(fixed_lines) + "\n")
fixed_rows.append(parts)
zout.writestr("feed_info.txt", _format_csv_rows(fixed_rows))
else:
zout.writestr(info, zin.read(info))
@ -237,12 +262,11 @@ def convert_high_freq_to_frequency_based(
# Step 1: Find metro/tram route IDs
target_route_ids: set[str] = set()
with zin.open("routes.txt") as f:
header = f.readline().decode("utf-8").strip()
cols = header.split(",")
cols = _parse_csv_line(f.readline())
route_id_idx = cols.index("route_id")
rt_idx = cols.index("route_type")
for line in f:
parts = line.decode("utf-8", errors="replace").strip().split(",")
parts = _parse_csv_line(line)
if not parts:
continue
route_type = parts[rt_idx].strip('"')
@ -259,14 +283,13 @@ def convert_high_freq_to_frequency_based(
# Step 2: Map target trips to grouping keys
trip_group_key: dict[str, tuple[str, str, str]] = {}
with zin.open("trips.txt") as f:
header = f.readline().decode("utf-8").strip()
cols = header.split(",")
cols = _parse_csv_line(f.readline())
trip_id_idx = cols.index("trip_id")
route_id_idx = cols.index("route_id")
dir_idx = cols.index("direction_id") if "direction_id" in cols else -1
service_idx = cols.index("service_id")
for line in f:
parts = line.decode("utf-8", errors="replace").strip().split(",")
parts = _parse_csv_line(line)
if not parts:
continue
route_id = parts[route_id_idx].strip('"')
@ -282,14 +305,13 @@ def convert_high_freq_to_frequency_based(
trip_first_dep: dict[str, int] = {}
trip_first_stop: dict[str, str] = {}
with zin.open("stop_times.txt") as f:
header = f.readline().decode("utf-8").strip()
cols = header.split(",")
cols = _parse_csv_line(f.readline())
trip_id_idx = cols.index("trip_id")
dep_idx = cols.index("departure_time")
seq_idx = cols.index("stop_sequence")
stop_id_idx = cols.index("stop_id")
for line in f:
parts = line.decode("utf-8", errors="replace").strip().split(",")
parts = _parse_csv_line(line)
if not parts:
continue
trip_id = parts[trip_id_idx].strip('"')
@ -361,8 +383,7 @@ def convert_high_freq_to_frequency_based(
if info.filename == "trips.txt":
with zin.open(info) as f:
header = f.readline()
header_str = header.decode("utf-8").strip()
cols = header_str.split(",")
cols = _parse_csv_line(header)
trip_id_idx = cols.index("trip_id")
tmp = tempfile.NamedTemporaryFile(
@ -370,9 +391,7 @@ def convert_high_freq_to_frequency_based(
)
tmp.write(header)
for line in f:
parts = (
line.decode("utf-8", errors="replace").strip().split(",")
)
parts = _parse_csv_line(line)
if not parts:
continue
if parts[trip_id_idx].strip('"') not in trips_to_remove:
@ -384,8 +403,7 @@ def convert_high_freq_to_frequency_based(
elif info.filename == "stop_times.txt":
with zin.open(info) as f:
header = f.readline()
header_str = header.decode("utf-8").strip()
cols = header_str.split(",")
cols = _parse_csv_line(header)
trip_id_idx = cols.index("trip_id")
tmp = tempfile.NamedTemporaryFile(
@ -393,9 +411,7 @@ def convert_high_freq_to_frequency_based(
)
tmp.write(header)
for line in f:
parts = (
line.decode("utf-8", errors="replace").strip().split(",")
)
parts = _parse_csv_line(line)
if not parts:
continue
if parts[trip_id_idx].strip('"') not in trips_to_remove:
@ -535,25 +551,23 @@ def clean_national_rail_gtfs(src: Path, dst: Path) -> None:
with zipfile.ZipFile(src, "r") as zin:
# Load valid stop IDs
with zin.open("stops.txt") as f:
header = f.readline().decode("utf-8").strip()
stop_id_idx = header.split(",").index("stop_id")
lat_idx = header.split(",").index("stop_lat")
cols = _parse_csv_line(f.readline())
stop_id_idx = cols.index("stop_id")
for line in f:
parts = line.decode("utf-8", errors="replace").strip().split(",")
parts = _parse_csv_line(line)
if parts:
stop_ids.add(parts[stop_id_idx])
# Find trips with backwards travel times
with zin.open("stop_times.txt") as f:
st_header = f.readline().decode("utf-8").strip()
st_cols = st_header.split(",")
st_cols = _parse_csv_line(f.readline())
trip_id_idx = st_cols.index("trip_id")
dep_idx = st_cols.index("departure_time")
prev_trip = ""
prev_dep_secs = -1
for line in f:
parts = line.decode("utf-8", errors="replace").strip().split(",")
parts = _parse_csv_line(line)
if not parts:
continue
trip_id = parts[trip_id_idx].strip('"')
@ -594,8 +608,7 @@ def clean_national_rail_gtfs(src: Path, dst: Path) -> None:
if info.filename == "stop_times.txt":
with zin.open(info) as f:
header = f.readline()
header_str = header.decode("utf-8").strip()
cols = header_str.split(",")
cols = _parse_csv_line(header)
trip_id_idx = cols.index("trip_id")
stop_id_idx = cols.index("stop_id")
seq_idx = cols.index("stop_sequence")
@ -614,10 +627,9 @@ def clean_national_rail_gtfs(src: Path, dst: Path) -> None:
prev_trip = ""
seq_counter = 0
for line in f:
line_str = line.decode("utf-8", errors="replace").strip()
if not line_str:
parts = _parse_csv_line(line)
if not parts:
continue
parts = line_str.split(",")
trip_id = parts[trip_id_idx].strip('"')
stop_id = parts[stop_id_idx].strip('"')
@ -651,7 +663,7 @@ def clean_national_rail_gtfs(src: Path, dst: Path) -> None:
if old_seq != str(seq_counter):
seqs_renumbered += 1
tmp.write((",".join(parts) + "\n").encode("utf-8"))
tmp.write(_format_csv_row(parts))
tmp.close()
zout.write(tmp.name, "stop_times.txt")
@ -660,8 +672,7 @@ def clean_national_rail_gtfs(src: Path, dst: Path) -> None:
elif info.filename == "stops.txt":
with zin.open(info) as f:
header = f.readline()
header_str = header.decode("utf-8").strip()
cols = header_str.split(",")
cols = _parse_csv_line(header)
lat_idx = cols.index("stop_lat")
lon_idx = cols.index("stop_lon")
@ -671,10 +682,9 @@ def clean_national_rail_gtfs(src: Path, dst: Path) -> None:
tmp.write(header)
for line in f:
line_str = line.decode("utf-8", errors="replace").strip()
if not line_str:
parts = _parse_csv_line(line)
if not parts:
continue
parts = line_str.split(",")
try:
lat = float(parts[lat_idx])
# Fix bogus Irish CIE coordinates (South Atlantic)
@ -685,7 +695,7 @@ def clean_national_rail_gtfs(src: Path, dst: Path) -> None:
coords_fixed += 1
except ValueError:
pass
tmp.write((",".join(parts) + "\n").encode("utf-8"))
tmp.write(_format_csv_row(parts))
tmp.close()
zout.write(tmp.name, "stops.txt")
@ -694,8 +704,7 @@ def clean_national_rail_gtfs(src: Path, dst: Path) -> None:
elif info.filename == "routes.txt":
with zin.open(info) as f:
header = f.readline()
header_str = header.decode("utf-8").strip()
cols = header_str.split(",")
cols = _parse_csv_line(header)
rt_idx = cols.index("route_type")
tmp = tempfile.NamedTemporaryFile(
@ -704,14 +713,13 @@ def clean_national_rail_gtfs(src: Path, dst: Path) -> None:
tmp.write(header)
for line in f:
line_str = line.decode("utf-8", errors="replace").strip()
if not line_str:
parts = _parse_csv_line(line)
if not parts:
continue
parts = line_str.split(",")
if parts[rt_idx].strip('"') == "714":
parts[rt_idx] = "3"
route_types_fixed += 1
tmp.write((",".join(parts) + "\n").encode("utf-8"))
tmp.write(_format_csv_row(parts))
tmp.close()
zout.write(tmp.name, "routes.txt")
@ -721,8 +729,7 @@ def clean_national_rail_gtfs(src: Path, dst: Path) -> None:
# Remove trips that have backwards travel times
with zin.open(info) as f:
header = f.readline()
header_str = header.decode("utf-8").strip()
cols = header_str.split(",")
cols = _parse_csv_line(header)
trip_id_idx = cols.index("trip_id")
tmp = tempfile.NamedTemporaryFile(
@ -731,10 +738,9 @@ def clean_national_rail_gtfs(src: Path, dst: Path) -> None:
tmp.write(header)
for line in f:
line_str = line.decode("utf-8", errors="replace").strip()
if not line_str:
parts = _parse_csv_line(line)
if not parts:
continue
parts = line_str.split(",")
if parts[trip_id_idx].strip('"') not in bad_trip_ids:
tmp.write(line)
@ -746,8 +752,7 @@ def clean_national_rail_gtfs(src: Path, dst: Path) -> None:
# Cap end_date year to 2099
with zin.open(info) as f:
header = f.readline()
header_str = header.decode("utf-8").strip()
cols = header_str.split(",")
cols = _parse_csv_line(header)
end_idx = cols.index("end_date")
tmp = tempfile.NamedTemporaryFile(
@ -756,10 +761,9 @@ def clean_national_rail_gtfs(src: Path, dst: Path) -> None:
tmp.write(header)
for line in f:
line_str = line.decode("utf-8", errors="replace").strip()
if not line_str:
parts = _parse_csv_line(line)
if not parts:
continue
parts = line_str.split(",")
date_val = parts[end_idx].strip('"')
if len(date_val) == 8:
try:
@ -768,7 +772,7 @@ def clean_national_rail_gtfs(src: Path, dst: Path) -> None:
parts[end_idx] = "20991231"
except ValueError:
pass
tmp.write((",".join(parts) + "\n").encode("utf-8"))
tmp.write(_format_csv_row(parts))
tmp.close()
zout.write(tmp.name, "calendar.txt")