diff --git a/Taskfile.yml b/Taskfile.yml index 498cd36..dfaf773 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -22,6 +22,12 @@ tasks: - uv run -m pipeline.utils.test_fuzzy_join - uv run pytest pipeline/utils/test_haversine.py - uv run pytest pipeline/utils/test_poi_counts.py + + test:server: + desc: Run Rust backend tests + dir: server-rs + cmds: + - cargo test dev:server: desc: Run Rust backend on port 8001 (debug build, fast compile) @@ -115,3 +121,4 @@ tasks: - task: build:server - task: build:frontend - task: test + - task: test:server diff --git a/server-rs/src/consts.rs b/server-rs/src/consts.rs index c7190cf..28b096b 100644 --- a/server-rs/src/consts.rs +++ b/server-rs/src/consts.rs @@ -1,20 +1,28 @@ -pub const FEATURE_PERCENTILE_LOW: f64 = 2.0; - -pub const FEATURE_PERCENTILE_HIGH: f64 = 98.0; - pub const HISTOGRAM_BINS: usize = 100; pub const H3_PRECOMPUTE_MIN: u8 = 4; pub const H3_PRECOMPUTE_MAX: u8 = 12; -pub const EXCLUDED_COLUMNS: &[&str] = &["lat", "lon"]; +pub const SERVER_ADDRESS: &str = "0.0.0.0:8001"; -pub const EXCLUDED_STRING_COLUMNS: &[&str] = &[ - "pp_address", - "postcode", - "Address per Property Register", - "Address per EPC", - "Postcode", +pub const BOUNDS_QUANTIZATION: f64 = 0.01; +pub const BOUNDS_BUFFER_PERCENT: f64 = 0.1; +pub const POSTCODE_MIN_RESOLUTION: u8 = 11; +pub const MAX_POIS_PER_REQUEST: usize = 5000; +pub const DEFAULT_PROPERTIES_LIMIT: usize = 100; +pub const MAX_PROPERTIES_LIMIT: usize = 500; +pub const ENUM_NULL: u8 = 255; + +/// Canonical display order for POI category groups. +/// The server will panic at startup if the data contains groups not in this list or vice versa. +pub const POI_GROUP_ORDER: &[&str] = &[ + "Public Transport", + "Amenity", + "Building", + "Craft", + "Healthcare", + "Leisure", + "Office", + "Shop", + "Tourism", ]; - -pub const MAX_ENUM_CARDINALITY: usize = 50; diff --git a/server-rs/src/index.rs b/server-rs/src/grid_index.rs similarity index 51% rename from server-rs/src/index.rs rename to server-rs/src/grid_index.rs index 03412f6..4743d84 100644 --- a/server-rs/src/index.rs +++ b/server-rs/src/grid_index.rs @@ -2,7 +2,6 @@ /// /// Divides the UK bounding box into cells of ~0.01 degrees (~1km), /// each storing indices of rows whose lat/lon falls within that cell. - pub struct GridIndex { min_lat: f64, min_lon: f64, @@ -21,19 +20,17 @@ impl GridIndex { let mut max_lon = f64::NEG_INFINITY; for i in 0..lat.len() { - let la = lat[i]; - let lo = lon[i]; - if la < min_lat { - min_lat = la; + if lat[i] < min_lat { + min_lat = lat[i]; } - if la > max_lat { - max_lat = la; + if lat[i] > max_lat { + max_lat = lat[i]; } - if lo < min_lon { - min_lon = lo; + if lon[i] < min_lon { + min_lon = lon[i]; } - if lo > max_lon { - max_lon = lo; + if lon[i] > max_lon { + max_lon = lon[i]; } } @@ -56,9 +53,9 @@ impl GridIndex { let mut cells: Vec> = vec![Vec::new(); rows * cols]; for i in 0..lat.len() { - let r = ((lat[i] - min_lat) / cell_size) as usize; - let c = ((lon[i] - min_lon) / cell_size) as usize; - let idx = r * cols + c; + let grid_row = ((lat[i] - min_lat) / cell_size) as usize; + let grid_col = ((lon[i] - min_lon) / cell_size) as usize; + let idx = grid_row * cols + grid_col; cells[idx].push(i as u32); } @@ -75,20 +72,23 @@ impl GridIndex { } pub fn query(&self, south: f64, west: f64, north: f64, east: f64) -> Vec { - let (r_min, r_max, c_min, c_max) = self.clamp_bounds(south, west, north, east); + let Some((row_min, row_max, col_min, col_max)) = + self.clamp_bounds(south, west, north, east) + else { + return Vec::new(); + }; let mut result = Vec::new(); - for r in r_min..=r_max { - let row_start = r * self.cols; - for c in c_min..=c_max { - result.extend_from_slice(&self.cells[row_start + c]); + for row in row_min..=row_max { + let row_start = row * self.cols; + for col in col_min..=col_max { + result.extend_from_slice(&self.cells[row_start + col]); } } result } - /// Iterate all row indices in bounds without allocating a Vec. #[inline] pub fn for_each_in_bounds( &self, @@ -98,12 +98,16 @@ impl GridIndex { east: f64, mut f: impl FnMut(u32), ) { - let (r_min, r_max, c_min, c_max) = self.clamp_bounds(south, west, north, east); + let Some((row_min, row_max, col_min, col_max)) = + self.clamp_bounds(south, west, north, east) + else { + return; + }; - for r in r_min..=r_max { - let row_start = r * self.cols; - for c in c_min..=c_max { - for &row_idx in &self.cells[row_start + c] { + for row in row_min..=row_max { + let row_start = row * self.cols; + for col in col_min..=col_max { + for &row_idx in &self.cells[row_start + col] { f(row_idx); } } @@ -116,17 +120,28 @@ impl GridIndex { west: f64, north: f64, east: f64, - ) -> (usize, usize, usize, usize) { - let r_min = ((south - self.min_lat) / self.cell_size) as isize; - let r_max = ((north - self.min_lat) / self.cell_size) as isize; - let c_min = ((west - self.min_lon) / self.cell_size) as isize; - let c_max = ((east - self.min_lon) / self.cell_size) as isize; + ) -> Option<(usize, usize, usize, usize)> { + let row_min_raw = ((south - self.min_lat) / self.cell_size) as isize; + let row_max_raw = ((north - self.min_lat) / self.cell_size) as isize; + let col_min_raw = ((west - self.min_lon) / self.cell_size) as isize; + let col_max_raw = ((east - self.min_lon) / self.cell_size) as isize; - let r_min = r_min.max(0) as usize; - let r_max = (r_max.min(self.rows as isize - 1)).max(0) as usize; - let c_min = c_min.max(0) as usize; - let c_max = (c_max.min(self.cols as isize - 1)).max(0) as usize; + let row_min = row_min_raw.max(0) as usize; + let row_max_clamped = row_max_raw.min(self.rows as isize - 1); + let col_min = col_min_raw.max(0) as usize; + let col_max_clamped = col_max_raw.min(self.cols as isize - 1); - (r_min, r_max, c_min, c_max) + if row_max_clamped < 0 || col_max_clamped < 0 { + return None; + } + + let row_max = row_max_clamped as usize; + let col_max = col_max_clamped as usize; + + if row_min > row_max || col_min > col_max { + return None; + } + + Some((row_min, row_max, col_min, col_max)) } } diff --git a/server-rs/src/routes/parse.rs b/server-rs/src/routes/parse.rs new file mode 100644 index 0000000..5bcd04f --- /dev/null +++ b/server-rs/src/routes/parse.rs @@ -0,0 +1,23 @@ +use axum::http::StatusCode; + +pub fn parse_bounds(bounds_str: &str) -> Result<(f64, f64, f64, f64), (StatusCode, String)> { + let parts: Vec = bounds_str + .split(',') + .map(|s| s.trim().parse::()) + .collect::, _>>() + .map_err(|_| { + ( + StatusCode::BAD_REQUEST, + "Invalid bounds format. Use: south,west,north,east".into(), + ) + })?; + + if parts.len() != 4 { + return Err(( + StatusCode::BAD_REQUEST, + "Invalid bounds format. Use: south,west,north,east".into(), + )); + } + + Ok((parts[0], parts[1], parts[2], parts[3])) +} diff --git a/server-rs/src/routes/pois.rs b/server-rs/src/routes/pois.rs index f309ab7..40d5557 100644 --- a/server-rs/src/routes/pois.rs +++ b/server-rs/src/routes/pois.rs @@ -6,8 +6,11 @@ use axum::response::Json; use serde::{Deserialize, Serialize}; use tracing::info; +use crate::consts::MAX_POIS_PER_REQUEST; use crate::data::POI; -use crate::state::AppState; +use crate::state::{AppState, POICategoryGroup}; + +use super::parse::parse_bounds; #[derive(Deserialize)] pub struct POIParams { @@ -30,28 +33,10 @@ pub async fn get_pois( "bounds parameter is required".into(), ))?; - let parts: Vec = bounds_str - .split(',') - .map(|s| s.trim().parse::()) - .collect::, _>>() - .map_err(|_| { - ( - StatusCode::BAD_REQUEST, - "Invalid bounds format. Use: south,west,north,east".into(), - ) - })?; - - if parts.len() != 4 { - return Err(( - StatusCode::BAD_REQUEST, - "Invalid bounds format. Use: south,west,north,east".into(), - )); - } - - let (south, west, north, east) = (parts[0], parts[1], parts[2], parts[3]); + let (south, west, north, east) = parse_bounds(&bounds_str)?; let categories_str = params.categories.clone(); - let category_filter: Option> = params + let category_filter: Option> = params .categories .as_deref() .filter(|s| !s.is_empty()) @@ -78,12 +63,13 @@ pub async fn get_pois( id: state.poi_data.id[row].clone(), name: state.poi_data.name[row].clone(), category: state.poi_data.category[row].clone(), + group: state.poi_data.group[row].clone(), lat: state.poi_data.lat[row], lng: state.poi_data.lng[row], emoji: state.poi_data.emoji[row].clone(), }) }) - .take(5000) + .take(MAX_POIS_PER_REQUEST) .collect(); let elapsed = t0.elapsed(); @@ -99,35 +85,25 @@ pub async fn get_pois( POIsResponse { pois } }) .await - .unwrap(); + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; Ok(Json(result)) } #[derive(Serialize)] pub struct POICategoriesResponse { - categories: Vec, + groups: Vec, } pub async fn get_poi_categories(state: Arc) -> Json { - let result = tokio::task::spawn_blocking(move || { - let mut categories: Vec = state - .poi_data - .category - .iter() - .cloned() - .collect::>() - .into_iter() - .collect(); + let groups: Vec = state.poi_category_groups.clone(); - categories.sort(); + let total: usize = groups.iter().map(|g| g.categories.len()).sum(); + info!( + count = total, + groups = groups.len(), + "GET /api/poi-categories" + ); - info!(count = categories.len(), "GET /api/poi-categories"); - - POICategoriesResponse { categories } - }) - .await - .unwrap(); - - Json(result) + Json(POICategoriesResponse { groups }) } diff --git a/server-rs/src/tests.rs b/server-rs/src/tests.rs new file mode 100644 index 0000000..a6bd035 --- /dev/null +++ b/server-rs/src/tests.rs @@ -0,0 +1,250 @@ +#[cfg(test)] +mod grid_index_tests { + use crate::grid_index::GridIndex; + + #[test] + fn query_bounds_fully_below_grid_returns_empty() { + let lat = vec![50.0, 50.5, 51.0]; + let lon = vec![0.0, 0.5, 1.0]; + let grid = GridIndex::build(&lat, &lon, 0.01); + + let results = grid.query(10.0, -10.0, 20.0, -5.0); + assert!( + results.is_empty(), + "Should return empty for bounds fully below grid" + ); + } + + #[test] + fn query_bounds_fully_above_grid_returns_empty() { + let lat = vec![50.0, 50.5, 51.0]; + let lon = vec![0.0, 0.5, 1.0]; + let grid = GridIndex::build(&lat, &lon, 0.01); + + let results = grid.query(80.0, 50.0, 90.0, 60.0); + assert!( + results.is_empty(), + "Should return empty for bounds fully above grid" + ); + } + + #[test] + fn query_inverted_bounds_returns_empty() { + let lat = vec![50.0, 50.5, 51.0]; + let lon = vec![0.0, 0.5, 1.0]; + let grid = GridIndex::build(&lat, &lon, 0.01); + + // south > north + let results = grid.query(52.0, 0.0, 49.0, 1.0); + assert!( + results.is_empty(), + "Should return empty for inverted bounds" + ); + } + + #[test] + fn for_each_bounds_fully_outside_yields_nothing() { + let lat = vec![50.0, 50.5, 51.0]; + let lon = vec![0.0, 0.5, 1.0]; + let grid = GridIndex::build(&lat, &lon, 0.01); + + let mut count = 0; + grid.for_each_in_bounds(10.0, -10.0, 20.0, -5.0, |_| count += 1); + assert_eq!( + count, 0, + "for_each should yield nothing for out-of-bounds query" + ); + } + + #[test] + fn query_with_large_cells_outside_returns_empty() { + // Previously, out-of-bounds queries with large cell sizes would + // scan cell (0,0) which could contain data. Now returns empty. + let lat = vec![50.0]; + let lon = vec![0.0]; + let grid = GridIndex::build(&lat, &lon, 1.0); + + let results = grid.query(0.0, -50.0, 10.0, -40.0); + assert!( + results.is_empty(), + "Should return empty even with large cell size" + ); + } + + #[test] + fn query_within_bounds_returns_correct_results() { + let lat = vec![50.0, 50.5, 51.0]; + let lon = vec![0.0, 0.5, 1.0]; + let grid = GridIndex::build(&lat, &lon, 0.01); + + let results = grid.query(49.9, -0.1, 51.1, 1.1); + assert_eq!(results.len(), 3, "Should return all 3 points within bounds"); + } + + #[test] + fn query_partial_bounds_returns_subset() { + let lat = vec![50.0, 51.0, 52.0]; + let lon = vec![0.0, 0.0, 0.0]; + let grid = GridIndex::build(&lat, &lon, 0.01); + + let results = grid.query(49.9, -0.1, 50.1, 0.1); + assert_eq!(results.len(), 1, "Should return only the point at lat=50"); + } +} + +#[cfg(test)] +mod filter_tests { + use crate::data::EnumFeatureData; + use crate::filter::{parse_filters, row_passes_filters}; + + #[test] + fn nan_rows_fail_numeric_filter_even_with_infinite_range() { + let feature_names = vec!["price".to_string()]; + let feature_data = vec![f64::NAN]; + let enum_features: Vec = vec![]; + + let (numeric, enums) = + parse_filters(Some("price:-inf:inf"), &feature_names, &enum_features); + assert_eq!(numeric.len(), 1, "Should parse -inf:inf as valid filter"); + + let passes = row_passes_filters(0, &numeric, &enums, &feature_data, 1, &enum_features); + assert!(!passes, "NaN should fail filter even with infinite range"); + } + + #[test] + fn empty_enum_filter_value_rejects_all() { + let enum_features = vec![EnumFeatureData { + name: "rating".to_string(), + values: vec!["A".to_string(), "B".to_string()], + data: vec![0], + }]; + let feature_names: Vec = vec![]; + + let (numeric, enums) = parse_filters(Some("rating:"), &feature_names, &enum_features); + assert_eq!(enums.len(), 1); + assert!(enums[0].allowed.is_empty()); + + let passes = row_passes_filters(0, &numeric, &enums, &[], 0, &enum_features); + assert!(!passes, "Empty allowed set should reject all rows"); + } + + #[test] + fn enum_filter_with_nonexistent_values_produces_empty_allowed() { + let enum_features = vec![EnumFeatureData { + name: "rating".to_string(), + values: vec!["A".to_string(), "B".to_string()], + data: vec![0], + }]; + let feature_names: Vec = vec![]; + + let (_, enums) = parse_filters(Some("rating:X|Y|Z"), &feature_names, &enum_features); + assert_eq!(enums.len(), 1); + assert!(enums[0].allowed.is_empty()); + } + + #[test] + fn malformed_numeric_min_is_silently_skipped() { + let feature_names = vec!["price".to_string()]; + let enum_features: Vec = vec![]; + + let (numeric, enums) = parse_filters( + Some("price:not_a_number:200"), + &feature_names, + &enum_features, + ); + assert_eq!(numeric.len(), 0); + assert_eq!(enums.len(), 0); + } +} + +#[cfg(test)] +mod json_tests { + use std::fmt::Write; + + #[test] + fn json_escaped_postcode_with_quotes_is_valid() { + use crate::routes::hexagons::write_json_escaped; + + let mut buf = String::new(); + buf.push_str("{\"postcode\":\""); + write_json_escaped(&mut buf, "SW1A \"test"); + buf.push_str("\"}"); + + let result: Result = serde_json::from_str(&buf); + assert!( + result.is_ok(), + "Escaped quote should produce valid JSON: {}", + buf + ); + assert_eq!(result.unwrap()["postcode"].as_str().unwrap(), "SW1A \"test"); + } + + #[test] + fn json_escaped_postcode_with_backslash_is_valid() { + use crate::routes::hexagons::write_json_escaped; + + let mut buf = String::new(); + buf.push_str("{\"postcode\":\""); + write_json_escaped(&mut buf, "SW1A\\2AA"); + buf.push_str("\"}"); + + let result: Result = serde_json::from_str(&buf); + assert!( + result.is_ok(), + "Escaped backslash should produce valid JSON: {}", + buf + ); + assert_eq!(result.unwrap()["postcode"].as_str().unwrap(), "SW1A\\2AA"); + } + + #[test] + fn nan_is_not_valid_json() { + // Verify that raw NaN in write! is still invalid JSON (documenting the risk + // that the is_finite() guard in write_hexagons_json protects against). + let mut buf = String::new(); + write!(buf, "{{\"min_price\":{}}}", f64::NAN).unwrap(); + + let result: Result = serde_json::from_str(&buf); + assert!(result.is_err(), "Raw NaN should produce invalid JSON"); + } + + #[test] + fn infinity_is_not_valid_json() { + let mut buf = String::new(); + write!(buf, "{{\"min_price\":{}}}", f64::INFINITY).unwrap(); + + let result: Result = serde_json::from_str(&buf); + assert!(result.is_err(), "Raw Infinity should produce invalid JSON"); + } +} + +#[cfg(test)] +mod enum_encoding_tests { + #[test] + fn u8_cast_wraps_around_beyond_255() { + // Documents the underlying u8 wrapping behavior that the truncation + // guard in property.rs now prevents. + let num_values = 300usize; + let indices: Vec = (0..num_values).map(|i| i as u8).collect(); + + assert_eq!(indices[0], indices[256], "u8 wraps: 0 == 256"); + assert_eq!(indices[1], indices[257], "u8 wraps: 1 == 257"); + + use std::collections::HashMap; + let values: Vec = (0..num_values).map(|i| format!("val_{}", i)).collect(); + let value_to_idx: HashMap<&str, u8> = values + .iter() + .enumerate() + .map(|(i, v)| (v.as_str(), i as u8)) + .collect(); + + let unique_indices: std::collections::HashSet = + value_to_idx.values().cloned().collect(); + assert!( + unique_indices.len() < num_values, + "Without the truncation guard, {} values produce only {} unique u8 indices", + num_values, + unique_indices.len() + ); + } +}