perfect-postcode/pipeline/download/pois.py

171 lines
4.8 KiB
Python

import argparse
import tempfile
import urllib.request
from pathlib import Path
from tempfile import mkdtemp
import osmium
import polars as pl
from tqdm import tqdm
from pathlib import Path
BATCH_SIZE = 50_000
MIN_OCCURENCE_COUNT = 20
GEOFABRIK_GB_URL = (
"https://download.geofabrik.de/europe/great-britain-latest.osm.pbf"
)
UK_BBOX_WEST = -7.57
UK_BBOX_SOUTH = 49.96
UK_BBOX_EAST = 1.68
UK_BBOX_NORTH = 58.64
POI_TAG_KEYS: list[str] = [
"amenity",
"building",
"craft",
"emergency",
"healthcare",
"leisure",
"office",
"shop",
"tourism",
"public_transport",
]
def download_pbf(pbf_file: Path) -> None:
pbf_file.parent.mkdir(parents=True, exist_ok=True)
tmp = pbf_file.with_suffix(".pbf.tmp")
print(f"Downloading {GEOFABRIK_GB_URL}")
with (
tqdm(unit="B", unit_scale=True, desc="Downloading") as bar,
urllib.request.urlopen(GEOFABRIK_GB_URL) as resp,
open(tmp, "wb") as f,
):
length = resp.headers.get("Content-Length")
if length:
bar.total = int(length)
while chunk := resp.read(1 << 20):
f.write(chunk)
bar.update(len(chunk))
tmp.rename(pbf_file)
print(f"Saved to {pbf_file}")
class POIHandler(osmium.SimpleHandler):
def __init__(self, progress: tqdm, tmp_dir: Path) -> None:
super().__init__()
self._batch: list[dict] = []
self._tmp_dir = tmp_dir
self._batch_num = 0
self.poi_count = 0
self._progress = progress
def _in_uk(self, lat: float, lon: float) -> bool:
return (
UK_BBOX_SOUTH <= lat <= UK_BBOX_NORTH
and UK_BBOX_WEST <= lon <= UK_BBOX_EAST
)
def _match_tags(self, tags: osmium.osm.TagList) -> list[str]:
return [f"{key}/{tags[key]}" for key in POI_TAG_KEYS if key in tags]
def _get_name(self, tags: osmium.osm.TagList) -> str:
return tags.get("name:en", tags.get("name", ""))
def _flush_batch(self) -> None:
if not self._batch:
return
df = pl.DataFrame(self._batch)
out = self._tmp_dir / f"batch_{self._batch_num:05d}.parquet"
df.write_parquet(out)
self._batch_num += 1
self._batch.clear()
def _add_poi(
self, osm_id: str, tags: osmium.osm.TagList, category: str, lat: float, lng: float
) -> None:
self._batch.append(
{
"id": osm_id,
"name": self._get_name(tags),
"category": category,
"lat": lat,
"lng": lng,
}
)
self.poi_count += 1
self._progress.set_postfix(pois=f"{self.poi_count:,}", refresh=False)
if len(self._batch) >= BATCH_SIZE:
self._flush_batch()
def _tick(self) -> None:
self._progress.update(1)
def node(self, n: osmium.osm.Node) -> None:
self._tick()
if not n.location.valid:
return
lat, lon = n.location.lat, n.location.lon
if not self._in_uk(lat, lon):
return
categories = self._match_tags(n.tags)
for category in categories:
self._add_poi(f"n{n.id}", n.tags, category, lat, lon)
def main() -> None:
parser = argparse.ArgumentParser(description="Download and extract POIs from OpenStreetMap")
parser.add_argument("--output", type=Path, required=True, help="Output parquet file path")
args = parser.parse_args()
with tempfile.TemporaryDirectory() as cache_dir:
pbf_file = Path(cache_dir) / "great-britain-latest.osm.pbf"
if not pbf_file.exists():
download_pbf(pbf_file)
else:
print(f"Using cached PBF file at {pbf_file}")
print(f"Tag keys: {POI_TAG_KEYS}")
tmp_dir = Path(mkdtemp(prefix="pois_"))
with tqdm(
unit=" elements",
unit_scale=True,
desc="Streaming",
smoothing=0.05,
mininterval=1.0,
) as progress:
handler = POIHandler(progress, tmp_dir)
handler.apply_file(str(pbf_file), locations=True)
handler._flush_batch() # write any remaining POIs
print(f"Extracted {handler.poi_count:,} POIs")
batch_files = sorted(tmp_dir.glob("batch_*.parquet"))
df = pl.concat([pl.scan_parquet(f) for f in batch_files])
# Only keep categories with enough occurrences
valid_categories = (
df.group_by("category")
.agg(pl.len().alias("count"))
.filter(pl.col("count") >= MIN_OCCURENCE_COUNT)
)
df = df.join(valid_categories.select("category"), on="category", how="semi")
print(f"Total POIs: {handler.poi_count:,}")
df.sink_parquet(args.output)
print(f"Saved to {args.output}")
if __name__ == "__main__":
main()