"""Estimate per-area renovation premiums from repeat-sale residuals. For each repeat-sale pair, computes the residual after removing the price-index predicted return. Pairs where renovation events occurred between sales should have systematically higher residuals. A WLS regression estimates the log-premium per event type, with hierarchical shrinkage and spatial smoothing. Output: renovation_premium.parquet — sector × type_group × event_type → log_premium """ import argparse import math from pathlib import Path import numpy as np import polars as pl from scipy.spatial import KDTree from pipeline.transform._price_utils import ( SHRINKAGE_K, TYPE_GROUPS, extract_centroids, hierarchy_keys, sector_expr, type_group_expr, ) HALF_LIFE = 10.0 DECAY_RATE = math.log(2) / HALF_LIFE OUTLIER_THRESHOLD = 3.0 MIN_PAIRS = 10 SPATIAL_NEIGHBORS = 5 SPATIAL_BLEND_K = 30 EVENT_TYPES = ["Extension", "Renovation", "Remodeling"] def extract_pairs_with_events(input_path: Path, index_path: Path) -> pl.DataFrame: """Extract repeat-sale pairs with renovation events and index residuals.""" print("Extracting repeat-sale pairs with renovation events...") df = ( pl.scan_parquet(input_path) .select("Postcode", "historical_prices", "Property type", "renovation_history") .filter( pl.col("Postcode").is_not_null(), pl.col("historical_prices").list.len() >= 2, ) .with_columns(sector_expr(), type_group_expr()) .collect() ) print(f" {len(df):,} properties with 2+ transactions") # Build consecutive pairs pairs = ( df.lazy() .with_columns( pl.col("historical_prices") .list.slice(0, pl.col("historical_prices").list.len() - 1) .alias("from_txn"), pl.col("historical_prices").list.slice(1).alias("to_txn"), ) .explode("from_txn", "to_txn") .with_columns( pl.col("from_txn").struct.field("year").alias("year1"), pl.col("from_txn").struct.field("price").alias("price1"), pl.col("to_txn").struct.field("year").alias("year2"), pl.col("to_txn").struct.field("price").alias("price2"), ) .select( "sector", "type_group", "year1", "price1", "year2", "price2", "renovation_history", ) .filter( pl.col("price1") > 0, pl.col("price2") > 0, pl.col("year2") > pl.col("year1"), ) .with_columns( (pl.col("price2").cast(pl.Float64) / pl.col("price1").cast(pl.Float64)) .log() .alias("log_ratio"), ) .filter(pl.col("log_ratio").abs() <= OUTLIER_THRESHOLD) .collect() ) print(f" {len(pairs):,} repeat-sale pairs") # Join price index to compute residuals index = pl.read_parquet(index_path) has_type_group = "type_group" in index.columns if has_type_group: idx_typed = index.filter(pl.col("type_group") != "All") idx_all = index.filter(pl.col("type_group") == "All") # Join at year1 pairs = pairs.join( idx_typed.select( "sector", "type_group", "year", pl.col("log_index").alias("li1_typed") ), left_on=["sector", "type_group", "year1"], right_on=["sector", "type_group", "year"], how="left", ).join( idx_all.select("sector", "year", pl.col("log_index").alias("li1_all")), left_on=["sector", "year1"], right_on=["sector", "year"], how="left", ) # Join at year2 pairs = pairs.join( idx_typed.select( "sector", "type_group", "year", pl.col("log_index").alias("li2_typed") ), left_on=["sector", "type_group", "year2"], right_on=["sector", "type_group", "year"], how="left", ).join( idx_all.select("sector", "year", pl.col("log_index").alias("li2_all")), left_on=["sector", "year2"], right_on=["sector", "year"], how="left", ) pairs = pairs.with_columns( (pl.col("li1_typed").fill_null(pl.col("li1_all"))).alias("_li1"), (pl.col("li2_typed").fill_null(pl.col("li2_all"))).alias("_li2"), ) else: pairs = pairs.join( index.select("sector", "year", pl.col("log_index").alias("_li1")), left_on=["sector", "year1"], right_on=["sector", "year"], how="left", ).join( index.select("sector", "year", pl.col("log_index").alias("_li2")), left_on=["sector", "year2"], right_on=["sector", "year"], how="left", ) # Compute residual = log_ratio - (index2 - index1) pairs = pairs.with_columns( ( pl.col("log_ratio") - (pl.col("_li2").fill_null(0.0) - pl.col("_li1").fill_null(0.0)) ).alias("residual"), (1.0 / (pl.col("year2") - pl.col("year1")).cast(pl.Float64).sqrt()).alias( "weight" ), ) # For each pair, compute time-decayed renovation indicators # Use row index for unique identification (composite keys aren't unique per pair) pairs = pairs.with_row_index("_pair_idx") for et in EVENT_TYPES: col_name = f"has_{et.lower()}" pairs = pairs.with_columns(pl.lit(0.0).alias(col_name)) # Process properties that have renovation history has_reno = pairs.filter( pl.col("renovation_history").is_not_null() & (pl.col("renovation_history").list.len() > 0) ) if len(has_reno) > 0: reno_exploded = ( has_reno.select("_pair_idx", "year1", "year2", "renovation_history") .explode("renovation_history") .with_columns( pl.col("renovation_history").struct.field("year").alias("event_year"), pl.col("renovation_history").struct.field("event").alias("event_type"), ) # Only events between the two sales .filter( (pl.col("event_year") > pl.col("year1")) & (pl.col("event_year") <= pl.col("year2")) ) ) if len(reno_exploded) > 0: # For each pair + event type, take the most recent event latest_events = reno_exploded.group_by( "_pair_idx", "event_type", "year2" ).agg(pl.col("event_year").max().alias("latest_event_year")) # Compute time-decayed indicator: exp(-decay_rate * (year2 - event_year)) latest_events = latest_events.with_columns( ( -DECAY_RATE * (pl.col("year2") - pl.col("latest_event_year")).cast(pl.Float64) ) .exp() .alias("decayed_indicator"), ) # Pivot to wide format using _pair_idx for unique join for et in EVENT_TYPES: et_data = latest_events.filter(pl.col("event_type") == et) if len(et_data) > 0: col_name = f"has_{et.lower()}" pairs = ( pairs.join( et_data.select( "_pair_idx", pl.col("decayed_indicator").alias(f"_{col_name}"), ), on="_pair_idx", how="left", ) .with_columns( pl.col(f"_{col_name}").fill_null(0.0).alias(col_name), ) .drop(f"_{col_name}") ) pairs = pairs.drop("_pair_idx") # Add hierarchy columns pairs = pairs.with_columns( pl.col("sector").str.replace(r"\s+\d+$", "").alias("district"), ).with_columns( pl.col("district").str.replace(r"\d.*$", "").alias("area"), ) # Count reno pairs reno_mask = ( (pl.col("has_extension") > 0) | (pl.col("has_renovation") > 0) | (pl.col("has_remodeling") > 0) ) n_reno = pairs.filter(reno_mask).height print( f" {n_reno:,} pairs with renovation events ({n_reno / len(pairs) * 100:.1f}%)" ) # Drop temporary columns from index join + renovation_history (no longer needed) temp_cols = [ c for c in pairs.columns if c.startswith("_li") or c.startswith("li1_") or c.startswith("li2_") ] pairs = pairs.drop(temp_cols + ["renovation_history"]) return pairs def wls_regression( residuals: np.ndarray, weights: np.ndarray, X: np.ndarray, ) -> np.ndarray: """Weighted least squares: residual ~ X (with intercept column in X). Uses sqrt(weights) scaling to avoid building a full N×N diagonal matrix. """ sqrt_w = np.sqrt(weights)[:, np.newaxis] Xw = X * sqrt_w yw = residuals * sqrt_w.ravel() try: betas = np.linalg.lstsq(Xw, yw, rcond=None)[0] except np.linalg.LinAlgError: betas = np.zeros(X.shape[1]) return betas def compute_premiums_for_group(df: pl.DataFrame) -> dict[str, float]: """Run WLS regression for a group, return {event_type: log_premium}.""" n = len(df) if n < MIN_PAIRS: return {} residuals = df["residual"].to_numpy().astype(np.float64) weights = df["weight"].to_numpy().astype(np.float64) # Build design matrix: intercept + 3 event indicators X = np.column_stack( [ np.ones(n), df["has_extension"].to_numpy().astype(np.float64), df["has_renovation"].to_numpy().astype(np.float64), df["has_remodeling"].to_numpy().astype(np.float64), ] ) # Check if we have any renovation pairs in this group reno_sum = X[:, 1:].sum() if reno_sum < 1.0: return {} betas = wls_regression(residuals, weights, X) # betas[0] is intercept, betas[1:4] are the premiums return { "Extension": float(betas[1]), "Renovation": float(betas[2]), "Remodeling": float(betas[3]), } def compute_premiums_for_level( pairs: pl.DataFrame, group_col: str ) -> tuple[dict, dict]: """Compute premiums per group at a given hierarchy level. Returns (premiums, n_reno_pairs) dicts keyed by group value. premiums[key] = {event_type: log_premium} """ groups = pairs.group_by(group_col) premiums = {} n_reno_pairs = {} for key, group_df in groups: key_val = key[0] result = compute_premiums_for_group(group_df) if result: premiums[key_val] = result # Count pairs with any reno indicator reno_mask = ( (group_df["has_extension"].to_numpy() > 0) | (group_df["has_renovation"].to_numpy() > 0) | (group_df["has_remodeling"].to_numpy() > 0) ) n_reno_pairs[key_val] = int(reno_mask.sum()) return premiums, n_reno_pairs def shrink_premium( raw: dict[str, float], parent: dict[str, float], n: int ) -> dict[str, float]: """Shrink raw premiums toward parent level.""" w = n / (n + SHRINKAGE_K) result = {} for et in EVENT_TYPES: r = raw.get(et, parent.get(et, 0.0)) p = parent.get(et, raw.get(et, 0.0)) result[et] = w * r + (1 - w) * p return result def apply_shrinkage( sector_prem, sector_n, district_prem, district_n, area_prem, area_n, national_prem, national_n, all_sectors, sector_to_dist, dist_to_area, ): """Top-down hierarchical shrinkage for premiums.""" # Area -> national area_shrunk = {} for area, prem in area_prem.items(): area_shrunk[area] = shrink_premium(prem, national_prem, area_n.get(area, 0)) # District -> area district_shrunk = {} for dist, prem in district_prem.items(): a = dist_to_area.get(dist, "") parent = area_shrunk.get(a, national_prem) district_shrunk[dist] = shrink_premium(prem, parent, district_n.get(dist, 0)) # Sector -> district sector_shrunk = {} for sec, prem in sector_prem.items(): d = sector_to_dist.get(sec, "") parent = district_shrunk.get(d, national_prem) sector_shrunk[sec] = shrink_premium(prem, parent, sector_n.get(sec, 0)) # Fill missing sectors for sec in all_sectors: if sec not in sector_shrunk: d = sector_to_dist.get(sec, "") a = dist_to_area.get(d, "") sector_shrunk[sec] = district_shrunk.get( d, area_shrunk.get(a, national_prem) ) return sector_shrunk def spatial_smooth( sector_premiums: dict[str, dict[str, float]], centroids: dict[str, tuple[float, float]], n_reno_map: dict[str, int], ) -> dict[str, dict[str, float]]: """Blend sparse sector premiums with K nearest neighbors.""" sectors_with_coords = [s for s in sector_premiums if s in centroids] if len(sectors_with_coords) < SPATIAL_NEIGHBORS + 1: return sector_premiums coords = np.array([centroids[s] for s in sectors_with_coords]) mean_lat = np.mean(coords[:, 0]) scale = np.cos(np.radians(mean_lat)) scaled_coords = np.column_stack([coords[:, 0], coords[:, 1] * scale]) tree = KDTree(scaled_coords) result = dict(sector_premiums) for i, sec in enumerate(sectors_with_coords): n = n_reno_map.get(sec, 0) self_w = n / (n + SPATIAL_BLEND_K) if self_w > 0.95: continue dists, idxs = tree.query(scaled_coords[i], k=SPATIAL_NEIGHBORS + 1) neighbor_dists = dists[1:] neighbor_idxs = idxs[1:] inv_dists = [] neighbor_prems = [] for d, j in zip(neighbor_dists, neighbor_idxs): ns = sectors_with_coords[j] if d > 0 and ns in sector_premiums: inv_dists.append(1.0 / d) neighbor_prems.append(sector_premiums[ns]) if not neighbor_prems: continue total_inv = sum(inv_dists) nbr_w = 1.0 - self_w ws = [iw / total_inv * nbr_w for iw in inv_dists] blended = {} for et in EVENT_TYPES: val = self_w * sector_premiums[sec].get(et, 0.0) for np_dict, w in zip(neighbor_prems, ws): val += w * np_dict.get(et, 0.0) blended[et] = val result[sec] = blended return result def main(): parser = argparse.ArgumentParser( description="Estimate renovation premiums from repeat-sale residuals" ) parser.add_argument( "--input", type=Path, required=True, help="Path to wide.parquet" ) parser.add_argument( "--index", type=Path, required=True, help="Path to price_index.parquet" ) parser.add_argument( "--output", type=Path, required=True, help="Output renovation_premium.parquet" ) args = parser.parse_args() pairs = extract_pairs_with_events(args.input, args.index) centroids = extract_centroids(args.input) # Precompute hierarchy all_sectors = pairs["sector"].unique().to_list() sector_to_dist = {} dist_to_area = {} for s in all_sectors: d, a = hierarchy_keys(s) sector_to_dist[s] = d dist_to_area[d] = a all_type_groups = ["All"] + TYPE_GROUPS rows = [] for tg in all_type_groups: print(f"\n--- {tg} ---") typed = pairs if tg == "All" else pairs.filter(pl.col("type_group") == tg) if len(typed) < MIN_PAIRS: print(f" Skipping (only {len(typed)} pairs)") continue print(f" {len(typed):,} pairs") # National national_prem = compute_premiums_for_group(typed) national_reno = typed.filter( (pl.col("has_extension") > 0) | (pl.col("has_renovation") > 0) | (pl.col("has_remodeling") > 0) ).height if not national_prem: print(" No renovation pairs at national level, skipping") continue print( " National premiums: " + ", ".join( f"{et}: {v:.4f} ({math.exp(v) - 1:.1%})" for et, v in national_prem.items() ) ) # Per-level print(" Computing per-level premiums:") area_prem, area_n = compute_premiums_for_level(typed, "area") district_prem, district_n = compute_premiums_for_level(typed, "district") sector_prem, sector_n = compute_premiums_for_level(typed, "sector") print( f" {len(area_prem)} areas, {len(district_prem)} districts, {len(sector_prem)} sectors with data" ) # Shrinkage print(" Applying shrinkage...") sector_shrunk = apply_shrinkage( sector_prem, sector_n, district_prem, district_n, area_prem, area_n, national_prem, national_reno, all_sectors, sector_to_dist, dist_to_area, ) # Spatial smoothing print(" Spatial smoothing...") sector_smoothed = spatial_smooth(sector_shrunk, centroids, sector_n) # Collect rows for sec in all_sectors: prem = sector_smoothed.get(sec, national_prem) n = sector_n.get(sec, 0) for et in EVENT_TYPES: rows.append((sec, tg, et, prem.get(et, 0.0), n)) result = pl.DataFrame( rows, schema={ "sector": pl.String, "type_group": pl.String, "event_type": pl.String, "log_premium": pl.Float64, "n_reno_pairs": pl.Int64, }, orient="row", ).sort("type_group", "sector", "event_type") result.write_parquet(args.output) size_mb = args.output.stat().st_size / (1024 * 1024) print(f"\nWrote {args.output} ({size_mb:.1f} MB)") print( f" {result['sector'].n_unique():,} sectors x {len(all_type_groups)} types x {len(EVENT_TYPES)} events = {len(result):,} rows" ) # Print summary statistics print("\nNational premium summary:") national = ( result.filter(pl.col("type_group") == "All") .group_by("event_type") .agg( pl.col("log_premium").mean().alias("mean_log_premium"), ) ) for row in national.iter_rows(named=True): et = row["event_type"] lp = row["mean_log_premium"] print(f" {et}: log_premium={lp:.4f} ({math.exp(lp) - 1:.1%} price uplift)") if __name__ == "__main__": main()