perfect-postcode/pipeline/transform/renovation_premium.py
2026-02-15 22:39:53 +00:00

572 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Estimate per-area renovation premiums from repeat-sale residuals.
For each repeat-sale pair, computes the residual after removing the price-index
predicted return. Pairs where renovation events occurred between sales should have
systematically higher residuals. A WLS regression estimates the log-premium per
event type, with hierarchical shrinkage and spatial smoothing.
Output: renovation_premium.parquet — sector × type_group × event_type → log_premium
"""
import argparse
import math
from pathlib import Path
import numpy as np
import polars as pl
from scipy.spatial import KDTree
from pipeline.transform._price_utils import (
SHRINKAGE_K,
TYPE_GROUPS,
extract_centroids,
hierarchy_keys,
sector_expr,
type_group_expr,
)
HALF_LIFE = 10.0
DECAY_RATE = math.log(2) / HALF_LIFE
OUTLIER_THRESHOLD = 3.0
MIN_PAIRS = 10
SPATIAL_NEIGHBORS = 5
SPATIAL_BLEND_K = 30
EVENT_TYPES = ["Extension", "Renovation", "Remodeling"]
def extract_pairs_with_events(input_path: Path, index_path: Path) -> pl.DataFrame:
"""Extract repeat-sale pairs with renovation events and index residuals."""
print("Extracting repeat-sale pairs with renovation events...")
df = (
pl.scan_parquet(input_path)
.select("Postcode", "historical_prices", "Property type", "renovation_history")
.filter(
pl.col("Postcode").is_not_null(),
pl.col("historical_prices").list.len() >= 2,
)
.with_columns(sector_expr(), type_group_expr())
.collect()
)
print(f" {len(df):,} properties with 2+ transactions")
# Build consecutive pairs
pairs = (
df.lazy()
.with_columns(
pl.col("historical_prices")
.list.slice(0, pl.col("historical_prices").list.len() - 1)
.alias("from_txn"),
pl.col("historical_prices").list.slice(1).alias("to_txn"),
)
.explode("from_txn", "to_txn")
.with_columns(
pl.col("from_txn").struct.field("year").alias("year1"),
pl.col("from_txn").struct.field("price").alias("price1"),
pl.col("to_txn").struct.field("year").alias("year2"),
pl.col("to_txn").struct.field("price").alias("price2"),
)
.select(
"sector",
"type_group",
"year1",
"price1",
"year2",
"price2",
"renovation_history",
)
.filter(
pl.col("price1") > 0,
pl.col("price2") > 0,
pl.col("year2") > pl.col("year1"),
)
.with_columns(
(pl.col("price2").cast(pl.Float64) / pl.col("price1").cast(pl.Float64))
.log()
.alias("log_ratio"),
)
.filter(pl.col("log_ratio").abs() <= OUTLIER_THRESHOLD)
.collect()
)
print(f" {len(pairs):,} repeat-sale pairs")
# Join price index to compute residuals
index = pl.read_parquet(index_path)
has_type_group = "type_group" in index.columns
if has_type_group:
idx_typed = index.filter(pl.col("type_group") != "All")
idx_all = index.filter(pl.col("type_group") == "All")
# Join at year1
pairs = pairs.join(
idx_typed.select(
"sector", "type_group", "year", pl.col("log_index").alias("li1_typed")
),
left_on=["sector", "type_group", "year1"],
right_on=["sector", "type_group", "year"],
how="left",
).join(
idx_all.select("sector", "year", pl.col("log_index").alias("li1_all")),
left_on=["sector", "year1"],
right_on=["sector", "year"],
how="left",
)
# Join at year2
pairs = pairs.join(
idx_typed.select(
"sector", "type_group", "year", pl.col("log_index").alias("li2_typed")
),
left_on=["sector", "type_group", "year2"],
right_on=["sector", "type_group", "year"],
how="left",
).join(
idx_all.select("sector", "year", pl.col("log_index").alias("li2_all")),
left_on=["sector", "year2"],
right_on=["sector", "year"],
how="left",
)
pairs = pairs.with_columns(
(pl.col("li1_typed").fill_null(pl.col("li1_all"))).alias("_li1"),
(pl.col("li2_typed").fill_null(pl.col("li2_all"))).alias("_li2"),
)
else:
pairs = pairs.join(
index.select("sector", "year", pl.col("log_index").alias("_li1")),
left_on=["sector", "year1"],
right_on=["sector", "year"],
how="left",
).join(
index.select("sector", "year", pl.col("log_index").alias("_li2")),
left_on=["sector", "year2"],
right_on=["sector", "year"],
how="left",
)
# Compute residual = log_ratio - (index2 - index1)
pairs = pairs.with_columns(
(
pl.col("log_ratio")
- (pl.col("_li2").fill_null(0.0) - pl.col("_li1").fill_null(0.0))
).alias("residual"),
(1.0 / (pl.col("year2") - pl.col("year1")).cast(pl.Float64).sqrt()).alias(
"weight"
),
)
# For each pair, compute time-decayed renovation indicators
# Use row index for unique identification (composite keys aren't unique per pair)
pairs = pairs.with_row_index("_pair_idx")
for et in EVENT_TYPES:
col_name = f"has_{et.lower()}"
pairs = pairs.with_columns(pl.lit(0.0).alias(col_name))
# Process properties that have renovation history
has_reno = pairs.filter(
pl.col("renovation_history").is_not_null()
& (pl.col("renovation_history").list.len() > 0)
)
if len(has_reno) > 0:
reno_exploded = (
has_reno.select("_pair_idx", "year1", "year2", "renovation_history")
.explode("renovation_history")
.with_columns(
pl.col("renovation_history").struct.field("year").alias("event_year"),
pl.col("renovation_history").struct.field("event").alias("event_type"),
)
# Only events between the two sales
.filter(
(pl.col("event_year") > pl.col("year1"))
& (pl.col("event_year") <= pl.col("year2"))
)
)
if len(reno_exploded) > 0:
# For each pair + event type, take the most recent event
latest_events = reno_exploded.group_by(
"_pair_idx", "event_type", "year2"
).agg(pl.col("event_year").max().alias("latest_event_year"))
# Compute time-decayed indicator: exp(-decay_rate * (year2 - event_year))
latest_events = latest_events.with_columns(
(
-DECAY_RATE
* (pl.col("year2") - pl.col("latest_event_year")).cast(pl.Float64)
)
.exp()
.alias("decayed_indicator"),
)
# Pivot to wide format using _pair_idx for unique join
for et in EVENT_TYPES:
et_data = latest_events.filter(pl.col("event_type") == et)
if len(et_data) > 0:
col_name = f"has_{et.lower()}"
pairs = (
pairs.join(
et_data.select(
"_pair_idx",
pl.col("decayed_indicator").alias(f"_{col_name}"),
),
on="_pair_idx",
how="left",
)
.with_columns(
pl.col(f"_{col_name}").fill_null(0.0).alias(col_name),
)
.drop(f"_{col_name}")
)
pairs = pairs.drop("_pair_idx")
# Add hierarchy columns
pairs = pairs.with_columns(
pl.col("sector").str.replace(r"\s+\d+$", "").alias("district"),
).with_columns(
pl.col("district").str.replace(r"\d.*$", "").alias("area"),
)
# Count reno pairs
reno_mask = (
(pl.col("has_extension") > 0)
| (pl.col("has_renovation") > 0)
| (pl.col("has_remodeling") > 0)
)
n_reno = pairs.filter(reno_mask).height
print(
f" {n_reno:,} pairs with renovation events ({n_reno / len(pairs) * 100:.1f}%)"
)
# Drop temporary columns from index join + renovation_history (no longer needed)
temp_cols = [
c
for c in pairs.columns
if c.startswith("_li") or c.startswith("li1_") or c.startswith("li2_")
]
pairs = pairs.drop(temp_cols + ["renovation_history"])
return pairs
def wls_regression(
residuals: np.ndarray,
weights: np.ndarray,
X: np.ndarray,
) -> np.ndarray:
"""Weighted least squares: residual ~ X (with intercept column in X).
Uses sqrt(weights) scaling to avoid building a full N×N diagonal matrix.
"""
sqrt_w = np.sqrt(weights)[:, np.newaxis]
Xw = X * sqrt_w
yw = residuals * sqrt_w.ravel()
try:
betas = np.linalg.lstsq(Xw, yw, rcond=None)[0]
except np.linalg.LinAlgError:
betas = np.zeros(X.shape[1])
return betas
def compute_premiums_for_group(df: pl.DataFrame) -> dict[str, float]:
"""Run WLS regression for a group, return {event_type: log_premium}."""
n = len(df)
if n < MIN_PAIRS:
return {}
residuals = df["residual"].to_numpy().astype(np.float64)
weights = df["weight"].to_numpy().astype(np.float64)
# Build design matrix: intercept + 3 event indicators
X = np.column_stack(
[
np.ones(n),
df["has_extension"].to_numpy().astype(np.float64),
df["has_renovation"].to_numpy().astype(np.float64),
df["has_remodeling"].to_numpy().astype(np.float64),
]
)
# Check if we have any renovation pairs in this group
reno_sum = X[:, 1:].sum()
if reno_sum < 1.0:
return {}
betas = wls_regression(residuals, weights, X)
# betas[0] is intercept, betas[1:4] are the premiums
return {
"Extension": float(betas[1]),
"Renovation": float(betas[2]),
"Remodeling": float(betas[3]),
}
def compute_premiums_for_level(
pairs: pl.DataFrame, group_col: str
) -> tuple[dict, dict]:
"""Compute premiums per group at a given hierarchy level.
Returns (premiums, n_reno_pairs) dicts keyed by group value.
premiums[key] = {event_type: log_premium}
"""
groups = pairs.group_by(group_col)
premiums = {}
n_reno_pairs = {}
for key, group_df in groups:
key_val = key[0]
result = compute_premiums_for_group(group_df)
if result:
premiums[key_val] = result
# Count pairs with any reno indicator
reno_mask = (
(group_df["has_extension"].to_numpy() > 0)
| (group_df["has_renovation"].to_numpy() > 0)
| (group_df["has_remodeling"].to_numpy() > 0)
)
n_reno_pairs[key_val] = int(reno_mask.sum())
return premiums, n_reno_pairs
def shrink_premium(
raw: dict[str, float], parent: dict[str, float], n: int
) -> dict[str, float]:
"""Shrink raw premiums toward parent level."""
w = n / (n + SHRINKAGE_K)
result = {}
for et in EVENT_TYPES:
r = raw.get(et, parent.get(et, 0.0))
p = parent.get(et, raw.get(et, 0.0))
result[et] = w * r + (1 - w) * p
return result
def apply_shrinkage(
sector_prem,
sector_n,
district_prem,
district_n,
area_prem,
area_n,
national_prem,
national_n,
all_sectors,
sector_to_dist,
dist_to_area,
):
"""Top-down hierarchical shrinkage for premiums."""
# Area -> national
area_shrunk = {}
for area, prem in area_prem.items():
area_shrunk[area] = shrink_premium(prem, national_prem, area_n.get(area, 0))
# District -> area
district_shrunk = {}
for dist, prem in district_prem.items():
a = dist_to_area.get(dist, "")
parent = area_shrunk.get(a, national_prem)
district_shrunk[dist] = shrink_premium(prem, parent, district_n.get(dist, 0))
# Sector -> district
sector_shrunk = {}
for sec, prem in sector_prem.items():
d = sector_to_dist.get(sec, "")
parent = district_shrunk.get(d, national_prem)
sector_shrunk[sec] = shrink_premium(prem, parent, sector_n.get(sec, 0))
# Fill missing sectors
for sec in all_sectors:
if sec not in sector_shrunk:
d = sector_to_dist.get(sec, "")
a = dist_to_area.get(d, "")
sector_shrunk[sec] = district_shrunk.get(
d, area_shrunk.get(a, national_prem)
)
return sector_shrunk
def spatial_smooth(
sector_premiums: dict[str, dict[str, float]],
centroids: dict[str, tuple[float, float]],
n_reno_map: dict[str, int],
) -> dict[str, dict[str, float]]:
"""Blend sparse sector premiums with K nearest neighbors."""
sectors_with_coords = [s for s in sector_premiums if s in centroids]
if len(sectors_with_coords) < SPATIAL_NEIGHBORS + 1:
return sector_premiums
coords = np.array([centroids[s] for s in sectors_with_coords])
mean_lat = np.mean(coords[:, 0])
scale = np.cos(np.radians(mean_lat))
scaled_coords = np.column_stack([coords[:, 0], coords[:, 1] * scale])
tree = KDTree(scaled_coords)
result = dict(sector_premiums)
for i, sec in enumerate(sectors_with_coords):
n = n_reno_map.get(sec, 0)
self_w = n / (n + SPATIAL_BLEND_K)
if self_w > 0.95:
continue
dists, idxs = tree.query(scaled_coords[i], k=SPATIAL_NEIGHBORS + 1)
neighbor_dists = dists[1:]
neighbor_idxs = idxs[1:]
inv_dists = []
neighbor_prems = []
for d, j in zip(neighbor_dists, neighbor_idxs):
ns = sectors_with_coords[j]
if d > 0 and ns in sector_premiums:
inv_dists.append(1.0 / d)
neighbor_prems.append(sector_premiums[ns])
if not neighbor_prems:
continue
total_inv = sum(inv_dists)
nbr_w = 1.0 - self_w
ws = [iw / total_inv * nbr_w for iw in inv_dists]
blended = {}
for et in EVENT_TYPES:
val = self_w * sector_premiums[sec].get(et, 0.0)
for np_dict, w in zip(neighbor_prems, ws):
val += w * np_dict.get(et, 0.0)
blended[et] = val
result[sec] = blended
return result
def main():
parser = argparse.ArgumentParser(
description="Estimate renovation premiums from repeat-sale residuals"
)
parser.add_argument(
"--input", type=Path, required=True, help="Path to wide.parquet"
)
parser.add_argument(
"--index", type=Path, required=True, help="Path to price_index.parquet"
)
parser.add_argument(
"--output", type=Path, required=True, help="Output renovation_premium.parquet"
)
args = parser.parse_args()
pairs = extract_pairs_with_events(args.input, args.index)
centroids = extract_centroids(args.input)
# Precompute hierarchy
all_sectors = pairs["sector"].unique().to_list()
sector_to_dist = {}
dist_to_area = {}
for s in all_sectors:
d, a = hierarchy_keys(s)
sector_to_dist[s] = d
dist_to_area[d] = a
all_type_groups = ["All"] + TYPE_GROUPS
rows = []
for tg in all_type_groups:
print(f"\n--- {tg} ---")
typed = pairs if tg == "All" else pairs.filter(pl.col("type_group") == tg)
if len(typed) < MIN_PAIRS:
print(f" Skipping (only {len(typed)} pairs)")
continue
print(f" {len(typed):,} pairs")
# National
national_prem = compute_premiums_for_group(typed)
national_reno = typed.filter(
(pl.col("has_extension") > 0)
| (pl.col("has_renovation") > 0)
| (pl.col("has_remodeling") > 0)
).height
if not national_prem:
print(" No renovation pairs at national level, skipping")
continue
print(
" National premiums: "
+ ", ".join(
f"{et}: {v:.4f} ({math.exp(v) - 1:.1%})"
for et, v in national_prem.items()
)
)
# Per-level
print(" Computing per-level premiums:")
area_prem, area_n = compute_premiums_for_level(typed, "area")
district_prem, district_n = compute_premiums_for_level(typed, "district")
sector_prem, sector_n = compute_premiums_for_level(typed, "sector")
print(
f" {len(area_prem)} areas, {len(district_prem)} districts, {len(sector_prem)} sectors with data"
)
# Shrinkage
print(" Applying shrinkage...")
sector_shrunk = apply_shrinkage(
sector_prem,
sector_n,
district_prem,
district_n,
area_prem,
area_n,
national_prem,
national_reno,
all_sectors,
sector_to_dist,
dist_to_area,
)
# Spatial smoothing
print(" Spatial smoothing...")
sector_smoothed = spatial_smooth(sector_shrunk, centroids, sector_n)
# Collect rows
for sec in all_sectors:
prem = sector_smoothed.get(sec, national_prem)
n = sector_n.get(sec, 0)
for et in EVENT_TYPES:
rows.append((sec, tg, et, prem.get(et, 0.0), n))
result = pl.DataFrame(
rows,
schema={
"sector": pl.String,
"type_group": pl.String,
"event_type": pl.String,
"log_premium": pl.Float64,
"n_reno_pairs": pl.Int64,
},
orient="row",
).sort("type_group", "sector", "event_type")
result.write_parquet(args.output)
size_mb = args.output.stat().st_size / (1024 * 1024)
print(f"\nWrote {args.output} ({size_mb:.1f} MB)")
print(
f" {result['sector'].n_unique():,} sectors x {len(all_type_groups)} types x {len(EVENT_TYPES)} events = {len(result):,} rows"
)
# Print summary statistics
print("\nNational premium summary:")
national = (
result.filter(pl.col("type_group") == "All")
.group_by("event_type")
.agg(
pl.col("log_premium").mean().alias("mean_log_premium"),
)
)
for row in national.iter_rows(named=True):
et = row["event_type"]
lp = row["mean_log_premium"]
print(f" {et}: log_premium={lp:.4f} ({math.exp(lp) - 1:.1%} price uplift)")
if __name__ == "__main__":
main()