perfect-postcode/pipeline/download/pois/__main__.py

198 lines
5.6 KiB
Python

import json
import shutil
import urllib.request
from pathlib import Path
from tempfile import mkdtemp
import osmium
import polars as pl
from tqdm import tqdm
from .config import (
GB_PBF_FILE,
GEOFABRIK_GB_URL,
OUTPUT_FILE,
POI_TAG_KEYS,
UK_BBOX_EAST,
UK_BBOX_NORTH,
UK_BBOX_SOUTH,
UK_BBOX_WEST,
)
BATCH_SIZE = 50_000
def download_pbf() -> None:
GB_PBF_FILE.parent.mkdir(parents=True, exist_ok=True)
tmp = GB_PBF_FILE.with_suffix(".pbf.tmp")
print(f"Downloading {GEOFABRIK_GB_URL}")
with (
tqdm(unit="B", unit_scale=True, desc="Downloading") as bar,
urllib.request.urlopen(GEOFABRIK_GB_URL) as resp,
open(tmp, "wb") as f,
):
length = resp.headers.get("Content-Length")
if length:
bar.total = int(length)
while chunk := resp.read(1 << 20):
f.write(chunk)
bar.update(len(chunk))
tmp.rename(GB_PBF_FILE)
print(f"Saved to {GB_PBF_FILE}")
class POIHandler(osmium.SimpleHandler):
"""Streams OSM data, filters to UK bbox, extracts matching POIs in batches."""
def __init__(self, progress: tqdm, tmp_dir: Path) -> None:
super().__init__()
self._batch: list[dict] = []
self._tmp_dir = tmp_dir
self._batch_num = 0
self.poi_count = 0
self._progress = progress
def _in_uk(self, lat: float, lon: float) -> bool:
return (
UK_BBOX_SOUTH <= lat <= UK_BBOX_NORTH
and UK_BBOX_WEST <= lon <= UK_BBOX_EAST
)
def _match_tags(self, tags: osmium.osm.TagList) -> str | None:
parts = [tags[key] for key in POI_TAG_KEYS if key in tags]
return " / ".join(parts) if parts else None
def _get_name(self, tags: osmium.osm.TagList) -> str:
return tags.get("name:en", tags.get("name", ""))
def _tags_to_json(self, tags: osmium.osm.TagList) -> str:
return json.dumps({tag.k: tag.v for tag in tags})
def _flush_batch(self) -> None:
if not self._batch:
return
df = pl.DataFrame(self._batch)
out = self._tmp_dir / f"batch_{self._batch_num:05d}.parquet"
df.write_parquet(out)
self._batch_num += 1
self._batch.clear()
def _add_poi(
self, osm_id: str, tags: osmium.osm.TagList, category: str, lat: float, lng: float
) -> None:
self._batch.append(
{
"id": osm_id,
"name": self._get_name(tags),
"category": category,
"lat": lat,
"lng": lng,
"osm_tags": self._tags_to_json(tags),
}
)
self.poi_count += 1
self._progress.set_postfix(pois=f"{self.poi_count:,}", refresh=False)
if len(self._batch) >= BATCH_SIZE:
self._flush_batch()
def _tick(self) -> None:
self._progress.update(1)
def node(self, n: osmium.osm.Node) -> None:
self._tick()
if not n.location.valid:
return
lat, lon = n.location.lat, n.location.lon
if not self._in_uk(lat, lon):
return
category = self._match_tags(n.tags)
if category:
self._add_poi(f"n{n.id}", n.tags, category, lat, lon)
def way(self, w: osmium.osm.Way) -> None:
self._tick()
category = self._match_tags(w.tags)
if not category:
return
lats = []
lons = []
for node in w.nodes:
try:
lats.append(node.location.lat)
lons.append(node.location.lon)
except osmium.InvalidLocationError:
continue
if not lats:
return
centroid_lat = sum(lats) / len(lats)
centroid_lng = sum(lons) / len(lons)
if not self._in_uk(centroid_lat, centroid_lng):
return
self._add_poi(f"w{w.id}", w.tags, category, centroid_lat, centroid_lng)
def main() -> None:
if not GB_PBF_FILE.exists():
download_pbf()
print(f"=== POI Extraction from {GB_PBF_FILE} ===")
print(
f"UK bbox: ({UK_BBOX_WEST}, {UK_BBOX_SOUTH}, {UK_BBOX_EAST}, {UK_BBOX_NORTH})"
)
print(f"Tag keys: {POI_TAG_KEYS}")
print()
if OUTPUT_FILE.exists():
print("POIs are up-to-date")
return
tmp_dir = Path(mkdtemp(prefix="pois_"))
try:
with tqdm(
unit=" elements",
unit_scale=True,
desc="Streaming",
smoothing=0.05,
mininterval=1.0,
) as progress:
handler = POIHandler(progress, tmp_dir)
handler.apply_file(str(GB_PBF_FILE), locations=True)
handler._flush_batch() # write any remaining POIs
print(f"Extracted {handler.poi_count:,} POIs")
batch_files = sorted(tmp_dir.glob("batch_*.parquet"))
if not batch_files:
print("No POIs found.")
return
df = pl.concat([pl.scan_parquet(f) for f in batch_files])
OUTPUT_FILE.parent.mkdir(parents=True, exist_ok=True)
df.sink_parquet(OUTPUT_FILE)
print(f"Saved to {OUTPUT_FILE}")
print("\n=== Summary ===")
print(f"Total POIs: {handler.poi_count:,}")
print("\nPOIs by category:")
category_counts = (
df.group_by("category")
.agg(pl.len().alias("count"))
.sort("count", descending=True)
.collect()
)
for row in category_counts.iter_rows(named=True):
print(f" {row['category']}: {row['count']:,}")
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)
if __name__ == "__main__":
main()