changes
This commit is contained in:
parent
524580eb25
commit
ffe080adef
82 changed files with 2652 additions and 2956 deletions
|
|
@ -20,7 +20,6 @@ from pipeline.transform.price_estimation.knn import (
|
|||
from pipeline.transform.price_estimation.utils import (
|
||||
CURRENT_YEAR,
|
||||
MAX_LOG_ADJUSTMENT,
|
||||
compute_seasonal_factors,
|
||||
interpolate_log_index,
|
||||
sector_expr,
|
||||
type_group_expr,
|
||||
|
|
@ -91,7 +90,7 @@ def extract_test_set(input_path: Path) -> pl.DataFrame:
|
|||
|
||||
|
||||
def predict(test: pl.DataFrame, index: pl.DataFrame) -> pl.DataFrame:
|
||||
"""Index-based prediction with interpolation, capping, and seasonal adjustment."""
|
||||
"""Index-based prediction with interpolation and capping."""
|
||||
test = interpolate_log_index(
|
||||
index, test, "sector", "type_group", "input_frac_year", "log_index_input"
|
||||
)
|
||||
|
|
@ -105,7 +104,6 @@ def predict(test: pl.DataFrame, index: pl.DataFrame) -> pl.DataFrame:
|
|||
* (pl.col("log_index_actual") - pl.col("log_index_input"))
|
||||
.clip(-MAX_LOG_ADJUSTMENT, MAX_LOG_ADJUSTMENT)
|
||||
.exp()
|
||||
* pl.col("_seasonal_adj")
|
||||
)
|
||||
.fill_null(pl.col("input_price").cast(pl.Float64))
|
||||
.alias("predicted"),
|
||||
|
|
@ -175,7 +173,10 @@ def print_metrics_table(metrics_by_stage: dict):
|
|||
def main():
|
||||
parser = argparse.ArgumentParser(description="Backtest price estimation model")
|
||||
parser.add_argument(
|
||||
"--input", type=Path, required=True, help="Path to wide.parquet"
|
||||
"--input", type=Path, required=True, help="Path to properties.parquet"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--postcodes", type=Path, required=True, help="Path to postcode.parquet (for lat/lon)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", type=Path, required=True, help="Output backtest_results.parquet"
|
||||
|
|
@ -184,38 +185,28 @@ def main():
|
|||
|
||||
# Build index from pre-test data only (temporal holdout)
|
||||
print(f"Building price index (pairs with year2 < {TEST_YEAR_MIN})...")
|
||||
index = build_index(args.input, max_pair_year=TEST_YEAR_MIN)
|
||||
index = build_index(args.input, max_pair_year=TEST_YEAR_MIN, postcodes_path=args.postcodes)
|
||||
print(
|
||||
f"\nHoldout index: {len(index):,} rows, {index['sector'].n_unique():,} sectors, "
|
||||
f"{index['type_group'].n_unique()} type groups"
|
||||
)
|
||||
|
||||
# Compute seasonal factors from pre-test data only
|
||||
seasonal = compute_seasonal_factors(args.input, max_sale_year=TEST_YEAR_MIN)
|
||||
months = [
|
||||
"Jan", "Feb", "Mar", "Apr", "May", "Jun",
|
||||
"Jul", "Aug", "Sep", "Oct", "Nov", "Dec",
|
||||
]
|
||||
print(
|
||||
f"Seasonal factors: {', '.join(f'{m}={f:.3f}' for m, f in zip(months, seasonal))}"
|
||||
)
|
||||
|
||||
test = extract_test_set(args.input)
|
||||
|
||||
# Compute seasonal adjustment for each test pair
|
||||
input_months = test["input_month"].fill_null(6).to_numpy().astype(np.int32)
|
||||
actual_months = test["actual_month"].fill_null(6).to_numpy().astype(np.int32)
|
||||
seasonal_adj = seasonal[actual_months - 1] / seasonal[input_months - 1]
|
||||
test = test.with_columns(
|
||||
pl.Series("_seasonal_adj", seasonal_adj, dtype=pl.Float64),
|
||||
)
|
||||
# Join lat/lon from postcode.parquet (properties.parquet no longer has them)
|
||||
postcodes = pl.read_parquet(args.postcodes).select("Postcode", "lat", "lon")
|
||||
test = test.join(postcodes, on="Postcode", how="left")
|
||||
|
||||
print("\nPredicting with price index...")
|
||||
test = predict(test, index)
|
||||
|
||||
# --- kNN ---
|
||||
ref_fy = float(TEST_YEAR_MIN)
|
||||
trees = build_knn_pool(args.input, index, ref_fy, max_sale_year=TEST_YEAR_MIN)
|
||||
# Pass joined LazyFrame (with lat/lon) instead of raw properties path
|
||||
pool_lf = pl.scan_parquet(args.input).join(
|
||||
postcodes.lazy(), on="Postcode", how="left"
|
||||
)
|
||||
trees = build_knn_pool(pool_lf, index, ref_fy, max_sale_year=TEST_YEAR_MIN)
|
||||
|
||||
# Interpolate log_index at reference year for temporal adjustment
|
||||
test = test.with_columns(pl.lit(ref_fy).alias("_ref_fy"))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue