POI vibe coding
This commit is contained in:
parent
d4fe881ef4
commit
5c39f31283
6 changed files with 5137 additions and 198 deletions
0
pipeline/download/pois/__init__.py
Normal file
0
pipeline/download/pois/__init__.py
Normal file
198
pipeline/download/pois/__main__.py
Normal file
198
pipeline/download/pois/__main__.py
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
import json
|
||||
import shutil
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
from tempfile import mkdtemp
|
||||
|
||||
import osmium
|
||||
import polars as pl
|
||||
from tqdm import tqdm
|
||||
|
||||
from .config import (
|
||||
GB_PBF_FILE,
|
||||
GEOFABRIK_GB_URL,
|
||||
OUTPUT_FILE,
|
||||
POI_TAG_KEYS,
|
||||
UK_BBOX_EAST,
|
||||
UK_BBOX_NORTH,
|
||||
UK_BBOX_SOUTH,
|
||||
UK_BBOX_WEST,
|
||||
)
|
||||
|
||||
BATCH_SIZE = 50_000
|
||||
|
||||
|
||||
def download_pbf() -> None:
|
||||
GB_PBF_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = GB_PBF_FILE.with_suffix(".pbf.tmp")
|
||||
print(f"Downloading {GEOFABRIK_GB_URL}")
|
||||
|
||||
with (
|
||||
tqdm(unit="B", unit_scale=True, desc="Downloading") as bar,
|
||||
urllib.request.urlopen(GEOFABRIK_GB_URL) as resp,
|
||||
open(tmp, "wb") as f,
|
||||
):
|
||||
length = resp.headers.get("Content-Length")
|
||||
if length:
|
||||
bar.total = int(length)
|
||||
while chunk := resp.read(1 << 20):
|
||||
f.write(chunk)
|
||||
bar.update(len(chunk))
|
||||
|
||||
tmp.rename(GB_PBF_FILE)
|
||||
print(f"Saved to {GB_PBF_FILE}")
|
||||
|
||||
|
||||
class POIHandler(osmium.SimpleHandler):
|
||||
"""Streams OSM data, filters to UK bbox, extracts matching POIs in batches."""
|
||||
|
||||
def __init__(self, progress: tqdm, tmp_dir: Path) -> None:
|
||||
super().__init__()
|
||||
self._batch: list[dict] = []
|
||||
self._tmp_dir = tmp_dir
|
||||
self._batch_num = 0
|
||||
self.poi_count = 0
|
||||
self._progress = progress
|
||||
|
||||
def _in_uk(self, lat: float, lon: float) -> bool:
|
||||
return (
|
||||
UK_BBOX_SOUTH <= lat <= UK_BBOX_NORTH
|
||||
and UK_BBOX_WEST <= lon <= UK_BBOX_EAST
|
||||
)
|
||||
|
||||
def _match_tags(self, tags: osmium.osm.TagList) -> str | None:
|
||||
parts = [tags[key] for key in POI_TAG_KEYS if key in tags]
|
||||
return " / ".join(parts) if parts else None
|
||||
|
||||
def _get_name(self, tags: osmium.osm.TagList) -> str:
|
||||
return tags.get("name:en", tags.get("name", ""))
|
||||
|
||||
def _tags_to_json(self, tags: osmium.osm.TagList) -> str:
|
||||
return json.dumps({tag.k: tag.v for tag in tags})
|
||||
|
||||
def _flush_batch(self) -> None:
|
||||
if not self._batch:
|
||||
return
|
||||
df = pl.DataFrame(self._batch)
|
||||
out = self._tmp_dir / f"batch_{self._batch_num:05d}.parquet"
|
||||
df.write_parquet(out)
|
||||
self._batch_num += 1
|
||||
self._batch.clear()
|
||||
|
||||
def _add_poi(
|
||||
self, osm_id: str, tags: osmium.osm.TagList, category: str, lat: float, lng: float
|
||||
) -> None:
|
||||
self._batch.append(
|
||||
{
|
||||
"id": osm_id,
|
||||
"name": self._get_name(tags),
|
||||
"category": category,
|
||||
"lat": lat,
|
||||
"lng": lng,
|
||||
"osm_tags": self._tags_to_json(tags),
|
||||
}
|
||||
)
|
||||
self.poi_count += 1
|
||||
self._progress.set_postfix(pois=f"{self.poi_count:,}", refresh=False)
|
||||
if len(self._batch) >= BATCH_SIZE:
|
||||
self._flush_batch()
|
||||
|
||||
def _tick(self) -> None:
|
||||
self._progress.update(1)
|
||||
|
||||
def node(self, n: osmium.osm.Node) -> None:
|
||||
self._tick()
|
||||
if not n.location.valid:
|
||||
return
|
||||
lat, lon = n.location.lat, n.location.lon
|
||||
if not self._in_uk(lat, lon):
|
||||
return
|
||||
category = self._match_tags(n.tags)
|
||||
if category:
|
||||
self._add_poi(f"n{n.id}", n.tags, category, lat, lon)
|
||||
|
||||
def way(self, w: osmium.osm.Way) -> None:
|
||||
self._tick()
|
||||
category = self._match_tags(w.tags)
|
||||
if not category:
|
||||
return
|
||||
|
||||
lats = []
|
||||
lons = []
|
||||
for node in w.nodes:
|
||||
try:
|
||||
lats.append(node.location.lat)
|
||||
lons.append(node.location.lon)
|
||||
except osmium.InvalidLocationError:
|
||||
continue
|
||||
|
||||
if not lats:
|
||||
return
|
||||
|
||||
centroid_lat = sum(lats) / len(lats)
|
||||
centroid_lng = sum(lons) / len(lons)
|
||||
|
||||
if not self._in_uk(centroid_lat, centroid_lng):
|
||||
return
|
||||
|
||||
self._add_poi(f"w{w.id}", w.tags, category, centroid_lat, centroid_lng)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if not GB_PBF_FILE.exists():
|
||||
download_pbf()
|
||||
|
||||
print(f"=== POI Extraction from {GB_PBF_FILE} ===")
|
||||
print(
|
||||
f"UK bbox: ({UK_BBOX_WEST}, {UK_BBOX_SOUTH}, {UK_BBOX_EAST}, {UK_BBOX_NORTH})"
|
||||
)
|
||||
print(f"Tag keys: {POI_TAG_KEYS}")
|
||||
print()
|
||||
|
||||
if OUTPUT_FILE.exists():
|
||||
print("POIs are up-to-date")
|
||||
return
|
||||
|
||||
tmp_dir = Path(mkdtemp(prefix="pois_"))
|
||||
try:
|
||||
with tqdm(
|
||||
unit=" elements",
|
||||
unit_scale=True,
|
||||
desc="Streaming",
|
||||
smoothing=0.05,
|
||||
mininterval=1.0,
|
||||
) as progress:
|
||||
handler = POIHandler(progress, tmp_dir)
|
||||
handler.apply_file(str(GB_PBF_FILE), locations=True)
|
||||
handler._flush_batch() # write any remaining POIs
|
||||
|
||||
print(f"Extracted {handler.poi_count:,} POIs")
|
||||
|
||||
batch_files = sorted(tmp_dir.glob("batch_*.parquet"))
|
||||
if not batch_files:
|
||||
print("No POIs found.")
|
||||
return
|
||||
|
||||
df = pl.concat([pl.scan_parquet(f) for f in batch_files])
|
||||
|
||||
OUTPUT_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.sink_parquet(OUTPUT_FILE)
|
||||
print(f"Saved to {OUTPUT_FILE}")
|
||||
|
||||
print("\n=== Summary ===")
|
||||
print(f"Total POIs: {handler.poi_count:,}")
|
||||
print("\nPOIs by category:")
|
||||
category_counts = (
|
||||
df.group_by("category")
|
||||
.agg(pl.len().alias("count"))
|
||||
.sort("count", descending=True)
|
||||
.collect()
|
||||
)
|
||||
for row in category_counts.iter_rows(named=True):
|
||||
print(f" {row['category']}: {row['count']:,}")
|
||||
finally:
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
34
pipeline/download/pois/config.py
Normal file
34
pipeline/download/pois/config.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
from pathlib import Path
|
||||
|
||||
DATA_DIR = Path("./data_sources")
|
||||
GB_PBF_FILE = DATA_DIR / "great-britain-latest.osm.pbf"
|
||||
OUTPUT_FILE = DATA_DIR / "uk_pois.parquet"
|
||||
|
||||
GEOFABRIK_GB_URL = (
|
||||
"https://download.geofabrik.de/europe/great-britain-latest.osm.pbf"
|
||||
)
|
||||
|
||||
# UK bounding box (west, south, east, north) — used for way centroid filtering
|
||||
UK_BBOX_WEST = -7.57
|
||||
UK_BBOX_SOUTH = 49.96
|
||||
UK_BBOX_EAST = 1.68
|
||||
UK_BBOX_NORTH = 58.64
|
||||
|
||||
# OSM tag keys that indicate a POI. Any element with one of these keys is kept,
|
||||
# regardless of the specific value. When multiple keys match, their values are
|
||||
# concatenated with " / ".
|
||||
POI_TAG_KEYS: list[str] = [
|
||||
"amenity",
|
||||
"shop",
|
||||
"leisure",
|
||||
"tourism",
|
||||
"railway",
|
||||
"aeroway",
|
||||
"highway",
|
||||
"public_transport",
|
||||
"station",
|
||||
"building",
|
||||
"military",
|
||||
"emergency",
|
||||
"healthcare",
|
||||
]
|
||||
Loading…
Add table
Add a link
Reference in a new issue