perfect-postcode/server-rs/src/routes/pois.rs

145 lines
4.4 KiB
Rust

use std::sync::Arc;
use axum::extract::{Query, State};
use axum::http::StatusCode;
use axum::response::Json;
use serde::{Deserialize, Serialize};
use tracing::info;
use crate::consts::MAX_POIS_PER_REQUEST;
use crate::data::POICategoryGroup;
use crate::parsing::require_bounds;
use crate::state::SharedState;
#[derive(Serialize)]
#[allow(clippy::upper_case_acronyms)]
pub struct POI {
id: String,
name: String,
category: String,
group: String,
lat: f32,
lng: f32,
emoji: String,
}
#[derive(Serialize)]
pub struct POIsResponse {
pois: Vec<POI>,
}
#[derive(Deserialize)]
pub struct POIParams {
bounds: Option<String>,
/// Comma-separated list of categories to filter by
categories: Option<String>,
}
pub async fn get_pois(
State(shared): State<Arc<SharedState>>,
Query(params): Query<POIParams>,
) -> Result<Json<POIsResponse>, (StatusCode, String)> {
let state = shared.load_state();
let (south, west, north, east) = require_bounds(params.bounds)?;
let category_filter: Option<rustc_hash::FxHashSet<u16>> = params
.categories
.as_deref()
.filter(|text| !text.is_empty())
.map(|text| {
text.split(',')
.filter_map(|part| {
let name = part.trim();
state
.poi_data
.category
.values
.iter()
.position(|v| v == name)
.map(|pos| pos as u16)
})
.collect()
});
let categories_raw = params.categories;
let num_categories = category_filter.as_ref().map(|cats| cats.len()).unwrap_or(0);
let pois = tokio::task::spawn_blocking(move || {
let t0 = std::time::Instant::now();
let row_indices = state.poi_grid.query(south, west, north, east);
let mut matching_rows: Vec<usize> = row_indices
.iter()
.filter_map(|&row_idx| {
let row = row_idx as usize;
if let Some(ref categories) = category_filter {
if !categories.contains(&state.poi_data.category.indices[row]) {
return None;
}
}
Some(row)
})
.collect();
if matching_rows.len() > MAX_POIS_PER_REQUEST {
let ratio = (matching_rows.len() / MAX_POIS_PER_REQUEST) as u32;
let step = ratio.next_power_of_two();
let mask = step - 1;
matching_rows.retain(|&row| state.poi_data.priority[row] & mask == 0);
if matching_rows.len() > MAX_POIS_PER_REQUEST {
matching_rows.sort_unstable_by_key(|&row| state.poi_data.priority[row]);
matching_rows.truncate(MAX_POIS_PER_REQUEST);
}
}
let pois: Vec<POI> = matching_rows
.iter()
.map(|&row| POI {
id: state.poi_data.id(row).to_string(),
name: state.poi_data.name[row].clone(),
category: state.poi_data.category.get(row).to_string(),
group: state.poi_data.group.get(row).to_string(),
lat: state.poi_data.lat[row],
lng: state.poi_data.lng[row],
emoji: state.poi_data.emoji.get(row).to_string(),
})
.collect();
let elapsed = t0.elapsed();
info!(
results = pois.len(),
candidates = row_indices.len(),
categories = num_categories,
categories_raw = categories_raw.as_deref().unwrap_or("-"),
ms = format_args!("{:.1}", elapsed.as_secs_f64() * 1000.0),
"GET /api/pois"
);
pois
})
.await
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()))?;
Ok(Json(POIsResponse { pois }))
}
#[derive(Serialize)]
pub struct POICategoriesResponse {
groups: Vec<POICategoryGroup>,
}
pub async fn get_poi_categories(
State(shared): State<Arc<SharedState>>,
) -> Json<POICategoriesResponse> {
let state = shared.load_state();
let groups: Vec<POICategoryGroup> = state.poi_category_groups.to_vec();
let total: usize = groups.iter().map(|group| group.categories.len()).sum();
info!(
count = total,
groups = groups.len(),
"GET /api/poi-categories"
);
Json(POICategoriesResponse { groups })
}