Format rust

This commit is contained in:
Andras Schmelczer 2026-01-31 13:57:43 +00:00
parent 0fde087c3d
commit f60fbec9d4
5 changed files with 191 additions and 94 deletions

View file

@ -9,3 +9,6 @@ pub const HISTOGRAM_BINS: usize = 100;
/// H3 resolutions to precompute at startup (covers typical zoom levels) /// H3 resolutions to precompute at startup (covers typical zoom levels)
pub const H3_PRECOMPUTE_MIN: u8 = 4; pub const H3_PRECOMPUTE_MIN: u8 = 4;
pub const H3_PRECOMPUTE_MAX: u8 = 12; pub const H3_PRECOMPUTE_MAX: u8 = 12;
/// Columns to exclude from feature discovery
pub const EXCLUDED_COLUMNS: &[&str] = &["lat", "lon"];

View file

@ -1,18 +1,13 @@
use polars::prelude::*;
use polars::lazy::frame::LazyFrame; use polars::lazy::frame::LazyFrame;
use polars::prelude::*;
use rayon::prelude::*; use rayon::prelude::*;
use serde::Serialize; use serde::Serialize;
use std::path::Path; use std::path::Path;
use crate::consts::{FEATURE_PERCENTILE_LOW, FEATURE_PERCENTILE_HIGH, HISTOGRAM_BINS, H3_PRECOMPUTE_MIN, H3_PRECOMPUTE_MAX}; use crate::consts::{
EXCLUDED_COLUMNS, FEATURE_PERCENTILE_HIGH, FEATURE_PERCENTILE_LOW, H3_PRECOMPUTE_MAX,
/// Columns to exclude from feature discovery H3_PRECOMPUTE_MIN, HISTOGRAM_BINS,
const EXCLUDED_COLUMNS: &[&str] = &["lat", "lon"]; };
/// H3 valid resolution range (0-15)
pub const MIN_RESOLUTION: u8 = 0;
pub const MAX_RESOLUTION: u8 = 15;
pub const DEFAULT_RESOLUTION: u8 = 8;
/// Returns true if the polars DataType is numeric (integer or float) /// Returns true if the polars DataType is numeric (integer or float)
fn is_numeric_dtype(dtype: &DataType) -> bool { fn is_numeric_dtype(dtype: &DataType) -> bool {
@ -76,7 +71,13 @@ pub struct PropertyData {
/// Approximate a percentile from a histogram using linear interpolation. /// Approximate a percentile from a histogram using linear interpolation.
/// `p` is in [0, 100]. `total` is the sum of all bin counts. /// `p` is in [0, 100]. `total` is the sum of all bin counts.
fn percentile_from_histogram(counts: &[u64], min: f64, bin_width: f64, total: usize, p: f64) -> f64 { fn percentile_from_histogram(
counts: &[u64],
min: f64,
bin_width: f64,
total: usize,
p: f64,
) -> f64 {
let target = (p / 100.0) * (total as f64 - 1.0); let target = (p / 100.0) * (total as f64 - 1.0);
let mut cumulative = 0u64; let mut cumulative = 0u64;
for (i, &c) in counts.iter().enumerate() { for (i, &c) in counts.iter().enumerate() {
@ -104,8 +105,12 @@ fn compute_feature_stats(vals: &[f64]) -> FeatureStats {
let mut count = 0usize; let mut count = 0usize;
for &v in vals { for &v in vals {
if !v.is_nan() { if !v.is_nan() {
if v < min { min = v; } if v < min {
if v > max { max = v; } min = v;
}
if v > max {
max = v;
}
count += 1; count += 1;
} }
} }
@ -222,8 +227,12 @@ impl PropertyData {
// Add string columns (using actual column names from parquet) // Add string columns (using actual column names from parquet)
let string_cols = vec![ let string_cols = vec![
"pp_address", "postcode", "pp_property_type", "built_form", "pp_address",
"current_energy_rating", "potential_energy_rating" "postcode",
"pp_property_type",
"built_form",
"current_energy_rating",
"potential_energy_rating",
]; ];
// Build selection with proper casting // Build selection with proper casting
@ -256,10 +265,20 @@ impl PropertyData {
// Extract lat/lon using bulk iterator // Extract lat/lon using bulk iterator
let lat_series = df.column("lat").unwrap().cast(&DataType::Float64).unwrap(); let lat_series = df.column("lat").unwrap().cast(&DataType::Float64).unwrap();
let lat: Vec<f64> = lat_series.f64().unwrap().into_iter().map(|v| v.unwrap_or(0.0)).collect(); let lat: Vec<f64> = lat_series
.f64()
.unwrap()
.into_iter()
.map(|v| v.unwrap_or(0.0))
.collect();
let lon_series = df.column("lon").unwrap().cast(&DataType::Float64).unwrap(); let lon_series = df.column("lon").unwrap().cast(&DataType::Float64).unwrap();
let lon: Vec<f64> = lon_series.f64().unwrap().into_iter().map(|v| v.unwrap_or(0.0)).collect(); let lon: Vec<f64> = lon_series
.f64()
.unwrap()
.into_iter()
.map(|v| v.unwrap_or(0.0))
.collect();
// Extract feature columns (column-major, for cache-friendly histogram computation) // Extract feature columns (column-major, for cache-friendly histogram computation)
eprintln!("Extracting feature columns..."); eprintln!("Extracting feature columns...");
@ -281,8 +300,10 @@ impl PropertyData {
eprintln!( eprintln!(
" {}: p{}={:.2}, p{}={:.2}, {} bins", " {}: p{}={:.2}, p{}={:.2}, {} bins",
feature_names[i], feature_names[i],
FEATURE_PERCENTILE_LOW, stats.p_low, FEATURE_PERCENTILE_LOW,
FEATURE_PERCENTILE_HIGH, stats.p_high, stats.p_low,
FEATURE_PERCENTILE_HIGH,
stats.p_high,
stats.histogram.counts.len() stats.histogram.counts.len()
); );
stats stats
@ -292,40 +313,66 @@ impl PropertyData {
// Extract string columns (before permutation) // Extract string columns (before permutation)
eprintln!("Extracting string columns..."); eprintln!("Extracting string columns...");
let address_raw: Vec<String> = if let Ok(col) = df.column("pp_address") { let address_raw: Vec<String> = if let Ok(col) = df.column("pp_address") {
col.str().unwrap().into_iter().map(|v| v.unwrap_or("").to_string()).collect() col.str()
.unwrap()
.into_iter()
.map(|v| v.unwrap_or("").to_string())
.collect()
} else { } else {
vec![String::new(); row_count] vec![String::new(); row_count]
}; };
let postcode_raw: Vec<String> = if let Ok(col) = df.column("postcode") { let postcode_raw: Vec<String> = if let Ok(col) = df.column("postcode") {
col.str().unwrap().into_iter().map(|v| v.unwrap_or("").to_string()).collect() col.str()
.unwrap()
.into_iter()
.map(|v| v.unwrap_or("").to_string())
.collect()
} else { } else {
vec![String::new(); row_count] vec![String::new(); row_count]
}; };
let property_type_raw: Vec<String> = if let Ok(col) = df.column("pp_property_type") { let property_type_raw: Vec<String> = if let Ok(col) = df.column("pp_property_type") {
col.str().unwrap().into_iter().map(|v| v.unwrap_or("").to_string()).collect() col.str()
.unwrap()
.into_iter()
.map(|v| v.unwrap_or("").to_string())
.collect()
} else { } else {
vec![String::new(); row_count] vec![String::new(); row_count]
}; };
let built_form_raw: Vec<String> = if let Ok(col) = df.column("built_form") { let built_form_raw: Vec<String> = if let Ok(col) = df.column("built_form") {
col.str().unwrap().into_iter().map(|v| v.unwrap_or("").to_string()).collect() col.str()
.unwrap()
.into_iter()
.map(|v| v.unwrap_or("").to_string())
.collect()
} else { } else {
vec![String::new(); row_count] vec![String::new(); row_count]
}; };
let current_energy_rating_raw: Vec<String> = if let Ok(col) = df.column("current_energy_rating") { let current_energy_rating_raw: Vec<String> =
col.str().unwrap().into_iter().map(|v| v.unwrap_or("").to_string()).collect() if let Ok(col) = df.column("current_energy_rating") {
} else { col.str()
vec![String::new(); row_count] .unwrap()
}; .into_iter()
.map(|v| v.unwrap_or("").to_string())
.collect()
} else {
vec![String::new(); row_count]
};
let potential_energy_rating_raw: Vec<String> = if let Ok(col) = df.column("potential_energy_rating") { let potential_energy_rating_raw: Vec<String> =
col.str().unwrap().into_iter().map(|v| v.unwrap_or("").to_string()).collect() if let Ok(col) = df.column("potential_energy_rating") {
} else { col.str()
vec![String::new(); row_count] .unwrap()
}; .into_iter()
.map(|v| v.unwrap_or("").to_string())
.collect()
} else {
vec![String::new(); row_count]
};
// Sort all rows by spatial locality so that grid queries access // Sort all rows by spatial locality so that grid queries access
// contiguous memory (sequential reads instead of random DRAM accesses). // contiguous memory (sequential reads instead of random DRAM accesses).
@ -349,12 +396,30 @@ impl PropertyData {
let lon: Vec<f64> = perm.iter().map(|&i| lon[i as usize]).collect(); let lon: Vec<f64> = perm.iter().map(|&i| lon[i as usize]).collect();
// Apply permutation to string columns // Apply permutation to string columns
let address: Vec<String> = perm.iter().map(|&i| address_raw[i as usize].clone()).collect(); let address: Vec<String> = perm
let postcode: Vec<String> = perm.iter().map(|&i| postcode_raw[i as usize].clone()).collect(); .iter()
let property_type: Vec<String> = perm.iter().map(|&i| property_type_raw[i as usize].clone()).collect(); .map(|&i| address_raw[i as usize].clone())
let built_form: Vec<String> = perm.iter().map(|&i| built_form_raw[i as usize].clone()).collect(); .collect();
let current_energy_rating: Vec<String> = perm.iter().map(|&i| current_energy_rating_raw[i as usize].clone()).collect(); let postcode: Vec<String> = perm
let potential_energy_rating: Vec<String> = perm.iter().map(|&i| potential_energy_rating_raw[i as usize].clone()).collect(); .iter()
.map(|&i| postcode_raw[i as usize].clone())
.collect();
let property_type: Vec<String> = perm
.iter()
.map(|&i| property_type_raw[i as usize].clone())
.collect();
let built_form: Vec<String> = perm
.iter()
.map(|&i| built_form_raw[i as usize].clone())
.collect();
let current_energy_rating: Vec<String> = perm
.iter()
.map(|&i| current_energy_rating_raw[i as usize].clone())
.collect();
let potential_energy_rating: Vec<String> = perm
.iter()
.map(|&i| potential_energy_rating_raw[i as usize].clone())
.collect();
// Transpose to row-major AND apply spatial permutation in one pass. // Transpose to row-major AND apply spatial permutation in one pass.
// Result: all features for one row are contiguous, and spatially // Result: all features for one row are contiguous, and spatially
@ -422,7 +487,8 @@ impl POIData {
eprintln!("Loaded {} POIs", row_count); eprintln!("Loaded {} POIs", row_count);
// Extract columns // Extract columns
let id: Vec<String> = df.column("id") let id: Vec<String> = df
.column("id")
.unwrap() .unwrap()
.str() .str()
.unwrap() .unwrap()
@ -430,7 +496,8 @@ impl POIData {
.map(|v| v.unwrap_or("").to_string()) .map(|v| v.unwrap_or("").to_string())
.collect(); .collect();
let name: Vec<String> = df.column("name") let name: Vec<String> = df
.column("name")
.unwrap() .unwrap()
.str() .str()
.unwrap() .unwrap()
@ -438,7 +505,8 @@ impl POIData {
.map(|v| v.unwrap_or("").to_string()) .map(|v| v.unwrap_or("").to_string())
.collect(); .collect();
let category: Vec<String> = df.column("category") let category: Vec<String> = df
.column("category")
.unwrap() .unwrap()
.str() .str()
.unwrap() .unwrap()
@ -446,7 +514,8 @@ impl POIData {
.map(|v| v.unwrap_or("").to_string()) .map(|v| v.unwrap_or("").to_string())
.collect(); .collect();
let lat: Vec<f64> = df.column("lat") let lat: Vec<f64> = df
.column("lat")
.unwrap() .unwrap()
.f64() .f64()
.unwrap() .unwrap()
@ -454,7 +523,8 @@ impl POIData {
.map(|v| v.unwrap_or(0.0)) .map(|v| v.unwrap_or(0.0))
.collect(); .collect();
let lng: Vec<f64> = df.column("lng") let lng: Vec<f64> = df
.column("lng")
.unwrap() .unwrap()
.f64() .f64()
.unwrap() .unwrap()
@ -462,7 +532,8 @@ impl POIData {
.map(|v| v.unwrap_or(0.0)) .map(|v| v.unwrap_or(0.0))
.collect(); .collect();
let emoji: Vec<String> = df.column("emoji") let emoji: Vec<String> = df
.column("emoji")
.unwrap() .unwrap()
.str() .str()
.unwrap() .unwrap()

View file

@ -114,7 +114,13 @@ impl GridIndex {
} }
} }
fn clamp_bounds(&self, south: f64, west: f64, north: f64, east: f64) -> (usize, usize, usize, usize) { fn clamp_bounds(
&self,
south: f64,
west: f64,
north: f64,
east: f64,
) -> (usize, usize, usize, usize) {
let r_min = ((south - self.min_lat) / self.cell_size) as isize; let r_min = ((south - self.min_lat) / self.cell_size) as isize;
let r_max = ((north - self.min_lat) / self.cell_size) as isize; let r_max = ((north - self.min_lat) / self.cell_size) as isize;
let c_min = ((west - self.min_lon) / self.cell_size) as isize; let c_min = ((west - self.min_lon) / self.cell_size) as isize;

View file

@ -42,7 +42,10 @@ async fn main() {
let poi_data = if poi_path.exists() { let poi_data = if poi_path.exists() {
data::POIData::load(&poi_path) data::POIData::load(&poi_path)
} else { } else {
eprintln!("Warning: {} not found. POI endpoints will be unavailable.", poi_path.display()); eprintln!(
"Warning: {} not found. POI endpoints will be unavailable.",
poi_path.display()
);
data::POIData { data::POIData {
id: Vec::new(), id: Vec::new(),
name: Vec::new(), name: Vec::new(),
@ -93,7 +96,9 @@ async fn main() {
) )
.route( .route(
"/api/hexagon-properties", "/api/hexagon-properties",
get(move |query| routes::get_hexagon_properties(state_hexagon_properties.clone(), query)), get(move |query| {
routes::get_hexagon_properties(state_hexagon_properties.clone(), query)
}),
); );
// Static file serving for frontend // Static file serving for frontend

View file

@ -8,7 +8,8 @@ use axum::response::{IntoResponse, Json};
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::data::{Histogram, PropertyData, POIData, POI, DEFAULT_RESOLUTION, MAX_RESOLUTION, MIN_RESOLUTION}; use crate::consts::{H3_PRECOMPUTE_MAX, H3_PRECOMPUTE_MIN};
use crate::data::{Histogram, POIData, PropertyData, POI};
use crate::index::GridIndex; use crate::index::GridIndex;
/// Shared application state /// Shared application state
@ -82,7 +83,7 @@ pub async fn get_features(state: Arc<AppState>) -> Json<FeaturesResponse> {
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct HexagonParams { pub struct HexagonParams {
resolution: Option<u8>, resolution: u8,
bounds: Option<String>, bounds: Option<String>,
/// Comma-separated filters: `name:min:max,...` /// Comma-separated filters: `name:min:max,...`
/// Rows must have non-NaN values within [min,max] for each filter. /// Rows must have non-NaN values within [min,max] for each filter.
@ -130,7 +131,6 @@ impl CellAgg {
} }
} }
} }
} }
/// Write the hexagons JSON response directly to a String buffer, /// Write the hexagons JSON response directly to a String buffer,
@ -172,20 +172,21 @@ pub async fn get_hexagons(
state: Arc<AppState>, state: Arc<AppState>,
Query(params): Query<HexagonParams>, Query(params): Query<HexagonParams>,
) -> Result<impl IntoResponse, (StatusCode, String)> { ) -> Result<impl IntoResponse, (StatusCode, String)> {
let resolution = params.resolution.unwrap_or(DEFAULT_RESOLUTION); let resolution = params.resolution;
if resolution > MAX_RESOLUTION { if resolution < H3_PRECOMPUTE_MIN || resolution > H3_PRECOMPUTE_MAX {
return Err(( return Err((
StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
format!( format!(
"resolution must be between {} and {}", "resolution must be between {} and {}",
MIN_RESOLUTION, MAX_RESOLUTION H3_PRECOMPUTE_MIN, H3_PRECOMPUTE_MAX
), ),
)); ));
} }
let bounds_str = params let bounds_str = params.bounds.ok_or((
.bounds StatusCode::BAD_REQUEST,
.ok_or((StatusCode::BAD_REQUEST, "bounds parameter is required".into()))?; "bounds parameter is required".into(),
))?;
let parts: Vec<f64> = bounds_str let parts: Vec<f64> = bounds_str
.split(',') .split(',')
@ -286,46 +287,44 @@ pub async fn get_hexagons(
if let Some(precomputed) = h3_cells_for_res { if let Some(precomputed) = h3_cells_for_res {
// Fast path: precomputed H3 + visitor pattern // Fast path: precomputed H3 + visitor pattern
state.grid.for_each_in_bounds(south, west, north, east, |row_idx| { state
let row = row_idx as usize; .grid
if !row_passes(row) { .for_each_in_bounds(south, west, north, east, |row_idx| {
return; let row = row_idx as usize;
} if !row_passes(row) {
let cell_id = precomputed[row]; return;
groups }
.entry(cell_id) let cell_id = precomputed[row];
.or_insert_with(|| CellAgg::new(num_features)) groups
.add_row(feature_data, row, num_features); .entry(cell_id)
}); .or_insert_with(|| CellAgg::new(num_features))
.add_row(feature_data, row, num_features);
});
} else { } else {
// Fallback: compute H3 on-the-fly // Fallback: compute H3 on-the-fly
let h3_res = h3o::Resolution::try_from(resolution).unwrap(); let h3_res = h3o::Resolution::try_from(resolution).unwrap();
state.grid.for_each_in_bounds(south, west, north, east, |row_idx| { state
let row = row_idx as usize; .grid
if !row_passes(row) { .for_each_in_bounds(south, west, north, east, |row_idx| {
return; let row = row_idx as usize;
} if !row_passes(row) {
let cell_id = h3o::LatLng::new(state.data.lat[row], state.data.lon[row]) return;
.map(|c| u64::from(c.to_cell(h3_res))) }
.unwrap_or(0); let cell_id = h3o::LatLng::new(state.data.lat[row], state.data.lon[row])
groups .map(|c| u64::from(c.to_cell(h3_res)))
.entry(cell_id) .unwrap_or(0);
.or_insert_with(|| CellAgg::new(num_features)) groups
.add_row(feature_data, row, num_features); .entry(cell_id)
}); .or_insert_with(|| CellAgg::new(num_features))
.add_row(feature_data, row, num_features);
});
} }
let t_agg = t0.elapsed(); let t_agg = t0.elapsed();
// Write JSON directly (no serde_json::Value allocation overhead) // Write JSON directly (no serde_json::Value allocation overhead)
let mut json_buf = String::with_capacity(groups.len() * 128); let mut json_buf = String::with_capacity(groups.len() * 128);
write_hexagons_json( write_hexagons_json(&mut json_buf, &groups, &min_keys, &max_keys, num_features);
&mut json_buf,
&groups,
&min_keys,
&max_keys,
num_features,
);
let t_total = t0.elapsed(); let t_total = t0.elapsed();
eprintln!( eprintln!(
@ -364,9 +363,10 @@ pub async fn get_pois(
state: Arc<AppState>, state: Arc<AppState>,
Query(params): Query<POIParams>, Query(params): Query<POIParams>,
) -> Result<Json<POIsResponse>, (StatusCode, String)> { ) -> Result<Json<POIsResponse>, (StatusCode, String)> {
let bounds_str = params let bounds_str = params.bounds.ok_or((
.bounds StatusCode::BAD_REQUEST,
.ok_or((StatusCode::BAD_REQUEST, "bounds parameter is required".into()))?; "bounds parameter is required".into(),
))?;
let parts: Vec<f64> = bounds_str let parts: Vec<f64> = bounds_str
.split(',') .split(',')
@ -501,7 +501,12 @@ pub struct HexagonPropertiesResponse {
} }
/// Helper function to check if a row passes all filters /// Helper function to check if a row passes all filters
fn row_passes_filters(row: usize, filters: &[ParsedFilter], feature_data: &[f64], num_features: usize) -> bool { fn row_passes_filters(
row: usize,
filters: &[ParsedFilter],
feature_data: &[f64],
num_features: usize,
) -> bool {
filters.iter().all(|f| { filters.iter().all(|f| {
let v = feature_data[row * num_features + f.feat_idx]; let v = feature_data[row * num_features + f.feat_idx];
v.is_finite() && v >= f.min && v <= f.max v.is_finite() && v >= f.min && v <= f.max
@ -520,7 +525,10 @@ pub async fn get_hexagon_properties(
// 2. Validate resolution // 2. Validate resolution
let resolution = params.resolution as usize; let resolution = params.resolution as usize;
if resolution >= state.h3_cells.len() || state.h3_cells[resolution].is_empty() { if resolution >= state.h3_cells.len() || state.h3_cells[resolution].is_empty() {
return Err((StatusCode::BAD_REQUEST, "Invalid or non-precomputed resolution".to_string())); return Err((
StatusCode::BAD_REQUEST,
"Invalid or non-precomputed resolution".to_string(),
));
} }
// 3. Parse filters (reuse existing filter parsing logic from get_hexagons) // 3. Parse filters (reuse existing filter parsing logic from get_hexagons)
@ -592,7 +600,11 @@ pub async fn get_hexagon_properties(
// Helper to get non-empty string // Helper to get non-empty string
let get_string = |s: &str| -> Option<String> { let get_string = |s: &str| -> Option<String> {
if s.is_empty() { None } else { Some(s.to_string()) } if s.is_empty() {
None
} else {
Some(s.to_string())
}
}; };
Property { Property {