"""kNN price estimation using nearby recently-sold properties. For each target property, finds k nearest sold properties of the same type, computes the median index-adjusted price-per-sqm, and multiplies by the target's floor area to produce an estimate. """ from pathlib import Path import numpy as np import polars as pl from scipy.spatial import KDTree from pipeline.transform.price_estimation.utils import ( TYPE_GROUPS, interpolate_log_index, sector_expr, type_group_expr, ) KNN_K = 20 KNN_MIN_NEIGHBORS = 5 KNN_BLEND_WEIGHT = 0.35 def _scale_coords(lat: np.ndarray, lon: np.ndarray) -> np.ndarray: """Equirectangular projection: scale lon by cos(lat) for approximate distances.""" return np.column_stack([lat, lon * np.cos(np.radians(lat))]) def build_knn_pool( source: Path | pl.LazyFrame, index: pl.DataFrame, ref_frac_year: float, max_sale_year: int | None = None, ) -> dict[str, tuple[KDTree, np.ndarray]]: """Build per-type_group KD-trees of index-adjusted price-per-sqm. Adjusts all pool properties' sale prices to ref_frac_year using the index, then builds a KD-tree per type_group for nearest-neighbor queries. Returns dict mapping type_group -> (KDTree over scaled lat/lon, adjusted_psm array). """ print("Building kNN pool...") lf = pl.scan_parquet(source) if isinstance(source, Path) else source query = lf.select( "Postcode", "Property type", "lat", "lon", "Total floor area (sqm)", "Last known price", "Date of last transaction", ).filter( pl.col("lat").is_not_null(), pl.col("lon").is_not_null(), pl.col("Total floor area (sqm)").is_not_null(), pl.col("Total floor area (sqm)") > 0, pl.col("Last known price").is_not_null(), pl.col("Last known price") > 0, pl.col("Postcode").is_not_null(), pl.col("Date of last transaction").is_not_null(), ) if max_sale_year is not None: query = query.filter( pl.col("Date of last transaction").dt.year() < max_sale_year ) pool = query.with_columns( sector_expr(), type_group_expr(), ( pl.col("Date of last transaction").dt.year().cast(pl.Float64) + (pl.col("Date of last transaction").dt.month().cast(pl.Float64) - 1.0) / 12.0 ).alias("_sale_fy"), pl.lit(ref_frac_year).alias("_ref_fy"), ).collect() pool = pool.filter(pl.col("type_group").is_not_null()) print(f" {len(pool):,} pool properties with lat/lon, floor area, price") # Interpolate log_index at sale date and reference date pool = interpolate_log_index( index, pool, "sector", "type_group", "_sale_fy", "_li_sale" ) pool = interpolate_log_index( index, pool, "sector", "type_group", "_ref_fy", "_li_ref" ) # adjusted_psm = price / floor_area * exp(log_index_ref - log_index_sale) pool = pool.with_columns( ( pl.col("Last known price").cast(pl.Float64) / pl.col("Total floor area (sqm)").cast(pl.Float64) * (pl.col("_li_ref") - pl.col("_li_sale")).exp() ).alias("_adj_psm") ).filter( pl.col("_adj_psm").is_not_null(), pl.col("_adj_psm").is_finite(), pl.col("_adj_psm") > 0, ) print(f" {len(pool):,} after index adjustment") # Build per-type KD-trees trees: dict[str, tuple[KDTree, np.ndarray]] = {} for tg in TYPE_GROUPS: sub = pool.filter(pl.col("type_group") == tg) n = len(sub) if n < KNN_MIN_NEIGHBORS: continue lat = sub["lat"].to_numpy().astype(np.float64) lon = sub["lon"].to_numpy().astype(np.float64) psm = sub["_adj_psm"].to_numpy().astype(np.float64) tree = KDTree(_scale_coords(lat, lon)) trees[tg] = (tree, psm) print(f" {tg}: {n:,}") return trees def knn_median_psm( trees: dict[str, tuple[KDTree, np.ndarray]], lat: np.ndarray, lon: np.ndarray, type_groups: np.ndarray, k: int = KNN_K, ) -> np.ndarray: """Return median adjusted-PSM of k nearest neighbours for each target. PSM is at the reference date used when building the pool. NaN where not computable (missing coords, unknown type, too few neighbors). """ n = len(lat) result = np.full(n, np.nan) for tg, (tree, psm) in trees.items(): mask = (type_groups == tg) & np.isfinite(lat) & np.isfinite(lon) idx = np.where(mask)[0] if len(idx) == 0: continue actual_k = min(k, len(psm)) if actual_k < KNN_MIN_NEIGHBORS: continue coords = _scale_coords(lat[idx], lon[idx]) _, nn_idx = tree.query(coords, k=actual_k) if nn_idx.ndim == 1: nn_idx = nn_idx.reshape(-1, 1) result[idx] = np.nanmedian(psm[nn_idx], axis=1) return result