Rust things
This commit is contained in:
parent
fc10381692
commit
3debacab4f
30 changed files with 3257 additions and 647 deletions
|
|
@ -11,14 +11,15 @@ use serde::{Deserialize, Serialize};
|
|||
use serde_json::{Map, Value};
|
||||
use tracing::info;
|
||||
|
||||
use crate::aggregation::{Aggregator, EnumDistConfig};
|
||||
use crate::aggregation::{Aggregator, EnumDistConfig, PoiAggregator};
|
||||
use crate::auth::OptionalUser;
|
||||
use crate::consts::MAX_CELLS_PER_REQUEST;
|
||||
use crate::data::travel_time::TravelData;
|
||||
use crate::licensing::{check_license_bounds, resolve_share_code};
|
||||
use crate::parsing::{
|
||||
cell_for_row_cached, needs_parent, parse_enum_dist, parse_field_indices, parse_filters,
|
||||
require_bounds, row_passes_filters, validate_h3_resolution,
|
||||
cell_for_row_cached, needs_parent, parse_enum_dist, parse_field_indices_with_poi,
|
||||
parse_filters_with_poi, require_bounds, row_passes_filters, row_passes_poi_filters,
|
||||
validate_h3_resolution,
|
||||
};
|
||||
use crate::routes::travel_time::{parse_optional_travel, TravelTimeAgg};
|
||||
use crate::state::SharedState;
|
||||
|
|
@ -29,6 +30,7 @@ const PARALLEL_THRESHOLD: usize = 50_000;
|
|||
/// Per-thread aggregation result: feature accumulators + travel time accumulators.
|
||||
type ChunkResult = (
|
||||
FxHashMap<u64, Aggregator>,
|
||||
FxHashMap<u64, PoiAggregator>,
|
||||
Vec<FxHashMap<u64, TravelTimeAgg>>,
|
||||
);
|
||||
|
||||
|
|
@ -79,11 +81,14 @@ pub struct HexagonParams {
|
|||
#[allow(clippy::too_many_arguments)]
|
||||
fn build_feature_maps(
|
||||
groups: &FxHashMap<u64, Aggregator>,
|
||||
poi_groups: &FxHashMap<u64, PoiAggregator>,
|
||||
min_keys: &[String],
|
||||
max_keys: &[String],
|
||||
avg_keys: &[String],
|
||||
num_features: usize,
|
||||
indices: Option<&[usize]>,
|
||||
poi_feature_names: &[String],
|
||||
poi_indices: &[usize],
|
||||
query_bounds: (f64, f64, f64, f64),
|
||||
resolution: h3o::Resolution,
|
||||
travel_aggs: &[FxHashMap<u64, TravelTimeAgg>],
|
||||
|
|
@ -163,6 +168,25 @@ fn build_feature_maps(
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(poi_aggregation) = poi_groups.get(&cell_id) {
|
||||
for &metric_idx in poi_indices {
|
||||
if poi_aggregation.counts[metric_idx] > 0 {
|
||||
let avg = poi_aggregation.sums[metric_idx]
|
||||
/ poi_aggregation.counts[metric_idx] as f64;
|
||||
if let (Some(min_num), Some(max_num), Some(avg_num)) = (
|
||||
serde_json::Number::from_f64(poi_aggregation.mins[metric_idx] as f64),
|
||||
serde_json::Number::from_f64(poi_aggregation.maxs[metric_idx] as f64),
|
||||
serde_json::Number::from_f64(avg),
|
||||
) {
|
||||
let name = &poi_feature_names[metric_idx];
|
||||
map.insert(format!("min_{name}"), Value::Number(min_num));
|
||||
map.insert(format!("max_{name}"), Value::Number(max_num));
|
||||
map.insert(format!("avg_{name}"), Value::Number(avg_num));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add travel time aggregation fields (using pre-computed key strings)
|
||||
for (ti, agg_map) in travel_aggs.iter().enumerate() {
|
||||
if let Some(agg) = agg_map.get(&cell_id) {
|
||||
|
|
@ -209,18 +233,25 @@ pub async fn get_hexagons(
|
|||
check_license_bounds(&user.0, (south, west, north, east), share_bounds)?;
|
||||
|
||||
let quant = state.data.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters) = parse_filters(
|
||||
let poi_quant = state.data.poi_metrics.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
|
||||
params.filters.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
&quant,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
&poi_quant,
|
||||
)
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
|
||||
let filters_str = params.filters;
|
||||
|
||||
let field_indices = parse_field_indices(params.fields.as_deref(), &state.feature_name_to_index)
|
||||
.map_err(|err| (err.0, err.1).into_response())?;
|
||||
let field_indices = parse_field_indices_with_poi(
|
||||
params.fields.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
)
|
||||
.map_err(|err| (err.0, err.1).into_response())?;
|
||||
|
||||
let travel_entries = parse_optional_travel(params.travel.as_deref())
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
|
|
@ -269,6 +300,11 @@ pub async fn get_hexagons(
|
|||
let min_keys = &state.min_keys;
|
||||
let max_keys = &state.max_keys;
|
||||
let avg_keys = &state.avg_keys;
|
||||
let poi_metrics = &state.data.poi_metrics;
|
||||
let poi_field_indices = field_indices.poi.as_slice();
|
||||
let has_poi_fields = !poi_field_indices.is_empty();
|
||||
let has_poi_filters = !parsed_poi_filters.is_empty();
|
||||
let poi_num_features = poi_metrics.num_features();
|
||||
|
||||
let h3_res = h3o::Resolution::try_from(resolution)
|
||||
.map_err(|error| format!("Invalid H3 resolution {}: {}", resolution, error))?;
|
||||
|
|
@ -276,6 +312,7 @@ pub async fn get_hexagons(
|
|||
let need_parent = needs_parent(resolution);
|
||||
|
||||
let mut groups: FxHashMap<u64, Aggregator> = FxHashMap::default();
|
||||
let mut poi_groups: FxHashMap<u64, PoiAggregator> = FxHashMap::default();
|
||||
let mut travel_aggs: Vec<FxHashMap<u64, TravelTimeAgg>> = (0..travel_entries.len())
|
||||
.map(|_| FxHashMap::default())
|
||||
.collect();
|
||||
|
|
@ -296,6 +333,7 @@ pub async fn get_hexagons(
|
|||
.par_chunks(chunk_size)
|
||||
.map(|chunk| {
|
||||
let mut local_groups: FxHashMap<u64, Aggregator> = FxHashMap::default();
|
||||
let mut local_poi_groups: FxHashMap<u64, PoiAggregator> = FxHashMap::default();
|
||||
let mut local_travel_aggs: Vec<FxHashMap<u64, TravelTimeAgg>> = (0
|
||||
..travel_entries.len())
|
||||
.map(|_| FxHashMap::default())
|
||||
|
|
@ -315,6 +353,11 @@ pub async fn get_hexagons(
|
|||
) {
|
||||
continue;
|
||||
}
|
||||
if has_poi_filters
|
||||
&& !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if has_travel {
|
||||
travel_minutes.clear();
|
||||
|
|
@ -352,7 +395,7 @@ pub async fn get_hexagons(
|
|||
let agg = local_groups
|
||||
.entry(cell_id)
|
||||
.or_insert_with(|| Aggregator::new(num_features, enum_dist_config));
|
||||
if let Some(sel_indices) = field_indices.as_deref() {
|
||||
if let Some(sel_indices) = field_indices.normal.as_deref() {
|
||||
agg.add_row_selective(
|
||||
feature_data,
|
||||
row,
|
||||
|
|
@ -364,6 +407,13 @@ pub async fn get_hexagons(
|
|||
agg.add_row(feature_data, row, num_features, &quant);
|
||||
}
|
||||
|
||||
if has_poi_fields {
|
||||
local_poi_groups
|
||||
.entry(cell_id)
|
||||
.or_insert_with(|| PoiAggregator::new(poi_num_features))
|
||||
.add_row_selective(poi_metrics, row, poi_field_indices);
|
||||
}
|
||||
|
||||
for (ti, minutes) in travel_minutes.iter().enumerate() {
|
||||
if let Some(mins) = minutes {
|
||||
let tagg = local_travel_aggs[ti]
|
||||
|
|
@ -374,18 +424,24 @@ pub async fn get_hexagons(
|
|||
}
|
||||
}
|
||||
|
||||
(local_groups, local_travel_aggs)
|
||||
(local_groups, local_poi_groups, local_travel_aggs)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Merge thread-local results into the main accumulators
|
||||
for (local_groups, local_travel) in thread_results {
|
||||
for (local_groups, local_poi_groups, local_travel) in thread_results {
|
||||
for (cell_id, local_agg) in local_groups {
|
||||
groups
|
||||
.entry(cell_id)
|
||||
.or_insert_with(|| Aggregator::new(num_features, enum_dist_config))
|
||||
.merge(&local_agg);
|
||||
}
|
||||
for (cell_id, local_agg) in local_poi_groups {
|
||||
poi_groups
|
||||
.entry(cell_id)
|
||||
.or_insert_with(|| PoiAggregator::new(poi_num_features))
|
||||
.merge(&local_agg);
|
||||
}
|
||||
for (ti, local_ta) in local_travel.into_iter().enumerate() {
|
||||
for (cell_id, local_tt) in local_ta {
|
||||
travel_aggs[ti]
|
||||
|
|
@ -414,6 +470,11 @@ pub async fn get_hexagons(
|
|||
) {
|
||||
return;
|
||||
}
|
||||
if has_poi_filters
|
||||
&& !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if has_travel {
|
||||
travel_minutes.clear();
|
||||
|
|
@ -444,7 +505,7 @@ pub async fn get_hexagons(
|
|||
let aggregation = groups
|
||||
.entry(cell_id)
|
||||
.or_insert_with(|| Aggregator::new(num_features, enum_dist_config));
|
||||
if let Some(sel_indices) = field_indices.as_deref() {
|
||||
if let Some(sel_indices) = field_indices.normal.as_deref() {
|
||||
aggregation.add_row_selective(
|
||||
feature_data,
|
||||
row,
|
||||
|
|
@ -456,6 +517,13 @@ pub async fn get_hexagons(
|
|||
aggregation.add_row(feature_data, row, num_features, &quant);
|
||||
}
|
||||
|
||||
if has_poi_fields {
|
||||
poi_groups
|
||||
.entry(cell_id)
|
||||
.or_insert_with(|| PoiAggregator::new(poi_num_features))
|
||||
.add_row_selective(poi_metrics, row, poi_field_indices);
|
||||
}
|
||||
|
||||
for (ti, minutes) in travel_minutes.iter().enumerate() {
|
||||
if let Some(mins) = minutes {
|
||||
let agg = travel_aggs[ti]
|
||||
|
|
@ -471,11 +539,14 @@ pub async fn get_hexagons(
|
|||
|
||||
let mut features = build_feature_maps(
|
||||
&groups,
|
||||
&poi_groups,
|
||||
min_keys,
|
||||
max_keys,
|
||||
avg_keys,
|
||||
num_features,
|
||||
field_indices.as_deref(),
|
||||
field_indices.normal.as_deref(),
|
||||
&poi_metrics.feature_names,
|
||||
poi_field_indices,
|
||||
(south, west, north, east),
|
||||
h3_res,
|
||||
&travel_aggs,
|
||||
|
|
@ -499,7 +570,11 @@ pub async fn get_hexagons(
|
|||
bounds = format_args!("{:.4},{:.4},{:.4},{:.4}", south, west, north, east),
|
||||
filters = num_filters,
|
||||
filters_raw = filters_str.as_deref().unwrap_or("-"),
|
||||
fields = field_indices.as_ref().map(|v| v.len() as i32).unwrap_or(-1),
|
||||
fields = field_indices
|
||||
.normal
|
||||
.as_ref()
|
||||
.map(|v| (v.len() + poi_field_indices.len()) as i32)
|
||||
.unwrap_or(-1),
|
||||
travel_entries = travel_entries.len(),
|
||||
grid_ms = format_args!("{:.1}", t_grid.as_secs_f64() * 1000.0),
|
||||
agg_ms = format_args!("{:.1}", (t_agg - t_grid).as_secs_f64() * 1000.0),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue