Morning improvements

This commit is contained in:
Andras Schmelczer 2026-03-17 13:29:03 +00:00
parent 3e9fba5303
commit 53fff3efaa
41 changed files with 2438 additions and 637 deletions

View file

@ -17,8 +17,8 @@ use crate::consts::{DEMO_BOUNDS, MAX_CELLS_PER_REQUEST};
use crate::data::travel_time::TravelData;
use crate::licensing::check_license_bounds;
use crate::parsing::{
bounds_intersect, cell_for_row_cached, h3_cell_bounds, needs_parent, parse_field_indices,
parse_filters, require_bounds, row_passes_filters, validate_h3_resolution,
cell_for_row_cached, needs_parent, parse_field_indices, parse_filters, require_bounds,
row_passes_filters, validate_h3_resolution,
};
use crate::routes::travel_time::{parse_optional_travel, TravelTimeAgg};
use crate::state::AppState;
@ -26,6 +26,28 @@ use crate::state::AppState;
/// Row count threshold above which we use rayon parallel aggregation.
const PARALLEL_THRESHOLD: usize = 50_000;
/// Per-thread aggregation result: feature accumulators + travel time accumulators.
type ChunkResult = (
FxHashMap<u64, Aggregator>,
Vec<FxHashMap<u64, TravelTimeAgg>>,
);
/// Maximum center-to-vertex distance in degrees per H3 resolution.
/// Generous for UK latitudes (49°61°) so we never false-exclude a visible cell.
/// Used for cheap center-based bounds filtering instead of computing full cell boundary.
const H3_CENTER_BUFFERS: [f64; 13] = [
5.0, 2.0, 1.0, 0.5, // res 03 (unused in practice)
0.50, // res 4
0.20, // res 5
0.08, // res 6
0.03, // res 7
0.012, // res 8
0.005, // res 9
0.002, // res 10
0.001, // res 11
0.0005, // res 12
];
#[derive(Serialize)]
pub struct HexagonsResponse {
features: Vec<Map<String, Value>>,
@ -45,7 +67,10 @@ pub struct HexagonParams {
travel: Option<String>,
}
/// Build feature maps from aggregated cell data, filtering to only cells that intersect the query bounds.
/// Build feature maps from aggregated cell data, filtering to only cells whose
/// center is within the query bounds (expanded by a resolution-dependent buffer).
/// This is much cheaper than the previous approach of computing full cell boundaries
/// (6 vertices per cell) — just 4 float comparisons per cell.
#[allow(clippy::too_many_arguments)]
fn build_feature_maps(
groups: &FxHashMap<u64, Aggregator>,
@ -55,44 +80,69 @@ fn build_feature_maps(
num_features: usize,
indices: Option<&[usize]>,
query_bounds: (f64, f64, f64, f64),
resolution: h3o::Resolution,
travel_aggs: &[FxHashMap<u64, TravelTimeAgg>],
travel_field_keys: &[String],
) -> Vec<Map<String, Value>> {
let mut features = Vec::with_capacity(groups.len());
let (q_south, q_west, q_north, q_east) = query_bounds;
// Expand bounds by resolution-dependent buffer for center-based filtering
let buf = H3_CENTER_BUFFERS[resolution as usize];
let bound_south = q_south - buf;
let bound_north = q_north + buf;
let bound_west = q_west - buf;
let bound_east = q_east + buf;
// Pre-compute travel time key strings (avoids per-cell format!())
let travel_keys: Vec<(String, String, String)> = travel_field_keys
.iter()
.map(|key| {
(
format!("min_{key}"),
format!("max_{key}"),
format!("avg_{key}"),
)
})
.collect();
// Pre-compute default feature indices to avoid per-cell Box<dyn Iterator> allocation
let default_indices: Vec<usize>;
let feat_indices: &[usize] = match indices {
Some(idx) => idx,
None => {
default_indices = (0..num_features).collect();
&default_indices
}
};
for (&cell_id, aggregation) in groups {
let Some(cell) = h3o::CellIndex::try_from(cell_id).ok() else {
continue;
};
// Filter out cells that don't intersect the query bounds
let (c_south, c_west, c_north, c_east) = h3_cell_bounds(cell, 0.0);
if !bounds_intersect(
c_south, c_west, c_north, c_east, q_south, q_west, q_north, q_east,
) {
// Center is already needed for lat/lon output — reuse for bounds check
let center: h3o::LatLng = cell.into();
let lat = center.lat();
let lng = center.lng();
// Center-based bounds check: 4 comparisons instead of computing 6 boundary vertices
if lat < bound_south || lat > bound_north || lng < bound_west || lng > bound_east {
continue;
}
let mut map = Map::new();
map.insert("h3".into(), Value::String(cell.to_string()));
map.insert("count".into(), Value::Number(aggregation.count.into()));
let center: h3o::LatLng = cell.into();
if let (Some(lat), Some(lon)) = (
serde_json::Number::from_f64(center.lat()),
serde_json::Number::from_f64(center.lng()),
if let (Some(lat_num), Some(lon_num)) = (
serde_json::Number::from_f64(lat),
serde_json::Number::from_f64(lng),
) {
map.insert("lat".into(), Value::Number(lat));
map.insert("lon".into(), Value::Number(lon));
map.insert("lat".into(), Value::Number(lat_num));
map.insert("lon".into(), Value::Number(lon_num));
}
let iter: Box<dyn Iterator<Item = usize>> = if let Some(idx) = indices {
Box::new(idx.iter().copied())
} else {
Box::new(0..num_features)
};
for feat_index in iter {
for &feat_index in feat_indices {
if aggregation.feat_counts[feat_index] > 0 {
let avg = aggregation.sums[feat_index] / aggregation.feat_counts[feat_index] as f64;
if let (Some(min_num), Some(max_num), Some(avg_num)) = (
@ -107,20 +157,19 @@ fn build_feature_maps(
}
}
// Add travel time aggregation fields
// Add travel time aggregation fields (using pre-computed key strings)
for (ti, agg_map) in travel_aggs.iter().enumerate() {
if let Some(agg) = agg_map.get(&cell_id) {
if agg.count > 0 {
let key = &travel_field_keys[ti];
let avg = agg.sum / agg.count as f64;
if let Some(nm) = serde_json::Number::from_f64(agg.min as f64) {
map.insert(format!("min_{key}"), Value::Number(nm));
map.insert(travel_keys[ti].0.clone(), Value::Number(nm));
}
if let Some(nm) = serde_json::Number::from_f64(agg.max as f64) {
map.insert(format!("max_{key}"), Value::Number(nm));
map.insert(travel_keys[ti].1.clone(), Value::Number(nm));
}
if let Some(nm) = serde_json::Number::from_f64(avg) {
map.insert(format!("avg_{key}"), Value::Number(nm));
map.insert(travel_keys[ti].2.clone(), Value::Number(nm));
}
}
}
@ -207,19 +256,31 @@ pub async fn get_hexagons(
.map(|_| FxHashMap::default())
.collect();
// Collect row indices for threshold-based sequential/parallel aggregation
let row_indices = state.grid.query(south, west, north, east);
// O(grid cells) count — no allocation. Used for parallel threshold decision.
let row_count = state.grid.count_in_bounds(south, west, north, east);
let t_grid = t0.elapsed();
if row_indices.len() >= PARALLEL_THRESHOLD && !has_travel {
// Parallel path: split rows across rayon threads, each with local accumulators
let parallel = row_count >= PARALLEL_THRESHOLD;
if parallel {
// Parallel: collect row indices for par_chunks, split across rayon threads.
// Now handles travel time too (postcode interner & travel data are thread-safe).
let row_indices = state.grid.query(south, west, north, east);
let chunk_size = (row_indices.len() / rayon::current_num_threads()).max(1000);
let thread_results: Vec<FxHashMap<u64, Aggregator>> = row_indices
let thread_results: Vec<ChunkResult> = row_indices
.par_chunks(chunk_size)
.map(|chunk| {
let mut local_groups: FxHashMap<u64, Aggregator> = FxHashMap::default();
let mut local_travel_aggs: Vec<FxHashMap<u64, TravelTimeAgg>> = (0
..travel_entries.len())
.map(|_| FxHashMap::default())
.collect();
let mut h3_cache: FxHashMap<u64, u64> = FxHashMap::default();
for &row_idx in chunk {
let mut travel_minutes: Vec<Option<i16>> =
Vec::with_capacity(travel_entries.len());
'row: for &row_idx in chunk {
let row = row_idx as usize;
if !row_passes_filters(
row,
@ -230,6 +291,32 @@ pub async fn get_hexagons(
) {
continue;
}
if has_travel {
travel_minutes.clear();
let postcode = pc_interner.resolve(&pc_keys[row]);
for (ti, entry) in travel_entries.iter().enumerate() {
let row_data = travel_data[ti].get(postcode);
let minutes = row_data.map(|r| {
if entry.use_best {
r.best_minutes.unwrap_or(r.minutes)
} else {
r.minutes
}
});
travel_minutes.push(minutes);
if let (Some(fmin), Some(fmax)) =
(entry.filter_min, entry.filter_max)
{
match minutes {
Some(mins)
if (mins as f32) >= fmin && (mins as f32) <= fmax => {}
_ => continue 'row,
}
}
}
}
let cell_id = cell_for_row_cached(
row,
precomputed,
@ -237,6 +324,7 @@ pub async fn get_hexagons(
need_parent,
&mut h3_cache,
);
let agg = local_groups
.entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features));
@ -251,91 +339,108 @@ pub async fn get_hexagons(
} else {
agg.add_row(feature_data, row, num_features, &quant);
}
}
local_groups
})
.collect();
// Merge thread-local results into the main groups map
for local_groups in thread_results {
for (cell_id, local_agg) in local_groups {
let agg = groups
.entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features));
agg.merge(&local_agg);
}
}
} else {
// Sequential path (also handles travel time which needs postcode lookups)
let mut travel_minutes: Vec<Option<i16>> = Vec::with_capacity(travel_entries.len());
let mut h3_cache: FxHashMap<u64, u64> = FxHashMap::default();
'row: for &row_idx in &row_indices {
let row = row_idx as usize;
// Regular filters
if !row_passes_filters(
row,
&parsed_filters,
&parsed_enum_filters,
feature_data,
num_features,
) {
continue;
}
// Travel time filter: check each entry with a range
if has_travel {
travel_minutes.clear();
let postcode = pc_interner.resolve(&pc_keys[row]);
for (ti, entry) in travel_entries.iter().enumerate() {
let row_data = travel_data[ti].get(postcode);
let minutes = row_data.map(|r| {
if entry.use_best {
r.best_minutes.unwrap_or(r.minutes)
} else {
r.minutes
}
});
travel_minutes.push(minutes);
if let (Some(fmin), Some(fmax)) = (entry.filter_min, entry.filter_max) {
match minutes {
Some(mins) if (mins as f32) >= fmin && (mins as f32) <= fmax => {}
_ => continue 'row, // Filtered out (jump to next row_idx)
for (ti, minutes) in travel_minutes.iter().enumerate() {
if let Some(mins) = minutes {
let tagg = local_travel_aggs[ti]
.entry(cell_id)
.or_insert_with(TravelTimeAgg::new);
tagg.add(*mins as f32);
}
}
}
(local_groups, local_travel_aggs)
})
.collect();
// Merge thread-local results into the main accumulators
for (local_groups, local_travel) in thread_results {
for (cell_id, local_agg) in local_groups {
groups
.entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features))
.merge(&local_agg);
}
let cell_id =
cell_for_row_cached(row, precomputed, h3_res, need_parent, &mut h3_cache);
// Aggregate regular features
let aggregation = groups
.entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features));
if let Some(sel_indices) = field_indices.as_deref() {
aggregation.add_row_selective(
feature_data,
row,
num_features,
sel_indices,
&quant,
);
} else {
aggregation.add_row(feature_data, row, num_features, &quant);
}
// Aggregate travel time
for (ti, minutes) in travel_minutes.iter().enumerate() {
if let Some(mins) = minutes {
let agg = travel_aggs[ti]
for (ti, local_ta) in local_travel.into_iter().enumerate() {
for (cell_id, local_tt) in local_ta {
travel_aggs[ti]
.entry(cell_id)
.or_insert_with(TravelTimeAgg::new);
agg.add(*mins as f32);
.or_insert_with(TravelTimeAgg::new)
.merge(&local_tt);
}
}
}
} else {
// Sequential: use for_each_in_bounds to avoid Vec<u32> allocation
let mut travel_minutes: Vec<Option<i16>> = Vec::with_capacity(travel_entries.len());
let mut h3_cache: FxHashMap<u64, u64> = FxHashMap::default();
state
.grid
.for_each_in_bounds(south, west, north, east, |row_idx| {
let row = row_idx as usize;
if !row_passes_filters(
row,
&parsed_filters,
&parsed_enum_filters,
feature_data,
num_features,
) {
return;
}
if has_travel {
travel_minutes.clear();
let postcode = pc_interner.resolve(&pc_keys[row]);
for (ti, entry) in travel_entries.iter().enumerate() {
let row_data = travel_data[ti].get(postcode);
let minutes = row_data.map(|r| {
if entry.use_best {
r.best_minutes.unwrap_or(r.minutes)
} else {
r.minutes
}
});
travel_minutes.push(minutes);
if let (Some(fmin), Some(fmax)) = (entry.filter_min, entry.filter_max) {
match minutes {
Some(mins)
if (mins as f32) >= fmin && (mins as f32) <= fmax => {}
_ => return,
}
}
}
}
let cell_id =
cell_for_row_cached(row, precomputed, h3_res, need_parent, &mut h3_cache);
let aggregation = groups
.entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features));
if let Some(sel_indices) = field_indices.as_deref() {
aggregation.add_row_selective(
feature_data,
row,
num_features,
sel_indices,
&quant,
);
} else {
aggregation.add_row(feature_data, row, num_features, &quant);
}
for (ti, minutes) in travel_minutes.iter().enumerate() {
if let Some(mins) = minutes {
let agg = travel_aggs[ti]
.entry(cell_id)
.or_insert_with(TravelTimeAgg::new);
agg.add(*mins as f32);
}
}
});
};
let t_agg = t0.elapsed();
@ -348,6 +453,7 @@ pub async fn get_hexagons(
num_features,
field_indices.as_deref(),
(south, west, north, east),
h3_res,
&travel_aggs,
&travel_field_keys,
);
@ -357,11 +463,10 @@ pub async fn get_hexagons(
features.truncate(MAX_CELLS_PER_REQUEST);
}
let parallel = row_indices.len() >= PARALLEL_THRESHOLD && !has_travel;
let t_total = t0.elapsed();
info!(
resolution,
rows = row_indices.len(),
rows = row_count,
parallel,
cells_before_filter = groups.len(),
cells_after_filter = features.len(),
@ -369,8 +474,11 @@ pub async fn get_hexagons(
bounds = format_args!("{:.4},{:.4},{:.4},{:.4}", south, west, north, east),
filters = num_filters,
filters_raw = filters_str.as_deref().unwrap_or("-"),
fields = field_indices.as_ref().map(|v| v.len() as i32).unwrap_or(-1),
travel_entries = travel_entries.len(),
agg_ms = format_args!("{:.1}", t_agg.as_secs_f64() * 1000.0),
grid_ms = format_args!("{:.1}", t_grid.as_secs_f64() * 1000.0),
agg_ms = format_args!("{:.1}", (t_agg - t_grid).as_secs_f64() * 1000.0),
json_ms = format_args!("{:.1}", (t_total - t_agg).as_secs_f64() * 1000.0),
total_ms = format_args!("{:.1}", t_total.as_secs_f64() * 1000.0),
"GET /api/hexagons"
);