165 lines
5.2 KiB
Python
165 lines
5.2 KiB
Python
"""Download NaPTAN data and extract railway/metro station POIs."""
|
||
|
||
import argparse
|
||
import io
|
||
import urllib.request
|
||
from pathlib import Path
|
||
|
||
import polars as pl
|
||
|
||
NAPTAN_CSV_URL = "https://naptan.api.dft.gov.uk/v1/access-nodes?dataFormat=csv"
|
||
|
||
|
||
STOP_TYPES = {
|
||
"AIR": "Airport",
|
||
"FTD": "Ferry",
|
||
"RSE": "Rail station",
|
||
"BCT": "Bus stop",
|
||
"BCE": "Bus station",
|
||
"TXR": "Taxi rank",
|
||
"TMU": "Tube station",
|
||
"MET": "Tube station",
|
||
}
|
||
|
||
|
||
OUTPUT_COLUMNS = ["id", "name", "category", "lat", "lng"]
|
||
|
||
|
||
def canonical_station_name_expr(name_col: str = "name") -> pl.Expr:
|
||
"""Normalize station names so entrances/transport-mode variants collapse."""
|
||
expr = pl.col(name_col).str.to_lowercase()
|
||
expr = expr.str.replace_all(r"\([^)]*\)", " ")
|
||
expr = expr.str.replace_all(r"['’`]", "")
|
||
expr = expr.str.replace_all(r"&", " and ")
|
||
expr = expr.str.replace_all(r"[^a-z0-9]+", " ")
|
||
expr = expr.str.replace_all(r"\s+", " ").str.strip_chars()
|
||
expr = expr.str.replace_all(
|
||
r"\s+(underground|tube|dlr|metro|rail|railway)\s+station$", ""
|
||
)
|
||
expr = expr.str.replace_all(r"\s+tram\s+stop$", "")
|
||
expr = expr.str.replace_all(r"\s+(station|stop)$", "")
|
||
return expr.str.strip_chars()
|
||
|
||
|
||
def _has_locality() -> pl.Expr:
|
||
return pl.col("locality").is_not_null() & (pl.col("locality") != "")
|
||
|
||
|
||
def _deduplicate_tube_partition(
|
||
df: pl.DataFrame, group_cols: list[str]
|
||
) -> pl.DataFrame:
|
||
if len(df) == 0:
|
||
return pl.DataFrame(
|
||
{
|
||
"id": pl.Series([], dtype=pl.String),
|
||
"name": pl.Series([], dtype=pl.String),
|
||
"category": pl.Series([], dtype=pl.String),
|
||
"lat": pl.Series([], dtype=pl.Float64),
|
||
"lng": pl.Series([], dtype=pl.Float64),
|
||
}
|
||
)
|
||
|
||
name_len = pl.col("name").str.len_chars()
|
||
return (
|
||
df.group_by(group_cols)
|
||
.agg(
|
||
pl.col("id").sort_by(name_len).first(),
|
||
pl.col("name").sort_by(name_len).first(),
|
||
pl.col("category").first(),
|
||
pl.col("lat").mean(),
|
||
pl.col("lng").mean(),
|
||
)
|
||
.select(OUTPUT_COLUMNS)
|
||
)
|
||
|
||
|
||
def deduplicate_naptan(df: pl.DataFrame) -> pl.DataFrame:
|
||
"""Deduplicate NaPTAN stops, with stricter station-level merging for Tube POIs."""
|
||
has_loc = df.filter(_has_locality())
|
||
no_loc = df.filter(~_has_locality())
|
||
cols_with_locality = [*OUTPUT_COLUMNS, "locality"]
|
||
|
||
# First pass: one record per exact stop name/category/locality.
|
||
deduped_has_loc = (
|
||
has_loc.group_by("name", "category", "locality")
|
||
.agg(
|
||
pl.col("id").first(),
|
||
pl.col("lat").mean(),
|
||
pl.col("lng").mean(),
|
||
)
|
||
.select(cols_with_locality)
|
||
)
|
||
df = pl.concat([deduped_has_loc, no_loc.select(cols_with_locality)])
|
||
|
||
tube = df.filter(pl.col("category") == "Tube station").with_columns(
|
||
canonical_station_name_expr().alias("_station_key")
|
||
)
|
||
other = df.filter(pl.col("category") != "Tube station")
|
||
|
||
tube_with_loc = tube.filter(_has_locality())
|
||
tube_no_loc = tube.filter(~_has_locality())
|
||
deduped_tube = pl.concat(
|
||
[
|
||
_deduplicate_tube_partition(tube_with_loc, ["_station_key", "locality"]),
|
||
_deduplicate_tube_partition(tube_no_loc, ["_station_key"]),
|
||
]
|
||
)
|
||
|
||
return pl.concat([other.select(OUTPUT_COLUMNS), deduped_tube])
|
||
|
||
|
||
def download_naptan(output: Path) -> None:
|
||
output.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
print(f"Downloading NaPTAN data from {NAPTAN_CSV_URL}")
|
||
with urllib.request.urlopen(NAPTAN_CSV_URL) as resp:
|
||
raw = resp.read()
|
||
|
||
print(f"Downloaded {len(raw) / (1024 * 1024):.1f} MB")
|
||
|
||
df = (
|
||
pl.read_csv(io.BytesIO(raw), infer_schema_length=0)
|
||
.with_columns(
|
||
pl.col("Latitude").cast(pl.Float64, strict=False),
|
||
pl.col("Longitude").cast(pl.Float64, strict=False),
|
||
)
|
||
.drop_nulls(subset=["Latitude", "Longitude"])
|
||
.filter(pl.col("StopType").is_in(list(STOP_TYPES.keys())))
|
||
.select(
|
||
pl.col("ATCOCode").alias("id"),
|
||
pl.col("CommonName").alias("name"),
|
||
pl.col("StopType").replace(STOP_TYPES).alias("category"),
|
||
pl.col("Latitude").alias("lat"),
|
||
pl.col("Longitude").alias("lng"),
|
||
pl.col("NptgLocalityCode").alias("locality"),
|
||
)
|
||
)
|
||
|
||
before = len(df)
|
||
df = deduplicate_naptan(df)
|
||
|
||
print(
|
||
f"Deduplicated {before:,} → {len(df):,} stops "
|
||
"(by name+category+locality; tube stations by normalized station name)"
|
||
)
|
||
|
||
df.write_parquet(output)
|
||
size_mb = output.stat().st_size / (1024 * 1024)
|
||
print(f"Wrote {output} ({size_mb:.1f} MB, {len(df):,} stations)")
|
||
|
||
counts = df.group_by("category").len().sort("len", descending=True)
|
||
for row in counts.iter_rows(named=True):
|
||
print(f" {row['category']}: {row['len']:,}")
|
||
|
||
|
||
def main() -> None:
|
||
parser = argparse.ArgumentParser(description="Download NaPTAN station data")
|
||
parser.add_argument(
|
||
"--output", type=Path, required=True, help="Output parquet file path"
|
||
)
|
||
args = parser.parse_args()
|
||
download_naptan(args.output)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|