215 lines
8.1 KiB
Python
215 lines
8.1 KiB
Python
"""Hierarchical shrinkage and spatial smoothing for sector-level estimates."""
|
|
|
|
from typing import Callable, TypeVar
|
|
|
|
import numpy as np
|
|
from scipy.spatial import KDTree
|
|
|
|
from pipeline.transform.price_estimation.utils import SHRINKAGE_K
|
|
|
|
V = TypeVar("V")
|
|
|
|
SPATIAL_NEIGHBORS = 5
|
|
SPATIAL_BLEND_K = 30
|
|
|
|
|
|
def _base_value(index: dict[int, float], base_year: int) -> float:
|
|
"""Value of an index dict at `base_year`, with forward/back-fill for gaps.
|
|
|
|
Each repeat-sales dict is anchored to 0 at its OWN earliest year, so its
|
|
values are log-levels relative to that origin. To express it on a common
|
|
origin we need its value at the shared `base_year`:
|
|
- exact hit: use it directly;
|
|
- base_year before the dict's history: back-fill, i.e. the earliest known
|
|
value (which is 0.0 by construction). We cannot observe the level move
|
|
between the global base and a later-starting cell, so we assume none,
|
|
matching forward_fill's back-fill convention;
|
|
- base_year inside a gap / after history: forward-fill the most recent
|
|
prior value.
|
|
"""
|
|
if base_year in index:
|
|
return index[base_year]
|
|
years = sorted(index)
|
|
if not years or base_year < years[0]:
|
|
return index[years[0]] if years else 0.0
|
|
prior = [y for y in years if y <= base_year]
|
|
return index[prior[-1]]
|
|
|
|
|
|
def lift_onto_parent(
|
|
child: dict[int, float], parent: dict[int, float]
|
|
) -> dict[int, float]:
|
|
"""Lift a child index onto its parent's base before blending the two.
|
|
|
|
solve_robust_index anchors every cell to log-index 0 at its OWN earliest
|
|
year, so a cell with a shorter history sits on a later origin than its
|
|
(wider) parent. Combining them key-by-key would average level-incompatible
|
|
numbers (a sector measured from 2008 blended with a district measured from
|
|
1996). We add the parent's accumulated level at the child's first year, so
|
|
``child[start] == parent[start]``: the child's own year-to-year moves are
|
|
layered on top of the parent's growth up to that point -- the same
|
|
assumption shrinkage already makes for years the child lacks.
|
|
|
|
Re-basing on each cell's OWN earliest year (rather than the global base,
|
|
which the child cannot observe) is what makes this effective: subtracting
|
|
the child's value at the global base is always 0 and changes nothing.
|
|
|
|
The shift is a single constant added to every year of the child, so the
|
|
child's own year-to-year differences are preserved. PRECONDITION for the
|
|
downstream estimate to be unaffected within the child's range: the parent's
|
|
year coverage must be a superset of the child's. This holds throughout
|
|
build_index, where each parent aggregates a superset of its children's sale
|
|
pairs, so shrink_dicts blends every child year against a present parent year
|
|
and the constant shift cancels in a within-range (current - sale) difference;
|
|
only comparisons that span the child's start year (e.g. a sale predating the
|
|
cell's own data) change. If a caller violates the precondition (a child year
|
|
the parent lacks), shrink_dicts passes that year through unshrunk and the
|
|
cancellation no longer holds.
|
|
"""
|
|
if not child or not parent:
|
|
return child
|
|
child_start = min(child)
|
|
offset = _base_value(parent, child_start) - child[child_start]
|
|
if offset == 0.0:
|
|
return child
|
|
return {y: v + offset for y, v in child.items()}
|
|
|
|
|
|
def shrink_dicts(raw: dict, parent: dict, n: int) -> dict:
|
|
"""Shrink dict values toward parent using n/(n+k) weighting.
|
|
|
|
Works for any dict keyed by year or category.
|
|
"""
|
|
w = n / (n + SHRINKAGE_K)
|
|
result = {}
|
|
for key in set(raw) | set(parent):
|
|
r = raw.get(key, parent.get(key, 0.0))
|
|
p = parent.get(key, raw.get(key, 0.0))
|
|
result[key] = w * r + (1 - w) * p
|
|
return result
|
|
|
|
|
|
def hierarchical_shrinkage(
|
|
sector_vals: dict[str, V],
|
|
sector_n: dict[str, int],
|
|
district_vals: dict[str, V],
|
|
district_n: dict[str, int],
|
|
area_vals: dict[str, V],
|
|
area_n: dict[str, int],
|
|
top_level: V,
|
|
all_sectors: list[str],
|
|
sector_to_dist: dict[str, str],
|
|
dist_to_area: dict[str, str],
|
|
shrink_fn: Callable[[V, V, int], V],
|
|
lift_fn: Callable[[V, V], V] | None = None,
|
|
) -> dict[str, V]:
|
|
"""Top-down hierarchical shrinkage: area->top, district->area, sector->district.
|
|
|
|
`top_level` is the ultimate fallback value (e.g. national shrunk toward hedonic,
|
|
or just national). `shrink_fn(raw, parent, n)` blends raw toward parent.
|
|
`lift_fn(raw, parent)`, if given, re-bases raw onto its parent before blending
|
|
(see lift_onto_parent); pass None for category-keyed dicts where re-basing is
|
|
meaningless.
|
|
"""
|
|
|
|
def combine(raw: V, parent: V, n: int) -> V:
|
|
if lift_fn is not None:
|
|
raw = lift_fn(raw, parent)
|
|
return shrink_fn(raw, parent, n)
|
|
|
|
# Area -> top level
|
|
area_shrunk = {}
|
|
for area, val in area_vals.items():
|
|
area_shrunk[area] = combine(val, top_level, area_n[area])
|
|
|
|
# District -> area
|
|
district_shrunk = {}
|
|
for dist, val in district_vals.items():
|
|
a = dist_to_area.get(dist, "")
|
|
parent = area_shrunk.get(a, top_level)
|
|
district_shrunk[dist] = combine(val, parent, district_n[dist])
|
|
|
|
# Sector -> district
|
|
sector_shrunk = {}
|
|
for sec, val in sector_vals.items():
|
|
d = sector_to_dist.get(sec, "")
|
|
parent = district_shrunk.get(d, top_level)
|
|
sector_shrunk[sec] = combine(val, parent, sector_n[sec])
|
|
|
|
# Fill sectors without their own values
|
|
for sec in all_sectors:
|
|
if sec not in sector_shrunk:
|
|
d = sector_to_dist.get(sec, "")
|
|
a = dist_to_area.get(d, "")
|
|
sector_shrunk[sec] = district_shrunk.get(d, area_shrunk.get(a, top_level))
|
|
|
|
return sector_shrunk
|
|
|
|
|
|
def spatial_smooth(
|
|
sector_values: dict[str, V],
|
|
centroids: dict[str, tuple[float, float]],
|
|
counts: dict[str, int],
|
|
blend_fn: Callable[[V, list[V], float, list[float]], V],
|
|
) -> dict[str, V]:
|
|
"""Blend sparse sector values with K nearest neighbors via KDTree."""
|
|
sectors_with_coords = [s for s in sector_values if s in centroids]
|
|
if len(sectors_with_coords) < SPATIAL_NEIGHBORS + 1:
|
|
return sector_values
|
|
|
|
coords = np.array([centroids[s] for s in sectors_with_coords])
|
|
# Scale longitude by cos(mean_lat) for approximate Euclidean distance
|
|
mean_lat = np.mean(coords[:, 0])
|
|
scale = np.cos(np.radians(mean_lat))
|
|
scaled_coords = np.column_stack([coords[:, 0], coords[:, 1] * scale])
|
|
tree = KDTree(scaled_coords)
|
|
|
|
result = dict(sector_values)
|
|
for i, sec in enumerate(sectors_with_coords):
|
|
n = counts.get(sec, 0)
|
|
self_w = n / (n + SPATIAL_BLEND_K)
|
|
if self_w > 0.90:
|
|
# Enough data, skip smoothing. Relaxed from 0.95 so higher-volume
|
|
# cells (n ~270-570) that still carry single-year noise get a light
|
|
# spatial blend, complementing the temporal smoothness prior.
|
|
continue
|
|
|
|
dists, idxs = tree.query(scaled_coords[i], k=SPATIAL_NEIGHBORS + 1)
|
|
# Skip self (index 0, distance ~0)
|
|
neighbor_dists = dists[1:]
|
|
neighbor_idxs = idxs[1:]
|
|
|
|
inv_dists = []
|
|
neighbor_vals = []
|
|
for d, j in zip(neighbor_dists, neighbor_idxs):
|
|
ns = sectors_with_coords[j]
|
|
if d > 0 and ns in sector_values:
|
|
inv_dists.append(1.0 / d)
|
|
neighbor_vals.append(sector_values[ns])
|
|
|
|
if not neighbor_vals:
|
|
continue
|
|
|
|
total_inv = sum(inv_dists)
|
|
nbr_w = 1.0 - self_w
|
|
neighbor_ws = [iw / total_inv * nbr_w for iw in inv_dists]
|
|
|
|
result[sec] = blend_fn(sector_values[sec], neighbor_vals, self_w, neighbor_ws)
|
|
|
|
return result
|
|
|
|
|
|
def blend_dicts(
|
|
self_val: dict, neighbor_vals: list[dict], self_w: float, neighbor_ws: list[float]
|
|
) -> dict:
|
|
"""Blend dict values by weighted sum across all keys."""
|
|
all_keys: set = set(self_val)
|
|
for nv in neighbor_vals:
|
|
all_keys |= set(nv)
|
|
result = {}
|
|
for k in all_keys:
|
|
val = self_w * self_val.get(k, 0.0)
|
|
for nv, w in zip(neighbor_vals, neighbor_ws):
|
|
val += w * nv.get(k, 0.0)
|
|
result[k] = val
|
|
return result
|