Refactor and other improvements
This commit is contained in:
parent
04a78e7bfe
commit
6c90cf3c0f
47 changed files with 2705 additions and 1568 deletions
167
pipeline/transform/price_backtest.py
Normal file
167
pipeline/transform/price_backtest.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
"""Backtesting: Evaluate price index model on held-out recent sales.
|
||||
|
||||
Test set: properties with 2+ sales where the last sale is 2022-2025.
|
||||
Uses the second-to-last sale as input, predicts the last sale price.
|
||||
Compares index-based prediction against a naive baseline (raw input price).
|
||||
|
||||
Output: backtest_results.parquet with predictions vs actuals.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
|
||||
CURRENT_YEAR = 2025
|
||||
TEST_YEAR_MIN = 2022
|
||||
|
||||
|
||||
def extract_test_set(input_path: Path) -> pl.DataFrame:
|
||||
"""Extract test pairs: second-to-last sale as input, last sale as ground truth."""
|
||||
print("Loading test set...")
|
||||
df = (
|
||||
pl.scan_parquet(input_path)
|
||||
.select("Postcode", "historical_prices")
|
||||
.filter(
|
||||
pl.col("Postcode").is_not_null(),
|
||||
pl.col("historical_prices").list.len() >= 2,
|
||||
)
|
||||
.with_columns(
|
||||
pl.col("Postcode").str.slice(0, pl.col("Postcode").str.len_chars() - 2).str.strip_chars().alias("sector"),
|
||||
# Last sale (ground truth)
|
||||
pl.col("historical_prices").list.last().struct.field("year").alias("actual_year"),
|
||||
pl.col("historical_prices").list.last().struct.field("price").alias("actual_price"),
|
||||
# Second-to-last sale (input)
|
||||
pl.col("historical_prices").list.get(-2).struct.field("year").alias("input_year"),
|
||||
pl.col("historical_prices").list.get(-2).struct.field("price").alias("input_price"),
|
||||
)
|
||||
.filter(
|
||||
pl.col("actual_year") >= TEST_YEAR_MIN,
|
||||
pl.col("input_price") > 0,
|
||||
pl.col("actual_price") > 0,
|
||||
pl.col("actual_year") > pl.col("input_year"),
|
||||
)
|
||||
.collect()
|
||||
)
|
||||
print(f" {len(df):,} test pairs (last sale {TEST_YEAR_MIN}-{CURRENT_YEAR})")
|
||||
return df
|
||||
|
||||
|
||||
def predict(test: pl.DataFrame, index: pl.DataFrame) -> pl.DataFrame:
|
||||
"""Index-based prediction: adjust input price by sector index change."""
|
||||
# Join index at input year
|
||||
test = test.join(
|
||||
index.select("sector", "year", pl.col("log_index").alias("log_index_input")),
|
||||
left_on=["sector", "input_year"],
|
||||
right_on=["sector", "year"],
|
||||
how="left",
|
||||
)
|
||||
# Join index at actual year
|
||||
test = test.join(
|
||||
index.select("sector", "year", pl.col("log_index").alias("log_index_actual")),
|
||||
left_on=["sector", "actual_year"],
|
||||
right_on=["sector", "year"],
|
||||
how="left",
|
||||
)
|
||||
|
||||
test = test.with_columns(
|
||||
(
|
||||
pl.col("input_price").cast(pl.Float64)
|
||||
* (pl.col("log_index_actual") - pl.col("log_index_input")).exp()
|
||||
).fill_null(pl.col("input_price").cast(pl.Float64)).alias("predicted"),
|
||||
)
|
||||
return test
|
||||
|
||||
|
||||
def compute_metrics(actual: np.ndarray, predicted: np.ndarray) -> dict:
|
||||
"""Compute error metrics."""
|
||||
valid = np.isfinite(predicted) & np.isfinite(actual) & (actual > 0)
|
||||
actual = actual[valid]
|
||||
predicted = predicted[valid]
|
||||
|
||||
ape = np.abs(predicted - actual) / actual
|
||||
signed_err = predicted - actual
|
||||
|
||||
return {
|
||||
"MdAPE (%)": float(np.median(ape) * 100),
|
||||
"% within 10%": float(np.mean(ape <= 0.10) * 100),
|
||||
"% within 20%": float(np.mean(ape <= 0.20) * 100),
|
||||
"% within 30%": float(np.mean(ape <= 0.30) * 100),
|
||||
"MAE (£)": float(np.mean(np.abs(signed_err))),
|
||||
"Mean signed error (£)": float(np.mean(signed_err)),
|
||||
"n": int(len(actual)),
|
||||
}
|
||||
|
||||
|
||||
def print_metrics_table(metrics_by_stage: dict):
|
||||
"""Print a comparison table of metrics."""
|
||||
print("\n" + "=" * 55)
|
||||
print("BACKTEST RESULTS")
|
||||
print("=" * 55)
|
||||
|
||||
metric_names = ["MdAPE (%)", "% within 10%", "% within 20%", "% within 30%", "MAE (£)", "Mean signed error (£)", "n"]
|
||||
stages = list(metrics_by_stage.keys())
|
||||
|
||||
# Header
|
||||
header = f"{'Metric':<25s}"
|
||||
for stage in stages:
|
||||
header += f" {stage:>14s}"
|
||||
print(header)
|
||||
print("-" * 55)
|
||||
|
||||
for metric in metric_names:
|
||||
row = f"{metric:<25s}"
|
||||
for stage in stages:
|
||||
val = metrics_by_stage[stage][metric]
|
||||
if metric == "n":
|
||||
row += f" {val:>14,d}"
|
||||
elif "£" in metric:
|
||||
row += f" {val:>13,.0f}"
|
||||
else:
|
||||
row += f" {val:>13.1f}%"
|
||||
print(row)
|
||||
|
||||
print("=" * 55)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Backtest price estimation model")
|
||||
parser.add_argument("--input", type=Path, required=True, help="Path to wide.parquet")
|
||||
parser.add_argument("--index", type=Path, required=True, help="Path to price_index.parquet")
|
||||
parser.add_argument("--output", type=Path, required=True, help="Output backtest_results.parquet")
|
||||
args = parser.parse_args()
|
||||
|
||||
index = pl.read_parquet(args.index)
|
||||
print(f"Price index: {len(index):,} rows, {index['sector'].n_unique():,} sectors")
|
||||
|
||||
test = extract_test_set(args.input)
|
||||
|
||||
print("\nPredicting with price index...")
|
||||
test = predict(test, index)
|
||||
|
||||
# Compute and print metrics
|
||||
actual = test["actual_price"].to_numpy().astype(np.float64)
|
||||
metrics = {
|
||||
"Naive": compute_metrics(actual, test["input_price"].to_numpy().astype(np.float64)),
|
||||
"Index": compute_metrics(actual, test["predicted"].to_numpy().astype(np.float64)),
|
||||
}
|
||||
|
||||
print_metrics_table(metrics)
|
||||
|
||||
# Save results
|
||||
result = test.select(
|
||||
"Postcode", "sector",
|
||||
"input_year", "input_price",
|
||||
"actual_year", "actual_price",
|
||||
"predicted",
|
||||
)
|
||||
|
||||
result.write_parquet(args.output)
|
||||
size_mb = args.output.stat().st_size / (1024 * 1024)
|
||||
print(f"\nWrote {args.output} ({size_mb:.1f} MB)")
|
||||
print(f" {len(result):,} rows")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue