perfect-postcode/pipeline/transform/price_index.py

272 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Stage 1: Repeat-Sales Price Index
Builds a hierarchical Case-Shiller repeat-sales price index from historical
transaction data. Solves WLS regression per postcode sector, district, area,
and nationally, then applies Bayesian shrinkage toward parent geographies.
Output: price_index.parquet with columns: sector, year, log_index, n_pairs
"""
import argparse
from pathlib import Path
import numpy as np
import polars as pl
from scipy.sparse import csc_matrix
from scipy.sparse.linalg import lsqr
from tqdm import tqdm
MIN_PAIRS = 5 # minimum pairs to compute an index
SHRINKAGE_K = 50 # shrinkage parameter: higher = more shrinkage toward parent
OUTLIER_THRESHOLD = 2.5 # |log_ratio| > this → drop (>12x price change)
def extract_pairs(input_path: Path) -> pl.DataFrame:
"""Extract consecutive sale pairs from historical_prices."""
print("Loading historical prices...")
df = (
pl.scan_parquet(input_path)
.select("Postcode", "historical_prices")
.filter(
pl.col("Postcode").is_not_null(),
pl.col("historical_prices").list.len() >= 2,
)
.with_columns(
pl.col("Postcode").str.slice(0, pl.col("Postcode").str.len_chars() - 2).str.strip_chars().alias("sector"),
)
.collect()
)
print(f" {len(df):,} properties with 2+ transactions")
print("Extracting consecutive pairs...")
pairs = (
df.lazy()
.with_columns(
pl.col("historical_prices").list.slice(0, pl.col("historical_prices").list.len() - 1).alias("from_txn"),
pl.col("historical_prices").list.slice(1).alias("to_txn"),
)
.explode("from_txn", "to_txn")
.with_columns(
pl.col("from_txn").struct.field("year").alias("year1"),
pl.col("from_txn").struct.field("price").alias("price1"),
pl.col("to_txn").struct.field("year").alias("year2"),
pl.col("to_txn").struct.field("price").alias("price2"),
)
.select("sector", "year1", "price1", "year2", "price2")
.filter(
pl.col("price1") > 0,
pl.col("price2") > 0,
pl.col("year2") > pl.col("year1"),
)
.with_columns(
(pl.col("price2").cast(pl.Float64) / pl.col("price1").cast(pl.Float64)).log().alias("log_ratio"),
(1.0 / (pl.col("year2") - pl.col("year1")).cast(pl.Float64).sqrt()).alias("weight"),
)
.filter(pl.col("log_ratio").abs() <= OUTLIER_THRESHOLD)
.collect()
)
print(f" {len(pairs):,} consecutive pairs extracted")
return pairs
def solve_wls_index(years1: np.ndarray, years2: np.ndarray, log_ratios: np.ndarray, weights: np.ndarray) -> dict[int, float]:
"""Solve WLS repeat-sales regression for a set of pairs.
Model: log(P2/P1) = beta[year2] - beta[year1], weighted by 1/sqrt(gap).
Pin beta[min_year] = 0.
Returns dict mapping year -> log_index (cumulative).
"""
if len(years1) < MIN_PAIRS:
return {}
all_years = np.union1d(years1, years2)
min_year = int(all_years.min())
# Map years to column indices, skipping min_year (pinned to 0)
col = 0
year_to_col = {}
for y in all_years:
if int(y) != min_year:
year_to_col[int(y)] = col
col += 1
n_cols = len(year_to_col)
if n_cols == 0:
return {}
n_rows = len(years1)
row_idx = []
col_idx = []
data = []
for i in range(n_rows):
y1, y2 = int(years1[i]), int(years2[i])
if y2 in year_to_col:
row_idx.append(i)
col_idx.append(year_to_col[y2])
data.append(weights[i])
if y1 in year_to_col:
row_idx.append(i)
col_idx.append(year_to_col[y1])
data.append(-weights[i])
A = csc_matrix((data, (row_idx, col_idx)), shape=(n_rows, n_cols))
b = log_ratios * weights
result = lsqr(A, b, atol=1e-10, btol=1e-10)
betas = result[0]
index = {min_year: 0.0}
for year, col in year_to_col.items():
index[year] = float(betas[col])
return index
def compute_indices_for_level(pairs: pl.DataFrame, group_col: str) -> dict[str, dict[int, float]]:
"""Compute raw indices for each geographic group."""
groups = pairs.group_by(group_col).agg(
pl.col("year1"), pl.col("year2"), pl.col("log_ratio"), pl.col("weight"),
)
indices = {}
n_pairs_map = {}
for row in tqdm(groups.iter_rows(named=True), total=len(groups), desc=f" Solving {group_col}"):
key = row[group_col]
y1 = np.array(row["year1"], dtype=np.int32)
y2 = np.array(row["year2"], dtype=np.int32)
lr = np.array(row["log_ratio"], dtype=np.float64)
w = np.array(row["weight"], dtype=np.float64)
idx = solve_wls_index(y1, y2, lr, w)
if idx:
indices[key] = idx
n_pairs_map[key] = len(y1)
return indices, n_pairs_map
def shrink_index(raw: dict[int, float], parent: dict[int, float], n_pairs: int) -> dict[int, float]:
"""Bayesian shrinkage toward parent index."""
w = n_pairs / (n_pairs + SHRINKAGE_K)
result = {}
all_years = set(raw.keys()) | set(parent.keys())
for y in all_years:
raw_val = raw.get(y, parent.get(y, 0.0))
parent_val = parent.get(y, raw.get(y, 0.0))
result[y] = w * raw_val + (1 - w) * parent_val
return result
def forward_fill_index(index: dict[int, float], min_year: int, max_year: int) -> dict[int, float]:
"""Forward-fill missing years so index is continuous."""
filled = {}
last_val = 0.0
for y in range(min_year, max_year + 1):
if y in index:
last_val = index[y]
filled[y] = last_val
return filled
def main():
parser = argparse.ArgumentParser(description="Build repeat-sales price index")
parser.add_argument("--input", type=Path, required=True, help="Path to wide.parquet")
parser.add_argument("--output", type=Path, required=True, help="Output price_index.parquet")
args = parser.parse_args()
pairs = extract_pairs(args.input)
# Derive geographic hierarchy columns
pairs = pairs.with_columns(
# district = sector minus trailing digit(s), e.g. "SW1A 1" -> "SW1A"
pl.col("sector").str.replace(r"\s+\d+$", "").alias("district"),
).with_columns(
# area = leading letters only, e.g. "SW1A" -> "SW"
pl.col("district").str.replace(r"\d.*$", "").alias("area"),
)
# Solve indices at each level
print("\nComputing national index...")
pairs_np = pairs.select("year1", "year2", "log_ratio", "weight")
national_idx = solve_wls_index(
pairs_np["year1"].to_numpy(),
pairs_np["year2"].to_numpy(),
pairs_np["log_ratio"].to_numpy(),
pairs_np["weight"].to_numpy(),
)
print(f" National index: {len(national_idx)} years")
print("\nComputing area indices...")
area_indices, area_pairs = compute_indices_for_level(pairs, "area")
print(f" {len(area_indices)} areas with indices")
print("\nComputing district indices...")
district_indices, district_pairs = compute_indices_for_level(pairs, "district")
print(f" {len(district_indices)} districts with indices")
print("\nComputing sector indices...")
sector_indices, sector_pairs = compute_indices_for_level(pairs, "sector")
print(f" {len(sector_indices)} sectors with indices")
# Shrink area -> national
print("\nApplying hierarchical shrinkage...")
for area, idx in tqdm(area_indices.items(), desc=" Area shrinkage"):
area_indices[area] = shrink_index(idx, national_idx, area_pairs[area])
# Shrink district -> area
for dist, idx in tqdm(district_indices.items(), desc=" District shrinkage"):
area = dist.replace(r"\d.*$", "")
# Extract area from district (leading letters)
area_key = ""
for ch in dist:
if ch.isalpha():
area_key += ch
else:
break
parent = area_indices.get(area_key, national_idx)
district_indices[dist] = shrink_index(idx, parent, district_pairs[dist])
# Shrink sector -> district
for sector, idx in tqdm(sector_indices.items(), desc=" Sector shrinkage"):
# District = sector minus trailing space+digit
dist_key = sector.rsplit(" ", 1)[0] if " " in sector else sector
parent = district_indices.get(dist_key, national_idx)
sector_indices[sector] = shrink_index(idx, parent, sector_pairs[sector])
# For sectors without enough data, fall back to district/area/national
all_sectors = pairs["sector"].unique().to_list()
min_year = int(pairs["year1"].min())
max_year = max(int(pairs["year2"].max()), 2025)
print(f"\nFilling gaps and forward-filling ({min_year}-{max_year})...")
rows = []
for sector in tqdm(all_sectors, desc=" Forward-fill"):
if sector in sector_indices:
idx = sector_indices[sector]
else:
# Fall back to district, area, national
dist_key = sector.rsplit(" ", 1)[0] if " " in sector else sector
area_key = ""
for ch in dist_key:
if ch.isalpha():
area_key += ch
else:
break
idx = district_indices.get(dist_key, area_indices.get(area_key, national_idx))
n = sector_pairs.get(sector, 0)
filled = forward_fill_index(idx, min_year, max_year)
for year, log_idx in filled.items():
rows.append((sector, year, log_idx, n))
result = pl.DataFrame(
rows,
schema={"sector": pl.String, "year": pl.Int32, "log_index": pl.Float64, "n_pairs": pl.Int64},
orient="row",
)
result = result.sort("sector", "year")
result.write_parquet(args.output)
size_mb = args.output.stat().st_size / (1024 * 1024)
print(f"\nWrote {args.output} ({size_mb:.1f} MB)")
print(f" {result['sector'].n_unique():,} sectors × {max_year - min_year + 1} years = {len(result):,} rows")
if __name__ == "__main__":
main()