312 lines
12 KiB
Rust
312 lines
12 KiB
Rust
use std::sync::Arc;
|
|
|
|
use axum::extract::Query;
|
|
use axum::http::StatusCode;
|
|
use axum::response::{IntoResponse, Json};
|
|
use axum::Extension;
|
|
use rustc_hash::FxHashMap;
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::{Map, Value};
|
|
use tracing::info;
|
|
|
|
use crate::aggregation::Aggregator;
|
|
use crate::auth::OptionalUser;
|
|
use crate::consts::{DEMO_BOUNDS, MAX_CELLS_PER_REQUEST};
|
|
use crate::data::travel_time::TravelData;
|
|
use crate::licensing::check_license_bounds;
|
|
use crate::parsing::{
|
|
bounds_intersect, cell_for_row, h3_cell_bounds, needs_parent, parse_field_indices,
|
|
parse_filters, require_bounds, row_passes_filters, validate_h3_resolution,
|
|
};
|
|
use crate::routes::travel_time::{parse_optional_travel, TravelTimeAgg};
|
|
use crate::state::AppState;
|
|
|
|
#[derive(Serialize)]
|
|
pub struct HexagonsResponse {
|
|
features: Vec<Map<String, Value>>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
pub struct HexagonParams {
|
|
resolution: u8,
|
|
bounds: Option<String>,
|
|
/// `;;`-separated filters: `name:min:max;;...`
|
|
filters: Option<String>,
|
|
/// Comma-separated feature names to include in min/max aggregation.
|
|
fields: Option<String>,
|
|
/// Pipe-separated travel time entries: `mode:slug|mode:slug:min:max`
|
|
/// Each entry requests travel time aggregation for that mode+destination.
|
|
/// Optional min:max applies as a filter (exclude properties outside range).
|
|
travel: Option<String>,
|
|
}
|
|
|
|
|
|
/// Build feature maps from aggregated cell data, filtering to only cells that intersect the query bounds.
|
|
#[allow(clippy::too_many_arguments)]
|
|
fn build_feature_maps(
|
|
groups: &FxHashMap<u64, Aggregator>,
|
|
min_keys: &[String],
|
|
max_keys: &[String],
|
|
avg_keys: &[String],
|
|
num_features: usize,
|
|
indices: Option<&[usize]>,
|
|
query_bounds: (f64, f64, f64, f64),
|
|
travel_aggs: &[FxHashMap<u64, TravelTimeAgg>],
|
|
travel_field_keys: &[String],
|
|
) -> Vec<Map<String, Value>> {
|
|
let mut features = Vec::with_capacity(groups.len());
|
|
let (q_south, q_west, q_north, q_east) = query_bounds;
|
|
|
|
for (&cell_id, aggregation) in groups {
|
|
let Some(cell) = h3o::CellIndex::try_from(cell_id).ok() else {
|
|
continue;
|
|
};
|
|
|
|
// Filter out cells that don't intersect the query bounds
|
|
let (c_south, c_west, c_north, c_east) = h3_cell_bounds(cell, 0.0);
|
|
if !bounds_intersect(
|
|
c_south, c_west, c_north, c_east, q_south, q_west, q_north, q_east,
|
|
) {
|
|
continue;
|
|
}
|
|
|
|
let mut map = Map::new();
|
|
map.insert("h3".into(), Value::String(cell.to_string()));
|
|
map.insert("count".into(), Value::Number(aggregation.count.into()));
|
|
let center: h3o::LatLng = cell.into();
|
|
if let (Some(lat), Some(lon)) = (
|
|
serde_json::Number::from_f64(center.lat()),
|
|
serde_json::Number::from_f64(center.lng()),
|
|
) {
|
|
map.insert("lat".into(), Value::Number(lat));
|
|
map.insert("lon".into(), Value::Number(lon));
|
|
}
|
|
|
|
let iter: Box<dyn Iterator<Item = usize>> = if let Some(idx) = indices {
|
|
Box::new(idx.iter().copied())
|
|
} else {
|
|
Box::new(0..num_features)
|
|
};
|
|
|
|
for feat_index in iter {
|
|
if aggregation.feat_counts[feat_index] > 0 {
|
|
let avg = aggregation.sums[feat_index] / aggregation.feat_counts[feat_index] as f64;
|
|
if let (Some(min_num), Some(max_num), Some(avg_num)) = (
|
|
serde_json::Number::from_f64(aggregation.mins[feat_index] as f64),
|
|
serde_json::Number::from_f64(aggregation.maxs[feat_index] as f64),
|
|
serde_json::Number::from_f64(avg),
|
|
) {
|
|
map.insert(min_keys[feat_index].clone(), Value::Number(min_num));
|
|
map.insert(max_keys[feat_index].clone(), Value::Number(max_num));
|
|
map.insert(avg_keys[feat_index].clone(), Value::Number(avg_num));
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add travel time aggregation fields
|
|
for (ti, agg_map) in travel_aggs.iter().enumerate() {
|
|
if let Some(agg) = agg_map.get(&cell_id) {
|
|
if agg.count > 0 {
|
|
let key = &travel_field_keys[ti];
|
|
let avg = agg.sum / agg.count as f64;
|
|
if let Some(nm) = serde_json::Number::from_f64(agg.min as f64) {
|
|
map.insert(format!("min_{key}"), Value::Number(nm));
|
|
}
|
|
if let Some(nm) = serde_json::Number::from_f64(agg.max as f64) {
|
|
map.insert(format!("max_{key}"), Value::Number(nm));
|
|
}
|
|
if let Some(nm) = serde_json::Number::from_f64(avg) {
|
|
map.insert(format!("avg_{key}"), Value::Number(nm));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
features.push(map);
|
|
}
|
|
|
|
features
|
|
}
|
|
|
|
pub async fn get_hexagons(
|
|
state: Arc<AppState>,
|
|
Extension(user): Extension<OptionalUser>,
|
|
Query(params): Query<HexagonParams>,
|
|
) -> Result<Json<HexagonsResponse>, axum::response::Response> {
|
|
let resolution = params.resolution;
|
|
validate_h3_resolution(resolution).map_err(IntoResponse::into_response)?;
|
|
|
|
let (south, west, north, east) =
|
|
require_bounds(params.bounds).map_err(IntoResponse::into_response)?;
|
|
|
|
let is_demo_view = (south, west, north, east) == DEMO_BOUNDS;
|
|
if !is_demo_view {
|
|
check_license_bounds(&user.0, (south, west, north, east))?;
|
|
}
|
|
|
|
let (parsed_filters, parsed_enum_filters) = parse_filters(
|
|
params.filters.as_deref(),
|
|
&state.feature_name_to_index,
|
|
&state.data.enum_values,
|
|
)
|
|
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
|
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
|
|
let filters_str = params.filters;
|
|
|
|
let field_indices = parse_field_indices(params.fields.as_deref(), &state.feature_name_to_index)
|
|
.map_err(|err| (err.0, err.1).into_response())?;
|
|
|
|
let travel_entries = parse_optional_travel(params.travel.as_deref())
|
|
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
|
|
|
let response = tokio::task::spawn_blocking(move || -> Result<HexagonsResponse, String> {
|
|
let t0 = std::time::Instant::now();
|
|
|
|
// Load travel time data from precomputed parquet files
|
|
let travel_data: Vec<TravelData> = if !travel_entries.is_empty() {
|
|
let store = &state.travel_time_store;
|
|
travel_entries
|
|
.iter()
|
|
.map(|entry| {
|
|
store
|
|
.get(&entry.mode, &entry.slug)
|
|
.map_err(|err| format!("Failed to load travel data: {}", err))
|
|
})
|
|
.collect::<Result<Vec<_>, _>>()?
|
|
} else {
|
|
Vec::new()
|
|
};
|
|
|
|
let has_travel = !travel_entries.is_empty();
|
|
let travel_field_keys: Vec<String> = travel_entries
|
|
.iter()
|
|
.map(|te| format!("tt_{}_{}", te.mode, te.slug))
|
|
.collect();
|
|
|
|
let num_features = state.data.num_features;
|
|
let feature_data = &state.data.feature_data;
|
|
let (pc_interner, pc_keys) = state.data.postcode_parts();
|
|
let min_keys = &state.min_keys;
|
|
let max_keys = &state.max_keys;
|
|
let avg_keys = &state.avg_keys;
|
|
|
|
let h3_res = h3o::Resolution::try_from(resolution)
|
|
.map_err(|error| format!("Invalid H3 resolution {}: {}", resolution, error))?;
|
|
let precomputed = &state.h3_cells;
|
|
let need_parent = needs_parent(resolution);
|
|
|
|
let mut groups: FxHashMap<u64, Aggregator> = FxHashMap::default();
|
|
let mut travel_aggs: Vec<FxHashMap<u64, TravelTimeAgg>> =
|
|
(0..travel_entries.len()).map(|_| FxHashMap::default()).collect();
|
|
|
|
// Main aggregation loop
|
|
let aggregate_row =
|
|
|row: usize,
|
|
groups: &mut FxHashMap<u64, Aggregator>,
|
|
travel_aggs: &mut [FxHashMap<u64, TravelTimeAgg>]| {
|
|
// Regular filters
|
|
if !row_passes_filters(
|
|
row,
|
|
&parsed_filters,
|
|
&parsed_enum_filters,
|
|
feature_data,
|
|
num_features,
|
|
) {
|
|
return;
|
|
}
|
|
|
|
// Travel time filter: check each entry with a range
|
|
let mut travel_minutes: Vec<Option<i16>> = Vec::new();
|
|
if has_travel {
|
|
let postcode = pc_interner.resolve(&pc_keys[row]);
|
|
travel_minutes.reserve(travel_entries.len());
|
|
for (ti, entry) in travel_entries.iter().enumerate() {
|
|
let row_data = travel_data[ti].get(postcode);
|
|
let minutes = row_data.map(|r| {
|
|
if entry.use_best {
|
|
r.best_minutes.unwrap_or(r.minutes)
|
|
} else {
|
|
r.minutes
|
|
}
|
|
});
|
|
travel_minutes.push(minutes);
|
|
if let (Some(fmin), Some(fmax)) = (entry.filter_min, entry.filter_max) {
|
|
match minutes {
|
|
Some(mins) if (mins as f32) >= fmin && (mins as f32) <= fmax => {}
|
|
_ => return, // Filtered out
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let cell_id = cell_for_row(row, precomputed, h3_res, need_parent);
|
|
|
|
// Aggregate regular features
|
|
let aggregation = groups
|
|
.entry(cell_id)
|
|
.or_insert_with(|| Aggregator::new(num_features));
|
|
if let Some(sel_indices) = field_indices.as_deref() {
|
|
aggregation.add_row_selective(feature_data, row, num_features, sel_indices);
|
|
} else {
|
|
aggregation.add_row(feature_data, row, num_features);
|
|
}
|
|
|
|
// Aggregate travel time
|
|
for (ti, minutes) in travel_minutes.iter().enumerate() {
|
|
if let Some(mins) = minutes {
|
|
let agg = travel_aggs[ti]
|
|
.entry(cell_id)
|
|
.or_insert_with(TravelTimeAgg::new);
|
|
agg.add(*mins as f32);
|
|
}
|
|
}
|
|
};
|
|
|
|
state
|
|
.grid
|
|
.for_each_in_bounds(south, west, north, east, |row_idx| {
|
|
aggregate_row(row_idx as usize, &mut groups, &mut travel_aggs);
|
|
});
|
|
|
|
let t_agg = t0.elapsed();
|
|
|
|
let mut features = build_feature_maps(
|
|
&groups,
|
|
min_keys,
|
|
max_keys,
|
|
avg_keys,
|
|
num_features,
|
|
field_indices.as_deref(),
|
|
(south, west, north, east),
|
|
&travel_aggs,
|
|
&travel_field_keys,
|
|
);
|
|
|
|
let truncated = features.len() > MAX_CELLS_PER_REQUEST;
|
|
if truncated {
|
|
features.truncate(MAX_CELLS_PER_REQUEST);
|
|
}
|
|
|
|
let t_total = t0.elapsed();
|
|
info!(
|
|
resolution,
|
|
cells_before_filter = groups.len(),
|
|
cells_after_filter = features.len(),
|
|
truncated,
|
|
bounds = format_args!("{:.4},{:.4},{:.4},{:.4}", south, west, north, east),
|
|
filters = num_filters,
|
|
filters_raw = filters_str.as_deref().unwrap_or("-"),
|
|
travel_entries = travel_entries.len(),
|
|
agg_ms = format_args!("{:.1}", t_agg.as_secs_f64() * 1000.0),
|
|
total_ms = format_args!("{:.1}", t_total.as_secs_f64() * 1000.0),
|
|
"GET /api/hexagons"
|
|
);
|
|
|
|
Ok(HexagonsResponse { features })
|
|
})
|
|
.await
|
|
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()).into_response())?
|
|
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error).into_response())?;
|
|
|
|
Ok(Json(response))
|
|
}
|