223 lines
6.5 KiB
Python
223 lines
6.5 KiB
Python
import argparse
|
|
import logging
|
|
from pathlib import Path
|
|
from tempfile import mkdtemp
|
|
|
|
import osmium
|
|
import polars as pl
|
|
from shapely import make_valid
|
|
from shapely.errors import GEOSException
|
|
from shapely.geometry import Point
|
|
from shapely.wkb import loads as load_wkb
|
|
from tqdm import tqdm
|
|
|
|
from pipeline.utils.england_geometry import (
|
|
ENGLAND_BBOX_EAST,
|
|
ENGLAND_BBOX_NORTH,
|
|
ENGLAND_BBOX_SOUTH,
|
|
ENGLAND_BBOX_WEST,
|
|
load_england_polygon,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
BATCH_SIZE = 50_000
|
|
|
|
MIN_OCCURENCE_COUNT = 20
|
|
|
|
POI_TAG_KEYS: list[str] = [
|
|
"amenity",
|
|
"building",
|
|
"craft",
|
|
"emergency",
|
|
"healthcare",
|
|
"leisure",
|
|
"office",
|
|
"shop",
|
|
"tourism",
|
|
"public_transport",
|
|
]
|
|
AREA_BUILDING_CATEGORIES = {"building/church", "building/university"}
|
|
|
|
|
|
def _representative_lat_lon(geom, england_polygon) -> tuple[float, float] | None:
|
|
if geom.is_empty:
|
|
return None
|
|
if not geom.is_valid:
|
|
geom = make_valid(geom)
|
|
if geom.is_empty:
|
|
return None
|
|
point = geom.representative_point()
|
|
lat, lon = point.y, point.x
|
|
if not england_polygon.contains(Point(lon, lat)):
|
|
return None
|
|
return lat, lon
|
|
|
|
|
|
class POIHandler(osmium.SimpleHandler):
|
|
def __init__(self, progress: tqdm, tmp_dir: Path, england_polygon) -> None:
|
|
super().__init__()
|
|
self._batch: list[dict] = []
|
|
self._tmp_dir = tmp_dir
|
|
self._batch_num = 0
|
|
self.poi_count = 0
|
|
self.skipped_areas = 0
|
|
self._progress = progress
|
|
self._england = england_polygon
|
|
self._wkb_factory = osmium.geom.WKBFactory()
|
|
|
|
def _in_england(self, lat: float, lon: float) -> bool:
|
|
# Fast bbox pre-filter, then precise polygon check
|
|
if not (
|
|
ENGLAND_BBOX_SOUTH <= lat <= ENGLAND_BBOX_NORTH
|
|
and ENGLAND_BBOX_WEST <= lon <= ENGLAND_BBOX_EAST
|
|
):
|
|
return False
|
|
return self._england.contains(Point(lon, lat))
|
|
|
|
def _match_tags(
|
|
self, tags: osmium.osm.TagList, *, polygonal: bool = False
|
|
) -> list[str]:
|
|
categories = [f"{key}/{tags[key]}" for key in POI_TAG_KEYS if key in tags]
|
|
if not polygonal:
|
|
return categories
|
|
return [
|
|
category
|
|
for category in categories
|
|
if not category.startswith("building/")
|
|
or category in AREA_BUILDING_CATEGORIES
|
|
]
|
|
|
|
def _get_name(self, tags: osmium.osm.TagList) -> str:
|
|
return tags.get("name:en", tags.get("name", ""))
|
|
|
|
def _flush_batch(self) -> None:
|
|
if not self._batch:
|
|
return
|
|
df = pl.DataFrame(self._batch)
|
|
out = self._tmp_dir / f"batch_{self._batch_num:05d}.parquet"
|
|
df.write_parquet(out)
|
|
self._batch_num += 1
|
|
self._batch.clear()
|
|
|
|
def _add_poi(
|
|
self,
|
|
osm_id: str,
|
|
tags: osmium.osm.TagList,
|
|
category: str,
|
|
lat: float,
|
|
lng: float,
|
|
) -> None:
|
|
self._batch.append(
|
|
{
|
|
"id": osm_id,
|
|
"name": self._get_name(tags),
|
|
"category": category,
|
|
"lat": lat,
|
|
"lng": lng,
|
|
}
|
|
)
|
|
self.poi_count += 1
|
|
self._progress.set_postfix(pois=f"{self.poi_count:,}", refresh=False)
|
|
if len(self._batch) >= BATCH_SIZE:
|
|
self._flush_batch()
|
|
|
|
def _point_from_area(self, area: osmium.osm.Area) -> tuple[float, float] | None:
|
|
try:
|
|
geom = load_wkb(self._wkb_factory.create_multipolygon(area), hex=True)
|
|
except (RuntimeError, GEOSException, ValueError) as exc:
|
|
self.skipped_areas += 1
|
|
logger.warning(
|
|
"Failed to build multipolygon WKB for area orig_id=%s (%s)",
|
|
getattr(area, "orig_id", lambda: "?")(),
|
|
type(exc).__name__,
|
|
exc_info=True,
|
|
)
|
|
return None
|
|
return _representative_lat_lon(geom, self._england)
|
|
|
|
def _tick(self) -> None:
|
|
self._progress.update(1)
|
|
|
|
def node(self, n: osmium.osm.Node) -> None:
|
|
self._tick()
|
|
if not n.location.valid:
|
|
return
|
|
lat, lon = n.location.lat, n.location.lon
|
|
if not self._in_england(lat, lon):
|
|
return
|
|
categories = self._match_tags(n.tags)
|
|
for category in categories:
|
|
self._add_poi(f"n{n.id}", n.tags, category, lat, lon)
|
|
|
|
def area(self, a: osmium.osm.Area) -> None:
|
|
self._tick()
|
|
categories = self._match_tags(a.tags, polygonal=True)
|
|
if not categories:
|
|
return
|
|
point = self._point_from_area(a)
|
|
if point is None:
|
|
return
|
|
lat, lon = point
|
|
for category in categories:
|
|
self._add_poi(f"a{a.id}", a.tags, category, lat, lon)
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description="Download and extract POIs from OpenStreetMap"
|
|
)
|
|
parser.add_argument(
|
|
"--output", type=Path, required=True, help="Output parquet file path"
|
|
)
|
|
parser.add_argument("--pbf", type=Path, required=True, help="Path to OSM PBF file")
|
|
parser.add_argument(
|
|
"--boundary",
|
|
type=Path,
|
|
required=True,
|
|
help="England boundary GeoJSON file",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
pbf_file = args.pbf
|
|
print(f"Tag keys: {POI_TAG_KEYS}")
|
|
|
|
england_polygon = load_england_polygon(args.boundary)
|
|
|
|
tmp_dir = Path(mkdtemp(prefix="pois_"))
|
|
with tqdm(
|
|
unit=" elements",
|
|
unit_scale=True,
|
|
desc="Streaming",
|
|
smoothing=0.05,
|
|
mininterval=1.0,
|
|
) as progress:
|
|
handler = POIHandler(progress, tmp_dir, england_polygon)
|
|
handler.apply_file(str(pbf_file), locations=True)
|
|
handler._flush_batch() # write any remaining POIs
|
|
|
|
print(f"Extracted {handler.poi_count:,} POIs")
|
|
if handler.skipped_areas:
|
|
logger.warning(
|
|
"Skipped %d areas due to geometry assembly errors",
|
|
handler.skipped_areas,
|
|
)
|
|
|
|
batch_files = sorted(tmp_dir.glob("batch_*.parquet"))
|
|
df = pl.concat([pl.scan_parquet(f) for f in batch_files])
|
|
|
|
# Only keep categories with enough occurrences
|
|
valid_categories = (
|
|
df.group_by("category")
|
|
.agg(pl.len().alias("count"))
|
|
.filter(pl.col("count") >= MIN_OCCURENCE_COUNT)
|
|
)
|
|
df = df.join(valid_categories.select("category"), on="category", how="semi")
|
|
|
|
print(f"Total POIs: {handler.poi_count:,}")
|
|
df.sink_parquet(args.output)
|
|
print(f"Saved to {args.output}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|