import argparse import logging from pathlib import Path from tempfile import mkdtemp import osmium import polars as pl from shapely import make_valid from shapely.errors import GEOSException from shapely.geometry import Point from shapely.wkb import loads as load_wkb from tqdm import tqdm from pipeline.local_temp import local_tmp_dir from pipeline.utils.england_geometry import ( ENGLAND_BBOX_EAST, ENGLAND_BBOX_NORTH, ENGLAND_BBOX_SOUTH, ENGLAND_BBOX_WEST, load_england_polygon, ) logger = logging.getLogger(__name__) BATCH_SIZE = 50_000 MIN_OCCURENCE_COUNT = 20 POI_TAG_KEYS: list[str] = [ "amenity", "building", "craft", "emergency", "healthcare", "leisure", "office", "shop", "tourism", "public_transport", ] AREA_BUILDING_CATEGORIES = {"building/church", "building/university"} def _representative_lat_lon(geom, england_polygon) -> tuple[float, float] | None: if geom.is_empty: return None if not geom.is_valid: geom = make_valid(geom) if geom.is_empty: return None point = geom.representative_point() lat, lon = point.y, point.x if not england_polygon.contains(Point(lon, lat)): return None return lat, lon class POIHandler(osmium.SimpleHandler): def __init__(self, progress: tqdm, tmp_dir: Path, england_polygon) -> None: super().__init__() self._batch: list[dict] = [] self._tmp_dir = tmp_dir self._batch_num = 0 self.poi_count = 0 self.skipped_areas = 0 self._progress = progress self._england = england_polygon self._wkb_factory = osmium.geom.WKBFactory() def _in_england(self, lat: float, lon: float) -> bool: # Fast bbox pre-filter, then precise polygon check if not ( ENGLAND_BBOX_SOUTH <= lat <= ENGLAND_BBOX_NORTH and ENGLAND_BBOX_WEST <= lon <= ENGLAND_BBOX_EAST ): return False return self._england.contains(Point(lon, lat)) def _match_tags( self, tags: osmium.osm.TagList, *, polygonal: bool = False ) -> list[str]: categories = [f"{key}/{tags[key]}" for key in POI_TAG_KEYS if key in tags] if not polygonal: return categories return [ category for category in categories if not category.startswith("building/") or category in AREA_BUILDING_CATEGORIES ] def _get_name(self, tags: osmium.osm.TagList) -> str: return tags.get("name:en", tags.get("name", "")) def _flush_batch(self) -> None: if not self._batch: return df = pl.DataFrame(self._batch) out = self._tmp_dir / f"batch_{self._batch_num:05d}.parquet" df.write_parquet(out) self._batch_num += 1 self._batch.clear() def _add_poi( self, osm_id: str, tags: osmium.osm.TagList, category: str, lat: float, lng: float, ) -> None: self._batch.append( { "id": osm_id, "name": self._get_name(tags), "category": category, "lat": lat, "lng": lng, } ) self.poi_count += 1 self._progress.set_postfix(pois=f"{self.poi_count:,}", refresh=False) if len(self._batch) >= BATCH_SIZE: self._flush_batch() def _point_from_area(self, area: osmium.osm.Area) -> tuple[float, float] | None: try: geom = load_wkb(self._wkb_factory.create_multipolygon(area), hex=True) except (RuntimeError, GEOSException, ValueError) as exc: self.skipped_areas += 1 logger.warning( "Failed to build multipolygon WKB for area orig_id=%s (%s)", getattr(area, "orig_id", lambda: "?")(), type(exc).__name__, exc_info=True, ) return None return _representative_lat_lon(geom, self._england) def _tick(self) -> None: self._progress.update(1) def node(self, n: osmium.osm.Node) -> None: self._tick() if not n.location.valid: return lat, lon = n.location.lat, n.location.lon if not self._in_england(lat, lon): return categories = self._match_tags(n.tags) for category in categories: self._add_poi(f"n{n.id}", n.tags, category, lat, lon) def area(self, a: osmium.osm.Area) -> None: self._tick() categories = self._match_tags(a.tags, polygonal=True) if not categories: return point = self._point_from_area(a) if point is None: return lat, lon = point for category in categories: self._add_poi(f"a{a.id}", a.tags, category, lat, lon) def main() -> None: parser = argparse.ArgumentParser( description="Download and extract POIs from OpenStreetMap" ) parser.add_argument( "--output", type=Path, required=True, help="Output parquet file path" ) parser.add_argument("--pbf", type=Path, required=True, help="Path to OSM PBF file") parser.add_argument( "--boundary", type=Path, required=True, help="England boundary GeoJSON file", ) args = parser.parse_args() pbf_file = args.pbf print(f"Tag keys: {POI_TAG_KEYS}") england_polygon = load_england_polygon(args.boundary) tmp_dir = Path(mkdtemp(prefix="pois_", dir=local_tmp_dir())) with tqdm( unit=" elements", unit_scale=True, desc="Streaming", smoothing=0.05, mininterval=1.0, ) as progress: handler = POIHandler(progress, tmp_dir, england_polygon) handler.apply_file(str(pbf_file), locations=True) handler._flush_batch() # write any remaining POIs print(f"Extracted {handler.poi_count:,} POIs") if handler.skipped_areas: logger.warning( "Skipped %d areas due to geometry assembly errors", handler.skipped_areas, ) batch_files = sorted(tmp_dir.glob("batch_*.parquet")) df = pl.concat([pl.scan_parquet(f) for f in batch_files]) # Only keep categories with enough occurrences valid_categories = ( df.group_by("category") .agg(pl.len().alias("count")) .filter(pl.col("count") >= MIN_OCCURENCE_COUNT) ) df = df.join(valid_categories.select("category"), on="category", how="semi") print(f"Total POIs: {handler.poi_count:,}") df.sink_parquet(args.output) print(f"Saved to {args.output}") if __name__ == "__main__": main()