use std::sync::Arc; use axum::extract::Query; use axum::http::StatusCode; use axum::response::{IntoResponse, Json}; use axum::Extension; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; use tracing::info; use crate::aggregation::Aggregator; use crate::auth::OptionalUser; use crate::consts::{DEMO_BOUNDS, MAX_CELLS_PER_REQUEST}; use crate::data::travel_time::TravelData; use crate::licensing::check_license_bounds; use crate::parsing::{ bounds_intersect, cell_for_row, h3_cell_bounds, needs_parent, parse_field_indices, parse_filters, require_bounds, row_passes_filters, validate_h3_resolution, }; use crate::routes::travel_time::{parse_optional_travel, TravelTimeAgg}; use crate::state::AppState; #[derive(Serialize)] pub struct HexagonsResponse { features: Vec>, } #[derive(Deserialize)] pub struct HexagonParams { resolution: u8, bounds: Option, /// `;;`-separated filters: `name:min:max;;...` filters: Option, /// Comma-separated feature names to include in min/max aggregation. fields: Option, /// Pipe-separated travel time entries: `mode:slug|mode:slug:min:max` /// Each entry requests travel time aggregation for that mode+destination. /// Optional min:max applies as a filter (exclude properties outside range). travel: Option, } /// Build feature maps from aggregated cell data, filtering to only cells that intersect the query bounds. #[allow(clippy::too_many_arguments)] fn build_feature_maps( groups: &FxHashMap, min_keys: &[String], max_keys: &[String], avg_keys: &[String], num_features: usize, indices: Option<&[usize]>, query_bounds: (f64, f64, f64, f64), travel_aggs: &[FxHashMap], travel_field_keys: &[String], ) -> Vec> { let mut features = Vec::with_capacity(groups.len()); let (q_south, q_west, q_north, q_east) = query_bounds; for (&cell_id, aggregation) in groups { let Some(cell) = h3o::CellIndex::try_from(cell_id).ok() else { continue; }; // Filter out cells that don't intersect the query bounds let (c_south, c_west, c_north, c_east) = h3_cell_bounds(cell, 0.0); if !bounds_intersect( c_south, c_west, c_north, c_east, q_south, q_west, q_north, q_east, ) { continue; } let mut map = Map::new(); map.insert("h3".into(), Value::String(cell.to_string())); map.insert("count".into(), Value::Number(aggregation.count.into())); let center: h3o::LatLng = cell.into(); if let (Some(lat), Some(lon)) = ( serde_json::Number::from_f64(center.lat()), serde_json::Number::from_f64(center.lng()), ) { map.insert("lat".into(), Value::Number(lat)); map.insert("lon".into(), Value::Number(lon)); } let iter: Box> = if let Some(idx) = indices { Box::new(idx.iter().copied()) } else { Box::new(0..num_features) }; for feat_index in iter { if aggregation.feat_counts[feat_index] > 0 { let avg = aggregation.sums[feat_index] / aggregation.feat_counts[feat_index] as f64; if let (Some(min_num), Some(max_num), Some(avg_num)) = ( serde_json::Number::from_f64(aggregation.mins[feat_index] as f64), serde_json::Number::from_f64(aggregation.maxs[feat_index] as f64), serde_json::Number::from_f64(avg), ) { map.insert(min_keys[feat_index].clone(), Value::Number(min_num)); map.insert(max_keys[feat_index].clone(), Value::Number(max_num)); map.insert(avg_keys[feat_index].clone(), Value::Number(avg_num)); } } } // Add travel time aggregation fields for (ti, agg_map) in travel_aggs.iter().enumerate() { if let Some(agg) = agg_map.get(&cell_id) { if agg.count > 0 { let key = &travel_field_keys[ti]; let avg = agg.sum / agg.count as f64; if let Some(nm) = serde_json::Number::from_f64(agg.min as f64) { map.insert(format!("min_{key}"), Value::Number(nm)); } if let Some(nm) = serde_json::Number::from_f64(agg.max as f64) { map.insert(format!("max_{key}"), Value::Number(nm)); } if let Some(nm) = serde_json::Number::from_f64(avg) { map.insert(format!("avg_{key}"), Value::Number(nm)); } } } } features.push(map); } features } pub async fn get_hexagons( state: Arc, Extension(user): Extension, Query(params): Query, ) -> Result, axum::response::Response> { let resolution = params.resolution; validate_h3_resolution(resolution).map_err(IntoResponse::into_response)?; let (south, west, north, east) = require_bounds(params.bounds).map_err(IntoResponse::into_response)?; let is_demo_view = (south, west, north, east) == DEMO_BOUNDS; if !is_demo_view { check_license_bounds(&user.0, (south, west, north, east))?; } let (parsed_filters, parsed_enum_filters) = parse_filters( params.filters.as_deref(), &state.feature_name_to_index, &state.data.enum_values, ) .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; let num_filters = parsed_filters.len() + parsed_enum_filters.len(); let filters_str = params.filters; let field_indices = parse_field_indices(params.fields.as_deref(), &state.feature_name_to_index) .map_err(|err| (err.0, err.1).into_response())?; let travel_entries = parse_optional_travel(params.travel.as_deref()) .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; let response = tokio::task::spawn_blocking(move || -> Result { let t0 = std::time::Instant::now(); // Load travel time data from precomputed parquet files let travel_data: Vec = if !travel_entries.is_empty() { let store = &state.travel_time_store; travel_entries .iter() .map(|entry| { store .get(&entry.mode, &entry.slug) .map_err(|err| format!("Failed to load travel data: {}", err)) }) .collect::, _>>()? } else { Vec::new() }; let has_travel = !travel_entries.is_empty(); let travel_field_keys: Vec = travel_entries .iter() .map(|te| format!("tt_{}_{}", te.mode, te.slug)) .collect(); let num_features = state.data.num_features; let feature_data = &state.data.feature_data; let (pc_interner, pc_keys) = state.data.postcode_parts(); let min_keys = &state.min_keys; let max_keys = &state.max_keys; let avg_keys = &state.avg_keys; let h3_res = h3o::Resolution::try_from(resolution) .map_err(|error| format!("Invalid H3 resolution {}: {}", resolution, error))?; let precomputed = &state.h3_cells; let need_parent = needs_parent(resolution); let mut groups: FxHashMap = FxHashMap::default(); let mut travel_aggs: Vec> = (0..travel_entries.len()).map(|_| FxHashMap::default()).collect(); // Main aggregation loop let aggregate_row = |row: usize, groups: &mut FxHashMap, travel_aggs: &mut [FxHashMap]| { // Regular filters if !row_passes_filters( row, &parsed_filters, &parsed_enum_filters, feature_data, num_features, ) { return; } // Travel time filter: check each entry with a range let mut travel_minutes: Vec> = Vec::new(); if has_travel { let postcode = pc_interner.resolve(&pc_keys[row]); travel_minutes.reserve(travel_entries.len()); for (ti, entry) in travel_entries.iter().enumerate() { let row_data = travel_data[ti].get(postcode); let minutes = row_data.map(|r| { if entry.use_best { r.best_minutes.unwrap_or(r.minutes) } else { r.minutes } }); travel_minutes.push(minutes); if let (Some(fmin), Some(fmax)) = (entry.filter_min, entry.filter_max) { match minutes { Some(mins) if (mins as f32) >= fmin && (mins as f32) <= fmax => {} _ => return, // Filtered out } } } } let cell_id = cell_for_row(row, precomputed, h3_res, need_parent); // Aggregate regular features let aggregation = groups .entry(cell_id) .or_insert_with(|| Aggregator::new(num_features)); if let Some(sel_indices) = field_indices.as_deref() { aggregation.add_row_selective(feature_data, row, num_features, sel_indices); } else { aggregation.add_row(feature_data, row, num_features); } // Aggregate travel time for (ti, minutes) in travel_minutes.iter().enumerate() { if let Some(mins) = minutes { let agg = travel_aggs[ti] .entry(cell_id) .or_insert_with(TravelTimeAgg::new); agg.add(*mins as f32); } } }; state .grid .for_each_in_bounds(south, west, north, east, |row_idx| { aggregate_row(row_idx as usize, &mut groups, &mut travel_aggs); }); let t_agg = t0.elapsed(); let mut features = build_feature_maps( &groups, min_keys, max_keys, avg_keys, num_features, field_indices.as_deref(), (south, west, north, east), &travel_aggs, &travel_field_keys, ); let truncated = features.len() > MAX_CELLS_PER_REQUEST; if truncated { features.truncate(MAX_CELLS_PER_REQUEST); } let t_total = t0.elapsed(); info!( resolution, cells_before_filter = groups.len(), cells_after_filter = features.len(), truncated, bounds = format_args!("{:.4},{:.4},{:.4},{:.4}", south, west, north, east), filters = num_filters, filters_raw = filters_str.as_deref().unwrap_or("-"), travel_entries = travel_entries.len(), agg_ms = format_args!("{:.1}", t_agg.as_secs_f64() * 1000.0), total_ms = format_args!("{:.1}", t_total.as_secs_f64() * 1000.0), "GET /api/hexagons" ); Ok(HexagonsResponse { features }) }) .await .map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()).into_response())? .map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error).into_response())?; Ok(Json(response)) }