perfect-postcode/pipeline/transform/price_estimation/test_knn.py
2026-05-14 08:09:19 +01:00

94 lines
2.7 KiB
Python

from datetime import date
import numpy as np
import polars as pl
from pipeline.transform.price_estimation.estimate import guarded_blend_estimates
from pipeline.transform.price_estimation.knn import build_knn_pool, knn_median_psm
from pipeline.transform.price_estimation.utils import TYPE_GROUPS, type_group_expr
def _flat_index() -> pl.DataFrame:
return pl.DataFrame(
{
"sector": ["AA1 1", "AA1 1"],
"type_group": ["Detached", "All"],
"year": [2026, 2026],
"log_index": [0.0, 0.0],
}
)
def test_knn_excludes_same_sale_and_uses_stable_comparables():
sale_date = date(2026, 1, 1)
rows = [
{
"Postcode": "AA1 1AA",
"Property type": "Detached",
"lat": 51.5000,
"lon": -0.1000,
"Total floor area (sqm)": 80.0,
"Last known price": 900_000.0,
"Date of last transaction": sale_date,
}
]
rows.extend(
{
"Postcode": "AA1 1AA",
"Property type": "Detached",
"lat": 51.5001 + i * 0.00001,
"lon": -0.1001,
"Total floor area (sqm)": 20.0,
"Last known price": 900_000.0,
"Date of last transaction": sale_date,
}
for i in range(5)
)
rows.extend(
{
"Postcode": f"AA1 1B{i}",
"Property type": "Detached",
"lat": 51.5010 + i * 0.00001,
"lon": -0.1010,
"Total floor area (sqm)": 80.0,
"Last known price": 200_000.0,
"Date of last transaction": sale_date,
}
for i in range(5)
)
df = pl.DataFrame(rows)
trees = build_knn_pool(df.lazy(), _flat_index(), 2026.0)
psm = knn_median_psm(
trees,
lat=np.array([51.5000]),
lon=np.array([-0.1000]),
type_groups=np.array(["Detached"]),
postcodes=np.array(["AA1 1AA"]),
last_prices=np.array([900_000.0]),
last_sale_dates=np.array(
[sale_date.toordinal() - date(1970, 1, 1).toordinal()]
),
)
assert psm[0] == 2_500.0
def test_guarded_blend_routes_unstable_knn_to_index_and_caps_uplift():
blended = guarded_blend_estimates(
index_est=np.array([120_000.0, 1_000_000.0]),
knn_est=np.array([5_000_000.0, 1_000_000.0]),
last_prices=np.array([100_000.0, 100_000.0]),
)
assert blended[0] == 120_000.0
assert blended[1] == 600_000.0
def test_bungalow_is_not_a_dead_price_index_type_group():
df = pl.DataFrame({"Property type": ["Bungalow", "Other"]}).with_columns(
type_group_expr()
)
assert "Bungalow" not in TYPE_GROUPS
assert df["type_group"].to_list() == [None, None]