171 lines
4.8 KiB
Python
171 lines
4.8 KiB
Python
import argparse
|
|
import tempfile
|
|
import urllib.request
|
|
from pathlib import Path
|
|
from tempfile import mkdtemp
|
|
|
|
import osmium
|
|
import polars as pl
|
|
from tqdm import tqdm
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
BATCH_SIZE = 50_000
|
|
|
|
MIN_OCCURENCE_COUNT = 20
|
|
|
|
GEOFABRIK_GB_URL = (
|
|
"https://download.geofabrik.de/europe/great-britain-latest.osm.pbf"
|
|
)
|
|
|
|
UK_BBOX_WEST = -7.57
|
|
UK_BBOX_SOUTH = 49.96
|
|
UK_BBOX_EAST = 1.68
|
|
UK_BBOX_NORTH = 58.64
|
|
|
|
POI_TAG_KEYS: list[str] = [
|
|
"amenity",
|
|
"building",
|
|
"craft",
|
|
"emergency",
|
|
"healthcare",
|
|
"leisure",
|
|
"office",
|
|
"shop",
|
|
"tourism",
|
|
"public_transport",
|
|
]
|
|
|
|
|
|
|
|
def download_pbf(pbf_file: Path) -> None:
|
|
pbf_file.parent.mkdir(parents=True, exist_ok=True)
|
|
tmp = pbf_file.with_suffix(".pbf.tmp")
|
|
print(f"Downloading {GEOFABRIK_GB_URL}")
|
|
|
|
with (
|
|
tqdm(unit="B", unit_scale=True, desc="Downloading") as bar,
|
|
urllib.request.urlopen(GEOFABRIK_GB_URL) as resp,
|
|
open(tmp, "wb") as f,
|
|
):
|
|
length = resp.headers.get("Content-Length")
|
|
if length:
|
|
bar.total = int(length)
|
|
while chunk := resp.read(1 << 20):
|
|
f.write(chunk)
|
|
bar.update(len(chunk))
|
|
|
|
tmp.rename(pbf_file)
|
|
print(f"Saved to {pbf_file}")
|
|
|
|
|
|
class POIHandler(osmium.SimpleHandler):
|
|
def __init__(self, progress: tqdm, tmp_dir: Path) -> None:
|
|
super().__init__()
|
|
self._batch: list[dict] = []
|
|
self._tmp_dir = tmp_dir
|
|
self._batch_num = 0
|
|
self.poi_count = 0
|
|
self._progress = progress
|
|
|
|
def _in_uk(self, lat: float, lon: float) -> bool:
|
|
return (
|
|
UK_BBOX_SOUTH <= lat <= UK_BBOX_NORTH
|
|
and UK_BBOX_WEST <= lon <= UK_BBOX_EAST
|
|
)
|
|
|
|
def _match_tags(self, tags: osmium.osm.TagList) -> list[str]:
|
|
return [f"{key}/{tags[key]}" for key in POI_TAG_KEYS if key in tags]
|
|
|
|
def _get_name(self, tags: osmium.osm.TagList) -> str:
|
|
return tags.get("name:en", tags.get("name", ""))
|
|
|
|
def _flush_batch(self) -> None:
|
|
if not self._batch:
|
|
return
|
|
df = pl.DataFrame(self._batch)
|
|
out = self._tmp_dir / f"batch_{self._batch_num:05d}.parquet"
|
|
df.write_parquet(out)
|
|
self._batch_num += 1
|
|
self._batch.clear()
|
|
|
|
def _add_poi(
|
|
self, osm_id: str, tags: osmium.osm.TagList, category: str, lat: float, lng: float
|
|
) -> None:
|
|
self._batch.append(
|
|
{
|
|
"id": osm_id,
|
|
"name": self._get_name(tags),
|
|
"category": category,
|
|
"lat": lat,
|
|
"lng": lng,
|
|
}
|
|
)
|
|
self.poi_count += 1
|
|
self._progress.set_postfix(pois=f"{self.poi_count:,}", refresh=False)
|
|
if len(self._batch) >= BATCH_SIZE:
|
|
self._flush_batch()
|
|
|
|
def _tick(self) -> None:
|
|
self._progress.update(1)
|
|
|
|
def node(self, n: osmium.osm.Node) -> None:
|
|
self._tick()
|
|
if not n.location.valid:
|
|
return
|
|
lat, lon = n.location.lat, n.location.lon
|
|
if not self._in_uk(lat, lon):
|
|
return
|
|
categories = self._match_tags(n.tags)
|
|
for category in categories:
|
|
self._add_poi(f"n{n.id}", n.tags, category, lat, lon)
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Download and extract POIs from OpenStreetMap")
|
|
parser.add_argument("--output", type=Path, required=True, help="Output parquet file path")
|
|
args = parser.parse_args()
|
|
|
|
with tempfile.TemporaryDirectory() as cache_dir:
|
|
pbf_file = Path(cache_dir) / "great-britain-latest.osm.pbf"
|
|
|
|
if not pbf_file.exists():
|
|
download_pbf(pbf_file)
|
|
else:
|
|
print(f"Using cached PBF file at {pbf_file}")
|
|
|
|
print(f"Tag keys: {POI_TAG_KEYS}")
|
|
|
|
tmp_dir = Path(mkdtemp(prefix="pois_"))
|
|
with tqdm(
|
|
unit=" elements",
|
|
unit_scale=True,
|
|
desc="Streaming",
|
|
smoothing=0.05,
|
|
mininterval=1.0,
|
|
) as progress:
|
|
handler = POIHandler(progress, tmp_dir)
|
|
handler.apply_file(str(pbf_file), locations=True)
|
|
handler._flush_batch() # write any remaining POIs
|
|
|
|
print(f"Extracted {handler.poi_count:,} POIs")
|
|
|
|
batch_files = sorted(tmp_dir.glob("batch_*.parquet"))
|
|
df = pl.concat([pl.scan_parquet(f) for f in batch_files])
|
|
|
|
# Only keep categories with enough occurrences
|
|
valid_categories = (
|
|
df.group_by("category")
|
|
.agg(pl.len().alias("count"))
|
|
.filter(pl.col("count") >= MIN_OCCURENCE_COUNT)
|
|
)
|
|
df = df.join(valid_categories.select("category"), on="category", how="semi")
|
|
|
|
print(f"Total POIs: {handler.poi_count:,}")
|
|
df.sink_parquet(args.output)
|
|
print(f"Saved to {args.output}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|