272 lines
9.9 KiB
Python
272 lines
9.9 KiB
Python
"""Stage 1: Repeat-Sales Price Index
|
||
|
||
Builds a hierarchical Case-Shiller repeat-sales price index from historical
|
||
transaction data. Solves WLS regression per postcode sector, district, area,
|
||
and nationally, then applies Bayesian shrinkage toward parent geographies.
|
||
|
||
Output: price_index.parquet with columns: sector, year, log_index, n_pairs
|
||
"""
|
||
|
||
import argparse
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
import polars as pl
|
||
from scipy.sparse import csc_matrix
|
||
from scipy.sparse.linalg import lsqr
|
||
from tqdm import tqdm
|
||
|
||
MIN_PAIRS = 5 # minimum pairs to compute an index
|
||
SHRINKAGE_K = 50 # shrinkage parameter: higher = more shrinkage toward parent
|
||
OUTLIER_THRESHOLD = 2.5 # |log_ratio| > this → drop (>12x price change)
|
||
|
||
|
||
def extract_pairs(input_path: Path) -> pl.DataFrame:
|
||
"""Extract consecutive sale pairs from historical_prices."""
|
||
print("Loading historical prices...")
|
||
df = (
|
||
pl.scan_parquet(input_path)
|
||
.select("Postcode", "historical_prices")
|
||
.filter(
|
||
pl.col("Postcode").is_not_null(),
|
||
pl.col("historical_prices").list.len() >= 2,
|
||
)
|
||
.with_columns(
|
||
pl.col("Postcode").str.slice(0, pl.col("Postcode").str.len_chars() - 2).str.strip_chars().alias("sector"),
|
||
)
|
||
.collect()
|
||
)
|
||
print(f" {len(df):,} properties with 2+ transactions")
|
||
|
||
print("Extracting consecutive pairs...")
|
||
pairs = (
|
||
df.lazy()
|
||
.with_columns(
|
||
pl.col("historical_prices").list.slice(0, pl.col("historical_prices").list.len() - 1).alias("from_txn"),
|
||
pl.col("historical_prices").list.slice(1).alias("to_txn"),
|
||
)
|
||
.explode("from_txn", "to_txn")
|
||
.with_columns(
|
||
pl.col("from_txn").struct.field("year").alias("year1"),
|
||
pl.col("from_txn").struct.field("price").alias("price1"),
|
||
pl.col("to_txn").struct.field("year").alias("year2"),
|
||
pl.col("to_txn").struct.field("price").alias("price2"),
|
||
)
|
||
.select("sector", "year1", "price1", "year2", "price2")
|
||
.filter(
|
||
pl.col("price1") > 0,
|
||
pl.col("price2") > 0,
|
||
pl.col("year2") > pl.col("year1"),
|
||
)
|
||
.with_columns(
|
||
(pl.col("price2").cast(pl.Float64) / pl.col("price1").cast(pl.Float64)).log().alias("log_ratio"),
|
||
(1.0 / (pl.col("year2") - pl.col("year1")).cast(pl.Float64).sqrt()).alias("weight"),
|
||
)
|
||
.filter(pl.col("log_ratio").abs() <= OUTLIER_THRESHOLD)
|
||
.collect()
|
||
)
|
||
print(f" {len(pairs):,} consecutive pairs extracted")
|
||
return pairs
|
||
|
||
|
||
def solve_wls_index(years1: np.ndarray, years2: np.ndarray, log_ratios: np.ndarray, weights: np.ndarray) -> dict[int, float]:
|
||
"""Solve WLS repeat-sales regression for a set of pairs.
|
||
|
||
Model: log(P2/P1) = beta[year2] - beta[year1], weighted by 1/sqrt(gap).
|
||
Pin beta[min_year] = 0.
|
||
Returns dict mapping year -> log_index (cumulative).
|
||
"""
|
||
if len(years1) < MIN_PAIRS:
|
||
return {}
|
||
|
||
all_years = np.union1d(years1, years2)
|
||
min_year = int(all_years.min())
|
||
# Map years to column indices, skipping min_year (pinned to 0)
|
||
col = 0
|
||
year_to_col = {}
|
||
for y in all_years:
|
||
if int(y) != min_year:
|
||
year_to_col[int(y)] = col
|
||
col += 1
|
||
n_cols = len(year_to_col)
|
||
if n_cols == 0:
|
||
return {}
|
||
|
||
n_rows = len(years1)
|
||
row_idx = []
|
||
col_idx = []
|
||
data = []
|
||
|
||
for i in range(n_rows):
|
||
y1, y2 = int(years1[i]), int(years2[i])
|
||
if y2 in year_to_col:
|
||
row_idx.append(i)
|
||
col_idx.append(year_to_col[y2])
|
||
data.append(weights[i])
|
||
if y1 in year_to_col:
|
||
row_idx.append(i)
|
||
col_idx.append(year_to_col[y1])
|
||
data.append(-weights[i])
|
||
|
||
A = csc_matrix((data, (row_idx, col_idx)), shape=(n_rows, n_cols))
|
||
b = log_ratios * weights
|
||
|
||
result = lsqr(A, b, atol=1e-10, btol=1e-10)
|
||
betas = result[0]
|
||
|
||
index = {min_year: 0.0}
|
||
for year, col in year_to_col.items():
|
||
index[year] = float(betas[col])
|
||
return index
|
||
|
||
|
||
def compute_indices_for_level(pairs: pl.DataFrame, group_col: str) -> dict[str, dict[int, float]]:
|
||
"""Compute raw indices for each geographic group."""
|
||
groups = pairs.group_by(group_col).agg(
|
||
pl.col("year1"), pl.col("year2"), pl.col("log_ratio"), pl.col("weight"),
|
||
)
|
||
|
||
indices = {}
|
||
n_pairs_map = {}
|
||
for row in tqdm(groups.iter_rows(named=True), total=len(groups), desc=f" Solving {group_col}"):
|
||
key = row[group_col]
|
||
y1 = np.array(row["year1"], dtype=np.int32)
|
||
y2 = np.array(row["year2"], dtype=np.int32)
|
||
lr = np.array(row["log_ratio"], dtype=np.float64)
|
||
w = np.array(row["weight"], dtype=np.float64)
|
||
idx = solve_wls_index(y1, y2, lr, w)
|
||
if idx:
|
||
indices[key] = idx
|
||
n_pairs_map[key] = len(y1)
|
||
return indices, n_pairs_map
|
||
|
||
|
||
def shrink_index(raw: dict[int, float], parent: dict[int, float], n_pairs: int) -> dict[int, float]:
|
||
"""Bayesian shrinkage toward parent index."""
|
||
w = n_pairs / (n_pairs + SHRINKAGE_K)
|
||
result = {}
|
||
all_years = set(raw.keys()) | set(parent.keys())
|
||
for y in all_years:
|
||
raw_val = raw.get(y, parent.get(y, 0.0))
|
||
parent_val = parent.get(y, raw.get(y, 0.0))
|
||
result[y] = w * raw_val + (1 - w) * parent_val
|
||
return result
|
||
|
||
|
||
def forward_fill_index(index: dict[int, float], min_year: int, max_year: int) -> dict[int, float]:
|
||
"""Forward-fill missing years so index is continuous."""
|
||
filled = {}
|
||
last_val = 0.0
|
||
for y in range(min_year, max_year + 1):
|
||
if y in index:
|
||
last_val = index[y]
|
||
filled[y] = last_val
|
||
return filled
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="Build repeat-sales price index")
|
||
parser.add_argument("--input", type=Path, required=True, help="Path to wide.parquet")
|
||
parser.add_argument("--output", type=Path, required=True, help="Output price_index.parquet")
|
||
args = parser.parse_args()
|
||
|
||
pairs = extract_pairs(args.input)
|
||
|
||
# Derive geographic hierarchy columns
|
||
pairs = pairs.with_columns(
|
||
# district = sector minus trailing digit(s), e.g. "SW1A 1" -> "SW1A"
|
||
pl.col("sector").str.replace(r"\s+\d+$", "").alias("district"),
|
||
).with_columns(
|
||
# area = leading letters only, e.g. "SW1A" -> "SW"
|
||
pl.col("district").str.replace(r"\d.*$", "").alias("area"),
|
||
)
|
||
|
||
# Solve indices at each level
|
||
print("\nComputing national index...")
|
||
pairs_np = pairs.select("year1", "year2", "log_ratio", "weight")
|
||
national_idx = solve_wls_index(
|
||
pairs_np["year1"].to_numpy(),
|
||
pairs_np["year2"].to_numpy(),
|
||
pairs_np["log_ratio"].to_numpy(),
|
||
pairs_np["weight"].to_numpy(),
|
||
)
|
||
print(f" National index: {len(national_idx)} years")
|
||
|
||
print("\nComputing area indices...")
|
||
area_indices, area_pairs = compute_indices_for_level(pairs, "area")
|
||
print(f" {len(area_indices)} areas with indices")
|
||
|
||
print("\nComputing district indices...")
|
||
district_indices, district_pairs = compute_indices_for_level(pairs, "district")
|
||
print(f" {len(district_indices)} districts with indices")
|
||
|
||
print("\nComputing sector indices...")
|
||
sector_indices, sector_pairs = compute_indices_for_level(pairs, "sector")
|
||
print(f" {len(sector_indices)} sectors with indices")
|
||
|
||
# Shrink area -> national
|
||
print("\nApplying hierarchical shrinkage...")
|
||
for area, idx in tqdm(area_indices.items(), desc=" Area shrinkage"):
|
||
area_indices[area] = shrink_index(idx, national_idx, area_pairs[area])
|
||
|
||
# Shrink district -> area
|
||
for dist, idx in tqdm(district_indices.items(), desc=" District shrinkage"):
|
||
area = dist.replace(r"\d.*$", "")
|
||
# Extract area from district (leading letters)
|
||
area_key = ""
|
||
for ch in dist:
|
||
if ch.isalpha():
|
||
area_key += ch
|
||
else:
|
||
break
|
||
parent = area_indices.get(area_key, national_idx)
|
||
district_indices[dist] = shrink_index(idx, parent, district_pairs[dist])
|
||
|
||
# Shrink sector -> district
|
||
for sector, idx in tqdm(sector_indices.items(), desc=" Sector shrinkage"):
|
||
# District = sector minus trailing space+digit
|
||
dist_key = sector.rsplit(" ", 1)[0] if " " in sector else sector
|
||
parent = district_indices.get(dist_key, national_idx)
|
||
sector_indices[sector] = shrink_index(idx, parent, sector_pairs[sector])
|
||
|
||
# For sectors without enough data, fall back to district/area/national
|
||
all_sectors = pairs["sector"].unique().to_list()
|
||
min_year = int(pairs["year1"].min())
|
||
max_year = max(int(pairs["year2"].max()), 2025)
|
||
|
||
print(f"\nFilling gaps and forward-filling ({min_year}-{max_year})...")
|
||
rows = []
|
||
for sector in tqdm(all_sectors, desc=" Forward-fill"):
|
||
if sector in sector_indices:
|
||
idx = sector_indices[sector]
|
||
else:
|
||
# Fall back to district, area, national
|
||
dist_key = sector.rsplit(" ", 1)[0] if " " in sector else sector
|
||
area_key = ""
|
||
for ch in dist_key:
|
||
if ch.isalpha():
|
||
area_key += ch
|
||
else:
|
||
break
|
||
idx = district_indices.get(dist_key, area_indices.get(area_key, national_idx))
|
||
|
||
n = sector_pairs.get(sector, 0)
|
||
filled = forward_fill_index(idx, min_year, max_year)
|
||
for year, log_idx in filled.items():
|
||
rows.append((sector, year, log_idx, n))
|
||
|
||
result = pl.DataFrame(
|
||
rows,
|
||
schema={"sector": pl.String, "year": pl.Int32, "log_index": pl.Float64, "n_pairs": pl.Int64},
|
||
orient="row",
|
||
)
|
||
|
||
result = result.sort("sector", "year")
|
||
result.write_parquet(args.output)
|
||
size_mb = args.output.stat().st_size / (1024 * 1024)
|
||
print(f"\nWrote {args.output} ({size_mb:.1f} MB)")
|
||
print(f" {result['sector'].n_unique():,} sectors × {max_year - min_year + 1} years = {len(result):,} rows")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|