"""Hierarchical shrinkage and spatial smoothing for sector-level estimates.""" from typing import Callable, TypeVar import numpy as np from scipy.spatial import KDTree from pipeline.transform.price_estimation.utils import SHRINKAGE_K V = TypeVar("V") SPATIAL_NEIGHBORS = 5 SPATIAL_BLEND_K = 30 # Hard band on a sector's per-year index move RELATIVE to its parent (the # national index), enforced by winsorize_steps after spatial smoothing. The # support-scaled temporal smoothness prior still under-penalises years # identified by only 1-2 repeat-sale pairs in thin early histories, leaving # artefacts like a x9.7 single-year jump (log +2.27, sector "M3 1" # 1998->1999). A sector may genuinely outpace the nation -- regeneration, new # transport links -- but those stories play out over multiple years, not as a # one-year x9.7 step. +/-0.40 log/yr (~x1.5 in a year) relative to the # national move keeps every plausible genuine sector-level divergence while # clamping thin-year data artefacts. MAX_STEP_DEVIATION_PER_YEAR = 0.40 def _base_value(index: dict[int, float], base_year: int) -> float: """Value of an index dict at `base_year`, with forward/back-fill for gaps. Each repeat-sales dict is anchored to 0 at its OWN earliest year, so its values are log-levels relative to that origin. To express it on a common origin we need its value at the shared `base_year`: - exact hit: use it directly; - base_year before the dict's history: back-fill, i.e. the earliest known value (which is 0.0 by construction). We cannot observe the level move between the global base and a later-starting cell, so we assume none, matching forward_fill's back-fill convention; - base_year inside a gap / after history: forward-fill the most recent prior value. """ if base_year in index: return index[base_year] years = sorted(index) if not years or base_year < years[0]: return index[years[0]] if years else 0.0 prior = [y for y in years if y <= base_year] return index[prior[-1]] def lift_onto_parent( child: dict[int, float], parent: dict[int, float] ) -> dict[int, float]: """Lift a child index onto its parent's base before blending the two. solve_robust_index anchors every cell to log-index 0 at its OWN earliest year, so a cell with a shorter history sits on a later origin than its (wider) parent. Combining them key-by-key would average level-incompatible numbers (a sector measured from 2008 blended with a district measured from 1996). We add the parent's accumulated level at the child's first year, so ``child[start] == parent[start]``: the child's own year-to-year moves are layered on top of the parent's growth up to that point -- the same assumption shrinkage already makes for years the child lacks. Re-basing on each cell's OWN earliest year (rather than the global base, which the child cannot observe) is what makes this effective: subtracting the child's value at the global base is always 0 and changes nothing. The shift is a single constant added to every year of the child, so the child's own year-to-year differences are preserved. PRECONDITION for the downstream estimate to be unaffected within the child's range: the parent's year coverage must be a superset of the child's. This holds throughout build_index, where each parent aggregates a superset of its children's sale pairs, so shrink_dicts blends every child year against a present parent year and the constant shift cancels in a within-range (current - sale) difference; only comparisons that span the child's start year (e.g. a sale predating the cell's own data) change. If a caller violates the precondition (a child year the parent lacks), shrink_dicts passes that year through unshrunk and the cancellation no longer holds. """ if not child or not parent: return child child_start = min(child) offset = _base_value(parent, child_start) - child[child_start] if offset == 0.0: return child return {y: v + offset for y, v in child.items()} def winsorize_steps( child: dict[int, float], parent: dict[int, float], max_dev_per_year: float, ) -> dict[int, float]: """Clamp a child's per-year index steps to within a band of the parent's. For each consecutive pair of solved years (y_prev, y) the child's per-year rate r = (child[y] - child[y_prev]) / (y - y_prev) is winsorised into [p - max_dev_per_year, p + max_dev_per_year], where p is the parent's per-year rate over the same span (via _base_value, so gaps in the parent's coverage are forward/back-filled rather than crashing). The series is then rebuilt cumulatively from the FIRST year's value, so: - the first year's level is preserved; - non-outlier steps are preserved exactly (later years simply shift by whatever the clamped steps removed); - a multi-year gap is judged on its per-year rate, not as one giant single-year move, so genuine level changes across gaps survive. A child with <2 years has no steps to clamp; an empty parent only occurs in degenerate paths (build_index always passes the national index) -- both are returned unchanged. """ if len(child) < 2 or not parent: return child years = sorted(child) result = {years[0]: child[years[0]]} for y_prev, y in zip(years[:-1], years[1:]): span = y - y_prev r = (child[y] - child[y_prev]) / span p = (_base_value(parent, y) - _base_value(parent, y_prev)) / span r = min(max(r, p - max_dev_per_year), p + max_dev_per_year) result[y] = result[y_prev] + r * span return result def shrink_dicts(raw: dict, parent: dict, n: int) -> dict: """Shrink dict values toward parent using n/(n+k) weighting. Works for any dict keyed by year or category. """ w = n / (n + SHRINKAGE_K) result = {} for key in set(raw) | set(parent): r = raw.get(key, parent.get(key, 0.0)) p = parent.get(key, raw.get(key, 0.0)) result[key] = w * r + (1 - w) * p return result def hierarchical_shrinkage( sector_vals: dict[str, V], sector_n: dict[str, int], district_vals: dict[str, V], district_n: dict[str, int], area_vals: dict[str, V], area_n: dict[str, int], top_level: V, all_sectors: list[str], sector_to_dist: dict[str, str], dist_to_area: dict[str, str], shrink_fn: Callable[[V, V, int], V], lift_fn: Callable[[V, V], V] | None = None, ) -> dict[str, V]: """Top-down hierarchical shrinkage: area->top, district->area, sector->district. `top_level` is the ultimate fallback value (e.g. national shrunk toward hedonic, or just national). `shrink_fn(raw, parent, n)` blends raw toward parent. `lift_fn(raw, parent)`, if given, re-bases raw onto its parent before blending (see lift_onto_parent); pass None for category-keyed dicts where re-basing is meaningless. """ def combine(raw: V, parent: V, n: int) -> V: if lift_fn is not None: raw = lift_fn(raw, parent) return shrink_fn(raw, parent, n) # Area -> top level area_shrunk = {} for area, val in area_vals.items(): area_shrunk[area] = combine(val, top_level, area_n[area]) # District -> area district_shrunk = {} for dist, val in district_vals.items(): a = dist_to_area.get(dist, "") parent = area_shrunk.get(a, top_level) district_shrunk[dist] = combine(val, parent, district_n[dist]) # Sector -> district sector_shrunk = {} for sec, val in sector_vals.items(): d = sector_to_dist.get(sec, "") parent = district_shrunk.get(d, top_level) sector_shrunk[sec] = combine(val, parent, sector_n[sec]) # Fill sectors without their own values for sec in all_sectors: if sec not in sector_shrunk: d = sector_to_dist.get(sec, "") a = dist_to_area.get(d, "") sector_shrunk[sec] = district_shrunk.get(d, area_shrunk.get(a, top_level)) return sector_shrunk def spatial_smooth( sector_values: dict[str, V], centroids: dict[str, tuple[float, float]], counts: dict[str, int], blend_fn: Callable[[V, list[V], float, list[float]], V], ) -> dict[str, V]: """Blend sparse sector values with K nearest neighbors via KDTree.""" sectors_with_coords = [s for s in sector_values if s in centroids] if len(sectors_with_coords) < SPATIAL_NEIGHBORS + 1: return sector_values coords = np.array([centroids[s] for s in sectors_with_coords]) # Scale longitude by cos(mean_lat) for approximate Euclidean distance mean_lat = np.mean(coords[:, 0]) scale = np.cos(np.radians(mean_lat)) scaled_coords = np.column_stack([coords[:, 0], coords[:, 1] * scale]) tree = KDTree(scaled_coords) result = dict(sector_values) for i, sec in enumerate(sectors_with_coords): n = counts.get(sec, 0) self_w = n / (n + SPATIAL_BLEND_K) if self_w > 0.90: # Enough data, skip smoothing. Relaxed from 0.95 so higher-volume # cells (n ~270-570) that still carry single-year noise get a light # spatial blend, complementing the temporal smoothness prior. continue dists, idxs = tree.query(scaled_coords[i], k=SPATIAL_NEIGHBORS + 1) # Skip self (index 0, distance ~0) neighbor_dists = dists[1:] neighbor_idxs = idxs[1:] inv_dists = [] neighbor_vals = [] for d, j in zip(neighbor_dists, neighbor_idxs): ns = sectors_with_coords[j] if d > 0 and ns in sector_values: inv_dists.append(1.0 / d) neighbor_vals.append(sector_values[ns]) if not neighbor_vals: continue total_inv = sum(inv_dists) nbr_w = 1.0 - self_w neighbor_ws = [iw / total_inv * nbr_w for iw in inv_dists] result[sec] = blend_fn(sector_values[sec], neighbor_vals, self_w, neighbor_ws) return result def blend_dicts( self_val: dict, neighbor_vals: list[dict], self_w: float, neighbor_ws: list[float] ) -> dict: """Blend dict values by weighted sum across all keys.""" all_keys: set = set(self_val) for nv in neighbor_vals: all_keys |= set(nv) result = {} for k in all_keys: val = self_w * self_val.get(k, 0.0) for nv, w in zip(neighbor_vals, neighbor_ws): val += w * nv.get(k, 0.0) result[k] = val return result