"""Hierarchical shrinkage and spatial smoothing for sector-level estimates.""" from typing import Callable, TypeVar import numpy as np from scipy.spatial import KDTree from pipeline.transform.price_estimation.utils import SHRINKAGE_K V = TypeVar("V") SPATIAL_NEIGHBORS = 5 SPATIAL_BLEND_K = 30 def _base_value(index: dict[int, float], base_year: int) -> float: """Value of an index dict at `base_year`, with forward/back-fill for gaps. Each repeat-sales dict is anchored to 0 at its OWN earliest year, so its values are log-levels relative to that origin. To express it on a common origin we need its value at the shared `base_year`: - exact hit: use it directly; - base_year before the dict's history: back-fill, i.e. the earliest known value (which is 0.0 by construction). We cannot observe the level move between the global base and a later-starting cell, so we assume none, matching forward_fill's back-fill convention; - base_year inside a gap / after history: forward-fill the most recent prior value. """ if base_year in index: return index[base_year] years = sorted(index) if not years or base_year < years[0]: return index[years[0]] if years else 0.0 prior = [y for y in years if y <= base_year] return index[prior[-1]] def lift_onto_parent( child: dict[int, float], parent: dict[int, float] ) -> dict[int, float]: """Lift a child index onto its parent's base before blending the two. solve_robust_index anchors every cell to log-index 0 at its OWN earliest year, so a cell with a shorter history sits on a later origin than its (wider) parent. Combining them key-by-key would average level-incompatible numbers (a sector measured from 2008 blended with a district measured from 1996). We add the parent's accumulated level at the child's first year, so ``child[start] == parent[start]``: the child's own year-to-year moves are layered on top of the parent's growth up to that point -- the same assumption shrinkage already makes for years the child lacks. Re-basing on each cell's OWN earliest year (rather than the global base, which the child cannot observe) is what makes this effective: subtracting the child's value at the global base is always 0 and changes nothing. The shift is a single constant added to every year of the child, so the child's own year-to-year differences are preserved. PRECONDITION for the downstream estimate to be unaffected within the child's range: the parent's year coverage must be a superset of the child's. This holds throughout build_index, where each parent aggregates a superset of its children's sale pairs, so shrink_dicts blends every child year against a present parent year and the constant shift cancels in a within-range (current - sale) difference; only comparisons that span the child's start year (e.g. a sale predating the cell's own data) change. If a caller violates the precondition (a child year the parent lacks), shrink_dicts passes that year through unshrunk and the cancellation no longer holds. """ if not child or not parent: return child child_start = min(child) offset = _base_value(parent, child_start) - child[child_start] if offset == 0.0: return child return {y: v + offset for y, v in child.items()} def shrink_dicts(raw: dict, parent: dict, n: int) -> dict: """Shrink dict values toward parent using n/(n+k) weighting. Works for any dict keyed by year or category. """ w = n / (n + SHRINKAGE_K) result = {} for key in set(raw) | set(parent): r = raw.get(key, parent.get(key, 0.0)) p = parent.get(key, raw.get(key, 0.0)) result[key] = w * r + (1 - w) * p return result def hierarchical_shrinkage( sector_vals: dict[str, V], sector_n: dict[str, int], district_vals: dict[str, V], district_n: dict[str, int], area_vals: dict[str, V], area_n: dict[str, int], top_level: V, all_sectors: list[str], sector_to_dist: dict[str, str], dist_to_area: dict[str, str], shrink_fn: Callable[[V, V, int], V], lift_fn: Callable[[V, V], V] | None = None, ) -> dict[str, V]: """Top-down hierarchical shrinkage: area->top, district->area, sector->district. `top_level` is the ultimate fallback value (e.g. national shrunk toward hedonic, or just national). `shrink_fn(raw, parent, n)` blends raw toward parent. `lift_fn(raw, parent)`, if given, re-bases raw onto its parent before blending (see lift_onto_parent); pass None for category-keyed dicts where re-basing is meaningless. """ def combine(raw: V, parent: V, n: int) -> V: if lift_fn is not None: raw = lift_fn(raw, parent) return shrink_fn(raw, parent, n) # Area -> top level area_shrunk = {} for area, val in area_vals.items(): area_shrunk[area] = combine(val, top_level, area_n[area]) # District -> area district_shrunk = {} for dist, val in district_vals.items(): a = dist_to_area.get(dist, "") parent = area_shrunk.get(a, top_level) district_shrunk[dist] = combine(val, parent, district_n[dist]) # Sector -> district sector_shrunk = {} for sec, val in sector_vals.items(): d = sector_to_dist.get(sec, "") parent = district_shrunk.get(d, top_level) sector_shrunk[sec] = combine(val, parent, sector_n[sec]) # Fill sectors without their own values for sec in all_sectors: if sec not in sector_shrunk: d = sector_to_dist.get(sec, "") a = dist_to_area.get(d, "") sector_shrunk[sec] = district_shrunk.get(d, area_shrunk.get(a, top_level)) return sector_shrunk def spatial_smooth( sector_values: dict[str, V], centroids: dict[str, tuple[float, float]], counts: dict[str, int], blend_fn: Callable[[V, list[V], float, list[float]], V], ) -> dict[str, V]: """Blend sparse sector values with K nearest neighbors via KDTree.""" sectors_with_coords = [s for s in sector_values if s in centroids] if len(sectors_with_coords) < SPATIAL_NEIGHBORS + 1: return sector_values coords = np.array([centroids[s] for s in sectors_with_coords]) # Scale longitude by cos(mean_lat) for approximate Euclidean distance mean_lat = np.mean(coords[:, 0]) scale = np.cos(np.radians(mean_lat)) scaled_coords = np.column_stack([coords[:, 0], coords[:, 1] * scale]) tree = KDTree(scaled_coords) result = dict(sector_values) for i, sec in enumerate(sectors_with_coords): n = counts.get(sec, 0) self_w = n / (n + SPATIAL_BLEND_K) if self_w > 0.90: # Enough data, skip smoothing. Relaxed from 0.95 so higher-volume # cells (n ~270-570) that still carry single-year noise get a light # spatial blend, complementing the temporal smoothness prior. continue dists, idxs = tree.query(scaled_coords[i], k=SPATIAL_NEIGHBORS + 1) # Skip self (index 0, distance ~0) neighbor_dists = dists[1:] neighbor_idxs = idxs[1:] inv_dists = [] neighbor_vals = [] for d, j in zip(neighbor_dists, neighbor_idxs): ns = sectors_with_coords[j] if d > 0 and ns in sector_values: inv_dists.append(1.0 / d) neighbor_vals.append(sector_values[ns]) if not neighbor_vals: continue total_inv = sum(inv_dists) nbr_w = 1.0 - self_w neighbor_ws = [iw / total_inv * nbr_w for iw in inv_dists] result[sec] = blend_fn(sector_values[sec], neighbor_vals, self_w, neighbor_ws) return result def blend_dicts( self_val: dict, neighbor_vals: list[dict], self_w: float, neighbor_ws: list[float] ) -> dict: """Blend dict values by weighted sum across all keys.""" all_keys: set = set(self_val) for nv in neighbor_vals: all_keys |= set(nv) result = {} for k in all_keys: val = self_w * self_val.get(k, 0.0) for nv, w in zip(neighbor_vals, neighbor_ws): val += w * nv.get(k, 0.0) result[k] = val return result