181 lines
5 KiB
Python
181 lines
5 KiB
Python
"""Single-pass POI extraction from OSM PBF file using pyosmium."""
|
|
|
|
import json
|
|
import urllib.request
|
|
|
|
import osmium
|
|
import polars as pl
|
|
from tqdm import tqdm
|
|
|
|
from .config import (
|
|
GB_PBF_FILE,
|
|
GEOFABRIK_GB_URL,
|
|
OSM_TAG_MAPPING,
|
|
OUTPUT_FILE,
|
|
TAG_KEYS_TO_CHECK,
|
|
UK_BBOX_EAST,
|
|
UK_BBOX_NORTH,
|
|
UK_BBOX_SOUTH,
|
|
UK_BBOX_WEST,
|
|
)
|
|
|
|
# Approximate element count for the GB PBF extract (for progress estimation).
|
|
ESTIMATED_ELEMENTS = 500_000_000
|
|
|
|
|
|
def download_pbf() -> None:
|
|
"""Download Great Britain PBF extract from Geofabrik."""
|
|
GB_PBF_FILE.parent.mkdir(parents=True, exist_ok=True)
|
|
tmp = GB_PBF_FILE.with_suffix(".pbf.tmp")
|
|
print(f"Downloading {GEOFABRIK_GB_URL}")
|
|
|
|
with (
|
|
tqdm(unit="B", unit_scale=True, desc="Downloading") as bar,
|
|
urllib.request.urlopen(GEOFABRIK_GB_URL) as resp,
|
|
open(tmp, "wb") as f,
|
|
):
|
|
length = resp.headers.get("Content-Length")
|
|
if length:
|
|
bar.total = int(length)
|
|
while chunk := resp.read(1 << 20):
|
|
f.write(chunk)
|
|
bar.update(len(chunk))
|
|
|
|
tmp.rename(GB_PBF_FILE)
|
|
print(f"Saved to {GB_PBF_FILE}")
|
|
|
|
|
|
class POIHandler(osmium.SimpleHandler):
|
|
"""Streams OSM data, filters to UK bbox, extracts matching POIs."""
|
|
|
|
def __init__(self, progress: tqdm) -> None:
|
|
super().__init__()
|
|
self.pois: list[dict] = []
|
|
self._poi_count = 0
|
|
self._progress = progress
|
|
|
|
def _in_uk(self, lat: float, lon: float) -> bool:
|
|
return (
|
|
UK_BBOX_SOUTH <= lat <= UK_BBOX_NORTH
|
|
and UK_BBOX_WEST <= lon <= UK_BBOX_EAST
|
|
)
|
|
|
|
def _match_tags(self, tags: osmium.osm.TagList) -> str | None:
|
|
for key in TAG_KEYS_TO_CHECK:
|
|
if key in tags:
|
|
value = tags[key]
|
|
if value in TAG_KEYS_TO_CHECK[key]:
|
|
return OSM_TAG_MAPPING[(key, value)]
|
|
return None
|
|
|
|
def _get_name(self, tags: osmium.osm.TagList) -> str:
|
|
return tags.get("name:en", tags.get("name", ""))
|
|
|
|
def _tags_to_json(self, tags: osmium.osm.TagList) -> str:
|
|
return json.dumps({tag.k: tag.v for tag in tags})
|
|
|
|
def _add_poi(
|
|
self, osm_id: str, tags: osmium.osm.TagList, category: str, lat: float, lng: float
|
|
) -> None:
|
|
self.pois.append(
|
|
{
|
|
"id": osm_id,
|
|
"name": self._get_name(tags),
|
|
"category": category,
|
|
"lat": lat,
|
|
"lng": lng,
|
|
"osm_tags": self._tags_to_json(tags),
|
|
}
|
|
)
|
|
self._poi_count += 1
|
|
self._progress.set_postfix(pois=f"{self._poi_count:,}", refresh=False)
|
|
|
|
def _tick(self) -> None:
|
|
self._progress.update(1)
|
|
|
|
def node(self, n: osmium.osm.Node) -> None:
|
|
self._tick()
|
|
if not n.location.valid:
|
|
return
|
|
lat, lon = n.location.lat, n.location.lon
|
|
if not self._in_uk(lat, lon):
|
|
return
|
|
category = self._match_tags(n.tags)
|
|
if category:
|
|
self._add_poi(f"n{n.id}", n.tags, category, lat, lon)
|
|
|
|
def way(self, w: osmium.osm.Way) -> None:
|
|
self._tick()
|
|
category = self._match_tags(w.tags)
|
|
if not category:
|
|
return
|
|
|
|
lats = []
|
|
lons = []
|
|
for node in w.nodes:
|
|
try:
|
|
lats.append(node.location.lat)
|
|
lons.append(node.location.lon)
|
|
except osmium.InvalidLocationError:
|
|
continue
|
|
|
|
if not lats:
|
|
return
|
|
|
|
centroid_lat = sum(lats) / len(lats)
|
|
centroid_lng = sum(lons) / len(lons)
|
|
|
|
if not self._in_uk(centroid_lat, centroid_lng):
|
|
return
|
|
|
|
self._add_poi(f"w{w.id}", w.tags, category, centroid_lat, centroid_lng)
|
|
|
|
|
|
def main() -> None:
|
|
if not GB_PBF_FILE.exists():
|
|
download_pbf()
|
|
|
|
print(f"=== POI Extraction from {GB_PBF_FILE} ===")
|
|
print(
|
|
f"UK bbox: ({UK_BBOX_WEST}, {UK_BBOX_SOUTH}, {UK_BBOX_EAST}, {UK_BBOX_NORTH})"
|
|
)
|
|
print(f"Categories: {len(OSM_TAG_MAPPING)}")
|
|
print()
|
|
|
|
with tqdm(
|
|
total=ESTIMATED_ELEMENTS,
|
|
unit=" elements",
|
|
unit_scale=True,
|
|
desc="Streaming",
|
|
smoothing=0.05,
|
|
mininterval=1.0,
|
|
) as progress:
|
|
handler = POIHandler(progress)
|
|
handler.apply_file(str(GB_PBF_FILE), locations=True)
|
|
|
|
print(f"Extracted {len(handler.pois):,} POIs")
|
|
|
|
if not handler.pois:
|
|
print("No POIs found.")
|
|
return
|
|
|
|
df = pl.DataFrame(handler.pois)
|
|
|
|
OUTPUT_FILE.parent.mkdir(parents=True, exist_ok=True)
|
|
df.write_parquet(OUTPUT_FILE)
|
|
print(f"Saved to {OUTPUT_FILE}")
|
|
|
|
print("\n=== Summary ===")
|
|
print(f"Total POIs: {len(df):,}")
|
|
print("\nPOIs by category:")
|
|
category_counts = (
|
|
df.group_by("category")
|
|
.agg(pl.len().alias("count"))
|
|
.sort("count", descending=True)
|
|
)
|
|
for row in category_counts.iter_rows(named=True):
|
|
print(f" {row['category']}: {row['count']:,}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|