perfect-postcode/pipeline/download/naptan.py
2026-05-06 22:40:46 +01:00

296 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Download NaPTAN data and extract railway/metro station POIs."""
import argparse
import io
import math
import re
import urllib.request
from dataclasses import dataclass
from pathlib import Path
import polars as pl
NAPTAN_CSV_URL = "https://naptan.api.dft.gov.uk/v1/access-nodes?dataFormat=csv"
TUBE_STATION_CATEGORY = "Tube station"
TUBE_STATION_MERGE_RADIUS_DEGREES = 0.01
STOP_TYPES = {
"AIR": "Airport",
"FTD": "Ferry",
"RSE": "Rail station",
"BCT": "Bus stop",
"BCE": "Bus station",
"TXR": "Taxi rank",
"TMU": "Tube station",
"MET": "Tube station",
}
OUTPUT_COLUMNS = ["id", "name", "category", "lat", "lng"]
def canonical_station_name(name: str | None) -> str:
"""Normalize station names so entrances/transport-mode variants collapse."""
if not name:
return ""
normalized = name.lower()
normalized = re.sub(r"\([^)]*\)", " ", normalized)
normalized = re.sub(r"['`]", "", normalized)
normalized = normalized.replace("&", " and ")
normalized = re.sub(r"[^a-z0-9]+", " ", normalized)
words = normalized.split()
suffixes = (
("underground", "station"),
("tube", "station"),
("dlr", "station"),
("metro", "station"),
("tram", "stop"),
("rail", "station"),
("railway", "station"),
("station",),
("stop",),
)
while True:
suffix = next(
(suffix for suffix in suffixes if words[-len(suffix) :] == list(suffix)),
None,
)
if suffix is None:
break
del words[-len(suffix) :]
return " ".join(words)
def canonical_station_name_expr(name_col: str = "name") -> pl.Expr:
"""Normalize station names so entrances/transport-mode variants collapse."""
expr = pl.col(name_col).str.to_lowercase()
expr = expr.str.replace_all(r"\([^)]*\)", " ")
expr = expr.str.replace_all(r"['`]", "")
expr = expr.str.replace_all(r"&", " and ")
expr = expr.str.replace_all(r"[^a-z0-9]+", " ")
expr = expr.str.replace_all(r"\s+", " ").str.strip_chars()
expr = expr.str.replace_all(
r"\s+(underground|tube|dlr|metro|rail|railway)\s+station$", ""
)
expr = expr.str.replace_all(r"\s+tram\s+stop$", "")
expr = expr.str.replace_all(r"\s+(station|stop)$", "")
return expr.str.strip_chars()
def _has_locality() -> pl.Expr:
return pl.col("locality").is_not_null() & (pl.col("locality") != "")
def _empty_output_frame() -> pl.DataFrame:
return pl.DataFrame(
{
"id": pl.Series([], dtype=pl.String),
"name": pl.Series([], dtype=pl.String),
"category": pl.Series([], dtype=pl.String),
"lat": pl.Series([], dtype=pl.Float64),
"lng": pl.Series([], dtype=pl.Float64),
}
)
def station_name_score(name: str) -> tuple[int, int]:
lower = name.lower()
suffix_penalty = int(
lower.endswith(
(
" underground station",
" tube station",
" dlr station",
" metro station",
" tram stop",
" station",
" stop",
)
)
)
return (suffix_penalty, len(name))
@dataclass
class StationAccumulator:
id: str
name: str
category: str
lat_sum: float
lng_sum: float
count: int = 1
@property
def lat(self) -> float:
return self.lat_sum / self.count
@property
def lng(self) -> float:
return self.lng_sum / self.count
def same_area(self, lat: float, lng: float) -> bool:
dlat = self.lat - lat
dlng = (self.lng - lng) * math.cos(math.radians(self.lat))
return (dlat * dlat + dlng * dlng) <= TUBE_STATION_MERGE_RADIUS_DEGREES**2
def merge(self, row: dict[str, object]) -> None:
self.lat_sum += float(row["lat"])
self.lng_sum += float(row["lng"])
self.count += 1
name = str(row["name"] or "")
if station_name_score(name) < station_name_score(self.name):
self.id = str(row["id"] or "")
self.name = name
def _station_from_row(row: dict[str, object]) -> StationAccumulator:
return StationAccumulator(
id=str(row["id"] or ""),
name=str(row["name"] or ""),
category=str(row["category"] or ""),
lat_sum=float(row["lat"]),
lng_sum=float(row["lng"]),
)
def _deduplicate_tube_stations(df: pl.DataFrame) -> pl.DataFrame:
if len(df) == 0:
return _empty_output_frame()
selected: list[StationAccumulator] = []
groups: dict[str, list[int]] = {}
for row in df.iter_rows(named=True):
station_key = canonical_station_name(str(row["name"] or ""))
if not station_key:
selected.append(_station_from_row(row))
continue
existing = next(
(
index
for index in groups.get(station_key, [])
if selected[index].same_area(float(row["lat"]), float(row["lng"]))
),
None,
)
if existing is not None:
selected[existing].merge(row)
continue
index = len(selected)
selected.append(_station_from_row(row))
groups.setdefault(station_key, []).append(index)
return pl.DataFrame(
{
"id": [station.id for station in selected],
"name": [station.name for station in selected],
"category": [station.category for station in selected],
"lat": [station.lat for station in selected],
"lng": [station.lng for station in selected],
}
).select(OUTPUT_COLUMNS)
def _deduplicate_non_tube_stops(df: pl.DataFrame) -> pl.DataFrame:
if len(df) == 0:
return _empty_output_frame()
has_loc = df.filter(_has_locality())
no_loc = df.filter(~_has_locality())
# First pass: one record per exact stop name/category/locality.
frames = []
if len(has_loc) > 0:
frames.append(
has_loc.group_by("name", "category", "locality")
.agg(
pl.col("id").first(),
pl.col("lat").mean(),
pl.col("lng").mean(),
)
.select(OUTPUT_COLUMNS)
)
if len(no_loc) > 0:
frames.append(no_loc.select(OUTPUT_COLUMNS))
if not frames:
return _empty_output_frame()
return pl.concat(frames).select(OUTPUT_COLUMNS)
def deduplicate_naptan(df: pl.DataFrame) -> pl.DataFrame:
"""Deduplicate NaPTAN stops, with station-level merging for Tube POIs."""
tube = df.filter(pl.col("category") == TUBE_STATION_CATEGORY)
other = df.filter(pl.col("category") != TUBE_STATION_CATEGORY)
return pl.concat(
[
_deduplicate_non_tube_stops(other),
_deduplicate_tube_stations(tube),
]
).select(OUTPUT_COLUMNS)
def download_naptan(output: Path) -> None:
output.parent.mkdir(parents=True, exist_ok=True)
print(f"Downloading NaPTAN data from {NAPTAN_CSV_URL}")
with urllib.request.urlopen(NAPTAN_CSV_URL) as resp:
raw = resp.read()
print(f"Downloaded {len(raw) / (1024 * 1024):.1f} MB")
df = (
pl.read_csv(io.BytesIO(raw), infer_schema_length=0)
.with_columns(
pl.col("Latitude").cast(pl.Float64, strict=False),
pl.col("Longitude").cast(pl.Float64, strict=False),
)
.drop_nulls(subset=["Latitude", "Longitude"])
.filter(pl.col("StopType").is_in(list(STOP_TYPES.keys())))
.select(
pl.col("ATCOCode").alias("id"),
pl.col("CommonName").alias("name"),
pl.col("StopType").replace(STOP_TYPES).alias("category"),
pl.col("Latitude").alias("lat"),
pl.col("Longitude").alias("lng"),
pl.col("NptgLocalityCode").alias("locality"),
)
)
before = len(df)
df = deduplicate_naptan(df)
print(
f"Deduplicated {before:,}{len(df):,} stops "
"(by name+category+locality; tube stations by normalized name+area)"
)
df.write_parquet(output)
size_mb = output.stat().st_size / (1024 * 1024)
print(f"Wrote {output} ({size_mb:.1f} MB, {len(df):,} stations)")
counts = df.group_by("category").len().sort("len", descending=True)
for row in counts.iter_rows(named=True):
print(f" {row['category']}: {row['len']:,}")
def main() -> None:
parser = argparse.ArgumentParser(description="Download NaPTAN station data")
parser.add_argument(
"--output", type=Path, required=True, help="Output parquet file path"
)
args = parser.parse_args()
download_naptan(args.output)
if __name__ == "__main__":
main()