296 lines
8.6 KiB
Python
296 lines
8.6 KiB
Python
"""Download NaPTAN data and extract railway/metro station POIs."""
|
||
|
||
import argparse
|
||
import io
|
||
import math
|
||
import re
|
||
import urllib.request
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
|
||
import polars as pl
|
||
|
||
NAPTAN_CSV_URL = "https://naptan.api.dft.gov.uk/v1/access-nodes?dataFormat=csv"
|
||
TUBE_STATION_CATEGORY = "Tube station"
|
||
TUBE_STATION_MERGE_RADIUS_DEGREES = 0.01
|
||
|
||
|
||
STOP_TYPES = {
|
||
"AIR": "Airport",
|
||
"FTD": "Ferry",
|
||
"RSE": "Rail station",
|
||
"BCT": "Bus stop",
|
||
"BCE": "Bus station",
|
||
"TXR": "Taxi rank",
|
||
"TMU": "Tube station",
|
||
"MET": "Tube station",
|
||
}
|
||
|
||
|
||
OUTPUT_COLUMNS = ["id", "name", "category", "lat", "lng"]
|
||
|
||
|
||
def canonical_station_name(name: str | None) -> str:
|
||
"""Normalize station names so entrances/transport-mode variants collapse."""
|
||
if not name:
|
||
return ""
|
||
|
||
normalized = name.lower()
|
||
normalized = re.sub(r"\([^)]*\)", " ", normalized)
|
||
normalized = re.sub(r"['’`]", "", normalized)
|
||
normalized = normalized.replace("&", " and ")
|
||
normalized = re.sub(r"[^a-z0-9]+", " ", normalized)
|
||
words = normalized.split()
|
||
|
||
suffixes = (
|
||
("underground", "station"),
|
||
("tube", "station"),
|
||
("dlr", "station"),
|
||
("metro", "station"),
|
||
("tram", "stop"),
|
||
("rail", "station"),
|
||
("railway", "station"),
|
||
("station",),
|
||
("stop",),
|
||
)
|
||
while True:
|
||
suffix = next(
|
||
(suffix for suffix in suffixes if words[-len(suffix) :] == list(suffix)),
|
||
None,
|
||
)
|
||
if suffix is None:
|
||
break
|
||
del words[-len(suffix) :]
|
||
|
||
return " ".join(words)
|
||
|
||
|
||
def canonical_station_name_expr(name_col: str = "name") -> pl.Expr:
|
||
"""Normalize station names so entrances/transport-mode variants collapse."""
|
||
expr = pl.col(name_col).str.to_lowercase()
|
||
expr = expr.str.replace_all(r"\([^)]*\)", " ")
|
||
expr = expr.str.replace_all(r"['’`]", "")
|
||
expr = expr.str.replace_all(r"&", " and ")
|
||
expr = expr.str.replace_all(r"[^a-z0-9]+", " ")
|
||
expr = expr.str.replace_all(r"\s+", " ").str.strip_chars()
|
||
expr = expr.str.replace_all(
|
||
r"\s+(underground|tube|dlr|metro|rail|railway)\s+station$", ""
|
||
)
|
||
expr = expr.str.replace_all(r"\s+tram\s+stop$", "")
|
||
expr = expr.str.replace_all(r"\s+(station|stop)$", "")
|
||
return expr.str.strip_chars()
|
||
|
||
|
||
def _has_locality() -> pl.Expr:
|
||
return pl.col("locality").is_not_null() & (pl.col("locality") != "")
|
||
|
||
|
||
def _empty_output_frame() -> pl.DataFrame:
|
||
return pl.DataFrame(
|
||
{
|
||
"id": pl.Series([], dtype=pl.String),
|
||
"name": pl.Series([], dtype=pl.String),
|
||
"category": pl.Series([], dtype=pl.String),
|
||
"lat": pl.Series([], dtype=pl.Float64),
|
||
"lng": pl.Series([], dtype=pl.Float64),
|
||
}
|
||
)
|
||
|
||
|
||
def station_name_score(name: str) -> tuple[int, int]:
|
||
lower = name.lower()
|
||
suffix_penalty = int(
|
||
lower.endswith(
|
||
(
|
||
" underground station",
|
||
" tube station",
|
||
" dlr station",
|
||
" metro station",
|
||
" tram stop",
|
||
" station",
|
||
" stop",
|
||
)
|
||
)
|
||
)
|
||
return (suffix_penalty, len(name))
|
||
|
||
|
||
@dataclass
|
||
class StationAccumulator:
|
||
id: str
|
||
name: str
|
||
category: str
|
||
lat_sum: float
|
||
lng_sum: float
|
||
count: int = 1
|
||
|
||
@property
|
||
def lat(self) -> float:
|
||
return self.lat_sum / self.count
|
||
|
||
@property
|
||
def lng(self) -> float:
|
||
return self.lng_sum / self.count
|
||
|
||
def same_area(self, lat: float, lng: float) -> bool:
|
||
dlat = self.lat - lat
|
||
dlng = (self.lng - lng) * math.cos(math.radians(self.lat))
|
||
return (dlat * dlat + dlng * dlng) <= TUBE_STATION_MERGE_RADIUS_DEGREES**2
|
||
|
||
def merge(self, row: dict[str, object]) -> None:
|
||
self.lat_sum += float(row["lat"])
|
||
self.lng_sum += float(row["lng"])
|
||
self.count += 1
|
||
|
||
name = str(row["name"] or "")
|
||
if station_name_score(name) < station_name_score(self.name):
|
||
self.id = str(row["id"] or "")
|
||
self.name = name
|
||
|
||
|
||
def _station_from_row(row: dict[str, object]) -> StationAccumulator:
|
||
return StationAccumulator(
|
||
id=str(row["id"] or ""),
|
||
name=str(row["name"] or ""),
|
||
category=str(row["category"] or ""),
|
||
lat_sum=float(row["lat"]),
|
||
lng_sum=float(row["lng"]),
|
||
)
|
||
|
||
|
||
def _deduplicate_tube_stations(df: pl.DataFrame) -> pl.DataFrame:
|
||
if len(df) == 0:
|
||
return _empty_output_frame()
|
||
|
||
selected: list[StationAccumulator] = []
|
||
groups: dict[str, list[int]] = {}
|
||
|
||
for row in df.iter_rows(named=True):
|
||
station_key = canonical_station_name(str(row["name"] or ""))
|
||
if not station_key:
|
||
selected.append(_station_from_row(row))
|
||
continue
|
||
|
||
existing = next(
|
||
(
|
||
index
|
||
for index in groups.get(station_key, [])
|
||
if selected[index].same_area(float(row["lat"]), float(row["lng"]))
|
||
),
|
||
None,
|
||
)
|
||
if existing is not None:
|
||
selected[existing].merge(row)
|
||
continue
|
||
|
||
index = len(selected)
|
||
selected.append(_station_from_row(row))
|
||
groups.setdefault(station_key, []).append(index)
|
||
|
||
return pl.DataFrame(
|
||
{
|
||
"id": [station.id for station in selected],
|
||
"name": [station.name for station in selected],
|
||
"category": [station.category for station in selected],
|
||
"lat": [station.lat for station in selected],
|
||
"lng": [station.lng for station in selected],
|
||
}
|
||
).select(OUTPUT_COLUMNS)
|
||
|
||
|
||
def _deduplicate_non_tube_stops(df: pl.DataFrame) -> pl.DataFrame:
|
||
if len(df) == 0:
|
||
return _empty_output_frame()
|
||
|
||
has_loc = df.filter(_has_locality())
|
||
no_loc = df.filter(~_has_locality())
|
||
|
||
# First pass: one record per exact stop name/category/locality.
|
||
frames = []
|
||
if len(has_loc) > 0:
|
||
frames.append(
|
||
has_loc.group_by("name", "category", "locality")
|
||
.agg(
|
||
pl.col("id").first(),
|
||
pl.col("lat").mean(),
|
||
pl.col("lng").mean(),
|
||
)
|
||
.select(OUTPUT_COLUMNS)
|
||
)
|
||
if len(no_loc) > 0:
|
||
frames.append(no_loc.select(OUTPUT_COLUMNS))
|
||
|
||
if not frames:
|
||
return _empty_output_frame()
|
||
|
||
return pl.concat(frames).select(OUTPUT_COLUMNS)
|
||
|
||
|
||
def deduplicate_naptan(df: pl.DataFrame) -> pl.DataFrame:
|
||
"""Deduplicate NaPTAN stops, with station-level merging for Tube POIs."""
|
||
tube = df.filter(pl.col("category") == TUBE_STATION_CATEGORY)
|
||
other = df.filter(pl.col("category") != TUBE_STATION_CATEGORY)
|
||
|
||
return pl.concat(
|
||
[
|
||
_deduplicate_non_tube_stops(other),
|
||
_deduplicate_tube_stations(tube),
|
||
]
|
||
).select(OUTPUT_COLUMNS)
|
||
|
||
|
||
def download_naptan(output: Path) -> None:
|
||
output.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
print(f"Downloading NaPTAN data from {NAPTAN_CSV_URL}")
|
||
with urllib.request.urlopen(NAPTAN_CSV_URL) as resp:
|
||
raw = resp.read()
|
||
|
||
print(f"Downloaded {len(raw) / (1024 * 1024):.1f} MB")
|
||
|
||
df = (
|
||
pl.read_csv(io.BytesIO(raw), infer_schema_length=0)
|
||
.with_columns(
|
||
pl.col("Latitude").cast(pl.Float64, strict=False),
|
||
pl.col("Longitude").cast(pl.Float64, strict=False),
|
||
)
|
||
.drop_nulls(subset=["Latitude", "Longitude"])
|
||
.filter(pl.col("StopType").is_in(list(STOP_TYPES.keys())))
|
||
.select(
|
||
pl.col("ATCOCode").alias("id"),
|
||
pl.col("CommonName").alias("name"),
|
||
pl.col("StopType").replace(STOP_TYPES).alias("category"),
|
||
pl.col("Latitude").alias("lat"),
|
||
pl.col("Longitude").alias("lng"),
|
||
pl.col("NptgLocalityCode").alias("locality"),
|
||
)
|
||
)
|
||
|
||
before = len(df)
|
||
df = deduplicate_naptan(df)
|
||
|
||
print(
|
||
f"Deduplicated {before:,} → {len(df):,} stops "
|
||
"(by name+category+locality; tube stations by normalized name+area)"
|
||
)
|
||
|
||
df.write_parquet(output)
|
||
size_mb = output.stat().st_size / (1024 * 1024)
|
||
print(f"Wrote {output} ({size_mb:.1f} MB, {len(df):,} stations)")
|
||
|
||
counts = df.group_by("category").len().sort("len", descending=True)
|
||
for row in counts.iter_rows(named=True):
|
||
print(f" {row['category']}: {row['len']:,}")
|
||
|
||
|
||
def main() -> None:
|
||
parser = argparse.ArgumentParser(description="Download NaPTAN station data")
|
||
parser.add_argument(
|
||
"--output", type=Path, required=True, help="Output parquet file path"
|
||
)
|
||
args = parser.parse_args()
|
||
download_naptan(args.output)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|