Rust things

This commit is contained in:
Andras Schmelczer 2026-05-10 14:55:43 +01:00
parent fc10381692
commit 3debacab4f
30 changed files with 3257 additions and 647 deletions

View file

@ -11,14 +11,15 @@ use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use tracing::info;
use crate::aggregation::{Aggregator, EnumDistConfig};
use crate::aggregation::{Aggregator, EnumDistConfig, PoiAggregator};
use crate::auth::OptionalUser;
use crate::consts::MAX_CELLS_PER_REQUEST;
use crate::data::travel_time::TravelData;
use crate::licensing::{check_license_bounds, resolve_share_code};
use crate::parsing::{
cell_for_row_cached, needs_parent, parse_enum_dist, parse_field_indices, parse_filters,
require_bounds, row_passes_filters, validate_h3_resolution,
cell_for_row_cached, needs_parent, parse_enum_dist, parse_field_indices_with_poi,
parse_filters_with_poi, require_bounds, row_passes_filters, row_passes_poi_filters,
validate_h3_resolution,
};
use crate::routes::travel_time::{parse_optional_travel, TravelTimeAgg};
use crate::state::SharedState;
@ -29,6 +30,7 @@ const PARALLEL_THRESHOLD: usize = 50_000;
/// Per-thread aggregation result: feature accumulators + travel time accumulators.
type ChunkResult = (
FxHashMap<u64, Aggregator>,
FxHashMap<u64, PoiAggregator>,
Vec<FxHashMap<u64, TravelTimeAgg>>,
);
@ -79,11 +81,14 @@ pub struct HexagonParams {
#[allow(clippy::too_many_arguments)]
fn build_feature_maps(
groups: &FxHashMap<u64, Aggregator>,
poi_groups: &FxHashMap<u64, PoiAggregator>,
min_keys: &[String],
max_keys: &[String],
avg_keys: &[String],
num_features: usize,
indices: Option<&[usize]>,
poi_feature_names: &[String],
poi_indices: &[usize],
query_bounds: (f64, f64, f64, f64),
resolution: h3o::Resolution,
travel_aggs: &[FxHashMap<u64, TravelTimeAgg>],
@ -163,6 +168,25 @@ fn build_feature_maps(
}
}
if let Some(poi_aggregation) = poi_groups.get(&cell_id) {
for &metric_idx in poi_indices {
if poi_aggregation.counts[metric_idx] > 0 {
let avg = poi_aggregation.sums[metric_idx]
/ poi_aggregation.counts[metric_idx] as f64;
if let (Some(min_num), Some(max_num), Some(avg_num)) = (
serde_json::Number::from_f64(poi_aggregation.mins[metric_idx] as f64),
serde_json::Number::from_f64(poi_aggregation.maxs[metric_idx] as f64),
serde_json::Number::from_f64(avg),
) {
let name = &poi_feature_names[metric_idx];
map.insert(format!("min_{name}"), Value::Number(min_num));
map.insert(format!("max_{name}"), Value::Number(max_num));
map.insert(format!("avg_{name}"), Value::Number(avg_num));
}
}
}
}
// Add travel time aggregation fields (using pre-computed key strings)
for (ti, agg_map) in travel_aggs.iter().enumerate() {
if let Some(agg) = agg_map.get(&cell_id) {
@ -209,18 +233,25 @@ pub async fn get_hexagons(
check_license_bounds(&user.0, (south, west, north, east), share_bounds)?;
let quant = state.data.quant_ref();
let (parsed_filters, parsed_enum_filters) = parse_filters(
let poi_quant = state.data.poi_metrics.quant_ref();
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
params.filters.as_deref(),
&state.feature_name_to_index,
&state.data.enum_values,
&quant,
&state.data.poi_metrics.name_to_index,
&poi_quant,
)
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
let filters_str = params.filters;
let field_indices = parse_field_indices(params.fields.as_deref(), &state.feature_name_to_index)
.map_err(|err| (err.0, err.1).into_response())?;
let field_indices = parse_field_indices_with_poi(
params.fields.as_deref(),
&state.feature_name_to_index,
&state.data.poi_metrics.name_to_index,
)
.map_err(|err| (err.0, err.1).into_response())?;
let travel_entries = parse_optional_travel(params.travel.as_deref())
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
@ -269,6 +300,11 @@ pub async fn get_hexagons(
let min_keys = &state.min_keys;
let max_keys = &state.max_keys;
let avg_keys = &state.avg_keys;
let poi_metrics = &state.data.poi_metrics;
let poi_field_indices = field_indices.poi.as_slice();
let has_poi_fields = !poi_field_indices.is_empty();
let has_poi_filters = !parsed_poi_filters.is_empty();
let poi_num_features = poi_metrics.num_features();
let h3_res = h3o::Resolution::try_from(resolution)
.map_err(|error| format!("Invalid H3 resolution {}: {}", resolution, error))?;
@ -276,6 +312,7 @@ pub async fn get_hexagons(
let need_parent = needs_parent(resolution);
let mut groups: FxHashMap<u64, Aggregator> = FxHashMap::default();
let mut poi_groups: FxHashMap<u64, PoiAggregator> = FxHashMap::default();
let mut travel_aggs: Vec<FxHashMap<u64, TravelTimeAgg>> = (0..travel_entries.len())
.map(|_| FxHashMap::default())
.collect();
@ -296,6 +333,7 @@ pub async fn get_hexagons(
.par_chunks(chunk_size)
.map(|chunk| {
let mut local_groups: FxHashMap<u64, Aggregator> = FxHashMap::default();
let mut local_poi_groups: FxHashMap<u64, PoiAggregator> = FxHashMap::default();
let mut local_travel_aggs: Vec<FxHashMap<u64, TravelTimeAgg>> = (0
..travel_entries.len())
.map(|_| FxHashMap::default())
@ -315,6 +353,11 @@ pub async fn get_hexagons(
) {
continue;
}
if has_poi_filters
&& !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics)
{
continue;
}
if has_travel {
travel_minutes.clear();
@ -352,7 +395,7 @@ pub async fn get_hexagons(
let agg = local_groups
.entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features, enum_dist_config));
if let Some(sel_indices) = field_indices.as_deref() {
if let Some(sel_indices) = field_indices.normal.as_deref() {
agg.add_row_selective(
feature_data,
row,
@ -364,6 +407,13 @@ pub async fn get_hexagons(
agg.add_row(feature_data, row, num_features, &quant);
}
if has_poi_fields {
local_poi_groups
.entry(cell_id)
.or_insert_with(|| PoiAggregator::new(poi_num_features))
.add_row_selective(poi_metrics, row, poi_field_indices);
}
for (ti, minutes) in travel_minutes.iter().enumerate() {
if let Some(mins) = minutes {
let tagg = local_travel_aggs[ti]
@ -374,18 +424,24 @@ pub async fn get_hexagons(
}
}
(local_groups, local_travel_aggs)
(local_groups, local_poi_groups, local_travel_aggs)
})
.collect();
// Merge thread-local results into the main accumulators
for (local_groups, local_travel) in thread_results {
for (local_groups, local_poi_groups, local_travel) in thread_results {
for (cell_id, local_agg) in local_groups {
groups
.entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features, enum_dist_config))
.merge(&local_agg);
}
for (cell_id, local_agg) in local_poi_groups {
poi_groups
.entry(cell_id)
.or_insert_with(|| PoiAggregator::new(poi_num_features))
.merge(&local_agg);
}
for (ti, local_ta) in local_travel.into_iter().enumerate() {
for (cell_id, local_tt) in local_ta {
travel_aggs[ti]
@ -414,6 +470,11 @@ pub async fn get_hexagons(
) {
return;
}
if has_poi_filters
&& !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics)
{
return;
}
if has_travel {
travel_minutes.clear();
@ -444,7 +505,7 @@ pub async fn get_hexagons(
let aggregation = groups
.entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features, enum_dist_config));
if let Some(sel_indices) = field_indices.as_deref() {
if let Some(sel_indices) = field_indices.normal.as_deref() {
aggregation.add_row_selective(
feature_data,
row,
@ -456,6 +517,13 @@ pub async fn get_hexagons(
aggregation.add_row(feature_data, row, num_features, &quant);
}
if has_poi_fields {
poi_groups
.entry(cell_id)
.or_insert_with(|| PoiAggregator::new(poi_num_features))
.add_row_selective(poi_metrics, row, poi_field_indices);
}
for (ti, minutes) in travel_minutes.iter().enumerate() {
if let Some(mins) = minutes {
let agg = travel_aggs[ti]
@ -471,11 +539,14 @@ pub async fn get_hexagons(
let mut features = build_feature_maps(
&groups,
&poi_groups,
min_keys,
max_keys,
avg_keys,
num_features,
field_indices.as_deref(),
field_indices.normal.as_deref(),
&poi_metrics.feature_names,
poi_field_indices,
(south, west, north, east),
h3_res,
&travel_aggs,
@ -499,7 +570,11 @@ pub async fn get_hexagons(
bounds = format_args!("{:.4},{:.4},{:.4},{:.4}", south, west, north, east),
filters = num_filters,
filters_raw = filters_str.as_deref().unwrap_or("-"),
fields = field_indices.as_ref().map(|v| v.len() as i32).unwrap_or(-1),
fields = field_indices
.normal
.as_ref()
.map(|v| (v.len() + poi_field_indices.len()) as i32)
.unwrap_or(-1),
travel_entries = travel_entries.len(),
grid_ms = format_args!("{:.1}", t_grid.as_secs_f64() * 1000.0),
agg_ms = format_args!("{:.1}", (t_agg - t_grid).as_secs_f64() * 1000.0),