"""Hierarchical shrinkage and spatial smoothing for sector-level estimates.""" from typing import Callable, TypeVar import numpy as np from scipy.spatial import KDTree from pipeline.transform.price_estimation.utils import SHRINKAGE_K V = TypeVar("V") SPATIAL_NEIGHBORS = 5 SPATIAL_BLEND_K = 30 def shrink_dicts(raw: dict, parent: dict, n: int) -> dict: """Shrink dict values toward parent using n/(n+k) weighting. Works for any dict keyed by year or category. """ w = n / (n + SHRINKAGE_K) result = {} for key in set(raw) | set(parent): r = raw.get(key, parent.get(key, 0.0)) p = parent.get(key, raw.get(key, 0.0)) result[key] = w * r + (1 - w) * p return result def hierarchical_shrinkage( sector_vals: dict[str, V], sector_n: dict[str, int], district_vals: dict[str, V], district_n: dict[str, int], area_vals: dict[str, V], area_n: dict[str, int], top_level: V, all_sectors: list[str], sector_to_dist: dict[str, str], dist_to_area: dict[str, str], shrink_fn: Callable[[V, V, int], V], ) -> dict[str, V]: """Top-down hierarchical shrinkage: area->top, district->area, sector->district. `top_level` is the ultimate fallback value (e.g. national shrunk toward hedonic, or just national). `shrink_fn(raw, parent, n)` blends raw toward parent. """ # Area -> top level area_shrunk = {} for area, val in area_vals.items(): area_shrunk[area] = shrink_fn(val, top_level, area_n[area]) # District -> area district_shrunk = {} for dist, val in district_vals.items(): a = dist_to_area.get(dist, "") parent = area_shrunk.get(a, top_level) district_shrunk[dist] = shrink_fn(val, parent, district_n[dist]) # Sector -> district sector_shrunk = {} for sec, val in sector_vals.items(): d = sector_to_dist.get(sec, "") parent = district_shrunk.get(d, top_level) sector_shrunk[sec] = shrink_fn(val, parent, sector_n[sec]) # Fill sectors without their own values for sec in all_sectors: if sec not in sector_shrunk: d = sector_to_dist.get(sec, "") a = dist_to_area.get(d, "") sector_shrunk[sec] = district_shrunk.get(d, area_shrunk.get(a, top_level)) return sector_shrunk def spatial_smooth( sector_values: dict[str, V], centroids: dict[str, tuple[float, float]], counts: dict[str, int], blend_fn: Callable[[V, list[V], float, list[float]], V], ) -> dict[str, V]: """Blend sparse sector values with K nearest neighbors via KDTree.""" sectors_with_coords = [s for s in sector_values if s in centroids] if len(sectors_with_coords) < SPATIAL_NEIGHBORS + 1: return sector_values coords = np.array([centroids[s] for s in sectors_with_coords]) # Scale longitude by cos(mean_lat) for approximate Euclidean distance mean_lat = np.mean(coords[:, 0]) scale = np.cos(np.radians(mean_lat)) scaled_coords = np.column_stack([coords[:, 0], coords[:, 1] * scale]) tree = KDTree(scaled_coords) result = dict(sector_values) for i, sec in enumerate(sectors_with_coords): n = counts.get(sec, 0) self_w = n / (n + SPATIAL_BLEND_K) if self_w > 0.95: continue # enough data, skip smoothing dists, idxs = tree.query(scaled_coords[i], k=SPATIAL_NEIGHBORS + 1) # Skip self (index 0, distance ~0) neighbor_dists = dists[1:] neighbor_idxs = idxs[1:] inv_dists = [] neighbor_vals = [] for d, j in zip(neighbor_dists, neighbor_idxs): ns = sectors_with_coords[j] if d > 0 and ns in sector_values: inv_dists.append(1.0 / d) neighbor_vals.append(sector_values[ns]) if not neighbor_vals: continue total_inv = sum(inv_dists) nbr_w = 1.0 - self_w neighbor_ws = [iw / total_inv * nbr_w for iw in inv_dists] result[sec] = blend_fn(sector_values[sec], neighbor_vals, self_w, neighbor_ws) return result def blend_dicts( self_val: dict, neighbor_vals: list[dict], self_w: float, neighbor_ws: list[float] ) -> dict: """Blend dict values by weighted sum across all keys.""" all_keys: set = set(self_val) for nv in neighbor_vals: all_keys |= set(nv) result = {} for k in all_keys: val = self_w * self_val.get(k, 0.0) for nv, w in zip(neighbor_vals, neighbor_ws): val += w * nv.get(k, 0.0) result[k] = val return result