"""Stage 1: Repeat-Sales Price Index Builds a hierarchical Case-Shiller repeat-sales price index from historical transaction data. Solves WLS regression per postcode sector, district, area, and nationally, then applies Bayesian shrinkage toward parent geographies. Output: price_index.parquet with columns: sector, year, log_index, n_pairs """ import argparse from pathlib import Path import numpy as np import polars as pl from scipy.sparse import csc_matrix from scipy.sparse.linalg import lsqr from tqdm import tqdm MIN_PAIRS = 5 # minimum pairs to compute an index SHRINKAGE_K = 50 # shrinkage parameter: higher = more shrinkage toward parent OUTLIER_THRESHOLD = 2.5 # |log_ratio| > this → drop (>12x price change) def extract_pairs(input_path: Path) -> pl.DataFrame: """Extract consecutive sale pairs from historical_prices.""" print("Loading historical prices...") df = ( pl.scan_parquet(input_path) .select("Postcode", "historical_prices") .filter( pl.col("Postcode").is_not_null(), pl.col("historical_prices").list.len() >= 2, ) .with_columns( pl.col("Postcode").str.slice(0, pl.col("Postcode").str.len_chars() - 2).str.strip_chars().alias("sector"), ) .collect() ) print(f" {len(df):,} properties with 2+ transactions") print("Extracting consecutive pairs...") pairs = ( df.lazy() .with_columns( pl.col("historical_prices").list.slice(0, pl.col("historical_prices").list.len() - 1).alias("from_txn"), pl.col("historical_prices").list.slice(1).alias("to_txn"), ) .explode("from_txn", "to_txn") .with_columns( pl.col("from_txn").struct.field("year").alias("year1"), pl.col("from_txn").struct.field("price").alias("price1"), pl.col("to_txn").struct.field("year").alias("year2"), pl.col("to_txn").struct.field("price").alias("price2"), ) .select("sector", "year1", "price1", "year2", "price2") .filter( pl.col("price1") > 0, pl.col("price2") > 0, pl.col("year2") > pl.col("year1"), ) .with_columns( (pl.col("price2").cast(pl.Float64) / pl.col("price1").cast(pl.Float64)).log().alias("log_ratio"), (1.0 / (pl.col("year2") - pl.col("year1")).cast(pl.Float64).sqrt()).alias("weight"), ) .filter(pl.col("log_ratio").abs() <= OUTLIER_THRESHOLD) .collect() ) print(f" {len(pairs):,} consecutive pairs extracted") return pairs def solve_wls_index(years1: np.ndarray, years2: np.ndarray, log_ratios: np.ndarray, weights: np.ndarray) -> dict[int, float]: """Solve WLS repeat-sales regression for a set of pairs. Model: log(P2/P1) = beta[year2] - beta[year1], weighted by 1/sqrt(gap). Pin beta[min_year] = 0. Returns dict mapping year -> log_index (cumulative). """ if len(years1) < MIN_PAIRS: return {} all_years = np.union1d(years1, years2) min_year = int(all_years.min()) # Map years to column indices, skipping min_year (pinned to 0) col = 0 year_to_col = {} for y in all_years: if int(y) != min_year: year_to_col[int(y)] = col col += 1 n_cols = len(year_to_col) if n_cols == 0: return {} n_rows = len(years1) row_idx = [] col_idx = [] data = [] for i in range(n_rows): y1, y2 = int(years1[i]), int(years2[i]) if y2 in year_to_col: row_idx.append(i) col_idx.append(year_to_col[y2]) data.append(weights[i]) if y1 in year_to_col: row_idx.append(i) col_idx.append(year_to_col[y1]) data.append(-weights[i]) A = csc_matrix((data, (row_idx, col_idx)), shape=(n_rows, n_cols)) b = log_ratios * weights result = lsqr(A, b, atol=1e-10, btol=1e-10) betas = result[0] index = {min_year: 0.0} for year, col in year_to_col.items(): index[year] = float(betas[col]) return index def compute_indices_for_level(pairs: pl.DataFrame, group_col: str) -> dict[str, dict[int, float]]: """Compute raw indices for each geographic group.""" groups = pairs.group_by(group_col).agg( pl.col("year1"), pl.col("year2"), pl.col("log_ratio"), pl.col("weight"), ) indices = {} n_pairs_map = {} for row in tqdm(groups.iter_rows(named=True), total=len(groups), desc=f" Solving {group_col}"): key = row[group_col] y1 = np.array(row["year1"], dtype=np.int32) y2 = np.array(row["year2"], dtype=np.int32) lr = np.array(row["log_ratio"], dtype=np.float64) w = np.array(row["weight"], dtype=np.float64) idx = solve_wls_index(y1, y2, lr, w) if idx: indices[key] = idx n_pairs_map[key] = len(y1) return indices, n_pairs_map def shrink_index(raw: dict[int, float], parent: dict[int, float], n_pairs: int) -> dict[int, float]: """Bayesian shrinkage toward parent index.""" w = n_pairs / (n_pairs + SHRINKAGE_K) result = {} all_years = set(raw.keys()) | set(parent.keys()) for y in all_years: raw_val = raw.get(y, parent.get(y, 0.0)) parent_val = parent.get(y, raw.get(y, 0.0)) result[y] = w * raw_val + (1 - w) * parent_val return result def forward_fill_index(index: dict[int, float], min_year: int, max_year: int) -> dict[int, float]: """Forward-fill missing years so index is continuous.""" filled = {} last_val = 0.0 for y in range(min_year, max_year + 1): if y in index: last_val = index[y] filled[y] = last_val return filled def main(): parser = argparse.ArgumentParser(description="Build repeat-sales price index") parser.add_argument("--input", type=Path, required=True, help="Path to wide.parquet") parser.add_argument("--output", type=Path, required=True, help="Output price_index.parquet") args = parser.parse_args() pairs = extract_pairs(args.input) # Derive geographic hierarchy columns pairs = pairs.with_columns( # district = sector minus trailing digit(s), e.g. "SW1A 1" -> "SW1A" pl.col("sector").str.replace(r"\s+\d+$", "").alias("district"), ).with_columns( # area = leading letters only, e.g. "SW1A" -> "SW" pl.col("district").str.replace(r"\d.*$", "").alias("area"), ) # Solve indices at each level print("\nComputing national index...") pairs_np = pairs.select("year1", "year2", "log_ratio", "weight") national_idx = solve_wls_index( pairs_np["year1"].to_numpy(), pairs_np["year2"].to_numpy(), pairs_np["log_ratio"].to_numpy(), pairs_np["weight"].to_numpy(), ) print(f" National index: {len(national_idx)} years") print("\nComputing area indices...") area_indices, area_pairs = compute_indices_for_level(pairs, "area") print(f" {len(area_indices)} areas with indices") print("\nComputing district indices...") district_indices, district_pairs = compute_indices_for_level(pairs, "district") print(f" {len(district_indices)} districts with indices") print("\nComputing sector indices...") sector_indices, sector_pairs = compute_indices_for_level(pairs, "sector") print(f" {len(sector_indices)} sectors with indices") # Shrink area -> national print("\nApplying hierarchical shrinkage...") for area, idx in tqdm(area_indices.items(), desc=" Area shrinkage"): area_indices[area] = shrink_index(idx, national_idx, area_pairs[area]) # Shrink district -> area for dist, idx in tqdm(district_indices.items(), desc=" District shrinkage"): area = dist.replace(r"\d.*$", "") # Extract area from district (leading letters) area_key = "" for ch in dist: if ch.isalpha(): area_key += ch else: break parent = area_indices.get(area_key, national_idx) district_indices[dist] = shrink_index(idx, parent, district_pairs[dist]) # Shrink sector -> district for sector, idx in tqdm(sector_indices.items(), desc=" Sector shrinkage"): # District = sector minus trailing space+digit dist_key = sector.rsplit(" ", 1)[0] if " " in sector else sector parent = district_indices.get(dist_key, national_idx) sector_indices[sector] = shrink_index(idx, parent, sector_pairs[sector]) # For sectors without enough data, fall back to district/area/national all_sectors = pairs["sector"].unique().to_list() min_year = int(pairs["year1"].min()) max_year = max(int(pairs["year2"].max()), 2025) print(f"\nFilling gaps and forward-filling ({min_year}-{max_year})...") rows = [] for sector in tqdm(all_sectors, desc=" Forward-fill"): if sector in sector_indices: idx = sector_indices[sector] else: # Fall back to district, area, national dist_key = sector.rsplit(" ", 1)[0] if " " in sector else sector area_key = "" for ch in dist_key: if ch.isalpha(): area_key += ch else: break idx = district_indices.get(dist_key, area_indices.get(area_key, national_idx)) n = sector_pairs.get(sector, 0) filled = forward_fill_index(idx, min_year, max_year) for year, log_idx in filled.items(): rows.append((sector, year, log_idx, n)) result = pl.DataFrame( rows, schema={"sector": pl.String, "year": pl.Int32, "log_index": pl.Float64, "n_pairs": pl.Int64}, orient="row", ) result = result.sort("sector", "year") result.write_parquet(args.output) size_mb = args.output.stat().st_size / (1024 * 1024) print(f"\nWrote {args.output} ({size_mb:.1f} MB)") print(f" {result['sector'].n_unique():,} sectors × {max_year - min_year + 1} years = {len(result):,} rows") if __name__ == "__main__": main()