Refactor and other improvements

This commit is contained in:
Andras Schmelczer 2026-02-08 18:25:58 +00:00
parent 04a78e7bfe
commit 6c90cf3c0f
47 changed files with 2705 additions and 1568 deletions

View file

@ -0,0 +1,272 @@
"""Stage 1: Repeat-Sales Price Index
Builds a hierarchical Case-Shiller repeat-sales price index from historical
transaction data. Solves WLS regression per postcode sector, district, area,
and nationally, then applies Bayesian shrinkage toward parent geographies.
Output: price_index.parquet with columns: sector, year, log_index, n_pairs
"""
import argparse
from pathlib import Path
import numpy as np
import polars as pl
from scipy.sparse import csc_matrix
from scipy.sparse.linalg import lsqr
from tqdm import tqdm
MIN_PAIRS = 5 # minimum pairs to compute an index
SHRINKAGE_K = 50 # shrinkage parameter: higher = more shrinkage toward parent
OUTLIER_THRESHOLD = 2.5 # |log_ratio| > this → drop (>12x price change)
def extract_pairs(input_path: Path) -> pl.DataFrame:
"""Extract consecutive sale pairs from historical_prices."""
print("Loading historical prices...")
df = (
pl.scan_parquet(input_path)
.select("Postcode", "historical_prices")
.filter(
pl.col("Postcode").is_not_null(),
pl.col("historical_prices").list.len() >= 2,
)
.with_columns(
pl.col("Postcode").str.slice(0, pl.col("Postcode").str.len_chars() - 2).str.strip_chars().alias("sector"),
)
.collect()
)
print(f" {len(df):,} properties with 2+ transactions")
print("Extracting consecutive pairs...")
pairs = (
df.lazy()
.with_columns(
pl.col("historical_prices").list.slice(0, pl.col("historical_prices").list.len() - 1).alias("from_txn"),
pl.col("historical_prices").list.slice(1).alias("to_txn"),
)
.explode("from_txn", "to_txn")
.with_columns(
pl.col("from_txn").struct.field("year").alias("year1"),
pl.col("from_txn").struct.field("price").alias("price1"),
pl.col("to_txn").struct.field("year").alias("year2"),
pl.col("to_txn").struct.field("price").alias("price2"),
)
.select("sector", "year1", "price1", "year2", "price2")
.filter(
pl.col("price1") > 0,
pl.col("price2") > 0,
pl.col("year2") > pl.col("year1"),
)
.with_columns(
(pl.col("price2").cast(pl.Float64) / pl.col("price1").cast(pl.Float64)).log().alias("log_ratio"),
(1.0 / (pl.col("year2") - pl.col("year1")).cast(pl.Float64).sqrt()).alias("weight"),
)
.filter(pl.col("log_ratio").abs() <= OUTLIER_THRESHOLD)
.collect()
)
print(f" {len(pairs):,} consecutive pairs extracted")
return pairs
def solve_wls_index(years1: np.ndarray, years2: np.ndarray, log_ratios: np.ndarray, weights: np.ndarray) -> dict[int, float]:
"""Solve WLS repeat-sales regression for a set of pairs.
Model: log(P2/P1) = beta[year2] - beta[year1], weighted by 1/sqrt(gap).
Pin beta[min_year] = 0.
Returns dict mapping year -> log_index (cumulative).
"""
if len(years1) < MIN_PAIRS:
return {}
all_years = np.union1d(years1, years2)
min_year = int(all_years.min())
# Map years to column indices, skipping min_year (pinned to 0)
col = 0
year_to_col = {}
for y in all_years:
if int(y) != min_year:
year_to_col[int(y)] = col
col += 1
n_cols = len(year_to_col)
if n_cols == 0:
return {}
n_rows = len(years1)
row_idx = []
col_idx = []
data = []
for i in range(n_rows):
y1, y2 = int(years1[i]), int(years2[i])
if y2 in year_to_col:
row_idx.append(i)
col_idx.append(year_to_col[y2])
data.append(weights[i])
if y1 in year_to_col:
row_idx.append(i)
col_idx.append(year_to_col[y1])
data.append(-weights[i])
A = csc_matrix((data, (row_idx, col_idx)), shape=(n_rows, n_cols))
b = log_ratios * weights
result = lsqr(A, b, atol=1e-10, btol=1e-10)
betas = result[0]
index = {min_year: 0.0}
for year, col in year_to_col.items():
index[year] = float(betas[col])
return index
def compute_indices_for_level(pairs: pl.DataFrame, group_col: str) -> dict[str, dict[int, float]]:
"""Compute raw indices for each geographic group."""
groups = pairs.group_by(group_col).agg(
pl.col("year1"), pl.col("year2"), pl.col("log_ratio"), pl.col("weight"),
)
indices = {}
n_pairs_map = {}
for row in tqdm(groups.iter_rows(named=True), total=len(groups), desc=f" Solving {group_col}"):
key = row[group_col]
y1 = np.array(row["year1"], dtype=np.int32)
y2 = np.array(row["year2"], dtype=np.int32)
lr = np.array(row["log_ratio"], dtype=np.float64)
w = np.array(row["weight"], dtype=np.float64)
idx = solve_wls_index(y1, y2, lr, w)
if idx:
indices[key] = idx
n_pairs_map[key] = len(y1)
return indices, n_pairs_map
def shrink_index(raw: dict[int, float], parent: dict[int, float], n_pairs: int) -> dict[int, float]:
"""Bayesian shrinkage toward parent index."""
w = n_pairs / (n_pairs + SHRINKAGE_K)
result = {}
all_years = set(raw.keys()) | set(parent.keys())
for y in all_years:
raw_val = raw.get(y, parent.get(y, 0.0))
parent_val = parent.get(y, raw.get(y, 0.0))
result[y] = w * raw_val + (1 - w) * parent_val
return result
def forward_fill_index(index: dict[int, float], min_year: int, max_year: int) -> dict[int, float]:
"""Forward-fill missing years so index is continuous."""
filled = {}
last_val = 0.0
for y in range(min_year, max_year + 1):
if y in index:
last_val = index[y]
filled[y] = last_val
return filled
def main():
parser = argparse.ArgumentParser(description="Build repeat-sales price index")
parser.add_argument("--input", type=Path, required=True, help="Path to wide.parquet")
parser.add_argument("--output", type=Path, required=True, help="Output price_index.parquet")
args = parser.parse_args()
pairs = extract_pairs(args.input)
# Derive geographic hierarchy columns
pairs = pairs.with_columns(
# district = sector minus trailing digit(s), e.g. "SW1A 1" -> "SW1A"
pl.col("sector").str.replace(r"\s+\d+$", "").alias("district"),
).with_columns(
# area = leading letters only, e.g. "SW1A" -> "SW"
pl.col("district").str.replace(r"\d.*$", "").alias("area"),
)
# Solve indices at each level
print("\nComputing national index...")
pairs_np = pairs.select("year1", "year2", "log_ratio", "weight")
national_idx = solve_wls_index(
pairs_np["year1"].to_numpy(),
pairs_np["year2"].to_numpy(),
pairs_np["log_ratio"].to_numpy(),
pairs_np["weight"].to_numpy(),
)
print(f" National index: {len(national_idx)} years")
print("\nComputing area indices...")
area_indices, area_pairs = compute_indices_for_level(pairs, "area")
print(f" {len(area_indices)} areas with indices")
print("\nComputing district indices...")
district_indices, district_pairs = compute_indices_for_level(pairs, "district")
print(f" {len(district_indices)} districts with indices")
print("\nComputing sector indices...")
sector_indices, sector_pairs = compute_indices_for_level(pairs, "sector")
print(f" {len(sector_indices)} sectors with indices")
# Shrink area -> national
print("\nApplying hierarchical shrinkage...")
for area, idx in tqdm(area_indices.items(), desc=" Area shrinkage"):
area_indices[area] = shrink_index(idx, national_idx, area_pairs[area])
# Shrink district -> area
for dist, idx in tqdm(district_indices.items(), desc=" District shrinkage"):
area = dist.replace(r"\d.*$", "")
# Extract area from district (leading letters)
area_key = ""
for ch in dist:
if ch.isalpha():
area_key += ch
else:
break
parent = area_indices.get(area_key, national_idx)
district_indices[dist] = shrink_index(idx, parent, district_pairs[dist])
# Shrink sector -> district
for sector, idx in tqdm(sector_indices.items(), desc=" Sector shrinkage"):
# District = sector minus trailing space+digit
dist_key = sector.rsplit(" ", 1)[0] if " " in sector else sector
parent = district_indices.get(dist_key, national_idx)
sector_indices[sector] = shrink_index(idx, parent, sector_pairs[sector])
# For sectors without enough data, fall back to district/area/national
all_sectors = pairs["sector"].unique().to_list()
min_year = int(pairs["year1"].min())
max_year = max(int(pairs["year2"].max()), 2025)
print(f"\nFilling gaps and forward-filling ({min_year}-{max_year})...")
rows = []
for sector in tqdm(all_sectors, desc=" Forward-fill"):
if sector in sector_indices:
idx = sector_indices[sector]
else:
# Fall back to district, area, national
dist_key = sector.rsplit(" ", 1)[0] if " " in sector else sector
area_key = ""
for ch in dist_key:
if ch.isalpha():
area_key += ch
else:
break
idx = district_indices.get(dist_key, area_indices.get(area_key, national_idx))
n = sector_pairs.get(sector, 0)
filled = forward_fill_index(idx, min_year, max_year)
for year, log_idx in filled.items():
rows.append((sector, year, log_idx, n))
result = pl.DataFrame(
rows,
schema={"sector": pl.String, "year": pl.Int32, "log_index": pl.Float64, "n_pairs": pl.Int64},
orient="row",
)
result = result.sort("sector", "year")
result.write_parquet(args.output)
size_mb = args.output.stat().st_size / (1024 * 1024)
print(f"\nWrote {args.output} ({size_mb:.1f} MB)")
print(f" {result['sector'].n_unique():,} sectors × {max_year - min_year + 1} years = {len(result):,} rows")
if __name__ == "__main__":
main()