Morning improvements

This commit is contained in:
Andras Schmelczer 2026-03-17 13:29:03 +00:00
parent 3e9fba5303
commit 53fff3efaa
41 changed files with 2438 additions and 637 deletions

View file

@ -510,14 +510,6 @@ pub async fn post_ai_filters(
.0
.ok_or((StatusCode::UNAUTHORIZED, "Login required".into()))?;
// Email verification check (skipped in dev mode)
if !user.verified && !state.is_dev {
return Err((
StatusCode::FORBIDDEN,
"Please verify your email to use AI filters".into(),
));
}
// Check weekly token usage
let current_week = current_week_number();
let (stored_tokens, stored_week) = fetch_ai_usage(&state, &user.id).await?;

View file

@ -17,8 +17,8 @@ use crate::consts::{DEMO_BOUNDS, MAX_CELLS_PER_REQUEST};
use crate::data::travel_time::TravelData;
use crate::licensing::check_license_bounds;
use crate::parsing::{
bounds_intersect, cell_for_row_cached, h3_cell_bounds, needs_parent, parse_field_indices,
parse_filters, require_bounds, row_passes_filters, validate_h3_resolution,
cell_for_row_cached, needs_parent, parse_field_indices, parse_filters, require_bounds,
row_passes_filters, validate_h3_resolution,
};
use crate::routes::travel_time::{parse_optional_travel, TravelTimeAgg};
use crate::state::AppState;
@ -26,6 +26,28 @@ use crate::state::AppState;
/// Row count threshold above which we use rayon parallel aggregation.
const PARALLEL_THRESHOLD: usize = 50_000;
/// Per-thread aggregation result: feature accumulators + travel time accumulators.
type ChunkResult = (
FxHashMap<u64, Aggregator>,
Vec<FxHashMap<u64, TravelTimeAgg>>,
);
/// Maximum center-to-vertex distance in degrees per H3 resolution.
/// Generous for UK latitudes (49°61°) so we never false-exclude a visible cell.
/// Used for cheap center-based bounds filtering instead of computing full cell boundary.
const H3_CENTER_BUFFERS: [f64; 13] = [
5.0, 2.0, 1.0, 0.5, // res 03 (unused in practice)
0.50, // res 4
0.20, // res 5
0.08, // res 6
0.03, // res 7
0.012, // res 8
0.005, // res 9
0.002, // res 10
0.001, // res 11
0.0005, // res 12
];
#[derive(Serialize)]
pub struct HexagonsResponse {
features: Vec<Map<String, Value>>,
@ -45,7 +67,10 @@ pub struct HexagonParams {
travel: Option<String>,
}
/// Build feature maps from aggregated cell data, filtering to only cells that intersect the query bounds.
/// Build feature maps from aggregated cell data, filtering to only cells whose
/// center is within the query bounds (expanded by a resolution-dependent buffer).
/// This is much cheaper than the previous approach of computing full cell boundaries
/// (6 vertices per cell) — just 4 float comparisons per cell.
#[allow(clippy::too_many_arguments)]
fn build_feature_maps(
groups: &FxHashMap<u64, Aggregator>,
@ -55,44 +80,69 @@ fn build_feature_maps(
num_features: usize,
indices: Option<&[usize]>,
query_bounds: (f64, f64, f64, f64),
resolution: h3o::Resolution,
travel_aggs: &[FxHashMap<u64, TravelTimeAgg>],
travel_field_keys: &[String],
) -> Vec<Map<String, Value>> {
let mut features = Vec::with_capacity(groups.len());
let (q_south, q_west, q_north, q_east) = query_bounds;
// Expand bounds by resolution-dependent buffer for center-based filtering
let buf = H3_CENTER_BUFFERS[resolution as usize];
let bound_south = q_south - buf;
let bound_north = q_north + buf;
let bound_west = q_west - buf;
let bound_east = q_east + buf;
// Pre-compute travel time key strings (avoids per-cell format!())
let travel_keys: Vec<(String, String, String)> = travel_field_keys
.iter()
.map(|key| {
(
format!("min_{key}"),
format!("max_{key}"),
format!("avg_{key}"),
)
})
.collect();
// Pre-compute default feature indices to avoid per-cell Box<dyn Iterator> allocation
let default_indices: Vec<usize>;
let feat_indices: &[usize] = match indices {
Some(idx) => idx,
None => {
default_indices = (0..num_features).collect();
&default_indices
}
};
for (&cell_id, aggregation) in groups {
let Some(cell) = h3o::CellIndex::try_from(cell_id).ok() else {
continue;
};
// Filter out cells that don't intersect the query bounds
let (c_south, c_west, c_north, c_east) = h3_cell_bounds(cell, 0.0);
if !bounds_intersect(
c_south, c_west, c_north, c_east, q_south, q_west, q_north, q_east,
) {
// Center is already needed for lat/lon output — reuse for bounds check
let center: h3o::LatLng = cell.into();
let lat = center.lat();
let lng = center.lng();
// Center-based bounds check: 4 comparisons instead of computing 6 boundary vertices
if lat < bound_south || lat > bound_north || lng < bound_west || lng > bound_east {
continue;
}
let mut map = Map::new();
map.insert("h3".into(), Value::String(cell.to_string()));
map.insert("count".into(), Value::Number(aggregation.count.into()));
let center: h3o::LatLng = cell.into();
if let (Some(lat), Some(lon)) = (
serde_json::Number::from_f64(center.lat()),
serde_json::Number::from_f64(center.lng()),
if let (Some(lat_num), Some(lon_num)) = (
serde_json::Number::from_f64(lat),
serde_json::Number::from_f64(lng),
) {
map.insert("lat".into(), Value::Number(lat));
map.insert("lon".into(), Value::Number(lon));
map.insert("lat".into(), Value::Number(lat_num));
map.insert("lon".into(), Value::Number(lon_num));
}
let iter: Box<dyn Iterator<Item = usize>> = if let Some(idx) = indices {
Box::new(idx.iter().copied())
} else {
Box::new(0..num_features)
};
for feat_index in iter {
for &feat_index in feat_indices {
if aggregation.feat_counts[feat_index] > 0 {
let avg = aggregation.sums[feat_index] / aggregation.feat_counts[feat_index] as f64;
if let (Some(min_num), Some(max_num), Some(avg_num)) = (
@ -107,20 +157,19 @@ fn build_feature_maps(
}
}
// Add travel time aggregation fields
// Add travel time aggregation fields (using pre-computed key strings)
for (ti, agg_map) in travel_aggs.iter().enumerate() {
if let Some(agg) = agg_map.get(&cell_id) {
if agg.count > 0 {
let key = &travel_field_keys[ti];
let avg = agg.sum / agg.count as f64;
if let Some(nm) = serde_json::Number::from_f64(agg.min as f64) {
map.insert(format!("min_{key}"), Value::Number(nm));
map.insert(travel_keys[ti].0.clone(), Value::Number(nm));
}
if let Some(nm) = serde_json::Number::from_f64(agg.max as f64) {
map.insert(format!("max_{key}"), Value::Number(nm));
map.insert(travel_keys[ti].1.clone(), Value::Number(nm));
}
if let Some(nm) = serde_json::Number::from_f64(avg) {
map.insert(format!("avg_{key}"), Value::Number(nm));
map.insert(travel_keys[ti].2.clone(), Value::Number(nm));
}
}
}
@ -207,19 +256,31 @@ pub async fn get_hexagons(
.map(|_| FxHashMap::default())
.collect();
// Collect row indices for threshold-based sequential/parallel aggregation
let row_indices = state.grid.query(south, west, north, east);
// O(grid cells) count — no allocation. Used for parallel threshold decision.
let row_count = state.grid.count_in_bounds(south, west, north, east);
let t_grid = t0.elapsed();
if row_indices.len() >= PARALLEL_THRESHOLD && !has_travel {
// Parallel path: split rows across rayon threads, each with local accumulators
let parallel = row_count >= PARALLEL_THRESHOLD;
if parallel {
// Parallel: collect row indices for par_chunks, split across rayon threads.
// Now handles travel time too (postcode interner & travel data are thread-safe).
let row_indices = state.grid.query(south, west, north, east);
let chunk_size = (row_indices.len() / rayon::current_num_threads()).max(1000);
let thread_results: Vec<FxHashMap<u64, Aggregator>> = row_indices
let thread_results: Vec<ChunkResult> = row_indices
.par_chunks(chunk_size)
.map(|chunk| {
let mut local_groups: FxHashMap<u64, Aggregator> = FxHashMap::default();
let mut local_travel_aggs: Vec<FxHashMap<u64, TravelTimeAgg>> = (0
..travel_entries.len())
.map(|_| FxHashMap::default())
.collect();
let mut h3_cache: FxHashMap<u64, u64> = FxHashMap::default();
for &row_idx in chunk {
let mut travel_minutes: Vec<Option<i16>> =
Vec::with_capacity(travel_entries.len());
'row: for &row_idx in chunk {
let row = row_idx as usize;
if !row_passes_filters(
row,
@ -230,6 +291,32 @@ pub async fn get_hexagons(
) {
continue;
}
if has_travel {
travel_minutes.clear();
let postcode = pc_interner.resolve(&pc_keys[row]);
for (ti, entry) in travel_entries.iter().enumerate() {
let row_data = travel_data[ti].get(postcode);
let minutes = row_data.map(|r| {
if entry.use_best {
r.best_minutes.unwrap_or(r.minutes)
} else {
r.minutes
}
});
travel_minutes.push(minutes);
if let (Some(fmin), Some(fmax)) =
(entry.filter_min, entry.filter_max)
{
match minutes {
Some(mins)
if (mins as f32) >= fmin && (mins as f32) <= fmax => {}
_ => continue 'row,
}
}
}
}
let cell_id = cell_for_row_cached(
row,
precomputed,
@ -237,6 +324,7 @@ pub async fn get_hexagons(
need_parent,
&mut h3_cache,
);
let agg = local_groups
.entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features));
@ -251,91 +339,108 @@ pub async fn get_hexagons(
} else {
agg.add_row(feature_data, row, num_features, &quant);
}
}
local_groups
})
.collect();
// Merge thread-local results into the main groups map
for local_groups in thread_results {
for (cell_id, local_agg) in local_groups {
let agg = groups
.entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features));
agg.merge(&local_agg);
}
}
} else {
// Sequential path (also handles travel time which needs postcode lookups)
let mut travel_minutes: Vec<Option<i16>> = Vec::with_capacity(travel_entries.len());
let mut h3_cache: FxHashMap<u64, u64> = FxHashMap::default();
'row: for &row_idx in &row_indices {
let row = row_idx as usize;
// Regular filters
if !row_passes_filters(
row,
&parsed_filters,
&parsed_enum_filters,
feature_data,
num_features,
) {
continue;
}
// Travel time filter: check each entry with a range
if has_travel {
travel_minutes.clear();
let postcode = pc_interner.resolve(&pc_keys[row]);
for (ti, entry) in travel_entries.iter().enumerate() {
let row_data = travel_data[ti].get(postcode);
let minutes = row_data.map(|r| {
if entry.use_best {
r.best_minutes.unwrap_or(r.minutes)
} else {
r.minutes
}
});
travel_minutes.push(minutes);
if let (Some(fmin), Some(fmax)) = (entry.filter_min, entry.filter_max) {
match minutes {
Some(mins) if (mins as f32) >= fmin && (mins as f32) <= fmax => {}
_ => continue 'row, // Filtered out (jump to next row_idx)
for (ti, minutes) in travel_minutes.iter().enumerate() {
if let Some(mins) = minutes {
let tagg = local_travel_aggs[ti]
.entry(cell_id)
.or_insert_with(TravelTimeAgg::new);
tagg.add(*mins as f32);
}
}
}
(local_groups, local_travel_aggs)
})
.collect();
// Merge thread-local results into the main accumulators
for (local_groups, local_travel) in thread_results {
for (cell_id, local_agg) in local_groups {
groups
.entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features))
.merge(&local_agg);
}
let cell_id =
cell_for_row_cached(row, precomputed, h3_res, need_parent, &mut h3_cache);
// Aggregate regular features
let aggregation = groups
.entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features));
if let Some(sel_indices) = field_indices.as_deref() {
aggregation.add_row_selective(
feature_data,
row,
num_features,
sel_indices,
&quant,
);
} else {
aggregation.add_row(feature_data, row, num_features, &quant);
}
// Aggregate travel time
for (ti, minutes) in travel_minutes.iter().enumerate() {
if let Some(mins) = minutes {
let agg = travel_aggs[ti]
for (ti, local_ta) in local_travel.into_iter().enumerate() {
for (cell_id, local_tt) in local_ta {
travel_aggs[ti]
.entry(cell_id)
.or_insert_with(TravelTimeAgg::new);
agg.add(*mins as f32);
.or_insert_with(TravelTimeAgg::new)
.merge(&local_tt);
}
}
}
} else {
// Sequential: use for_each_in_bounds to avoid Vec<u32> allocation
let mut travel_minutes: Vec<Option<i16>> = Vec::with_capacity(travel_entries.len());
let mut h3_cache: FxHashMap<u64, u64> = FxHashMap::default();
state
.grid
.for_each_in_bounds(south, west, north, east, |row_idx| {
let row = row_idx as usize;
if !row_passes_filters(
row,
&parsed_filters,
&parsed_enum_filters,
feature_data,
num_features,
) {
return;
}
if has_travel {
travel_minutes.clear();
let postcode = pc_interner.resolve(&pc_keys[row]);
for (ti, entry) in travel_entries.iter().enumerate() {
let row_data = travel_data[ti].get(postcode);
let minutes = row_data.map(|r| {
if entry.use_best {
r.best_minutes.unwrap_or(r.minutes)
} else {
r.minutes
}
});
travel_minutes.push(minutes);
if let (Some(fmin), Some(fmax)) = (entry.filter_min, entry.filter_max) {
match minutes {
Some(mins)
if (mins as f32) >= fmin && (mins as f32) <= fmax => {}
_ => return,
}
}
}
}
let cell_id =
cell_for_row_cached(row, precomputed, h3_res, need_parent, &mut h3_cache);
let aggregation = groups
.entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features));
if let Some(sel_indices) = field_indices.as_deref() {
aggregation.add_row_selective(
feature_data,
row,
num_features,
sel_indices,
&quant,
);
} else {
aggregation.add_row(feature_data, row, num_features, &quant);
}
for (ti, minutes) in travel_minutes.iter().enumerate() {
if let Some(mins) = minutes {
let agg = travel_aggs[ti]
.entry(cell_id)
.or_insert_with(TravelTimeAgg::new);
agg.add(*mins as f32);
}
}
});
};
let t_agg = t0.elapsed();
@ -348,6 +453,7 @@ pub async fn get_hexagons(
num_features,
field_indices.as_deref(),
(south, west, north, east),
h3_res,
&travel_aggs,
&travel_field_keys,
);
@ -357,11 +463,10 @@ pub async fn get_hexagons(
features.truncate(MAX_CELLS_PER_REQUEST);
}
let parallel = row_indices.len() >= PARALLEL_THRESHOLD && !has_travel;
let t_total = t0.elapsed();
info!(
resolution,
rows = row_indices.len(),
rows = row_count,
parallel,
cells_before_filter = groups.len(),
cells_after_filter = features.len(),
@ -369,8 +474,11 @@ pub async fn get_hexagons(
bounds = format_args!("{:.4},{:.4},{:.4},{:.4}", south, west, north, east),
filters = num_filters,
filters_raw = filters_str.as_deref().unwrap_or("-"),
fields = field_indices.as_ref().map(|v| v.len() as i32).unwrap_or(-1),
travel_entries = travel_entries.len(),
agg_ms = format_args!("{:.1}", t_agg.as_secs_f64() * 1000.0),
grid_ms = format_args!("{:.1}", t_grid.as_secs_f64() * 1000.0),
agg_ms = format_args!("{:.1}", (t_agg - t_grid).as_secs_f64() * 1000.0),
json_ms = format_args!("{:.1}", (t_total - t_agg).as_secs_f64() * 1000.0),
total_ms = format_args!("{:.1}", t_total.as_secs_f64() * 1000.0),
"GET /api/hexagons"
);

View file

@ -128,7 +128,7 @@ pub struct POICategoriesResponse {
}
pub async fn get_poi_categories(state: Arc<AppState>) -> Json<POICategoriesResponse> {
let groups: Vec<POICategoryGroup> = state.poi_category_groups.clone();
let groups: Vec<POICategoryGroup> = state.poi_category_groups.to_vec();
let total: usize = groups.iter().map(|group| group.categories.len()).sum();
info!(

View file

@ -177,6 +177,8 @@ pub async fn get_postcodes(
}
}
let t_agg = t0.elapsed();
// Build response, filtering postcodes to only those whose polygon intersects query bounds
let mut features = Vec::with_capacity(postcode_aggs.len());
let postcodes_before_filter = postcode_aggs.len();
@ -288,7 +290,10 @@ pub async fn get_postcodes(
bounds = format_args!("{:.6},{:.6},{:.6},{:.6}", south, west, north, east),
filters = num_filters,
filters_raw = filters_str.as_deref().unwrap_or("-"),
fields = field_indices.as_ref().map(|v| v.len() as i32).unwrap_or(-1),
travel_entries = travel_entries.len(),
agg_ms = format_args!("{:.1}", t_agg.as_secs_f64() * 1000.0),
json_ms = format_args!("{:.1}", (t_total - t_agg).as_secs_f64() * 1000.0),
total_ms = format_args!("{:.1}", t_total.as_secs_f64() * 1000.0),
"GET /api/postcodes"
);

View file

@ -0,0 +1,179 @@
use std::sync::Arc;
use std::time::Instant;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Json, Response};
use serde_json::json;
use tracing::{info, warn};
use crate::consts::GRID_CELL_SIZE;
use crate::data::{self, PropertyData};
use crate::metrics::record_data_stats;
use crate::routes::{build_features_response, build_system_prompt};
use crate::state::{AppState, SharedState};
use crate::utils::GridIndex;
pub async fn post_reload(shared: Arc<SharedState>) -> Response {
if !shared.try_start_reload() {
return (StatusCode::CONFLICT, "Reload already in progress").into_response();
}
info!("Reload triggered — rebuilding property data");
let start = Instant::now();
// shared is cloned so we retain a reference after spawn_blocking
let sh = Arc::clone(&shared);
let result = tokio::task::spawn_blocking(move || rebuild_data(&sh, start)).await;
// Always clear the reload flag
shared.finish_reload();
match result {
Ok(Ok((rows, features, elapsed_ms))) => Json(json!({
"status": "ok",
"rows": rows,
"features": features,
"elapsed_ms": elapsed_ms,
}))
.into_response(),
Ok(Err(err)) => {
warn!("Reload failed: {err:#}");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("{err:#}") })),
)
.into_response()
}
Err(err) => {
warn!("Reload task panicked: {err}");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Reload task panicked: {err}") })),
)
.into_response()
}
}
}
fn rebuild_data(shared: &SharedState, start: Instant) -> anyhow::Result<(usize, usize, u128)> {
let old = shared.load_state();
// 1. Load PropertyData from parquet files
let property_data = PropertyData::load(
&shared.properties_path,
&shared.postcode_features_path,
&shared.listings_buy_path,
&shared.listings_rent_path,
)?;
let row_count = property_data.lat.len();
let feature_count = property_data.num_features;
// 2. Build spatial grid index
info!("Reload: building spatial grid index");
let grid = GridIndex::build(&property_data.lat, &property_data.lon, GRID_CELL_SIZE);
// 3. Precompute H3 cells
info!("Reload: precomputing H3 cells");
let h3_cells = data::precompute_h3(&property_data.lat, &property_data.lon)?;
// 4. Build feature lookup tables
let feature_name_to_index = property_data
.feature_names
.iter()
.enumerate()
.map(|(idx, name)| (name.clone(), idx))
.collect();
let min_keys = property_data
.feature_names
.iter()
.map(|n| format!("min_{n}"))
.collect();
let max_keys = property_data
.feature_names
.iter()
.map(|n| format!("max_{n}"))
.collect();
let avg_keys = property_data
.feature_names
.iter()
.map(|n| format!("avg_{n}"))
.collect();
// 5. Build features response and AI prompt
let features_response = build_features_response(&property_data);
let mode_destinations: Vec<(String, usize)> = old
.travel_time_store
.available_modes
.iter()
.map(|mode| {
let count = old
.travel_time_store
.destinations
.get(mode.as_str())
.map(|slugs| slugs.len())
.unwrap_or(0);
(mode.clone(), count)
})
.filter(|(_, count)| *count > 0)
.collect();
let ai_filters_system_prompt = build_system_prompt(&features_response, &mode_destinations);
// 6. Update data metrics
record_data_stats(
row_count,
old.poi_data.lat.len(),
old.postcode_data.postcodes.len(),
);
// 7. Build new AppState, sharing unchanged fields via Arc
let new_state = AppState {
data: property_data,
grid,
h3_cells,
feature_name_to_index,
min_keys,
max_keys,
avg_keys,
features_response,
ai_filters_system_prompt,
// Shared across reloads (Arc clone is cheap)
poi_data: Arc::clone(&old.poi_data),
poi_grid: Arc::clone(&old.poi_grid),
place_data: Arc::clone(&old.place_data),
postcode_data: Arc::clone(&old.postcode_data),
poi_category_groups: Arc::clone(&old.poi_category_groups),
travel_time_store: Arc::clone(&old.travel_time_store),
token_cache: Arc::clone(&old.token_cache),
// Config (cheap clone)
screenshot_url: old.screenshot_url.clone(),
public_url: old.public_url.clone(),
is_dev: old.is_dev,
index_html: old.index_html.clone(),
http_client: old.http_client.clone(),
pocketbase_url: old.pocketbase_url.clone(),
pocketbase_admin_email: old.pocketbase_admin_email.clone(),
pocketbase_admin_password: old.pocketbase_admin_password.clone(),
gemini_api_key: old.gemini_api_key.clone(),
gemini_model: old.gemini_model.clone(),
google_maps_api_key: old.google_maps_api_key.clone(),
stripe_secret_key: old.stripe_secret_key.clone(),
stripe_webhook_secret: old.stripe_webhook_secret.clone(),
stripe_referral_coupon_id: old.stripe_referral_coupon_id.clone(),
};
// 8. Atomic swap
shared.swap_state(new_state);
let elapsed = start.elapsed();
info!(
rows = row_count,
features = feature_count,
elapsed_ms = elapsed.as_millis(),
"Reload complete"
);
Ok((row_count, feature_count, elapsed.as_millis()))
}

View file

@ -93,4 +93,18 @@ impl TravelTimeAgg {
self.sum += value as f64;
self.count += 1;
}
/// Merge another aggregator's results into this one (for parallel reduction).
pub fn merge(&mut self, other: &TravelTimeAgg) {
if other.count > 0 {
if other.min < self.min {
self.min = other.min;
}
if other.max > self.max {
self.max = other.max;
}
self.sum += other.sum;
self.count += other.count;
}
}
}