perfect-postcode/pipeline/transform/price_estimation/shrinkage.py
2026-02-15 22:39:54 +00:00

140 lines
4.5 KiB
Python

"""Hierarchical shrinkage and spatial smoothing for sector-level estimates."""
from typing import Callable, TypeVar
import numpy as np
from scipy.spatial import KDTree
from pipeline.transform.price_estimation.utils import SHRINKAGE_K
V = TypeVar("V")
SPATIAL_NEIGHBORS = 5
SPATIAL_BLEND_K = 30
def shrink_dicts(raw: dict, parent: dict, n: int) -> dict:
"""Shrink dict values toward parent using n/(n+k) weighting.
Works for any dict keyed by year or category.
"""
w = n / (n + SHRINKAGE_K)
result = {}
for key in set(raw) | set(parent):
r = raw.get(key, parent.get(key, 0.0))
p = parent.get(key, raw.get(key, 0.0))
result[key] = w * r + (1 - w) * p
return result
def hierarchical_shrinkage(
sector_vals: dict[str, V],
sector_n: dict[str, int],
district_vals: dict[str, V],
district_n: dict[str, int],
area_vals: dict[str, V],
area_n: dict[str, int],
top_level: V,
all_sectors: list[str],
sector_to_dist: dict[str, str],
dist_to_area: dict[str, str],
shrink_fn: Callable[[V, V, int], V],
) -> dict[str, V]:
"""Top-down hierarchical shrinkage: area->top, district->area, sector->district.
`top_level` is the ultimate fallback value (e.g. national shrunk toward hedonic,
or just national). `shrink_fn(raw, parent, n)` blends raw toward parent.
"""
# Area -> top level
area_shrunk = {}
for area, val in area_vals.items():
area_shrunk[area] = shrink_fn(val, top_level, area_n[area])
# District -> area
district_shrunk = {}
for dist, val in district_vals.items():
a = dist_to_area.get(dist, "")
parent = area_shrunk.get(a, top_level)
district_shrunk[dist] = shrink_fn(val, parent, district_n[dist])
# Sector -> district
sector_shrunk = {}
for sec, val in sector_vals.items():
d = sector_to_dist.get(sec, "")
parent = district_shrunk.get(d, top_level)
sector_shrunk[sec] = shrink_fn(val, parent, sector_n[sec])
# Fill sectors without their own values
for sec in all_sectors:
if sec not in sector_shrunk:
d = sector_to_dist.get(sec, "")
a = dist_to_area.get(d, "")
sector_shrunk[sec] = district_shrunk.get(d, area_shrunk.get(a, top_level))
return sector_shrunk
def spatial_smooth(
sector_values: dict[str, V],
centroids: dict[str, tuple[float, float]],
counts: dict[str, int],
blend_fn: Callable[[V, list[V], float, list[float]], V],
) -> dict[str, V]:
"""Blend sparse sector values with K nearest neighbors via KDTree."""
sectors_with_coords = [s for s in sector_values if s in centroids]
if len(sectors_with_coords) < SPATIAL_NEIGHBORS + 1:
return sector_values
coords = np.array([centroids[s] for s in sectors_with_coords])
# Scale longitude by cos(mean_lat) for approximate Euclidean distance
mean_lat = np.mean(coords[:, 0])
scale = np.cos(np.radians(mean_lat))
scaled_coords = np.column_stack([coords[:, 0], coords[:, 1] * scale])
tree = KDTree(scaled_coords)
result = dict(sector_values)
for i, sec in enumerate(sectors_with_coords):
n = counts.get(sec, 0)
self_w = n / (n + SPATIAL_BLEND_K)
if self_w > 0.95:
continue # enough data, skip smoothing
dists, idxs = tree.query(scaled_coords[i], k=SPATIAL_NEIGHBORS + 1)
# Skip self (index 0, distance ~0)
neighbor_dists = dists[1:]
neighbor_idxs = idxs[1:]
inv_dists = []
neighbor_vals = []
for d, j in zip(neighbor_dists, neighbor_idxs):
ns = sectors_with_coords[j]
if d > 0 and ns in sector_values:
inv_dists.append(1.0 / d)
neighbor_vals.append(sector_values[ns])
if not neighbor_vals:
continue
total_inv = sum(inv_dists)
nbr_w = 1.0 - self_w
neighbor_ws = [iw / total_inv * nbr_w for iw in inv_dists]
result[sec] = blend_fn(sector_values[sec], neighbor_vals, self_w, neighbor_ws)
return result
def blend_dicts(
self_val: dict, neighbor_vals: list[dict], self_w: float, neighbor_ws: list[float]
) -> dict:
"""Blend dict values by weighted sum across all keys."""
all_keys: set = set(self_val)
for nv in neighbor_vals:
all_keys |= set(nv)
result = {}
for k in all_keys:
val = self_w * self_val.get(k, 0.0)
for nv, w in zip(neighbor_vals, neighbor_ws):
val += w * nv.get(k, 0.0)
result[k] = val
return result