1375 lines
50 KiB
Rust
1375 lines
50 KiB
Rust
use anyhow::{bail, Context};
|
||
use polars::lazy::frame::LazyFrame;
|
||
use polars::prelude::*;
|
||
use rayon::prelude::*;
|
||
use serde::Serialize;
|
||
use std::path::Path;
|
||
|
||
use rustc_hash::FxHashMap;
|
||
|
||
use crate::consts::{H3_PRECOMPUTE_MAX, HISTOGRAM_BINS, NAN_U16, QUANT_SCALE};
|
||
use crate::features::{self, Bounds};
|
||
|
||
fn is_numeric_dtype(dtype: &DataType) -> bool {
|
||
matches!(
|
||
dtype,
|
||
DataType::Int8
|
||
| DataType::Int16
|
||
| DataType::Int32
|
||
| DataType::Int64
|
||
| DataType::UInt8
|
||
| DataType::UInt16
|
||
| DataType::UInt32
|
||
| DataType::UInt64
|
||
| DataType::Float32
|
||
| DataType::Float64
|
||
| DataType::Datetime(_, _)
|
||
| DataType::Date
|
||
)
|
||
}
|
||
|
||
fn is_datetime_dtype(dtype: &DataType) -> bool {
|
||
matches!(dtype, DataType::Datetime(_, _) | DataType::Date)
|
||
}
|
||
|
||
/// Histogram with outlier buckets at the edges.
|
||
/// - Bin 0: [min, p1) — low outliers
|
||
/// - Bins 1 to n-2: [p1, p99) — main distribution, evenly divided
|
||
/// - Bin n-1: [p99, max] — high outliers
|
||
#[derive(Serialize, Clone)]
|
||
pub struct Histogram {
|
||
pub min: f32,
|
||
pub max: f32,
|
||
/// 1st percentile (left edge of main distribution)
|
||
pub p1: f32,
|
||
/// 99th percentile (right edge of main distribution)
|
||
pub p99: f32,
|
||
pub counts: Vec<u64>,
|
||
}
|
||
|
||
impl Histogram {
|
||
/// Return the bin index for a given value using the outlier-bracket layout.
|
||
#[cfg(test)]
|
||
pub fn bin_for_value(&self, value: f32) -> usize {
|
||
let num_bins = self.counts.len();
|
||
if value < self.p1 {
|
||
0
|
||
} else if value >= self.p99 {
|
||
num_bins - 1
|
||
} else {
|
||
let middle_bins = num_bins.saturating_sub(2);
|
||
if middle_bins > 0 && self.p99 > self.p1 {
|
||
let width = (self.p99 - self.p1) / middle_bins as f32;
|
||
let middle_bin = ((value - self.p1) / width) as usize;
|
||
(1 + middle_bin).min(num_bins - 2)
|
||
} else {
|
||
num_bins / 2
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Width of a single middle bin (bins 1..n-2).
|
||
#[cfg(test)]
|
||
pub fn middle_bin_width(&self) -> f32 {
|
||
let middle_bins = self.counts.len().saturating_sub(2);
|
||
if middle_bins > 0 && self.p99 > self.p1 {
|
||
(self.p99 - self.p1) / middle_bins as f32
|
||
} else {
|
||
0.0
|
||
}
|
||
}
|
||
}
|
||
|
||
pub struct FeatureStats {
|
||
pub slider_min: f32,
|
||
pub slider_max: f32,
|
||
pub histogram: Histogram,
|
||
}
|
||
|
||
#[derive(Serialize, Clone)]
|
||
pub struct RenovationEvent {
|
||
pub year: i32,
|
||
pub event: String,
|
||
}
|
||
|
||
/// Lightweight reference to quantization parameters for decoding u16 feature data.
|
||
pub struct QuantRef<'a> {
|
||
pub dequant_a: &'a [f32],
|
||
pub quant_min: &'a [f32],
|
||
pub quant_range: &'a [f32],
|
||
pub num_numeric: usize,
|
||
}
|
||
|
||
impl QuantRef<'_> {
|
||
/// Decode a raw u16 value back to f32.
|
||
#[inline]
|
||
pub fn decode(&self, feat_idx: usize, raw: u16) -> f32 {
|
||
if raw == NAN_U16 {
|
||
return f32::NAN;
|
||
}
|
||
if feat_idx >= self.num_numeric {
|
||
raw as f32
|
||
} else {
|
||
raw as f32 * self.dequant_a[feat_idx] + self.quant_min[feat_idx]
|
||
}
|
||
}
|
||
|
||
/// Encode a filter minimum bound to u16 (floors to include boundary values).
|
||
#[inline]
|
||
pub fn encode_min(&self, feat_idx: usize, value: f32) -> u16 {
|
||
if !value.is_finite() || self.quant_range[feat_idx] == 0.0 {
|
||
return 0;
|
||
}
|
||
let norm = (value - self.quant_min[feat_idx]) / self.quant_range[feat_idx];
|
||
(norm * QUANT_SCALE).floor().clamp(0.0, QUANT_SCALE) as u16
|
||
}
|
||
|
||
/// Encode a filter maximum bound to u16 (ceils to include boundary values).
|
||
#[inline]
|
||
pub fn encode_max(&self, feat_idx: usize, value: f32) -> u16 {
|
||
if !value.is_finite() || self.quant_range[feat_idx] == 0.0 {
|
||
return QUANT_SCALE as u16;
|
||
}
|
||
let norm = (value - self.quant_min[feat_idx]) / self.quant_range[feat_idx];
|
||
(norm * QUANT_SCALE).ceil().clamp(0.0, QUANT_SCALE) as u16
|
||
}
|
||
}
|
||
|
||
pub struct PropertyData {
|
||
pub lat: Vec<f32>,
|
||
pub lon: Vec<f32>,
|
||
pub feature_names: Vec<String>,
|
||
pub num_features: usize,
|
||
/// Number of numeric features (enum features start at this index).
|
||
pub num_numeric: usize,
|
||
/// Row-major flat array: feature_data[row * num_features + feat_idx].
|
||
/// Quantized to u16. NaN sentinel = u16::MAX (65535).
|
||
/// Numeric features: encoded via (val - min) / range * 65534.
|
||
/// Enum features: stored directly as u16 cast of the f32 index.
|
||
pub feature_data: Vec<u16>,
|
||
/// Per-feature: range / QUANT_SCALE for fast decode.
|
||
dequant_a: Vec<f32>,
|
||
/// Per-feature: minimum value (offset for dequantization).
|
||
quant_min: Vec<f32>,
|
||
/// Per-feature: max - min (for encoding filter bounds).
|
||
quant_range: Vec<f32>,
|
||
pub feature_stats: Vec<FeatureStats>,
|
||
/// Contiguous buffer holding all address strings end-to-end.
|
||
address_buffer: String,
|
||
/// Byte offset into `address_buffer` where each row's address starts.
|
||
address_offsets: Vec<u32>,
|
||
/// Length in bytes of each row's address.
|
||
address_lengths: Vec<u16>,
|
||
/// Interned postcodes: reader is thread-safe, keys index into it.
|
||
postcode_interner: lasso::RodeoReader,
|
||
postcode_keys: Vec<lasso::Spur>,
|
||
/// For enum features: maps feature index to list of possible string values.
|
||
/// Index in values list corresponds to the u16 value stored in feature_data.
|
||
pub enum_values: rustc_hash::FxHashMap<usize, Vec<String>>,
|
||
/// For enum features: maps feature index to per-value global counts (same order as enum_values).
|
||
pub enum_counts: rustc_hash::FxHashMap<usize, Vec<u64>>,
|
||
/// Per-row flag: true = construction date is approximate (from EPC band),
|
||
/// false = exact (from new-build transaction date).
|
||
/// Bit-packed: byte `row / 8`, bit `row % 8`. 8x smaller than Vec<bool>.
|
||
approx_build_date_bits: Vec<u8>,
|
||
/// Per-row renovation events. Keyed by (permuted) row index.
|
||
/// Only rows with events are present in the map.
|
||
renovation_history: FxHashMap<u32, Vec<RenovationEvent>>,
|
||
property_sub_type: FxHashMap<u32, String>,
|
||
price_qualifier: FxHashMap<u32, String>,
|
||
}
|
||
|
||
impl PropertyData {
|
||
/// Get the address string for a given row.
|
||
pub fn address(&self, row: usize) -> &str {
|
||
let offset = self.address_offsets[row] as usize;
|
||
let length = self.address_lengths[row] as usize;
|
||
&self.address_buffer[offset..offset + length]
|
||
}
|
||
|
||
/// Get the postcode string for a given row.
|
||
pub fn postcode(&self, row: usize) -> &str {
|
||
self.postcode_interner.resolve(&self.postcode_keys[row])
|
||
}
|
||
|
||
/// Get postcode components for field-level borrowing (avoids conflicting borrows with feature_data).
|
||
pub fn postcode_parts(&self) -> (&lasso::RodeoReader, &[lasso::Spur]) {
|
||
(&self.postcode_interner, &self.postcode_keys)
|
||
}
|
||
|
||
/// Get the is_approx_build_date flag for a given row (bit-packed).
|
||
pub fn is_approx_build_date(&self, row: usize) -> bool {
|
||
let byte = self.approx_build_date_bits[row / 8];
|
||
byte & (1 << (row % 8)) != 0
|
||
}
|
||
|
||
/// Get renovation events for a given row (empty slice if none).
|
||
pub fn renovation_history(&self, row: usize) -> &[RenovationEvent] {
|
||
self.renovation_history
|
||
.get(&(row as u32))
|
||
.map(|v| v.as_slice())
|
||
.unwrap_or(&[])
|
||
}
|
||
|
||
/// Get property sub-type for a given row.
|
||
pub fn property_sub_type(&self, row: usize) -> Option<&str> {
|
||
self.property_sub_type
|
||
.get(&(row as u32))
|
||
.map(String::as_str)
|
||
}
|
||
|
||
/// Get price qualifier for a given row.
|
||
pub fn price_qualifier(&self, row: usize) -> Option<&str> {
|
||
self.price_qualifier.get(&(row as u32)).map(String::as_str)
|
||
}
|
||
|
||
/// Decode a single feature value from quantized u16 storage.
|
||
#[inline]
|
||
pub fn get_feature(&self, row: usize, feat_idx: usize) -> f32 {
|
||
let raw = self.feature_data[row * self.num_features + feat_idx];
|
||
if raw == NAN_U16 {
|
||
return f32::NAN;
|
||
}
|
||
if feat_idx >= self.num_numeric {
|
||
raw as f32
|
||
} else {
|
||
raw as f32 * self.dequant_a[feat_idx] + self.quant_min[feat_idx]
|
||
}
|
||
}
|
||
|
||
/// Get a QuantRef for passing to aggregation/filter functions.
|
||
pub fn quant_ref(&self) -> QuantRef<'_> {
|
||
QuantRef {
|
||
dequant_a: &self.dequant_a,
|
||
quant_min: &self.quant_min,
|
||
quant_range: &self.quant_range,
|
||
num_numeric: self.num_numeric,
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Compute a percentile from a uniformly-binned histogram.
|
||
/// `prelim_counts` are uniform bins over [min, max].
|
||
fn percentile_from_uniform_histogram(
|
||
count: usize,
|
||
min: f32,
|
||
max: f32,
|
||
prelim_counts: &[u64],
|
||
percentile: f32,
|
||
) -> f32 {
|
||
if count == 0 || prelim_counts.is_empty() {
|
||
return min;
|
||
}
|
||
let target = (count as f64 * percentile as f64 / 100.0).floor() as u64;
|
||
let bin_width = (max - min) / prelim_counts.len() as f32;
|
||
let mut cumulative = 0u64;
|
||
for (i, &bin_count) in prelim_counts.iter().enumerate() {
|
||
let prev_cumulative = cumulative;
|
||
cumulative += bin_count;
|
||
if cumulative > target {
|
||
// Interpolate within this bin
|
||
let bin_start = min + i as f32 * bin_width;
|
||
let fraction = if bin_count > 0 {
|
||
(target - prev_cumulative) as f32 / bin_count as f32
|
||
} else {
|
||
0.0
|
||
};
|
||
return bin_start + fraction * bin_width;
|
||
}
|
||
}
|
||
max
|
||
}
|
||
|
||
/// Build a histogram and compute slider bounds based on the feature's Bounds config.
|
||
pub fn compute_feature_stats(vals: &[f32], bounds: &Bounds, integer_bins: bool) -> FeatureStats {
|
||
// Single pass: min, max, count (skipping NaN and infinity)
|
||
let mut min = f32::INFINITY;
|
||
let mut max = f32::NEG_INFINITY;
|
||
let mut count = 0usize;
|
||
for &value in vals {
|
||
if value.is_finite() {
|
||
if value < min {
|
||
min = value;
|
||
}
|
||
if value > max {
|
||
max = value;
|
||
}
|
||
count += 1;
|
||
}
|
||
}
|
||
|
||
if count == 0 {
|
||
let (slider_min, slider_max) = match bounds {
|
||
Bounds::Fixed {
|
||
min: fmin,
|
||
max: fmax,
|
||
} => (*fmin, *fmax),
|
||
Bounds::Percentile { .. } => (0.0, 0.0),
|
||
};
|
||
return FeatureStats {
|
||
slider_min,
|
||
slider_max,
|
||
histogram: Histogram {
|
||
min: 0.0,
|
||
max: 0.0,
|
||
p1: 0.0,
|
||
p99: 0.0,
|
||
counts: vec![0; HISTOGRAM_BINS],
|
||
},
|
||
};
|
||
}
|
||
|
||
// Build preliminary histogram with uniform bins to compute percentiles
|
||
// Use full HISTOGRAM_BINS for percentile precision
|
||
let range = if max == min { 1.0 } else { max - min };
|
||
let prelim_max = min + range * (1.0 + 1e-6);
|
||
let prelim_bin_width = (prelim_max - min) / HISTOGRAM_BINS as f32;
|
||
|
||
let mut prelim_counts = vec![0u64; HISTOGRAM_BINS];
|
||
for &value in vals {
|
||
if value.is_finite() {
|
||
let bin = ((value - min) / prelim_bin_width) as usize;
|
||
prelim_counts[bin.min(HISTOGRAM_BINS - 1)] += 1;
|
||
}
|
||
}
|
||
|
||
// Compute p1 and p99 from preliminary histogram
|
||
let mut p1 = percentile_from_uniform_histogram(count, min, max, &prelim_counts, 1.0);
|
||
let mut p99 = percentile_from_uniform_histogram(count, min, max, &prelim_counts, 99.0);
|
||
|
||
// Iterative refinement for outlier-dominated distributions.
|
||
// When extreme outliers (e.g. 317M sqm from web scraping) dominate the range,
|
||
// the uniform histogram puts all real data in one bin, making percentile
|
||
// estimation useless. Zoom into the estimated data region and recompute.
|
||
let mut refined_counts = prelim_counts;
|
||
let mut refined_count = count;
|
||
let mut refined_min = min;
|
||
let mut refined_max = max;
|
||
for _ in 0..3 {
|
||
let iqr = p99 - p1;
|
||
if iqr <= 0.0 || (refined_max - refined_min) <= 5.0 * iqr {
|
||
break;
|
||
}
|
||
let new_min = (p1 - iqr).max(min);
|
||
let new_max = p99 + iqr;
|
||
if new_max <= new_min {
|
||
break;
|
||
}
|
||
let bin_width = (new_max - new_min) / HISTOGRAM_BINS as f32;
|
||
let mut counts = vec![0u64; HISTOGRAM_BINS];
|
||
let mut cnt = 0usize;
|
||
for &value in vals {
|
||
if value.is_finite() && value >= new_min && value <= new_max {
|
||
let bin = ((value - new_min) / bin_width) as usize;
|
||
counts[bin.min(HISTOGRAM_BINS - 1)] += 1;
|
||
cnt += 1;
|
||
}
|
||
}
|
||
if cnt == 0 {
|
||
break;
|
||
}
|
||
p1 = percentile_from_uniform_histogram(cnt, new_min, new_max, &counts, 1.0);
|
||
p99 = percentile_from_uniform_histogram(cnt, new_min, new_max, &counts, 99.0);
|
||
refined_counts = counts;
|
||
refined_count = cnt;
|
||
refined_min = new_min;
|
||
refined_max = new_max;
|
||
}
|
||
|
||
// For integer-binned features, snap p1/p99 to integer boundaries
|
||
// so each middle bin is exactly 1 unit wide.
|
||
if integer_bins {
|
||
p1 = p1.floor();
|
||
p99 = p99.ceil();
|
||
}
|
||
|
||
// Determine number of histogram bins
|
||
let num_bins = if integer_bins && p99 > p1 {
|
||
// One middle bin per integer + 2 outlier bins
|
||
(p99 - p1) as usize + 2
|
||
} else {
|
||
// Count unique values within the p1–p99 range to cap histogram bins.
|
||
// Using the full-range cardinality would over-allocate bins when outliers
|
||
// inflate it (e.g. bedrooms: 1–137 unique values but only ~10 within p1–p99).
|
||
let cardinality = {
|
||
let mut unique_set = rustc_hash::FxHashSet::default();
|
||
for &val in vals {
|
||
if val.is_finite() && val >= p1 && val <= p99 {
|
||
unique_set.insert(val.to_bits());
|
||
}
|
||
}
|
||
unique_set.len()
|
||
};
|
||
HISTOGRAM_BINS.min(cardinality).max(3)
|
||
};
|
||
|
||
// Build final histogram with outlier bins at edges:
|
||
// - Bin 0: [min, p1) — low outliers
|
||
// - Bins 1 to n-2: [p1, p99) — main distribution, evenly divided
|
||
// - Bin n-1: [p99, max] — high outliers
|
||
let mut counts = vec![0u64; num_bins];
|
||
let middle_bins = num_bins.saturating_sub(2);
|
||
let middle_width = if middle_bins > 0 && p99 > p1 {
|
||
(p99 - p1) / middle_bins as f32
|
||
} else {
|
||
0.0
|
||
};
|
||
|
||
for &value in vals {
|
||
if value.is_finite() {
|
||
let bin = if value < p1 {
|
||
0 // Low outlier bin
|
||
} else if value >= p99 {
|
||
num_bins - 1 // High outlier bin
|
||
} else if middle_width > 0.0 {
|
||
// Middle bins (1 to n-2)
|
||
let middle_bin = ((value - p1) / middle_width) as usize;
|
||
(1 + middle_bin).min(num_bins - 2)
|
||
} else {
|
||
num_bins / 2 // Fallback if p1 == p99
|
||
};
|
||
counts[bin] += 1;
|
||
}
|
||
}
|
||
|
||
let histogram = Histogram {
|
||
min: refined_min,
|
||
max: refined_max,
|
||
p1,
|
||
p99,
|
||
counts,
|
||
};
|
||
|
||
// Compute slider bounds (use refined histogram for accurate percentiles)
|
||
let (slider_min, slider_max) = match bounds {
|
||
Bounds::Fixed {
|
||
min: fmin,
|
||
max: fmax,
|
||
} => (*fmin, *fmax),
|
||
Bounds::Percentile { low, high } => {
|
||
let p_low = percentile_from_uniform_histogram(
|
||
refined_count,
|
||
refined_min,
|
||
refined_max,
|
||
&refined_counts,
|
||
*low as f32,
|
||
);
|
||
let p_high = percentile_from_uniform_histogram(
|
||
refined_count,
|
||
refined_min,
|
||
refined_max,
|
||
&refined_counts,
|
||
*high as f32,
|
||
);
|
||
(p_low, p_high)
|
||
}
|
||
};
|
||
|
||
FeatureStats {
|
||
slider_min,
|
||
slider_max,
|
||
histogram,
|
||
}
|
||
}
|
||
|
||
fn column_to_f32_vec(column: &Column) -> anyhow::Result<Vec<f32>> {
|
||
let float_series = column
|
||
.cast(&DataType::Float32)
|
||
.context("Failed to cast column to Float32")?;
|
||
let chunked = float_series
|
||
.f32()
|
||
.context("Failed to get f32 chunked array")?;
|
||
Ok(chunked
|
||
.into_iter()
|
||
.map(|value| value.unwrap_or(f32::NAN))
|
||
.collect())
|
||
}
|
||
|
||
/// Precompute H3 cell IDs for all rows at the maximum resolution only.
|
||
/// Parent cells for lower resolutions are derived on the fly via `CellIndex::parent()`.
|
||
pub fn precompute_h3(lat: &[f32], lon: &[f32]) -> anyhow::Result<Vec<u64>> {
|
||
let res = H3_PRECOMPUTE_MAX;
|
||
tracing::info!("Precomputing H3 cells at resolution {}", res);
|
||
|
||
let h3_res =
|
||
h3o::Resolution::try_from(res).with_context(|| format!("Invalid H3 resolution: {res}"))?;
|
||
|
||
let cells: Vec<u64> = lat
|
||
.par_iter()
|
||
.zip(lon.par_iter())
|
||
.enumerate()
|
||
.map(|(i, (&latitude, &longitude))| {
|
||
let coord = h3o::LatLng::new(latitude as f64, longitude as f64).unwrap_or_else(|err| {
|
||
panic!(
|
||
"Invalid coordinates at row {}: lat={}, lon={}: {}",
|
||
i, latitude, longitude, err
|
||
)
|
||
});
|
||
u64::from(coord.to_cell(h3_res))
|
||
})
|
||
.collect();
|
||
|
||
tracing::info!("H3 precomputation complete ({} cells)", cells.len());
|
||
Ok(cells)
|
||
}
|
||
|
||
impl PropertyData {
|
||
pub fn load(properties_path: &Path, postcode_features_path: &Path) -> anyhow::Result<Self> {
|
||
// Load postcode.parquet
|
||
tracing::info!(
|
||
"Loading postcode features from {:?}",
|
||
postcode_features_path
|
||
);
|
||
let postcode_df = LazyFrame::scan_parquet(postcode_features_path, Default::default())
|
||
.context("Failed to scan postcode parquet")?
|
||
.collect()
|
||
.context("Failed to read postcode parquet")?;
|
||
tracing::info!(rows = postcode_df.height(), "Postcode features loaded");
|
||
|
||
// Load properties.parquet and join with postcode data for lat/lon + area features
|
||
tracing::info!("Loading properties from {:?}", properties_path);
|
||
let properties_lf = LazyFrame::scan_parquet(properties_path, Default::default())
|
||
.context("Failed to scan properties parquet")?;
|
||
let combined = properties_lf
|
||
.join(
|
||
postcode_df.clone().lazy(),
|
||
[col("Postcode")],
|
||
[col("Postcode")],
|
||
JoinArgs::new(JoinType::Left),
|
||
)
|
||
.collect()
|
||
.context("Failed to join properties with postcodes")?;
|
||
let total_rows = combined.height();
|
||
tracing::info!(rows = total_rows, "Properties joined with postcodes");
|
||
|
||
// Get configured feature/enum names in config order
|
||
let numeric_names = features::all_numeric_feature_names();
|
||
let enum_names = features::all_enum_feature_names();
|
||
|
||
let schema = combined.schema();
|
||
|
||
for name in &numeric_names {
|
||
match schema.get(name) {
|
||
Some(dtype) if is_numeric_dtype(dtype) => {}
|
||
Some(dtype) => bail!(
|
||
"Configured numeric feature '{}' has non-numeric type {:?}",
|
||
name,
|
||
dtype
|
||
),
|
||
None => bail!(
|
||
"Configured numeric feature '{}' not found in combined schema",
|
||
name
|
||
),
|
||
}
|
||
}
|
||
for name in &enum_names {
|
||
match schema.get(name) {
|
||
Some(dtype) if matches!(dtype, DataType::String) || dtype.is_categorical() => {}
|
||
Some(dtype) => bail!(
|
||
"Configured enum feature '{}' has unexpected type {:?}",
|
||
name,
|
||
dtype
|
||
),
|
||
None => bail!(
|
||
"Configured enum feature '{}' not found in combined schema",
|
||
name
|
||
),
|
||
}
|
||
}
|
||
|
||
// Combine numeric and enum feature names (numeric first, then enum)
|
||
let feature_names: Vec<String> = numeric_names
|
||
.iter()
|
||
.chain(enum_names.iter())
|
||
.map(|name| name.to_string())
|
||
.collect();
|
||
let num_features = feature_names.len();
|
||
let num_numeric = numeric_names.len();
|
||
tracing::info!(
|
||
numeric = num_numeric,
|
||
enums = enum_names.len(),
|
||
total = num_features,
|
||
"Feature columns from config"
|
||
);
|
||
|
||
// Build select expressions for the combined DataFrame
|
||
let mut select_exprs: Vec<polars::prelude::Expr> = vec![];
|
||
select_exprs.push(col("lat").cast(DataType::Float32));
|
||
select_exprs.push(col("lon").cast(DataType::Float32));
|
||
|
||
// Select numeric features as Float32 (datetime columns → fractional year)
|
||
for &name in &numeric_names {
|
||
if is_datetime_dtype(schema.get(name).unwrap()) {
|
||
select_exprs.push(
|
||
(col(name).dt().year().cast(DataType::Float32)
|
||
+ (col(name).dt().month().cast(DataType::Float32) - lit(1.0f32))
|
||
/ lit(12.0f32))
|
||
.alias(name),
|
||
);
|
||
} else {
|
||
select_exprs.push(col(name).cast(DataType::Float32));
|
||
}
|
||
}
|
||
|
||
// String columns for address/postcode and property metadata
|
||
for &string_col_name in &[
|
||
"Address per Property Register",
|
||
"Address per EPC",
|
||
"Postcode",
|
||
"Property sub-type",
|
||
"Price qualifier",
|
||
] {
|
||
if schema.get(string_col_name).is_some() {
|
||
select_exprs.push(col(string_col_name).cast(DataType::String));
|
||
}
|
||
}
|
||
|
||
// Enum features as String
|
||
for &name in &enum_names {
|
||
select_exprs.push(col(name).cast(DataType::String));
|
||
}
|
||
|
||
// Optional columns
|
||
let has_approx_col = schema.get("Is construction date approximate").is_some();
|
||
if has_approx_col {
|
||
select_exprs.push(col("Is construction date approximate").cast(DataType::Float32));
|
||
}
|
||
let has_renovation_history = schema.get("renovation_history").is_some();
|
||
if has_renovation_history {
|
||
select_exprs.push(col("renovation_history"));
|
||
}
|
||
let df = combined
|
||
.lazy()
|
||
.filter(col("lat").is_not_null().and(col("lon").is_not_null()))
|
||
.select(select_exprs)
|
||
.collect()
|
||
.context("Failed to select columns from combined data")?;
|
||
|
||
let row_count = df.height();
|
||
if row_count == 0 {
|
||
bail!("No property rows have usable coordinates after joining postcode data");
|
||
}
|
||
let dropped_coordinate_rows = total_rows.saturating_sub(row_count);
|
||
if dropped_coordinate_rows > 0 {
|
||
tracing::warn!(
|
||
rows = dropped_coordinate_rows,
|
||
"Dropped properties with missing postcode coordinates"
|
||
);
|
||
}
|
||
tracing::info!(rows = row_count, "Combined data selected");
|
||
|
||
let lat_series = df
|
||
.column("lat")
|
||
.context("Missing 'lat' column")?
|
||
.cast(&DataType::Float32)
|
||
.context("Failed to cast 'lat' to Float32")?;
|
||
let lat: Vec<f32> = lat_series
|
||
.f32()
|
||
.context("Failed to read 'lat' as f32")?
|
||
.into_iter()
|
||
.map(|value| value.context("Missing 'lat' value after coordinate filter"))
|
||
.collect::<anyhow::Result<Vec<_>>>()?;
|
||
|
||
let lon_series = df
|
||
.column("lon")
|
||
.context("Missing 'lon' column")?
|
||
.cast(&DataType::Float32)
|
||
.context("Failed to cast 'lon' to Float32")?;
|
||
let lon: Vec<f32> = lon_series
|
||
.f32()
|
||
.context("Failed to read 'lon' as f32")?
|
||
.into_iter()
|
||
.map(|value| value.context("Missing 'lon' value after coordinate filter"))
|
||
.collect::<anyhow::Result<Vec<_>>>()?;
|
||
|
||
for (row, (&latitude, &longitude)) in lat.iter().zip(&lon).enumerate() {
|
||
if !(-90.0..=90.0).contains(&latitude) || !(-180.0..=180.0).contains(&longitude) {
|
||
bail!("Invalid coordinates at row {row}: lat={latitude}, lon={longitude}");
|
||
}
|
||
}
|
||
|
||
tracing::info!("Extracting numeric feature columns");
|
||
let numeric_col_major: Vec<Vec<f32>> = numeric_names
|
||
.par_iter()
|
||
.map(|name| {
|
||
let column = df
|
||
.column(name)
|
||
.with_context(|| format!("Missing feature column '{name}'"))?;
|
||
column_to_f32_vec(column)
|
||
})
|
||
.collect::<anyhow::Result<Vec<_>>>()?;
|
||
|
||
tracing::info!("Computing histograms for numeric features");
|
||
let numeric_feature_stats: Vec<FeatureStats> = numeric_col_major
|
||
.par_iter()
|
||
.enumerate()
|
||
.map(|(feat_index, vals)| {
|
||
let name = numeric_names[feat_index];
|
||
let bounds = features::bounds_for(name)
|
||
.with_context(|| format!("No bounds config for feature '{}'", name))?;
|
||
let stats = compute_feature_stats(vals, bounds, features::has_integer_bins(name));
|
||
tracing::debug!(
|
||
feature = %name,
|
||
slider_min = format_args!("{:.2}", stats.slider_min),
|
||
slider_max = format_args!("{:.2}", stats.slider_max),
|
||
bins = stats.histogram.counts.len(),
|
||
"Feature stats"
|
||
);
|
||
Ok(stats)
|
||
})
|
||
.collect::<anyhow::Result<Vec<_>>>()?;
|
||
|
||
// Compute quantization parameters from feature stats (numeric features).
|
||
// For features with Fixed bounds, use those bounds so the full configured range
|
||
// is representable — the histogram refinement can narrow min/max to exclude
|
||
// "outliers" that are actually valid data (e.g. ethnicity percentages).
|
||
// For Percentile-bounded features, use the (possibly refined) histogram range
|
||
// so extreme outliers don't destroy precision for the main distribution.
|
||
let mut quant_min = Vec::with_capacity(num_features);
|
||
let mut quant_range = Vec::with_capacity(num_features);
|
||
for (feat_idx, stats) in numeric_feature_stats.iter().enumerate() {
|
||
let (min, max) = match features::bounds_for(numeric_names[feat_idx]) {
|
||
Some(Bounds::Fixed { min, max }) => (*min, *max),
|
||
_ => (stats.histogram.min, stats.histogram.max),
|
||
};
|
||
quant_min.push(min);
|
||
quant_range.push(if max > min { max - min } else { 0.0 });
|
||
}
|
||
|
||
tracing::info!("Extracting string columns");
|
||
let extract_string_col = |df: &DataFrame, name: &str| -> anyhow::Result<Vec<String>> {
|
||
let column = df
|
||
.column(name)
|
||
.with_context(|| format!("Required column '{name}' not found in parquet"))?;
|
||
let string_column = column
|
||
.str()
|
||
.with_context(|| format!("Column '{name}' is not a string column"))?;
|
||
Ok(string_column
|
||
.into_iter()
|
||
.map(|value| value.unwrap_or("").to_string())
|
||
.collect())
|
||
};
|
||
|
||
let address_raw = extract_string_col(&df, "Address per Property Register")?;
|
||
let postcode_raw = extract_string_col(&df, "Postcode")?;
|
||
|
||
// Extract optional string columns
|
||
let extract_optional_string_col =
|
||
|df: &DataFrame, name: &str| -> anyhow::Result<Vec<Option<String>>> {
|
||
if let Ok(column) = df.column(name) {
|
||
let string_column = column
|
||
.str()
|
||
.with_context(|| format!("Column '{name}' is not a string column"))?;
|
||
Ok(string_column
|
||
.into_iter()
|
||
.map(|value| {
|
||
value.and_then(|s| {
|
||
let trimmed = s.trim();
|
||
if trimmed.is_empty() {
|
||
None
|
||
} else {
|
||
Some(trimmed.to_string())
|
||
}
|
||
})
|
||
})
|
||
.collect())
|
||
} else {
|
||
Ok(vec![None; row_count])
|
||
}
|
||
};
|
||
|
||
let property_sub_type_raw = extract_optional_string_col(&df, "Property sub-type")?;
|
||
let price_qualifier_raw = extract_optional_string_col(&df, "Price qualifier")?;
|
||
|
||
tracing::info!("Building enum features");
|
||
// enum_col_major: Vec<(values_list, encoded_as_f32)>
|
||
let enum_col_major: Vec<(Vec<String>, Vec<f32>)> = enum_names
|
||
.par_iter()
|
||
.filter_map(|&name| {
|
||
let column_data = df.column(name).ok()?;
|
||
let string_column = column_data.str().ok()?;
|
||
let unique_set: std::collections::HashSet<String> = string_column
|
||
.into_iter()
|
||
.filter_map(|value| {
|
||
let text = value.unwrap_or("");
|
||
if text.is_empty() {
|
||
None
|
||
} else {
|
||
Some(text.to_string())
|
||
}
|
||
})
|
||
.collect();
|
||
|
||
// Use configured order if available, otherwise alphabetical
|
||
let unique: Vec<String> = if let Some(order) = features::order_for(name) {
|
||
let mut ordered: Vec<String> = Vec::new();
|
||
for &ordered_value in order {
|
||
if unique_set.contains(ordered_value) {
|
||
ordered.push(ordered_value.to_string());
|
||
}
|
||
}
|
||
// Append any values not in the configured order, alphabetically
|
||
// Use HashSet for O(1) contains instead of O(n) slice search
|
||
let order_set: rustc_hash::FxHashSet<&str> = order.iter().copied().collect();
|
||
let mut remainder: Vec<String> = unique_set
|
||
.iter()
|
||
.filter(|value| !order_set.contains(value.as_str()))
|
||
.cloned()
|
||
.collect();
|
||
remainder.sort();
|
||
ordered.extend(remainder);
|
||
ordered
|
||
} else {
|
||
let mut sorted: Vec<String> = unique_set.into_iter().collect();
|
||
sorted.sort();
|
||
sorted
|
||
};
|
||
|
||
let value_to_idx: std::collections::HashMap<&str, f32> = unique
|
||
.iter()
|
||
.enumerate()
|
||
.map(|(index, value)| (value.as_str(), index as f32))
|
||
.collect();
|
||
|
||
let encoded: Vec<f32> = string_column
|
||
.into_iter()
|
||
.map(|value| {
|
||
let text = value.unwrap_or("");
|
||
if text.is_empty() {
|
||
f32::NAN
|
||
} else {
|
||
*value_to_idx.get(text).unwrap_or(&f32::NAN)
|
||
}
|
||
})
|
||
.collect();
|
||
|
||
tracing::debug!(column = %name, unique_values = unique.len(), "Enum feature encoded as f32");
|
||
Some((unique, encoded))
|
||
})
|
||
.collect();
|
||
|
||
// Extract is_approx_build_date: 0.0 = exact, anything else (1.0/NaN) = approximate
|
||
let is_approx_build_date_raw: Vec<bool> = if has_approx_col {
|
||
let column_data = df
|
||
.column("Is construction date approximate")
|
||
.context("Missing 'Is construction date approximate' column")?;
|
||
let float_series = column_data
|
||
.cast(&DataType::Float32)
|
||
.context("Failed to cast 'Is construction date approximate' to Float32")?;
|
||
let chunked = float_series
|
||
.f32()
|
||
.context("Failed to read 'Is construction date approximate' as f32")?;
|
||
chunked
|
||
.into_iter()
|
||
.map(|value| match value {
|
||
Some(0.0) => false,
|
||
_ => true, // 1.0 or NaN → approximate
|
||
})
|
||
.collect()
|
||
} else {
|
||
vec![true; row_count] // default: all approximate
|
||
};
|
||
|
||
// Extract renovation_history: List<Struct{year: i32, event: str}>
|
||
let mut renovation_raw: FxHashMap<u32, Vec<RenovationEvent>> = if has_renovation_history {
|
||
tracing::info!("Extracting renovation history");
|
||
let reno_col = df
|
||
.column("renovation_history")
|
||
.context("Missing renovation_history column")?;
|
||
let list_ca = reno_col
|
||
.list()
|
||
.context("renovation_history is not a list column")?;
|
||
|
||
let mut history: FxHashMap<u32, Vec<RenovationEvent>> = FxHashMap::default();
|
||
for old_row in 0..row_count {
|
||
if let Some(inner) = list_ca.get_as_series(old_row) {
|
||
if inner.is_empty() {
|
||
continue;
|
||
}
|
||
let structs = inner
|
||
.struct_()
|
||
.context("renovation_history inner is not a struct")?;
|
||
let years = structs
|
||
.field_by_name("year")
|
||
.context("Missing 'year' field in renovation_history struct")?;
|
||
let events = structs
|
||
.field_by_name("event")
|
||
.context("Missing 'event' field in renovation_history struct")?;
|
||
|
||
let mut row_events = Vec::new();
|
||
for idx in 0..inner.len() {
|
||
let year = years.get(idx).context("Failed to get year value")?;
|
||
let event = events.get(idx).context("Failed to get event value")?;
|
||
if let (AnyValue::Int32(yr), AnyValue::String(ev)) = (&year, &event) {
|
||
row_events.push(RenovationEvent {
|
||
year: *yr,
|
||
event: ev.to_string(),
|
||
});
|
||
}
|
||
}
|
||
if !row_events.is_empty() {
|
||
history.insert(old_row as u32, row_events);
|
||
}
|
||
}
|
||
}
|
||
tracing::info!(
|
||
properties_with_events = history.len(),
|
||
"Renovation history extracted"
|
||
);
|
||
history
|
||
} else {
|
||
FxHashMap::default()
|
||
};
|
||
|
||
// Sort all rows by spatial locality so that grid queries access
|
||
// contiguous memory (sequential reads instead of random DRAM accesses).
|
||
tracing::info!("Sorting rows by spatial locality");
|
||
let grid_cell_size = 0.01_f32;
|
||
let min_lat_val = lat.iter().cloned().fold(f32::INFINITY, f32::min) - grid_cell_size;
|
||
let min_lon_val = lon.iter().cloned().fold(f32::INFINITY, f32::min) - grid_cell_size;
|
||
let max_lon_val = lon.iter().cloned().fold(f32::NEG_INFINITY, f32::max) + grid_cell_size;
|
||
let grid_cols = ((max_lon_val - min_lon_val) / grid_cell_size).ceil() as u64 + 1;
|
||
|
||
let mut perm: Vec<u32> = (0..row_count as u32).collect();
|
||
perm.par_sort_unstable_by_key(|&perm_index| {
|
||
let grid_row = ((lat[perm_index as usize] - min_lat_val) / grid_cell_size) as u64;
|
||
let grid_col = ((lon[perm_index as usize] - min_lon_val) / grid_cell_size) as u64;
|
||
grid_row * grid_cols + grid_col
|
||
});
|
||
|
||
let lat: Vec<f32> = perm
|
||
.iter()
|
||
.map(|&perm_index| lat[perm_index as usize])
|
||
.collect();
|
||
let lon: Vec<f32> = perm
|
||
.iter()
|
||
.map(|&perm_index| lon[perm_index as usize])
|
||
.collect();
|
||
|
||
// Build contiguous address buffer (permuted)
|
||
tracing::info!("Building interned strings");
|
||
let total_addr_bytes: usize = address_raw.iter().map(|text| text.len()).sum();
|
||
let mut address_buffer = String::with_capacity(total_addr_bytes);
|
||
let mut address_offsets = Vec::with_capacity(row_count);
|
||
let mut address_lengths = Vec::with_capacity(row_count);
|
||
for &perm_index in &perm {
|
||
let addr = &address_raw[perm_index as usize];
|
||
let offset = address_buffer.len() as u32;
|
||
let length = addr.len().min(u16::MAX as usize) as u16;
|
||
address_offsets.push(offset);
|
||
address_lengths.push(length);
|
||
address_buffer.push_str(&addr[..length as usize]);
|
||
}
|
||
|
||
// Intern postcodes (permuted)
|
||
let mut postcode_rodeo = lasso::Rodeo::default();
|
||
let postcode_keys: Vec<lasso::Spur> = perm
|
||
.iter()
|
||
.map(|&perm_index| postcode_rodeo.get_or_intern(&postcode_raw[perm_index as usize]))
|
||
.collect();
|
||
let postcode_interner = postcode_rodeo.into_reader();
|
||
|
||
// Pack is_approx_build_date into a bitvec (8 bools per byte)
|
||
let num_bytes = row_count.div_ceil(8);
|
||
let mut approx_build_date_bits = vec![0u8; num_bytes];
|
||
for (new_row, &old_row) in perm.iter().enumerate() {
|
||
if is_approx_build_date_raw[old_row as usize] {
|
||
approx_build_date_bits[new_row / 8] |= 1 << (new_row % 8);
|
||
}
|
||
}
|
||
|
||
// Re-key renovation_history by permuted row index
|
||
let renovation_history: FxHashMap<u32, Vec<RenovationEvent>> = {
|
||
let mut map =
|
||
FxHashMap::with_capacity_and_hasher(renovation_raw.len(), Default::default());
|
||
for (new_row, &old_row) in perm.iter().enumerate() {
|
||
if let Some(events) = renovation_raw.remove(&old_row) {
|
||
map.insert(new_row as u32, events);
|
||
}
|
||
}
|
||
map
|
||
};
|
||
|
||
// Permute optional string columns into sparse HashMaps
|
||
let property_sub_type: FxHashMap<u32, String> = {
|
||
let mut map = FxHashMap::default();
|
||
for (new_row, &old_row) in perm.iter().enumerate() {
|
||
if let Some(ref s) = property_sub_type_raw[old_row as usize] {
|
||
map.insert(new_row as u32, s.clone());
|
||
}
|
||
}
|
||
map
|
||
};
|
||
let price_qualifier: FxHashMap<u32, String> = {
|
||
let mut map = FxHashMap::default();
|
||
for (new_row, &old_row) in perm.iter().enumerate() {
|
||
if let Some(ref s) = price_qualifier_raw[old_row as usize] {
|
||
map.insert(new_row as u32, s.clone());
|
||
}
|
||
}
|
||
map
|
||
};
|
||
|
||
// Build enum_values map: feature_index -> list of string values
|
||
// and enum_counts map: feature_index -> per-value global counts
|
||
let mut enum_values: rustc_hash::FxHashMap<usize, Vec<String>> =
|
||
rustc_hash::FxHashMap::default();
|
||
let mut enum_counts: rustc_hash::FxHashMap<usize, Vec<u64>> =
|
||
rustc_hash::FxHashMap::default();
|
||
for (enum_idx, (values, encoded)) in enum_col_major.iter().enumerate() {
|
||
let feature_idx = num_numeric + enum_idx;
|
||
enum_values.insert(feature_idx, values.clone());
|
||
let mut counts = vec![0u64; values.len()];
|
||
for &val in encoded {
|
||
if val.is_finite() {
|
||
let idx = val as usize;
|
||
if idx < counts.len() {
|
||
counts[idx] += 1;
|
||
}
|
||
}
|
||
}
|
||
enum_counts.insert(feature_idx, counts);
|
||
}
|
||
|
||
// Build feature_stats: numeric stats + placeholder stats for enums
|
||
let mut feature_stats = numeric_feature_stats;
|
||
for (values, _) in &enum_col_major {
|
||
// For enum features, slider range is 0 to num_values-1
|
||
let num_values = values.len();
|
||
let max_val = num_values as f32;
|
||
feature_stats.push(FeatureStats {
|
||
slider_min: 0.0,
|
||
slider_max: (num_values.saturating_sub(1)) as f32,
|
||
histogram: Histogram {
|
||
min: 0.0,
|
||
max: max_val,
|
||
p1: 0.0,
|
||
p99: max_val,
|
||
counts: vec![0; num_values.max(1)],
|
||
},
|
||
});
|
||
// Enum features: not quantized, stored directly as u16
|
||
quant_min.push(0.0);
|
||
quant_range.push(0.0);
|
||
}
|
||
let dequant_a: Vec<f32> = quant_range
|
||
.iter()
|
||
.map(|&r| if r > 0.0 { r / QUANT_SCALE } else { 0.0 })
|
||
.collect();
|
||
|
||
// Transpose to row-major AND apply spatial permutation in one pass.
|
||
// Combines numeric and enum features into a single feature_data array, quantized to u16.
|
||
tracing::info!("Transposing to row-major layout (spatially sorted, quantized to u16)");
|
||
let mut feature_data = vec![NAN_U16; row_count * num_features];
|
||
feature_data
|
||
.par_chunks_mut(num_features)
|
||
.enumerate()
|
||
.for_each(|(new_row, row_slice)| {
|
||
let old_index = perm[new_row] as usize;
|
||
// Numeric features: quantize to u16
|
||
for (feat_idx, col_vec) in numeric_col_major.iter().enumerate() {
|
||
let value = col_vec[old_index];
|
||
row_slice[feat_idx] = if value.is_finite() {
|
||
let range = quant_range[feat_idx];
|
||
if range > 0.0 {
|
||
let normalized = (value - quant_min[feat_idx]) / range;
|
||
(normalized * QUANT_SCALE).round().clamp(0.0, QUANT_SCALE) as u16
|
||
} else {
|
||
0
|
||
}
|
||
} else {
|
||
NAN_U16
|
||
};
|
||
}
|
||
// Enum features: store as u16 directly
|
||
for (enum_idx, (_, encoded)) in enum_col_major.iter().enumerate() {
|
||
let value = encoded[old_index];
|
||
row_slice[num_numeric + enum_idx] = if value.is_finite() {
|
||
value as u16
|
||
} else {
|
||
NAN_U16
|
||
};
|
||
}
|
||
});
|
||
|
||
tracing::info!("Data loading complete");
|
||
|
||
Ok(PropertyData {
|
||
lat,
|
||
lon,
|
||
feature_names,
|
||
num_features,
|
||
num_numeric,
|
||
feature_data,
|
||
dequant_a,
|
||
quant_min,
|
||
quant_range,
|
||
feature_stats,
|
||
address_buffer,
|
||
address_offsets,
|
||
address_lengths,
|
||
postcode_interner,
|
||
postcode_keys,
|
||
enum_values,
|
||
enum_counts,
|
||
approx_build_date_bits,
|
||
renovation_history,
|
||
property_sub_type,
|
||
price_qualifier,
|
||
})
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use crate::features::Bounds;
|
||
|
||
fn make_fixed_bounds(min: f32, max: f32) -> Bounds {
|
||
Bounds::Fixed { min, max }
|
||
}
|
||
|
||
fn make_percentile_bounds(low: f64, high: f64) -> Bounds {
|
||
Bounds::Percentile { low, high }
|
||
}
|
||
|
||
#[test]
|
||
fn histogram_empty_data() {
|
||
let data: Vec<f32> = vec![];
|
||
let bounds = make_fixed_bounds(0.0, 100.0);
|
||
let stats = compute_feature_stats(&data, &bounds, false);
|
||
|
||
assert_eq!(stats.slider_min, 0.0);
|
||
assert_eq!(stats.slider_max, 100.0);
|
||
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 0);
|
||
}
|
||
|
||
#[test]
|
||
fn histogram_single_value() {
|
||
let data = vec![50.0_f32];
|
||
let bounds = make_fixed_bounds(0.0, 100.0);
|
||
let stats = compute_feature_stats(&data, &bounds, false);
|
||
|
||
assert_eq!(stats.histogram.min, 50.0);
|
||
assert_eq!(stats.histogram.max, 50.0);
|
||
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 1);
|
||
}
|
||
|
||
#[test]
|
||
fn histogram_uniform_distribution() {
|
||
let data: Vec<f32> = (0..100).map(|i| i as f32).collect();
|
||
let bounds = make_fixed_bounds(0.0, 100.0);
|
||
let stats = compute_feature_stats(&data, &bounds, false);
|
||
|
||
assert_eq!(stats.histogram.min, 0.0);
|
||
assert_eq!(stats.histogram.max, 99.0);
|
||
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 100);
|
||
}
|
||
|
||
#[test]
|
||
fn histogram_with_nan_values() {
|
||
let data = vec![10.0_f32, f32::NAN, 20.0, f32::NAN, 30.0];
|
||
let bounds = make_fixed_bounds(0.0, 100.0);
|
||
let stats = compute_feature_stats(&data, &bounds, false);
|
||
|
||
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 3);
|
||
assert_eq!(stats.histogram.min, 10.0);
|
||
assert_eq!(stats.histogram.max, 30.0);
|
||
}
|
||
|
||
#[test]
|
||
fn histogram_all_nan() {
|
||
let data = vec![f32::NAN, f32::NAN, f32::NAN];
|
||
let bounds = make_fixed_bounds(0.0, 100.0);
|
||
let stats = compute_feature_stats(&data, &bounds, false);
|
||
|
||
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 0);
|
||
}
|
||
|
||
#[test]
|
||
fn histogram_all_same_value() {
|
||
let data = vec![42.0_f32; 1000];
|
||
let bounds = make_fixed_bounds(0.0, 100.0);
|
||
let stats = compute_feature_stats(&data, &bounds, false);
|
||
|
||
assert_eq!(stats.histogram.min, 42.0);
|
||
assert_eq!(stats.histogram.max, 42.0);
|
||
assert_eq!(stats.histogram.p1, 42.0);
|
||
assert_eq!(stats.histogram.p99, 42.0);
|
||
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 1000);
|
||
}
|
||
|
||
#[test]
|
||
fn histogram_percentile_bounds() {
|
||
let mut data: Vec<f32> = vec![0.0]; // Low outlier
|
||
data.extend((1..99).map(|i| 50.0 + i as f32 * 0.01));
|
||
data.push(1000.0); // High outlier
|
||
|
||
let bounds = make_percentile_bounds(2.0, 98.0);
|
||
let stats = compute_feature_stats(&data, &bounds, false);
|
||
|
||
assert!(stats.slider_min > 0.0);
|
||
assert!(stats.slider_max < 1000.0);
|
||
}
|
||
|
||
#[test]
|
||
fn histogram_bin_for_value() {
|
||
let hist = Histogram {
|
||
min: 0.0,
|
||
max: 100.0,
|
||
p1: 10.0,
|
||
p99: 90.0,
|
||
counts: vec![0; 10],
|
||
};
|
||
|
||
assert_eq!(hist.bin_for_value(5.0), 0); // Low outlier bin
|
||
assert_eq!(hist.bin_for_value(95.0), 9); // High outlier bin
|
||
|
||
let mid_value = 50.0;
|
||
let bin = hist.bin_for_value(mid_value);
|
||
assert!((1..=8).contains(&bin));
|
||
}
|
||
|
||
#[test]
|
||
fn histogram_middle_bin_width() {
|
||
let hist = Histogram {
|
||
min: 0.0,
|
||
max: 100.0,
|
||
p1: 10.0,
|
||
p99: 90.0,
|
||
counts: vec![0; 10],
|
||
};
|
||
|
||
let expected_width = (90.0 - 10.0) / 8.0;
|
||
assert!((hist.middle_bin_width() - expected_width).abs() < 0.001);
|
||
}
|
||
|
||
#[test]
|
||
fn histogram_cardinality_caps_bins() {
|
||
let data = vec![1.0_f32, 1.0, 2.0, 2.0, 3.0, 3.0];
|
||
let bounds = make_fixed_bounds(0.0, 100.0);
|
||
let stats = compute_feature_stats(&data, &bounds, false);
|
||
|
||
assert_eq!(stats.histogram.counts.len(), 3);
|
||
}
|
||
|
||
#[test]
|
||
fn min_max_skips_nan() {
|
||
let values = vec![10.0_f32, f32::NAN, 20.0, f32::NAN, 5.0];
|
||
|
||
let mut min = f32::INFINITY;
|
||
let mut max = f32::NEG_INFINITY;
|
||
for &v in &values {
|
||
if v.is_finite() {
|
||
if v < min {
|
||
min = v;
|
||
}
|
||
if v > max {
|
||
max = v;
|
||
}
|
||
}
|
||
}
|
||
|
||
assert_eq!(min, 5.0);
|
||
assert_eq!(max, 20.0);
|
||
}
|
||
|
||
#[test]
|
||
fn count_skips_nan() {
|
||
let values = [1.0_f32, f32::NAN, 2.0, f32::NAN, 3.0];
|
||
let count = values.iter().filter(|v| v.is_finite()).count();
|
||
assert_eq!(count, 3);
|
||
}
|
||
|
||
#[test]
|
||
fn enum_value_counting() {
|
||
let values = vec![0.0_f32, 1.0, 1.0, 2.0, f32::NAN, 3.0, 1.0];
|
||
let enum_count = 4;
|
||
|
||
let mut counts = vec![0u64; enum_count];
|
||
for &v in &values {
|
||
if v.is_finite() {
|
||
let idx = v as usize;
|
||
if idx < enum_count {
|
||
counts[idx] += 1;
|
||
}
|
||
}
|
||
}
|
||
|
||
assert_eq!(counts[0], 1);
|
||
assert_eq!(counts[1], 3);
|
||
assert_eq!(counts[2], 1);
|
||
assert_eq!(counts[3], 1);
|
||
}
|
||
|
||
#[test]
|
||
fn infinity_values_excluded() {
|
||
let data = vec![f32::INFINITY, f32::NEG_INFINITY, 50.0];
|
||
let bounds = Bounds::Fixed {
|
||
min: 0.0,
|
||
max: 100.0,
|
||
};
|
||
let stats = compute_feature_stats(&data, &bounds, false);
|
||
|
||
assert_eq!(stats.histogram.min, 50.0);
|
||
assert_eq!(stats.histogram.max, 50.0);
|
||
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 1);
|
||
}
|
||
|
||
#[test]
|
||
fn only_finite_values() {
|
||
let data = vec![10.0_f32, 20.0, 30.0];
|
||
let bounds = Bounds::Fixed {
|
||
min: 0.0,
|
||
max: 100.0,
|
||
};
|
||
let stats = compute_feature_stats(&data, &bounds, false);
|
||
|
||
assert_eq!(stats.histogram.min, 10.0);
|
||
assert_eq!(stats.histogram.max, 30.0);
|
||
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 3);
|
||
}
|
||
|
||
#[test]
|
||
fn extreme_outlier_does_not_destroy_quantization() {
|
||
// Simulate floor area: 10k normal values (50-200 sqm) + one 317M outlier
|
||
let mut data: Vec<f32> = (0..10_000).map(|i| 50.0 + (i % 150) as f32).collect();
|
||
data.push(317_000_000.0); // Extreme outlier from web scraping
|
||
|
||
let bounds = make_percentile_bounds(0.0, 98.0);
|
||
let stats = compute_feature_stats(&data, &bounds, false);
|
||
|
||
// After refinement, histogram range should be much tighter than 317M
|
||
assert!(
|
||
stats.histogram.max < 1_000_000.0,
|
||
"histogram.max should be refined, got {}",
|
||
stats.histogram.max,
|
||
);
|
||
// p1 should be near 50, not millions
|
||
assert!(
|
||
stats.histogram.p1 < 100.0,
|
||
"p1 should be near real data, got {}",
|
||
stats.histogram.p1,
|
||
);
|
||
// Slider min should reflect actual data range
|
||
assert!(
|
||
stats.slider_min < 100.0,
|
||
"slider_min should be near real data, got {}",
|
||
stats.slider_min,
|
||
);
|
||
|
||
// Quantization using histogram.min/max should give usable range
|
||
let qmin = stats.histogram.min;
|
||
let qrange = stats.histogram.max - stats.histogram.min;
|
||
assert!(qrange > 0.0 && qrange < 1_000_000.0);
|
||
|
||
// A typical floor area (100 sqm) should be distinguishable from min
|
||
let normalized = (100.0 - qmin) / qrange;
|
||
let encoded = (normalized * QUANT_SCALE).round() as u16;
|
||
assert!(
|
||
encoded > 100,
|
||
"100 sqm should encode to a meaningful u16 value, got {}",
|
||
encoded,
|
||
);
|
||
}
|
||
}
|