use std::sync::Arc; use axum::extract::{Query, State}; use axum::http::StatusCode; use axum::response::{IntoResponse, Json}; use axum::Extension; use metrics::histogram; use rayon::prelude::*; use rustc_hash::{FxHashMap, FxHashSet}; use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; use tracing::info; use crate::aggregation::{Aggregator, EnumDistConfig, PoiAggregator}; use crate::auth::OptionalUser; use crate::consts::MAX_CELLS_PER_REQUEST; use crate::data::travel_time::TravelData; use crate::licensing::{check_license_bounds, resolve_share_code}; use crate::parsing::{ cell_for_row_cached, needs_parent, parse_enum_dist, parse_field_indices_with_poi, parse_filters_with_poi, require_bounds, row_passes_filters, row_passes_poi_filters, validate_h3_resolution, }; use crate::routes::travel_time::{parse_optional_travel, TravelTimeAgg}; use crate::state::SharedState; /// Row count threshold above which we use rayon parallel aggregation. const PARALLEL_THRESHOLD: usize = 50_000; /// Per-thread aggregation result: feature accumulators + travel time accumulators. type ChunkResult = ( FxHashMap, FxHashMap, Vec>, FxHashSet, ); /// Maximum center-to-vertex distance in degrees per H3 resolution. /// Generous for UK latitudes (49°–61°) so we never false-exclude a visible cell. /// Used for cheap center-based bounds filtering instead of computing full cell boundary. const H3_CENTER_BUFFERS: [f64; 13] = [ 5.0, 2.0, 1.0, 0.5, // res 0–3 (unused in practice) 0.50, // res 4 0.20, // res 5 0.08, // res 6 0.03, // res 7 0.012, // res 8 0.005, // res 9 0.002, // res 10 0.001, // res 11 0.0005, // res 12 ]; #[derive(Serialize)] pub struct HexagonsResponse { features: Vec>, } #[derive(Deserialize)] pub struct HexagonParams { resolution: u8, bounds: Option, /// `;;`-separated filters: `name:min:max;;...` filters: Option, /// Comma-separated feature names to include in min/max aggregation. fields: Option, /// Pipe-separated travel time entries: `mode:slug|mode:slug:min:max` /// Each entry requests travel time aggregation for that mode+destination. /// Optional min:max applies as a filter (exclude properties outside range). travel: Option, /// Feature name for enum distribution counting (pie chart visualization). /// When set, each cell includes `dist_{name}: [count_val0, count_val1, ...]`. enum_dist: Option, /// Share-link code; grants bbox-scoped access for unlicensed users. share: Option, } /// Build feature maps from aggregated cell data, filtering to only cells whose /// center is within the query bounds (expanded by a resolution-dependent buffer). /// This is much cheaper than the previous approach of computing full cell boundaries /// (6 vertices per cell) — just 4 float comparisons per cell. #[allow(clippy::too_many_arguments)] fn build_feature_maps( groups: &FxHashMap, poi_groups: &FxHashMap, selectable_cells: &FxHashSet, min_keys: &[String], max_keys: &[String], avg_keys: &[String], num_features: usize, indices: Option<&[usize]>, poi_feature_names: &[String], poi_indices: &[usize], query_bounds: (f64, f64, f64, f64), resolution: h3o::Resolution, travel_aggs: &[FxHashMap], travel_field_keys: &[String], enum_dist_key: Option<&str>, ) -> Vec> { let mut features = Vec::with_capacity(groups.len()); let (q_south, q_west, q_north, q_east) = query_bounds; // Expand bounds by resolution-dependent buffer for center-based filtering let buf = H3_CENTER_BUFFERS[resolution as usize]; let bound_south = q_south - buf; let bound_north = q_north + buf; let bound_west = q_west - buf; let bound_east = q_east + buf; // Pre-compute travel time key strings (avoids per-cell format!()) let travel_keys: Vec<(String, String, String)> = travel_field_keys .iter() .map(|key| { ( format!("min_{key}"), format!("max_{key}"), format!("avg_{key}"), ) }) .collect(); // Pre-compute default feature indices to avoid per-cell Box allocation let default_indices: Vec; let feat_indices: &[usize] = match indices { Some(idx) => idx, None => { default_indices = (0..num_features).collect(); &default_indices } }; for (&cell_id, aggregation) in groups { let Some(cell) = h3o::CellIndex::try_from(cell_id).ok() else { continue; }; // Center is already needed for lat/lon output — reuse for bounds check let center: h3o::LatLng = cell.into(); let lat = center.lat(); let lng = center.lng(); // Center-based bounds check: 4 comparisons instead of computing 6 boundary vertices if lat < bound_south || lat > bound_north || lng < bound_west || lng > bound_east { continue; } let mut map = Map::new(); map.insert("h3".into(), Value::String(cell.to_string())); map.insert("count".into(), Value::Number(aggregation.count.into())); if let (Some(lat_num), Some(lon_num)) = ( serde_json::Number::from_f64(lat), serde_json::Number::from_f64(lng), ) { map.insert("lat".into(), Value::Number(lat_num)); map.insert("lon".into(), Value::Number(lon_num)); } for &feat_index in feat_indices { if aggregation.feat_counts[feat_index] > 0 { let avg = aggregation.sums[feat_index] / aggregation.feat_counts[feat_index] as f64; if let (Some(min_num), Some(max_num), Some(avg_num)) = ( serde_json::Number::from_f64(aggregation.mins[feat_index] as f64), serde_json::Number::from_f64(aggregation.maxs[feat_index] as f64), serde_json::Number::from_f64(avg), ) { map.insert(min_keys[feat_index].clone(), Value::Number(min_num)); map.insert(max_keys[feat_index].clone(), Value::Number(max_num)); map.insert(avg_keys[feat_index].clone(), Value::Number(avg_num)); } } } if let Some(poi_aggregation) = poi_groups.get(&cell_id) { for &metric_idx in poi_indices { if poi_aggregation.counts[metric_idx] > 0 { let avg = poi_aggregation.sums[metric_idx] / poi_aggregation.counts[metric_idx] as f64; if let (Some(min_num), Some(max_num), Some(avg_num)) = ( serde_json::Number::from_f64(poi_aggregation.mins[metric_idx] as f64), serde_json::Number::from_f64(poi_aggregation.maxs[metric_idx] as f64), serde_json::Number::from_f64(avg), ) { let name = &poi_feature_names[metric_idx]; map.insert(format!("min_{name}"), Value::Number(min_num)); map.insert(format!("max_{name}"), Value::Number(max_num)); map.insert(format!("avg_{name}"), Value::Number(avg_num)); } } } } // Add travel time aggregation fields (using pre-computed key strings) for (ti, agg_map) in travel_aggs.iter().enumerate() { if let Some(agg) = agg_map.get(&cell_id) { if agg.count > 0 { let avg = agg.sum / agg.count as f64; if let Some(nm) = serde_json::Number::from_f64(agg.min as f64) { map.insert(travel_keys[ti].0.clone(), Value::Number(nm)); } if let Some(nm) = serde_json::Number::from_f64(agg.max as f64) { map.insert(travel_keys[ti].1.clone(), Value::Number(nm)); } if let Some(nm) = serde_json::Number::from_f64(avg) { map.insert(travel_keys[ti].2.clone(), Value::Number(nm)); } } } } // Add enum distribution array (for pie chart visualization) if let (Some(key), Some(ref ed)) = (enum_dist_key, &aggregation.enum_dist) { let arr: Vec = ed.counts.iter().map(|&c| Value::from(c)).collect(); map.insert(key.to_string(), Value::Array(arr)); } features.push(map); } for &cell_id in selectable_cells { if groups.contains_key(&cell_id) { continue; } let Some(cell) = h3o::CellIndex::try_from(cell_id).ok() else { continue; }; let center: h3o::LatLng = cell.into(); let lat = center.lat(); let lng = center.lng(); if lat < bound_south || lat > bound_north || lng < bound_west || lng > bound_east { continue; } let mut map = Map::new(); map.insert("h3".into(), Value::String(cell.to_string())); map.insert("count".into(), Value::from(0)); if let (Some(lat_num), Some(lon_num)) = ( serde_json::Number::from_f64(lat), serde_json::Number::from_f64(lng), ) { map.insert("lat".into(), Value::Number(lat_num)); map.insert("lon".into(), Value::Number(lon_num)); } features.push(map); } features } pub async fn get_hexagons( State(shared): State>, Extension(user): Extension, Query(params): Query, ) -> Result, axum::response::Response> { let state = shared.load_state(); let resolution = params.resolution; validate_h3_resolution(resolution).map_err(IntoResponse::into_response)?; let (south, west, north, east) = require_bounds(params.bounds).map_err(IntoResponse::into_response)?; let share_bounds = resolve_share_code(&state, params.share.as_deref()).await; check_license_bounds(&user.0, (south, west, north, east), share_bounds)?; let quant = state.data.quant_ref(); let poi_quant = state.data.poi_metrics.quant_ref(); let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi( params.filters.as_deref(), &state.feature_name_to_index, &state.data.enum_values, &quant, &state.data.poi_metrics.name_to_index, &poi_quant, ) .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len(); let filters_str = params.filters; let field_indices = parse_field_indices_with_poi( params.fields.as_deref(), &state.feature_name_to_index, &state.data.poi_metrics.name_to_index, ) .map_err(|err| (err.0, err.1).into_response())?; let travel_entries = parse_optional_travel(params.travel.as_deref()) .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; let enum_dist_config: EnumDistConfig = parse_enum_dist( params.enum_dist.as_deref(), &state.feature_name_to_index, &state.data.enum_values, ) .map_err(|err| (err.0, err.1).into_response())?; // Pre-compute the dist_ key name (e.g. "dist_Property type") outside spawn_blocking let enum_dist_key: Option = params .enum_dist .as_ref() .map(|name| format!("dist_{}", name.trim())); let response = tokio::task::spawn_blocking(move || -> Result { let t0 = std::time::Instant::now(); // Load travel time data from precomputed parquet files let travel_data: Vec = if !travel_entries.is_empty() { let store = &state.travel_time_store; travel_entries .iter() .map(|entry| { store .get(&entry.mode, &entry.slug) .map_err(|err| format!("Failed to load travel data: {}", err)) }) .collect::, _>>()? } else { Vec::new() }; let has_travel = !travel_entries.is_empty(); let travel_field_keys: Vec = travel_entries .iter() .map(|te| format!("tt_{}_{}", te.mode, te.slug)) .collect(); let num_features = state.data.num_features; let feature_data = &state.data.feature_data; let quant = state.data.quant_ref(); let (pc_interner, pc_keys) = state.data.postcode_parts(); let min_keys = &state.min_keys; let max_keys = &state.max_keys; let avg_keys = &state.avg_keys; let poi_metrics = &state.data.poi_metrics; let poi_field_indices = field_indices.poi.as_slice(); let has_poi_fields = !poi_field_indices.is_empty(); let has_poi_filters = !parsed_poi_filters.is_empty(); let poi_num_features = poi_metrics.num_features(); let h3_res = h3o::Resolution::try_from(resolution) .map_err(|error| format!("Invalid H3 resolution {}: {}", resolution, error))?; let precomputed = &state.h3_cells; let need_parent = needs_parent(resolution); let mut groups: FxHashMap = FxHashMap::default(); let mut poi_groups: FxHashMap = FxHashMap::default(); let mut selectable_cells: FxHashSet = FxHashSet::default(); let mut travel_aggs: Vec> = (0..travel_entries.len()) .map(|_| FxHashMap::default()) .collect(); // O(grid cells) count — no allocation. Used for parallel threshold decision. let row_count = state.grid.count_in_bounds(south, west, north, east); let t_grid = t0.elapsed(); let parallel = row_count >= PARALLEL_THRESHOLD; if parallel { // Parallel: collect row indices for par_chunks, split across rayon threads. // Now handles travel time too (postcode interner & travel data are thread-safe). let row_indices = state.grid.query(south, west, north, east); let chunk_size = (row_indices.len() / rayon::current_num_threads()).max(1000); let thread_results: Vec = row_indices .par_chunks(chunk_size) .map(|chunk| { let mut local_groups: FxHashMap = FxHashMap::default(); let mut local_poi_groups: FxHashMap = FxHashMap::default(); let mut local_travel_aggs: Vec> = (0 ..travel_entries.len()) .map(|_| FxHashMap::default()) .collect(); let mut local_selectable_cells: FxHashSet = FxHashSet::default(); let mut h3_cache: FxHashMap = FxHashMap::default(); let mut travel_minutes: Vec> = Vec::with_capacity(travel_entries.len()); 'row: for &row_idx in chunk { let row = row_idx as usize; let cell_id = cell_for_row_cached( row, precomputed, h3_res, need_parent, &mut h3_cache, ); local_selectable_cells.insert(cell_id); if !row_passes_filters( row, &parsed_filters, &parsed_enum_filters, feature_data, num_features, ) { continue; } if has_poi_filters && !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics) { continue; } if has_travel { travel_minutes.clear(); let postcode = pc_interner.resolve(&pc_keys[row]); for (ti, entry) in travel_entries.iter().enumerate() { let row_data = travel_data[ti].get(postcode); let minutes = row_data.map(|r| { if entry.use_best { r.best_minutes.unwrap_or(r.minutes) } else { r.minutes } }); travel_minutes.push(minutes); if let (Some(fmin), Some(fmax)) = (entry.filter_min, entry.filter_max) { match minutes { Some(mins) if (mins as f32) >= fmin && (mins as f32) <= fmax => {} _ => continue 'row, } } } } let agg = local_groups .entry(cell_id) .or_insert_with(|| Aggregator::new(num_features, enum_dist_config)); if let Some(sel_indices) = field_indices.normal.as_deref() { agg.add_row_selective( feature_data, row, num_features, sel_indices, &quant, ); } else { agg.add_row(feature_data, row, num_features, &quant); } if has_poi_fields { local_poi_groups .entry(cell_id) .or_insert_with(|| PoiAggregator::new(poi_num_features)) .add_row_selective(poi_metrics, row, poi_field_indices); } for (ti, minutes) in travel_minutes.iter().enumerate() { if let Some(mins) = minutes { let tagg = local_travel_aggs[ti] .entry(cell_id) .or_insert_with(TravelTimeAgg::new); tagg.add(*mins as f32); } } } ( local_groups, local_poi_groups, local_travel_aggs, local_selectable_cells, ) }) .collect(); // Merge thread-local results into the main accumulators for (local_groups, local_poi_groups, local_travel, local_selectable_cells) in thread_results { for (cell_id, local_agg) in local_groups { groups .entry(cell_id) .or_insert_with(|| Aggregator::new(num_features, enum_dist_config)) .merge(&local_agg); } for (cell_id, local_agg) in local_poi_groups { poi_groups .entry(cell_id) .or_insert_with(|| PoiAggregator::new(poi_num_features)) .merge(&local_agg); } for (ti, local_ta) in local_travel.into_iter().enumerate() { for (cell_id, local_tt) in local_ta { travel_aggs[ti] .entry(cell_id) .or_insert_with(TravelTimeAgg::new) .merge(&local_tt); } } selectable_cells.extend(local_selectable_cells); } } else { // Sequential: use for_each_in_bounds to avoid Vec allocation let mut travel_minutes: Vec> = Vec::with_capacity(travel_entries.len()); let mut h3_cache: FxHashMap = FxHashMap::default(); state .grid .for_each_in_bounds(south, west, north, east, |row_idx| { let row = row_idx as usize; let cell_id = cell_for_row_cached(row, precomputed, h3_res, need_parent, &mut h3_cache); selectable_cells.insert(cell_id); if !row_passes_filters( row, &parsed_filters, &parsed_enum_filters, feature_data, num_features, ) { return; } if has_poi_filters && !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics) { return; } if has_travel { travel_minutes.clear(); let postcode = pc_interner.resolve(&pc_keys[row]); for (ti, entry) in travel_entries.iter().enumerate() { let row_data = travel_data[ti].get(postcode); let minutes = row_data.map(|r| { if entry.use_best { r.best_minutes.unwrap_or(r.minutes) } else { r.minutes } }); travel_minutes.push(minutes); if let (Some(fmin), Some(fmax)) = (entry.filter_min, entry.filter_max) { match minutes { Some(mins) if (mins as f32) >= fmin && (mins as f32) <= fmax => {} _ => return, } } } } let aggregation = groups .entry(cell_id) .or_insert_with(|| Aggregator::new(num_features, enum_dist_config)); if let Some(sel_indices) = field_indices.normal.as_deref() { aggregation.add_row_selective( feature_data, row, num_features, sel_indices, &quant, ); } else { aggregation.add_row(feature_data, row, num_features, &quant); } if has_poi_fields { poi_groups .entry(cell_id) .or_insert_with(|| PoiAggregator::new(poi_num_features)) .add_row_selective(poi_metrics, row, poi_field_indices); } for (ti, minutes) in travel_minutes.iter().enumerate() { if let Some(mins) = minutes { let agg = travel_aggs[ti] .entry(cell_id) .or_insert_with(TravelTimeAgg::new); agg.add(*mins as f32); } } }); }; let t_agg = t0.elapsed(); let mut features = build_feature_maps( &groups, &poi_groups, &selectable_cells, min_keys, max_keys, avg_keys, num_features, field_indices.normal.as_deref(), &poi_metrics.feature_names, poi_field_indices, (south, west, north, east), h3_res, &travel_aggs, &travel_field_keys, enum_dist_key.as_deref(), ); let truncated = features.len() > MAX_CELLS_PER_REQUEST; if truncated { features.truncate(MAX_CELLS_PER_REQUEST); } let t_total = t0.elapsed(); info!( resolution, rows = row_count, parallel, selectable_cells = selectable_cells.len(), cells_before_filter = groups.len(), cells_after_filter = features.len(), truncated, bounds = format_args!("{:.4},{:.4},{:.4},{:.4}", south, west, north, east), filters = num_filters, filters_raw = filters_str.as_deref().unwrap_or("-"), fields = field_indices .normal .as_ref() .map(|v| (v.len() + poi_field_indices.len()) as i32) .unwrap_or(-1), travel_entries = travel_entries.len(), grid_ms = format_args!("{:.1}", t_grid.as_secs_f64() * 1000.0), agg_ms = format_args!("{:.1}", (t_agg - t_grid).as_secs_f64() * 1000.0), json_ms = format_args!("{:.1}", (t_total - t_agg).as_secs_f64() * 1000.0), total_ms = format_args!("{:.1}", t_total.as_secs_f64() * 1000.0), "GET /api/hexagons" ); histogram!("hexagons_response_count").record(features.len() as f64); Ok(HexagonsResponse { features }) }) .await .map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()).into_response())? .map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error).into_response())?; Ok(Json(response)) }