This commit is contained in:
Andras Schmelczer 2026-03-12 22:11:00 +00:00
parent 14a3555cf1
commit 7e92bf112e
34 changed files with 1214437 additions and 224 deletions

View file

@ -1,4 +1,4 @@
"""Count POIs within a radius of properties, optimized via postcode deduplication."""
"""Count POIs within a radius of properties, optimised via postcode deduplication."""
import numpy as np
import polars as pl
@ -6,6 +6,49 @@ import polars as pl
from .haversine import haversine_km
def _build_poi_grid(
pois: pl.DataFrame, grid_size: float = 0.05
) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict]:
"""Build spatial grid index for POIs. Returns (lats, lngs, cats, grid_dict)."""
poi_lats = pois["lat"].to_numpy()
poi_lngs = pois["lng"].to_numpy()
poi_cats = pois["category"].to_numpy()
poi_grid_lats = np.floor(poi_lats / grid_size).astype(np.int32)
poi_grid_lngs = np.floor(poi_lngs / grid_size).astype(np.int32)
poi_grid: dict[tuple[int, int], list[int]] = {}
for i in range(len(pois)):
key = (poi_grid_lats[i], poi_grid_lngs[i])
if key not in poi_grid:
poi_grid[key] = []
poi_grid[key].append(i)
for key in poi_grid:
poi_grid[key] = np.array(poi_grid[key], dtype=np.int32)
return poi_lats, poi_lngs, poi_cats, poi_grid
def _get_nearby_indices(
pc_lat: float, pc_lon: float, poi_grid: dict, grid_size: float = 0.05
) -> np.ndarray | None:
"""Get POI indices from grid cells near the given coordinate."""
grid_lat = int(np.floor(pc_lat / grid_size))
grid_lng = int(np.floor(pc_lon / grid_size))
nearby_indices = []
for dlat in [-1, 0, 1]:
for dlng in [-1, 0, 1]:
cell_key = (grid_lat + dlat, grid_lng + dlng)
if cell_key in poi_grid:
nearby_indices.append(poi_grid[cell_key])
if not nearby_indices:
return None
return np.concatenate(nearby_indices)
def count_pois_per_postcode(
postcodes_df: pl.DataFrame,
pois: pl.DataFrame,
@ -22,31 +65,9 @@ def count_pois_per_postcode(
n_pois = len(pois)
print(f" {n_postcodes:,} postcodes, {n_pois:,} POIs")
# Build spatial grid for POIs (0.05 degree cells ~5.5km)
grid_size = 0.05
print(" Building POI spatial grid...")
# Convert to numpy arrays
poi_lats = pois["lat"].to_numpy()
poi_lngs = pois["lng"].to_numpy()
poi_cats = pois["category"].to_numpy()
# Compute grid coordinates for all POIs
poi_grid_lats = np.floor(poi_lats / grid_size).astype(np.int32)
poi_grid_lngs = np.floor(poi_lngs / grid_size).astype(np.int32)
# Build grid cell lookup using numpy indexing
poi_grid = {}
for i in range(n_pois):
key = (poi_grid_lats[i], poi_grid_lngs[i])
if key not in poi_grid:
poi_grid[key] = []
poi_grid[key].append(i)
# Convert grid values to numpy arrays for faster indexing
for key in poi_grid:
poi_grid[key] = np.array(poi_grid[key], dtype=np.int32)
poi_lats, poi_lngs, poi_cats, poi_grid = _build_poi_grid(pois, grid_size)
print(f" POI grid has {len(poi_grid):,} occupied cells")
# Pre-compute category masks
@ -81,38 +102,18 @@ def count_pois_per_postcode(
# Process batch
for i in range(start_idx, end_idx):
pc_lat = pc_lats[i]
pc_lon = pc_lons[i]
# Find grid cells to check (3x3 grid)
grid_lat = int(np.floor(pc_lat / grid_size))
grid_lng = int(np.floor(pc_lon / grid_size))
# Collect nearby POI indices
nearby_indices = []
for dlat in [-1, 0, 1]:
for dlng in [-1, 0, 1]:
cell_key = (grid_lat + dlat, grid_lng + dlng)
if cell_key in poi_grid:
nearby_indices.append(poi_grid[cell_key])
if not nearby_indices:
nearby = _get_nearby_indices(pc_lats[i], pc_lons[i], poi_grid, grid_size)
if nearby is None:
continue
# Concatenate all nearby POI indices
nearby = np.concatenate(nearby_indices)
distances = haversine_km(poi_lats[nearby], poi_lngs[nearby], pc_lats[i], pc_lons[i])
# Vectorized distance calculation for all nearby POIs
distances = haversine_km(poi_lats[nearby], poi_lngs[nearby], pc_lat, pc_lon)
# Filter by radius
within_mask = distances <= radius_km
within_indices = nearby[within_mask]
if len(within_indices) == 0:
continue
# Count by category group using pre-computed masks
for group, cat_mask in category_masks.items():
result_counts[group][i] = cat_mask[within_indices].sum()
@ -124,3 +125,71 @@ def count_pois_per_postcode(
result = pl.DataFrame(result_data)
print(" Completed POI counting")
return result
def min_distance_per_postcode(
postcodes_df: pl.DataFrame,
pois: pl.DataFrame,
groups: dict[str, list[str]],
) -> pl.DataFrame:
"""
For each postcode, compute the distance (km) to the closest POI per group.
Returns NaN where no POI of that group exists within the grid search range (~5.5km).
"""
print("Computing minimum POI distances per postcode...")
n_postcodes = len(postcodes_df)
n_pois = len(pois)
print(f" {n_postcodes:,} postcodes, {n_pois:,} POIs")
grid_size = 0.05
print(" Building POI spatial grid...")
poi_lats, poi_lngs, poi_cats, poi_grid = _build_poi_grid(pois, grid_size)
print(f" POI grid has {len(poi_grid):,} occupied cells")
category_masks = {}
for group, categories in groups.items():
mask = np.isin(poi_cats, categories)
category_masks[group] = mask
print(f" {group}: {mask.sum():,} POIs")
pc_lats = postcodes_df["lat"].to_numpy()
pc_lons = postcodes_df["lon"].to_numpy()
pc_codes = postcodes_df["postcode"].to_list()
result_min_dist = {
group: np.full(n_postcodes, np.nan, dtype=np.float32) for group in groups
}
batch_size = 50000
n_batches = (n_postcodes + batch_size - 1) // batch_size
print(f" Processing {n_postcodes:,} postcodes in {n_batches} batches...")
for batch_idx in range(n_batches):
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, n_postcodes)
if batch_idx % 5 == 0:
print(
f" Batch {batch_idx + 1}/{n_batches}: postcodes {start_idx:,} - {end_idx:,}"
)
for i in range(start_idx, end_idx):
nearby = _get_nearby_indices(pc_lats[i], pc_lons[i], poi_grid, grid_size)
if nearby is None:
continue
distances = haversine_km(poi_lats[nearby], poi_lngs[nearby], pc_lats[i], pc_lons[i])
for group, cat_mask in category_masks.items():
group_mask = cat_mask[nearby]
if group_mask.any():
result_min_dist[group][i] = distances[group_mask].min()
result_data = {"postcode": pc_codes}
for group in groups:
result_data[f"{group}_nearest_km"] = result_min_dist[group]
result = pl.DataFrame(result_data)
print(" Completed minimum distance computation")
return result