perfect-postcode/pipeline/transform/price_estimation/shrinkage.py
2026-06-02 20:14:32 +01:00

215 lines
8.1 KiB
Python

"""Hierarchical shrinkage and spatial smoothing for sector-level estimates."""
from typing import Callable, TypeVar
import numpy as np
from scipy.spatial import KDTree
from pipeline.transform.price_estimation.utils import SHRINKAGE_K
V = TypeVar("V")
SPATIAL_NEIGHBORS = 5
SPATIAL_BLEND_K = 30
def _base_value(index: dict[int, float], base_year: int) -> float:
"""Value of an index dict at `base_year`, with forward/back-fill for gaps.
Each repeat-sales dict is anchored to 0 at its OWN earliest year, so its
values are log-levels relative to that origin. To express it on a common
origin we need its value at the shared `base_year`:
- exact hit: use it directly;
- base_year before the dict's history: back-fill, i.e. the earliest known
value (which is 0.0 by construction). We cannot observe the level move
between the global base and a later-starting cell, so we assume none,
matching forward_fill's back-fill convention;
- base_year inside a gap / after history: forward-fill the most recent
prior value.
"""
if base_year in index:
return index[base_year]
years = sorted(index)
if not years or base_year < years[0]:
return index[years[0]] if years else 0.0
prior = [y for y in years if y <= base_year]
return index[prior[-1]]
def lift_onto_parent(
child: dict[int, float], parent: dict[int, float]
) -> dict[int, float]:
"""Lift a child index onto its parent's base before blending the two.
solve_robust_index anchors every cell to log-index 0 at its OWN earliest
year, so a cell with a shorter history sits on a later origin than its
(wider) parent. Combining them key-by-key would average level-incompatible
numbers (a sector measured from 2008 blended with a district measured from
1996). We add the parent's accumulated level at the child's first year, so
``child[start] == parent[start]``: the child's own year-to-year moves are
layered on top of the parent's growth up to that point -- the same
assumption shrinkage already makes for years the child lacks.
Re-basing on each cell's OWN earliest year (rather than the global base,
which the child cannot observe) is what makes this effective: subtracting
the child's value at the global base is always 0 and changes nothing.
The shift is a single constant added to every year of the child, so the
child's own year-to-year differences are preserved. PRECONDITION for the
downstream estimate to be unaffected within the child's range: the parent's
year coverage must be a superset of the child's. This holds throughout
build_index, where each parent aggregates a superset of its children's sale
pairs, so shrink_dicts blends every child year against a present parent year
and the constant shift cancels in a within-range (current - sale) difference;
only comparisons that span the child's start year (e.g. a sale predating the
cell's own data) change. If a caller violates the precondition (a child year
the parent lacks), shrink_dicts passes that year through unshrunk and the
cancellation no longer holds.
"""
if not child or not parent:
return child
child_start = min(child)
offset = _base_value(parent, child_start) - child[child_start]
if offset == 0.0:
return child
return {y: v + offset for y, v in child.items()}
def shrink_dicts(raw: dict, parent: dict, n: int) -> dict:
"""Shrink dict values toward parent using n/(n+k) weighting.
Works for any dict keyed by year or category.
"""
w = n / (n + SHRINKAGE_K)
result = {}
for key in set(raw) | set(parent):
r = raw.get(key, parent.get(key, 0.0))
p = parent.get(key, raw.get(key, 0.0))
result[key] = w * r + (1 - w) * p
return result
def hierarchical_shrinkage(
sector_vals: dict[str, V],
sector_n: dict[str, int],
district_vals: dict[str, V],
district_n: dict[str, int],
area_vals: dict[str, V],
area_n: dict[str, int],
top_level: V,
all_sectors: list[str],
sector_to_dist: dict[str, str],
dist_to_area: dict[str, str],
shrink_fn: Callable[[V, V, int], V],
lift_fn: Callable[[V, V], V] | None = None,
) -> dict[str, V]:
"""Top-down hierarchical shrinkage: area->top, district->area, sector->district.
`top_level` is the ultimate fallback value (e.g. national shrunk toward hedonic,
or just national). `shrink_fn(raw, parent, n)` blends raw toward parent.
`lift_fn(raw, parent)`, if given, re-bases raw onto its parent before blending
(see lift_onto_parent); pass None for category-keyed dicts where re-basing is
meaningless.
"""
def combine(raw: V, parent: V, n: int) -> V:
if lift_fn is not None:
raw = lift_fn(raw, parent)
return shrink_fn(raw, parent, n)
# Area -> top level
area_shrunk = {}
for area, val in area_vals.items():
area_shrunk[area] = combine(val, top_level, area_n[area])
# District -> area
district_shrunk = {}
for dist, val in district_vals.items():
a = dist_to_area.get(dist, "")
parent = area_shrunk.get(a, top_level)
district_shrunk[dist] = combine(val, parent, district_n[dist])
# Sector -> district
sector_shrunk = {}
for sec, val in sector_vals.items():
d = sector_to_dist.get(sec, "")
parent = district_shrunk.get(d, top_level)
sector_shrunk[sec] = combine(val, parent, sector_n[sec])
# Fill sectors without their own values
for sec in all_sectors:
if sec not in sector_shrunk:
d = sector_to_dist.get(sec, "")
a = dist_to_area.get(d, "")
sector_shrunk[sec] = district_shrunk.get(d, area_shrunk.get(a, top_level))
return sector_shrunk
def spatial_smooth(
sector_values: dict[str, V],
centroids: dict[str, tuple[float, float]],
counts: dict[str, int],
blend_fn: Callable[[V, list[V], float, list[float]], V],
) -> dict[str, V]:
"""Blend sparse sector values with K nearest neighbors via KDTree."""
sectors_with_coords = [s for s in sector_values if s in centroids]
if len(sectors_with_coords) < SPATIAL_NEIGHBORS + 1:
return sector_values
coords = np.array([centroids[s] for s in sectors_with_coords])
# Scale longitude by cos(mean_lat) for approximate Euclidean distance
mean_lat = np.mean(coords[:, 0])
scale = np.cos(np.radians(mean_lat))
scaled_coords = np.column_stack([coords[:, 0], coords[:, 1] * scale])
tree = KDTree(scaled_coords)
result = dict(sector_values)
for i, sec in enumerate(sectors_with_coords):
n = counts.get(sec, 0)
self_w = n / (n + SPATIAL_BLEND_K)
if self_w > 0.90:
# Enough data, skip smoothing. Relaxed from 0.95 so higher-volume
# cells (n ~270-570) that still carry single-year noise get a light
# spatial blend, complementing the temporal smoothness prior.
continue
dists, idxs = tree.query(scaled_coords[i], k=SPATIAL_NEIGHBORS + 1)
# Skip self (index 0, distance ~0)
neighbor_dists = dists[1:]
neighbor_idxs = idxs[1:]
inv_dists = []
neighbor_vals = []
for d, j in zip(neighbor_dists, neighbor_idxs):
ns = sectors_with_coords[j]
if d > 0 and ns in sector_values:
inv_dists.append(1.0 / d)
neighbor_vals.append(sector_values[ns])
if not neighbor_vals:
continue
total_inv = sum(inv_dists)
nbr_w = 1.0 - self_w
neighbor_ws = [iw / total_inv * nbr_w for iw in inv_dists]
result[sec] = blend_fn(sector_values[sec], neighbor_vals, self_w, neighbor_ws)
return result
def blend_dicts(
self_val: dict, neighbor_vals: list[dict], self_w: float, neighbor_ws: list[float]
) -> dict:
"""Blend dict values by weighted sum across all keys."""
all_keys: set = set(self_val)
for nv in neighbor_vals:
all_keys |= set(nv)
result = {}
for k in all_keys:
val = self_w * self_val.get(k, 0.0)
for nv, w in zip(neighbor_vals, neighbor_ws):
val += w * nv.get(k, 0.0)
result[k] = val
return result