135 lines
4.4 KiB
Python
135 lines
4.4 KiB
Python
import numpy as np
|
|
import polars as pl
|
|
|
|
from pipeline.transform.price_estimation import index as index_mod
|
|
from pipeline.transform.price_estimation.index import (
|
|
compute_indices_for_level,
|
|
solve_robust_index,
|
|
)
|
|
|
|
|
|
def _pairs_from_path(true_levels: dict[int, float]):
|
|
"""Build adjacent-year repeat-sale pairs that exactly trace a known path.
|
|
|
|
Each consecutive pair's log_ratio is the difference of the true log-levels,
|
|
so the solver should recover the levels exactly (relative to the min year).
|
|
"""
|
|
years = sorted(true_levels)
|
|
y1, y2, lr, w = [], [], [], []
|
|
for a, b in zip(years[:-1], years[1:]):
|
|
y1.append(a)
|
|
y2.append(b)
|
|
lr.append(true_levels[b] - true_levels[a])
|
|
w.append(1.0)
|
|
return (
|
|
np.array(y1, dtype=np.int32),
|
|
np.array(y2, dtype=np.int32),
|
|
np.array(lr, dtype=np.float64),
|
|
np.array(w, dtype=np.float64),
|
|
)
|
|
|
|
|
|
def test_solver_recovers_contiguous_path():
|
|
"""A contiguous price path is recovered as log-levels relative to min_year.
|
|
|
|
Proves the IRLS solver is correct (and unchanged) for contiguous data: the
|
|
spacing-aware penalty reduces to the standard [1,-2,1] for unit spacing.
|
|
"""
|
|
years = range(2010, 2021)
|
|
true = {y: 0.04 * (y - 2010) for y in years} # smooth (zero curvature) ramp
|
|
# Replicate each adjacent pair so MIN_PAIRS is comfortably met.
|
|
y1, y2, lr, w = _pairs_from_path(true)
|
|
y1 = np.tile(y1, 3)
|
|
y2 = np.tile(y2, 3)
|
|
lr = np.tile(lr, 3)
|
|
w = np.tile(w, 3)
|
|
|
|
idx = solve_robust_index(y1, y2, lr, w)
|
|
|
|
assert idx[2010] == 0.0 # baseline anchor
|
|
for y in years:
|
|
assert abs(idx[y] - (true[y] - true[2010])) < 1e-3
|
|
|
|
|
|
def test_gap_spanning_level_jump_is_not_smoothed_into_a_ramp():
|
|
"""FIX #5: a sharp true level jump across a multi-year gap is preserved.
|
|
|
|
Coverage is 2000,2001,2002 then 2015,2016 with cross-gap pairs encoding a
|
|
sharp jump at the gap. The uniform [1,-2,1] curvature penalty treats
|
|
(beta_2002, beta_2015, beta_2016) as three adjacent years and over-penalizes
|
|
the genuine level jump, biasing beta_2015 down toward a smooth ramp. The
|
|
spacing-aware second difference relaxes the penalty across the gap.
|
|
"""
|
|
# True log-levels relative to min_year (2000 anchored at 0).
|
|
true = {
|
|
2000: 0.0,
|
|
2001: 0.05,
|
|
2002: 0.10,
|
|
2015: 1.10, # sharp +1.0 jump across the gap
|
|
2016: 1.15,
|
|
}
|
|
|
|
y1, y2, lr, w = [], [], [], []
|
|
|
|
def add(a, b, n=4):
|
|
for _ in range(n):
|
|
y1.append(a)
|
|
y2.append(b)
|
|
lr.append(true[b] - true[a])
|
|
w.append(1.0)
|
|
|
|
# In-segment adjacent pairs.
|
|
add(2000, 2001)
|
|
add(2001, 2002)
|
|
add(2015, 2016)
|
|
# Cross-gap pairs consistent with the sharp jump.
|
|
add(2002, 2015)
|
|
add(2002, 2016)
|
|
|
|
y1 = np.array(y1, dtype=np.int32)
|
|
y2 = np.array(y2, dtype=np.int32)
|
|
lr = np.array(lr, dtype=np.float64)
|
|
w = np.array(w, dtype=np.float64)
|
|
|
|
# Use a strong penalty to make the smoothing bias obvious.
|
|
original = index_mod.TEMPORAL_SMOOTHNESS_LAMBDA
|
|
index_mod.TEMPORAL_SMOOTHNESS_LAMBDA = 1.0
|
|
try:
|
|
idx = solve_robust_index(y1, y2, lr, w)
|
|
finally:
|
|
index_mod.TEMPORAL_SMOOTHNESS_LAMBDA = original
|
|
|
|
assert idx[2000] == 0.0 # baseline anchor
|
|
# beta_2015 must stay near its true post-gap level, not get dragged down by a
|
|
# spurious curvature penalty that treats the gap as a single-year step.
|
|
assert abs(idx[2015] - true[2015]) < 0.05
|
|
|
|
|
|
def test_n_pairs_counts_only_cross_year_pairs():
|
|
"""FIX #12: same-year pairs carry zero index information and must not inflate
|
|
the shrinkage weight; n_pairs counts only cross-year (year2 != year1) pairs."""
|
|
rows = []
|
|
|
|
def add_pairs(group, year1, year2, n):
|
|
for _ in range(n):
|
|
rows.append(
|
|
{
|
|
"grp": group,
|
|
"year1": year1,
|
|
"year2": year2,
|
|
"log_ratio": 0.03 * (year2 - year1),
|
|
"weight": 1.0,
|
|
}
|
|
)
|
|
|
|
# 8 genuine cross-year pairs spanning enough years for a valid solve, plus 3
|
|
# zero-information same-year pairs that must not be counted.
|
|
add_pairs("g", 2010, 2011, 4)
|
|
add_pairs("g", 2011, 2012, 4)
|
|
add_pairs("g", 2012, 2012, 3) # same-year, zero info
|
|
|
|
pairs = pl.DataFrame(rows)
|
|
indices, n_pairs = compute_indices_for_level(pairs, "grp")
|
|
|
|
assert "g" in indices
|
|
assert n_pairs["g"] == 8 # not 11
|