perfect-postcode/pipeline/download/pois.py
2026-05-17 13:52:11 +01:00

223 lines
6.5 KiB
Python

import argparse
import logging
from pathlib import Path
from tempfile import mkdtemp
import osmium
import polars as pl
from shapely import make_valid
from shapely.errors import GEOSException
from shapely.geometry import Point
from shapely.wkb import loads as load_wkb
from tqdm import tqdm
from pipeline.utils.england_geometry import (
ENGLAND_BBOX_EAST,
ENGLAND_BBOX_NORTH,
ENGLAND_BBOX_SOUTH,
ENGLAND_BBOX_WEST,
load_england_polygon,
)
logger = logging.getLogger(__name__)
BATCH_SIZE = 50_000
MIN_OCCURENCE_COUNT = 20
POI_TAG_KEYS: list[str] = [
"amenity",
"building",
"craft",
"emergency",
"healthcare",
"leisure",
"office",
"shop",
"tourism",
"public_transport",
]
AREA_BUILDING_CATEGORIES = {"building/church", "building/university"}
def _representative_lat_lon(geom, england_polygon) -> tuple[float, float] | None:
if geom.is_empty:
return None
if not geom.is_valid:
geom = make_valid(geom)
if geom.is_empty:
return None
point = geom.representative_point()
lat, lon = point.y, point.x
if not england_polygon.contains(Point(lon, lat)):
return None
return lat, lon
class POIHandler(osmium.SimpleHandler):
def __init__(self, progress: tqdm, tmp_dir: Path, england_polygon) -> None:
super().__init__()
self._batch: list[dict] = []
self._tmp_dir = tmp_dir
self._batch_num = 0
self.poi_count = 0
self.skipped_areas = 0
self._progress = progress
self._england = england_polygon
self._wkb_factory = osmium.geom.WKBFactory()
def _in_england(self, lat: float, lon: float) -> bool:
# Fast bbox pre-filter, then precise polygon check
if not (
ENGLAND_BBOX_SOUTH <= lat <= ENGLAND_BBOX_NORTH
and ENGLAND_BBOX_WEST <= lon <= ENGLAND_BBOX_EAST
):
return False
return self._england.contains(Point(lon, lat))
def _match_tags(
self, tags: osmium.osm.TagList, *, polygonal: bool = False
) -> list[str]:
categories = [f"{key}/{tags[key]}" for key in POI_TAG_KEYS if key in tags]
if not polygonal:
return categories
return [
category
for category in categories
if not category.startswith("building/")
or category in AREA_BUILDING_CATEGORIES
]
def _get_name(self, tags: osmium.osm.TagList) -> str:
return tags.get("name:en", tags.get("name", ""))
def _flush_batch(self) -> None:
if not self._batch:
return
df = pl.DataFrame(self._batch)
out = self._tmp_dir / f"batch_{self._batch_num:05d}.parquet"
df.write_parquet(out)
self._batch_num += 1
self._batch.clear()
def _add_poi(
self,
osm_id: str,
tags: osmium.osm.TagList,
category: str,
lat: float,
lng: float,
) -> None:
self._batch.append(
{
"id": osm_id,
"name": self._get_name(tags),
"category": category,
"lat": lat,
"lng": lng,
}
)
self.poi_count += 1
self._progress.set_postfix(pois=f"{self.poi_count:,}", refresh=False)
if len(self._batch) >= BATCH_SIZE:
self._flush_batch()
def _point_from_area(self, area: osmium.osm.Area) -> tuple[float, float] | None:
try:
geom = load_wkb(self._wkb_factory.create_multipolygon(area), hex=True)
except (RuntimeError, GEOSException, ValueError) as exc:
self.skipped_areas += 1
logger.warning(
"Failed to build multipolygon WKB for area orig_id=%s (%s)",
getattr(area, "orig_id", lambda: "?")(),
type(exc).__name__,
exc_info=True,
)
return None
return _representative_lat_lon(geom, self._england)
def _tick(self) -> None:
self._progress.update(1)
def node(self, n: osmium.osm.Node) -> None:
self._tick()
if not n.location.valid:
return
lat, lon = n.location.lat, n.location.lon
if not self._in_england(lat, lon):
return
categories = self._match_tags(n.tags)
for category in categories:
self._add_poi(f"n{n.id}", n.tags, category, lat, lon)
def area(self, a: osmium.osm.Area) -> None:
self._tick()
categories = self._match_tags(a.tags, polygonal=True)
if not categories:
return
point = self._point_from_area(a)
if point is None:
return
lat, lon = point
for category in categories:
self._add_poi(f"a{a.id}", a.tags, category, lat, lon)
def main() -> None:
parser = argparse.ArgumentParser(
description="Download and extract POIs from OpenStreetMap"
)
parser.add_argument(
"--output", type=Path, required=True, help="Output parquet file path"
)
parser.add_argument("--pbf", type=Path, required=True, help="Path to OSM PBF file")
parser.add_argument(
"--boundary",
type=Path,
required=True,
help="England boundary GeoJSON file",
)
args = parser.parse_args()
pbf_file = args.pbf
print(f"Tag keys: {POI_TAG_KEYS}")
england_polygon = load_england_polygon(args.boundary)
tmp_dir = Path(mkdtemp(prefix="pois_"))
with tqdm(
unit=" elements",
unit_scale=True,
desc="Streaming",
smoothing=0.05,
mininterval=1.0,
) as progress:
handler = POIHandler(progress, tmp_dir, england_polygon)
handler.apply_file(str(pbf_file), locations=True)
handler._flush_batch() # write any remaining POIs
print(f"Extracted {handler.poi_count:,} POIs")
if handler.skipped_areas:
logger.warning(
"Skipped %d areas due to geometry assembly errors",
handler.skipped_areas,
)
batch_files = sorted(tmp_dir.glob("batch_*.parquet"))
df = pl.concat([pl.scan_parquet(f) for f in batch_files])
# Only keep categories with enough occurrences
valid_categories = (
df.group_by("category")
.agg(pl.len().alias("count"))
.filter(pl.col("count") >= MIN_OCCURENCE_COUNT)
)
df = df.join(valid_categories.select("category"), on="category", how="semi")
print(f"Total POIs: {handler.poi_count:,}")
df.sink_parquet(args.output)
print(f"Saved to {args.output}")
if __name__ == "__main__":
main()