perfect-postcode/pipeline/download/pois.py
2026-03-15 21:22:28 +00:00

159 lines
4.4 KiB
Python

import argparse
from pathlib import Path
from tempfile import mkdtemp
import osmium
import polars as pl
from shapely.geometry import Point
from tqdm import tqdm
from pipeline.utils.england_geometry import (
ENGLAND_BBOX_EAST,
ENGLAND_BBOX_NORTH,
ENGLAND_BBOX_SOUTH,
ENGLAND_BBOX_WEST,
load_england_polygon,
)
BATCH_SIZE = 50_000
MIN_OCCURENCE_COUNT = 20
POI_TAG_KEYS: list[str] = [
"amenity",
"building",
"craft",
"emergency",
"healthcare",
"leisure",
"office",
"shop",
"tourism",
"public_transport",
]
class POIHandler(osmium.SimpleHandler):
def __init__(self, progress: tqdm, tmp_dir: Path, england_polygon) -> None:
super().__init__()
self._batch: list[dict] = []
self._tmp_dir = tmp_dir
self._batch_num = 0
self.poi_count = 0
self._progress = progress
self._england = england_polygon
def _in_england(self, lat: float, lon: float) -> bool:
# Fast bbox pre-filter, then precise polygon check
if not (
ENGLAND_BBOX_SOUTH <= lat <= ENGLAND_BBOX_NORTH
and ENGLAND_BBOX_WEST <= lon <= ENGLAND_BBOX_EAST
):
return False
return self._england.contains(Point(lon, lat))
def _match_tags(self, tags: osmium.osm.TagList) -> list[str]:
return [f"{key}/{tags[key]}" for key in POI_TAG_KEYS if key in tags]
def _get_name(self, tags: osmium.osm.TagList) -> str:
return tags.get("name:en", tags.get("name", ""))
def _flush_batch(self) -> None:
if not self._batch:
return
df = pl.DataFrame(self._batch)
out = self._tmp_dir / f"batch_{self._batch_num:05d}.parquet"
df.write_parquet(out)
self._batch_num += 1
self._batch.clear()
def _add_poi(
self,
osm_id: str,
tags: osmium.osm.TagList,
category: str,
lat: float,
lng: float,
) -> None:
self._batch.append(
{
"id": osm_id,
"name": self._get_name(tags),
"category": category,
"lat": lat,
"lng": lng,
}
)
self.poi_count += 1
self._progress.set_postfix(pois=f"{self.poi_count:,}", refresh=False)
if len(self._batch) >= BATCH_SIZE:
self._flush_batch()
def _tick(self) -> None:
self._progress.update(1)
def node(self, n: osmium.osm.Node) -> None:
self._tick()
if not n.location.valid:
return
lat, lon = n.location.lat, n.location.lon
if not self._in_england(lat, lon):
return
categories = self._match_tags(n.tags)
for category in categories:
self._add_poi(f"n{n.id}", n.tags, category, lat, lon)
def main() -> None:
parser = argparse.ArgumentParser(
description="Download and extract POIs from OpenStreetMap"
)
parser.add_argument(
"--output", type=Path, required=True, help="Output parquet file path"
)
parser.add_argument("--pbf", type=Path, required=True, help="Path to OSM PBF file")
parser.add_argument(
"--boundary",
type=Path,
required=True,
help="England boundary GeoJSON file",
)
args = parser.parse_args()
pbf_file = args.pbf
print(f"Tag keys: {POI_TAG_KEYS}")
england_polygon = load_england_polygon(args.boundary)
tmp_dir = Path(mkdtemp(prefix="pois_"))
with tqdm(
unit=" elements",
unit_scale=True,
desc="Streaming",
smoothing=0.05,
mininterval=1.0,
) as progress:
handler = POIHandler(progress, tmp_dir, england_polygon)
handler.apply_file(str(pbf_file), locations=True)
handler._flush_batch() # write any remaining POIs
print(f"Extracted {handler.poi_count:,} POIs")
batch_files = sorted(tmp_dir.glob("batch_*.parquet"))
df = pl.concat([pl.scan_parquet(f) for f in batch_files])
# Only keep categories with enough occurrences
valid_categories = (
df.group_by("category")
.agg(pl.len().alias("count"))
.filter(pl.col("count") >= MIN_OCCURENCE_COUNT)
)
df = df.join(valid_categories.select("category"), on="category", how="semi")
print(f"Total POIs: {handler.poi_count:,}")
df.sink_parquet(args.output)
print(f"Saved to {args.output}")
if __name__ == "__main__":
main()