perfect-postcode/server-rs/src/routes/hexagons.rs

495 lines
20 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::sync::Arc;
use axum::extract::Query;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Json};
use axum::Extension;
use metrics::histogram;
use rayon::prelude::*;
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use tracing::info;
use crate::aggregation::Aggregator;
use crate::auth::OptionalUser;
use crate::consts::{DEMO_BOUNDS, MAX_CELLS_PER_REQUEST};
use crate::data::travel_time::TravelData;
use crate::licensing::check_license_bounds;
use crate::parsing::{
cell_for_row_cached, needs_parent, parse_field_indices, parse_filters, require_bounds,
row_passes_filters, validate_h3_resolution,
};
use crate::routes::travel_time::{parse_optional_travel, TravelTimeAgg};
use crate::state::AppState;
/// Row count threshold above which we use rayon parallel aggregation.
const PARALLEL_THRESHOLD: usize = 50_000;
/// Per-thread aggregation result: feature accumulators + travel time accumulators.
type ChunkResult = (
FxHashMap<u64, Aggregator>,
Vec<FxHashMap<u64, TravelTimeAgg>>,
);
/// Maximum center-to-vertex distance in degrees per H3 resolution.
/// Generous for UK latitudes (49°61°) so we never false-exclude a visible cell.
/// Used for cheap center-based bounds filtering instead of computing full cell boundary.
const H3_CENTER_BUFFERS: [f64; 13] = [
5.0, 2.0, 1.0, 0.5, // res 03 (unused in practice)
0.50, // res 4
0.20, // res 5
0.08, // res 6
0.03, // res 7
0.012, // res 8
0.005, // res 9
0.002, // res 10
0.001, // res 11
0.0005, // res 12
];
#[derive(Serialize)]
pub struct HexagonsResponse {
features: Vec<Map<String, Value>>,
}
#[derive(Deserialize)]
pub struct HexagonParams {
resolution: u8,
bounds: Option<String>,
/// `;;`-separated filters: `name:min:max;;...`
filters: Option<String>,
/// Comma-separated feature names to include in min/max aggregation.
fields: Option<String>,
/// Pipe-separated travel time entries: `mode:slug|mode:slug:min:max`
/// Each entry requests travel time aggregation for that mode+destination.
/// Optional min:max applies as a filter (exclude properties outside range).
travel: Option<String>,
}
/// Build feature maps from aggregated cell data, filtering to only cells whose
/// center is within the query bounds (expanded by a resolution-dependent buffer).
/// This is much cheaper than the previous approach of computing full cell boundaries
/// (6 vertices per cell) — just 4 float comparisons per cell.
#[allow(clippy::too_many_arguments)]
fn build_feature_maps(
groups: &FxHashMap<u64, Aggregator>,
min_keys: &[String],
max_keys: &[String],
avg_keys: &[String],
num_features: usize,
indices: Option<&[usize]>,
query_bounds: (f64, f64, f64, f64),
resolution: h3o::Resolution,
travel_aggs: &[FxHashMap<u64, TravelTimeAgg>],
travel_field_keys: &[String],
) -> Vec<Map<String, Value>> {
let mut features = Vec::with_capacity(groups.len());
let (q_south, q_west, q_north, q_east) = query_bounds;
// Expand bounds by resolution-dependent buffer for center-based filtering
let buf = H3_CENTER_BUFFERS[resolution as usize];
let bound_south = q_south - buf;
let bound_north = q_north + buf;
let bound_west = q_west - buf;
let bound_east = q_east + buf;
// Pre-compute travel time key strings (avoids per-cell format!())
let travel_keys: Vec<(String, String, String)> = travel_field_keys
.iter()
.map(|key| {
(
format!("min_{key}"),
format!("max_{key}"),
format!("avg_{key}"),
)
})
.collect();
// Pre-compute default feature indices to avoid per-cell Box<dyn Iterator> allocation
let default_indices: Vec<usize>;
let feat_indices: &[usize] = match indices {
Some(idx) => idx,
None => {
default_indices = (0..num_features).collect();
&default_indices
}
};
for (&cell_id, aggregation) in groups {
let Some(cell) = h3o::CellIndex::try_from(cell_id).ok() else {
continue;
};
// Center is already needed for lat/lon output — reuse for bounds check
let center: h3o::LatLng = cell.into();
let lat = center.lat();
let lng = center.lng();
// Center-based bounds check: 4 comparisons instead of computing 6 boundary vertices
if lat < bound_south || lat > bound_north || lng < bound_west || lng > bound_east {
continue;
}
let mut map = Map::new();
map.insert("h3".into(), Value::String(cell.to_string()));
map.insert("count".into(), Value::Number(aggregation.count.into()));
if let (Some(lat_num), Some(lon_num)) = (
serde_json::Number::from_f64(lat),
serde_json::Number::from_f64(lng),
) {
map.insert("lat".into(), Value::Number(lat_num));
map.insert("lon".into(), Value::Number(lon_num));
}
for &feat_index in feat_indices {
if aggregation.feat_counts[feat_index] > 0 {
let avg = aggregation.sums[feat_index] / aggregation.feat_counts[feat_index] as f64;
if let (Some(min_num), Some(max_num), Some(avg_num)) = (
serde_json::Number::from_f64(aggregation.mins[feat_index] as f64),
serde_json::Number::from_f64(aggregation.maxs[feat_index] as f64),
serde_json::Number::from_f64(avg),
) {
map.insert(min_keys[feat_index].clone(), Value::Number(min_num));
map.insert(max_keys[feat_index].clone(), Value::Number(max_num));
map.insert(avg_keys[feat_index].clone(), Value::Number(avg_num));
}
}
}
// Add travel time aggregation fields (using pre-computed key strings)
for (ti, agg_map) in travel_aggs.iter().enumerate() {
if let Some(agg) = agg_map.get(&cell_id) {
if agg.count > 0 {
let avg = agg.sum / agg.count as f64;
if let Some(nm) = serde_json::Number::from_f64(agg.min as f64) {
map.insert(travel_keys[ti].0.clone(), Value::Number(nm));
}
if let Some(nm) = serde_json::Number::from_f64(agg.max as f64) {
map.insert(travel_keys[ti].1.clone(), Value::Number(nm));
}
if let Some(nm) = serde_json::Number::from_f64(avg) {
map.insert(travel_keys[ti].2.clone(), Value::Number(nm));
}
}
}
}
features.push(map);
}
features
}
pub async fn get_hexagons(
state: Arc<AppState>,
Extension(user): Extension<OptionalUser>,
Query(params): Query<HexagonParams>,
) -> Result<Json<HexagonsResponse>, axum::response::Response> {
let resolution = params.resolution;
validate_h3_resolution(resolution).map_err(IntoResponse::into_response)?;
let (south, west, north, east) =
require_bounds(params.bounds).map_err(IntoResponse::into_response)?;
let is_demo_view = (south, west, north, east) == DEMO_BOUNDS;
if !is_demo_view {
check_license_bounds(&user.0, (south, west, north, east))?;
}
let quant = state.data.quant_ref();
let (parsed_filters, parsed_enum_filters) = parse_filters(
params.filters.as_deref(),
&state.feature_name_to_index,
&state.data.enum_values,
&quant,
)
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
let filters_str = params.filters;
let field_indices = parse_field_indices(params.fields.as_deref(), &state.feature_name_to_index)
.map_err(|err| (err.0, err.1).into_response())?;
let travel_entries = parse_optional_travel(params.travel.as_deref())
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let response = tokio::task::spawn_blocking(move || -> Result<HexagonsResponse, String> {
let t0 = std::time::Instant::now();
// Load travel time data from precomputed parquet files
let travel_data: Vec<TravelData> = if !travel_entries.is_empty() {
let store = &state.travel_time_store;
travel_entries
.iter()
.map(|entry| {
store
.get(&entry.mode, &entry.slug)
.map_err(|err| format!("Failed to load travel data: {}", err))
})
.collect::<Result<Vec<_>, _>>()?
} else {
Vec::new()
};
let has_travel = !travel_entries.is_empty();
let travel_field_keys: Vec<String> = travel_entries
.iter()
.map(|te| format!("tt_{}_{}", te.mode, te.slug))
.collect();
let num_features = state.data.num_features;
let feature_data = &state.data.feature_data;
let quant = state.data.quant_ref();
let (pc_interner, pc_keys) = state.data.postcode_parts();
let min_keys = &state.min_keys;
let max_keys = &state.max_keys;
let avg_keys = &state.avg_keys;
let h3_res = h3o::Resolution::try_from(resolution)
.map_err(|error| format!("Invalid H3 resolution {}: {}", resolution, error))?;
let precomputed = &state.h3_cells;
let need_parent = needs_parent(resolution);
let mut groups: FxHashMap<u64, Aggregator> = FxHashMap::default();
let mut travel_aggs: Vec<FxHashMap<u64, TravelTimeAgg>> = (0..travel_entries.len())
.map(|_| FxHashMap::default())
.collect();
// O(grid cells) count — no allocation. Used for parallel threshold decision.
let row_count = state.grid.count_in_bounds(south, west, north, east);
let t_grid = t0.elapsed();
let parallel = row_count >= PARALLEL_THRESHOLD;
if parallel {
// Parallel: collect row indices for par_chunks, split across rayon threads.
// Now handles travel time too (postcode interner & travel data are thread-safe).
let row_indices = state.grid.query(south, west, north, east);
let chunk_size = (row_indices.len() / rayon::current_num_threads()).max(1000);
let thread_results: Vec<ChunkResult> = row_indices
.par_chunks(chunk_size)
.map(|chunk| {
let mut local_groups: FxHashMap<u64, Aggregator> = FxHashMap::default();
let mut local_travel_aggs: Vec<FxHashMap<u64, TravelTimeAgg>> = (0
..travel_entries.len())
.map(|_| FxHashMap::default())
.collect();
let mut h3_cache: FxHashMap<u64, u64> = FxHashMap::default();
let mut travel_minutes: Vec<Option<i16>> =
Vec::with_capacity(travel_entries.len());
'row: for &row_idx in chunk {
let row = row_idx as usize;
if !row_passes_filters(
row,
&parsed_filters,
&parsed_enum_filters,
feature_data,
num_features,
) {
continue;
}
if has_travel {
travel_minutes.clear();
let postcode = pc_interner.resolve(&pc_keys[row]);
for (ti, entry) in travel_entries.iter().enumerate() {
let row_data = travel_data[ti].get(postcode);
let minutes = row_data.map(|r| {
if entry.use_best {
r.best_minutes.unwrap_or(r.minutes)
} else {
r.minutes
}
});
travel_minutes.push(minutes);
if let (Some(fmin), Some(fmax)) =
(entry.filter_min, entry.filter_max)
{
match minutes {
Some(mins)
if (mins as f32) >= fmin && (mins as f32) <= fmax => {}
_ => continue 'row,
}
}
}
}
let cell_id = cell_for_row_cached(
row,
precomputed,
h3_res,
need_parent,
&mut h3_cache,
);
let agg = local_groups
.entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features));
if let Some(sel_indices) = field_indices.as_deref() {
agg.add_row_selective(
feature_data,
row,
num_features,
sel_indices,
&quant,
);
} else {
agg.add_row(feature_data, row, num_features, &quant);
}
for (ti, minutes) in travel_minutes.iter().enumerate() {
if let Some(mins) = minutes {
let tagg = local_travel_aggs[ti]
.entry(cell_id)
.or_insert_with(TravelTimeAgg::new);
tagg.add(*mins as f32);
}
}
}
(local_groups, local_travel_aggs)
})
.collect();
// Merge thread-local results into the main accumulators
for (local_groups, local_travel) in thread_results {
for (cell_id, local_agg) in local_groups {
groups
.entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features))
.merge(&local_agg);
}
for (ti, local_ta) in local_travel.into_iter().enumerate() {
for (cell_id, local_tt) in local_ta {
travel_aggs[ti]
.entry(cell_id)
.or_insert_with(TravelTimeAgg::new)
.merge(&local_tt);
}
}
}
} else {
// Sequential: use for_each_in_bounds to avoid Vec<u32> allocation
let mut travel_minutes: Vec<Option<i16>> = Vec::with_capacity(travel_entries.len());
let mut h3_cache: FxHashMap<u64, u64> = FxHashMap::default();
state
.grid
.for_each_in_bounds(south, west, north, east, |row_idx| {
let row = row_idx as usize;
if !row_passes_filters(
row,
&parsed_filters,
&parsed_enum_filters,
feature_data,
num_features,
) {
return;
}
if has_travel {
travel_minutes.clear();
let postcode = pc_interner.resolve(&pc_keys[row]);
for (ti, entry) in travel_entries.iter().enumerate() {
let row_data = travel_data[ti].get(postcode);
let minutes = row_data.map(|r| {
if entry.use_best {
r.best_minutes.unwrap_or(r.minutes)
} else {
r.minutes
}
});
travel_minutes.push(minutes);
if let (Some(fmin), Some(fmax)) = (entry.filter_min, entry.filter_max) {
match minutes {
Some(mins)
if (mins as f32) >= fmin && (mins as f32) <= fmax => {}
_ => return,
}
}
}
}
let cell_id =
cell_for_row_cached(row, precomputed, h3_res, need_parent, &mut h3_cache);
let aggregation = groups
.entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features));
if let Some(sel_indices) = field_indices.as_deref() {
aggregation.add_row_selective(
feature_data,
row,
num_features,
sel_indices,
&quant,
);
} else {
aggregation.add_row(feature_data, row, num_features, &quant);
}
for (ti, minutes) in travel_minutes.iter().enumerate() {
if let Some(mins) = minutes {
let agg = travel_aggs[ti]
.entry(cell_id)
.or_insert_with(TravelTimeAgg::new);
agg.add(*mins as f32);
}
}
});
};
let t_agg = t0.elapsed();
let mut features = build_feature_maps(
&groups,
min_keys,
max_keys,
avg_keys,
num_features,
field_indices.as_deref(),
(south, west, north, east),
h3_res,
&travel_aggs,
&travel_field_keys,
);
let truncated = features.len() > MAX_CELLS_PER_REQUEST;
if truncated {
features.truncate(MAX_CELLS_PER_REQUEST);
}
let t_total = t0.elapsed();
info!(
resolution,
rows = row_count,
parallel,
cells_before_filter = groups.len(),
cells_after_filter = features.len(),
truncated,
bounds = format_args!("{:.4},{:.4},{:.4},{:.4}", south, west, north, east),
filters = num_filters,
filters_raw = filters_str.as_deref().unwrap_or("-"),
fields = field_indices.as_ref().map(|v| v.len() as i32).unwrap_or(-1),
travel_entries = travel_entries.len(),
grid_ms = format_args!("{:.1}", t_grid.as_secs_f64() * 1000.0),
agg_ms = format_args!("{:.1}", (t_agg - t_grid).as_secs_f64() * 1000.0),
json_ms = format_args!("{:.1}", (t_total - t_agg).as_secs_f64() * 1000.0),
total_ms = format_args!("{:.1}", t_total.as_secs_f64() * 1000.0),
"GET /api/hexagons"
);
histogram!("hexagons_response_count").record(features.len() as f64);
Ok(HexagonsResponse { features })
})
.await
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()).into_response())?
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error).into_response())?;
Ok(Json(response))
}