140 lines
4.5 KiB
Python
140 lines
4.5 KiB
Python
"""Hierarchical shrinkage and spatial smoothing for sector-level estimates."""
|
|
|
|
from typing import Callable, TypeVar
|
|
|
|
import numpy as np
|
|
from scipy.spatial import KDTree
|
|
|
|
from pipeline.transform.price_estimation.utils import SHRINKAGE_K
|
|
|
|
V = TypeVar("V")
|
|
|
|
SPATIAL_NEIGHBORS = 5
|
|
SPATIAL_BLEND_K = 30
|
|
|
|
|
|
def shrink_dicts(raw: dict, parent: dict, n: int) -> dict:
|
|
"""Shrink dict values toward parent using n/(n+k) weighting.
|
|
|
|
Works for any dict keyed by year or category.
|
|
"""
|
|
w = n / (n + SHRINKAGE_K)
|
|
result = {}
|
|
for key in set(raw) | set(parent):
|
|
r = raw.get(key, parent.get(key, 0.0))
|
|
p = parent.get(key, raw.get(key, 0.0))
|
|
result[key] = w * r + (1 - w) * p
|
|
return result
|
|
|
|
|
|
def hierarchical_shrinkage(
|
|
sector_vals: dict[str, V],
|
|
sector_n: dict[str, int],
|
|
district_vals: dict[str, V],
|
|
district_n: dict[str, int],
|
|
area_vals: dict[str, V],
|
|
area_n: dict[str, int],
|
|
top_level: V,
|
|
all_sectors: list[str],
|
|
sector_to_dist: dict[str, str],
|
|
dist_to_area: dict[str, str],
|
|
shrink_fn: Callable[[V, V, int], V],
|
|
) -> dict[str, V]:
|
|
"""Top-down hierarchical shrinkage: area->top, district->area, sector->district.
|
|
|
|
`top_level` is the ultimate fallback value (e.g. national shrunk toward hedonic,
|
|
or just national). `shrink_fn(raw, parent, n)` blends raw toward parent.
|
|
"""
|
|
# Area -> top level
|
|
area_shrunk = {}
|
|
for area, val in area_vals.items():
|
|
area_shrunk[area] = shrink_fn(val, top_level, area_n[area])
|
|
|
|
# District -> area
|
|
district_shrunk = {}
|
|
for dist, val in district_vals.items():
|
|
a = dist_to_area.get(dist, "")
|
|
parent = area_shrunk.get(a, top_level)
|
|
district_shrunk[dist] = shrink_fn(val, parent, district_n[dist])
|
|
|
|
# Sector -> district
|
|
sector_shrunk = {}
|
|
for sec, val in sector_vals.items():
|
|
d = sector_to_dist.get(sec, "")
|
|
parent = district_shrunk.get(d, top_level)
|
|
sector_shrunk[sec] = shrink_fn(val, parent, sector_n[sec])
|
|
|
|
# Fill sectors without their own values
|
|
for sec in all_sectors:
|
|
if sec not in sector_shrunk:
|
|
d = sector_to_dist.get(sec, "")
|
|
a = dist_to_area.get(d, "")
|
|
sector_shrunk[sec] = district_shrunk.get(d, area_shrunk.get(a, top_level))
|
|
|
|
return sector_shrunk
|
|
|
|
|
|
def spatial_smooth(
|
|
sector_values: dict[str, V],
|
|
centroids: dict[str, tuple[float, float]],
|
|
counts: dict[str, int],
|
|
blend_fn: Callable[[V, list[V], float, list[float]], V],
|
|
) -> dict[str, V]:
|
|
"""Blend sparse sector values with K nearest neighbors via KDTree."""
|
|
sectors_with_coords = [s for s in sector_values if s in centroids]
|
|
if len(sectors_with_coords) < SPATIAL_NEIGHBORS + 1:
|
|
return sector_values
|
|
|
|
coords = np.array([centroids[s] for s in sectors_with_coords])
|
|
# Scale longitude by cos(mean_lat) for approximate Euclidean distance
|
|
mean_lat = np.mean(coords[:, 0])
|
|
scale = np.cos(np.radians(mean_lat))
|
|
scaled_coords = np.column_stack([coords[:, 0], coords[:, 1] * scale])
|
|
tree = KDTree(scaled_coords)
|
|
|
|
result = dict(sector_values)
|
|
for i, sec in enumerate(sectors_with_coords):
|
|
n = counts.get(sec, 0)
|
|
self_w = n / (n + SPATIAL_BLEND_K)
|
|
if self_w > 0.95:
|
|
continue # enough data, skip smoothing
|
|
|
|
dists, idxs = tree.query(scaled_coords[i], k=SPATIAL_NEIGHBORS + 1)
|
|
# Skip self (index 0, distance ~0)
|
|
neighbor_dists = dists[1:]
|
|
neighbor_idxs = idxs[1:]
|
|
|
|
inv_dists = []
|
|
neighbor_vals = []
|
|
for d, j in zip(neighbor_dists, neighbor_idxs):
|
|
ns = sectors_with_coords[j]
|
|
if d > 0 and ns in sector_values:
|
|
inv_dists.append(1.0 / d)
|
|
neighbor_vals.append(sector_values[ns])
|
|
|
|
if not neighbor_vals:
|
|
continue
|
|
|
|
total_inv = sum(inv_dists)
|
|
nbr_w = 1.0 - self_w
|
|
neighbor_ws = [iw / total_inv * nbr_w for iw in inv_dists]
|
|
|
|
result[sec] = blend_fn(sector_values[sec], neighbor_vals, self_w, neighbor_ws)
|
|
|
|
return result
|
|
|
|
|
|
def blend_dicts(
|
|
self_val: dict, neighbor_vals: list[dict], self_w: float, neighbor_ws: list[float]
|
|
) -> dict:
|
|
"""Blend dict values by weighted sum across all keys."""
|
|
all_keys: set = set(self_val)
|
|
for nv in neighbor_vals:
|
|
all_keys |= set(nv)
|
|
result = {}
|
|
for k in all_keys:
|
|
val = self_w * self_val.get(k, 0.0)
|
|
for nv, w in zip(neighbor_vals, neighbor_ws):
|
|
val += w * nv.get(k, 0.0)
|
|
result[k] = val
|
|
return result
|