Rust things

This commit is contained in:
Andras Schmelczer 2026-05-10 14:55:43 +01:00
parent fc10381692
commit 3debacab4f
30 changed files with 3257 additions and 647 deletions

View file

@ -8,10 +8,8 @@ use serde::{Deserialize, Serialize};
use tracing::{info, warn};
use crate::auth::OptionalUser;
use crate::pocketbase::get_superuser_token;
use crate::state::{AppState, SharedState};
use super::pricing::{count_licensed_users, price_for_count};
use crate::checkout_sessions::{start_license_checkout, CheckoutStart};
use crate::state::SharedState;
#[derive(Deserialize)]
pub struct CheckoutRequest {
@ -23,8 +21,8 @@ struct CheckoutResponse {
url: String,
}
/// Create a Stripe Checkout session for the lifetime license (or grant for free if in free tier).
/// Requires authentication. Optionally accepts a referral code to apply a coupon.
/// Create a reserved Stripe Checkout session for the lifetime license.
/// Requires authentication. Referral discounts are issued via invite redemption.
pub async fn post_checkout(
State(shared): State<Arc<SharedState>>,
Extension(user): Extension<OptionalUser>,
@ -36,147 +34,27 @@ pub async fn post_checkout(
None => return StatusCode::UNAUTHORIZED.into_response(),
};
let count = match count_licensed_users(&state).await {
Ok(c) => c,
Err(err) => {
warn!("Failed to count licensed users at checkout: {err}");
return StatusCode::SERVICE_UNAVAILABLE.into_response();
}
};
let price_pence = price_for_count(count);
let public_url = &state.public_url;
let success_url = format!("{public_url}/pricing?license_success=1");
// Free tier — grant license directly without Stripe
if price_pence == 0 {
if let Err(err) = grant_license(&state, &user.id).await {
warn!(user_id = %user.id, "Failed to grant free license: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
info!(user_id = %user.id, "Granted free early-bird license");
return Json(CheckoutResponse { url: success_url }).into_response();
}
// Paid tier — create Stripe checkout with dynamic price
let secret_key = &state.stripe_secret_key;
let cancel_url = format!("{public_url}/pricing");
let mut form_params = vec![
("mode", "payment".to_string()),
(
"line_items[0][price_data][unit_amount]",
price_pence.to_string(),
),
("line_items[0][price_data][currency]", "gbp".to_string()),
(
"line_items[0][price_data][product_data][name]",
"Perfect Postcodes Lifetime License".to_string(),
),
("line_items[0][quantity]", "1".to_string()),
("success_url", success_url),
("cancel_url", cancel_url),
("client_reference_id", user.id.clone()),
("customer_email", user.email.clone()),
];
// If a referral code is provided and valid, look it up and apply the coupon
if let Some(ref code) = req.referral_code {
if validate_referral_invite(&state, code).await {
form_params.push((
"discounts[0][coupon]",
state.stripe_referral_coupon_id.clone(),
));
info!(code = %code, "Applying referral coupon to checkout");
} else {
warn!(code = %code, "Referral code validation failed, proceeding without discount");
}
if req.referral_code.is_some() {
return (
StatusCode::BAD_REQUEST,
"Referral codes must be redeemed from the invite link",
)
.into_response();
}
let res = state
.http_client
.post("https://api.stripe.com/v1/checkout/sessions")
.basic_auth(secret_key, None::<&str>)
.form(&form_params)
.send()
.await;
match res {
Ok(resp) if resp.status().is_success() => {
let body: serde_json::Value = match resp.json().await {
Ok(v) => v,
Err(err) => {
warn!("Failed to parse Stripe response: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
};
let url = body["url"].as_str().unwrap_or_default().to_string();
if url.is_empty() {
warn!("Stripe session missing URL");
return StatusCode::BAD_GATEWAY.into_response();
}
info!(user_id = %user.id, price_pence, "Created Stripe checkout session");
Json(CheckoutResponse { url }).into_response()
}
Ok(resp) => {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
warn!("Stripe checkout failed ({status}): {text}");
StatusCode::BAD_GATEWAY.into_response()
match start_license_checkout(&state, &user, &success_url, &cancel_url, None, None).await {
Ok(CheckoutStart::Free) => {
info!(user_id = %user.id, "Granted free early-bird license");
Json(CheckoutResponse { url: success_url }).into_response()
}
Ok(CheckoutStart::Stripe { url }) => Json(CheckoutResponse { url }).into_response(),
Err(err) => {
warn!("Stripe request error: {err}");
warn!(user_id = %user.id, "Failed to start checkout: {err:?}");
StatusCode::BAD_GATEWAY.into_response()
}
}
}
/// Grant a license by updating the user's subscription to "licensed" in PocketBase.
async fn grant_license(state: &AppState, user_id: &str) -> anyhow::Result<()> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({ "subscription": "licensed" }))
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
anyhow::bail!("PocketBase update failed ({status}): {text}");
}
state.token_cache.invalidate_by_user_id(user_id);
Ok(())
}
/// Check if a referral invite code exists and is unused.
async fn validate_referral_invite(state: &AppState, code: &str) -> bool {
// Only allow alphanumeric codes to prevent PocketBase filter injection
if code.is_empty() || code.len() > 20 || !code.bytes().all(|b| b.is_ascii_alphanumeric()) {
return false;
}
let pb_url = state.pocketbase_url.trim_end_matches('/');
let filter = format!(
"code=\"{}\" && invite_type=\"referral\" && used_by_id=\"\"",
code
);
let url = format!(
"{pb_url}/api/collections/invites/records?filter={}&perPage=1",
urlencoding::encode(&filter)
);
match state.http_client.get(&url).send().await {
Ok(resp) if resp.status().is_success() => {
let body: serde_json::Value = resp.json().await.unwrap_or_default();
body["totalItems"].as_u64().unwrap_or(0) > 0
}
_ => false,
}
}

View file

@ -1,6 +1,7 @@
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::time::Duration;
use axum::extract::{Query, State};
use axum::http::{header, HeaderMap, StatusCode};
@ -13,14 +14,18 @@ use tracing::{info, warn};
use crate::auth::OptionalUser;
use crate::consts::NAN_U16;
use crate::data::QuantRef;
use crate::features::INTEGER_BIN_FEATURES;
use crate::data::{PostcodePoiMetrics, QuantRef};
use crate::features;
use crate::licensing::check_license_bounds;
use crate::parsing::{parse_field_indices, parse_filters, require_bounds, row_passes_filters};
use crate::parsing::{
parse_field_indices_with_poi, parse_filters_with_poi, require_bounds, row_passes_filters,
row_passes_poi_filters,
};
use crate::routes::{fetch_screenshot_bytes, FeatureInfo};
use crate::state::SharedState;
const MAX_EXPORT_POSTCODES: usize = 250;
const EXPORT_SCREENSHOT_TIMEOUT_SECS: u64 = 12;
/// Height (in pixels) reserved for the screenshot row
const IMAGE_ROW_HEIGHT: f64 = 225.0;
@ -41,11 +46,11 @@ struct PostcodeExportAgg {
}
impl PostcodeExportAgg {
fn new(num_features: usize) -> Self {
fn new(total_features: usize) -> Self {
Self {
count: 0,
sums: vec![0.0; num_features],
finite_counts: vec![0; num_features],
sums: vec![0.0; total_features],
finite_counts: vec![0; total_features],
enum_freqs: FxHashMap::default(),
}
}
@ -58,6 +63,7 @@ impl PostcodeExportAgg {
num_features: usize,
enum_indices: &FxHashMap<usize, ()>,
quant: &QuantRef,
poi_metrics: &PostcodePoiMetrics,
) {
self.count += 1;
let base = row * num_features;
@ -79,6 +85,18 @@ impl PostcodeExportAgg {
self.finite_counts[feat_idx] += 1;
}
}
let poi_offset = num_features;
for metric_idx in 0..poi_metrics.num_features() {
let raw = poi_metrics.raw_for_property_row(row, metric_idx);
if raw == NAN_U16 {
continue;
}
let value = poi_metrics.decode_raw(metric_idx, raw);
let out_idx = poi_offset + metric_idx;
self.sums[out_idx] += value as f64;
self.finite_counts[out_idx] += 1;
}
}
}
@ -138,13 +156,17 @@ pub async fn get_export(
check_license_bounds(&user.0, (south, west, north, east), None)?;
let quant = state.data.quant_ref();
let (parsed_filters, parsed_enum_filters) = parse_filters(
let poi_quant = state.data.poi_metrics.quant_ref();
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
params.filters.as_deref(),
&state.feature_name_to_index,
&state.data.enum_values,
&quant,
&state.data.poi_metrics.name_to_index,
&poi_quant,
)
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let has_poi_filters = !parsed_poi_filters.is_empty();
let filters_str = params.filters;
let fields_str = params.fields;
@ -164,16 +186,28 @@ pub async fn get_export(
// Fetch screenshot (async, before spawn_blocking)
let auth_header = headers.get(header::AUTHORIZATION);
let screenshot_bytes = match fetch_screenshot_bytes(&state, &frontend_params, auth_header).await
let screenshot_fetch = fetch_screenshot_bytes(&state, &frontend_params, auth_header);
let screenshot_bytes = match tokio::time::timeout(
Duration::from_secs(EXPORT_SCREENSHOT_TIMEOUT_SECS),
screenshot_fetch,
)
.await
{
Ok(bytes) => {
Ok(Ok(bytes)) => {
info!(bytes = bytes.len(), "Fetched screenshot for export");
Some(bytes)
}
Err(err) => {
Ok(Err(err)) => {
warn!("Screenshot failed for export: {err}");
None
}
Err(_) => {
warn!(
timeout_secs = EXPORT_SCREENSHOT_TIMEOUT_SECS,
"Screenshot timed out for export"
);
None
}
};
// Build feature name → description map from the precomputed features response
@ -200,6 +234,9 @@ pub async fn get_export(
let feature_names = &state.data.feature_names;
let enum_values = &state.data.enum_values;
let postcode_data = &state.postcode_data;
let poi_metrics = &state.data.poi_metrics;
let poi_offset = num_features;
let total_export_features = num_features + poi_metrics.num_features();
// Build set of enum feature indices for quick lookup
let enum_indices: FxHashMap<usize, ()> = enum_values.keys().map(|&idx| (idx, ())).collect();
@ -219,6 +256,10 @@ pub async fn get_export(
) {
return;
}
if has_poi_filters && !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics)
{
return;
}
let postcode = state.data.postcode(row);
if let Some(&pc_idx) = postcode_data.postcode_to_idx.get(postcode) {
postcode_rows.entry(pc_idx).or_default().push(row);
@ -229,9 +270,16 @@ pub async fn get_export(
let mut postcode_aggs: Vec<(usize, PostcodeExportAgg)> =
Vec::with_capacity(postcode_rows.len());
for (pc_idx, rows) in postcode_rows {
let mut agg = PostcodeExportAgg::new(num_features);
let mut agg = PostcodeExportAgg::new(total_export_features);
for &row in &rows {
agg.add_row(feature_data, row, num_features, &enum_indices, &quant);
agg.add_row(
feature_data,
row,
num_features,
&enum_indices,
&quant,
poi_metrics,
);
}
if agg.count > 0 {
postcode_aggs.push((pc_idx, agg));
@ -265,14 +313,19 @@ pub async fn get_export(
// Determine column order: filter features first, then remaining
let filter_feature_names = extract_filter_feature_names(filters_str.as_deref());
let field_indices =
parse_field_indices(fields_str.as_deref(), &state.feature_name_to_index)
.map_err(|err| err.1)?;
let field_indices = parse_field_indices_with_poi(
fields_str.as_deref(),
&state.feature_name_to_index,
&state.data.poi_metrics.name_to_index,
)
.map_err(|err| err.1)?;
let all_feature_indices: Vec<usize> = if let Some(ref indices) = field_indices {
indices.clone()
let all_feature_indices: Vec<usize> = if let Some(ref indices) = field_indices.normal {
let mut selected = indices.clone();
selected.extend(field_indices.poi.iter().map(|idx| poi_offset + *idx));
selected
} else {
let mut ordered = Vec::with_capacity(num_features);
let mut ordered = Vec::with_capacity(total_export_features);
let mut used = FxHashSet::default();
for name in &filter_feature_names {
@ -280,6 +333,11 @@ pub async fn get_export(
if used.insert(idx) {
ordered.push(idx);
}
} else if let Some(&idx) = state.data.poi_metrics.name_to_index.get(name.as_str()) {
let virtual_idx = poi_offset + idx;
if used.insert(virtual_idx) {
ordered.push(virtual_idx);
}
}
}
for idx in 0..num_features {
@ -287,15 +345,42 @@ pub async fn get_export(
ordered.push(idx);
}
}
for idx in 0..poi_metrics.num_features() {
let virtual_idx = poi_offset + idx;
if used.insert(virtual_idx) {
ordered.push(virtual_idx);
}
}
ordered
};
// Filter-only feature indices for the Selected sheet
let filter_feature_indices: Vec<usize> = filter_feature_names
.iter()
.filter_map(|name| state.feature_name_to_index.get(name.as_str()).copied())
.filter_map(|name| {
state
.feature_name_to_index
.get(name.as_str())
.copied()
.or_else(|| {
state
.data
.poi_metrics
.name_to_index
.get(name.as_str())
.map(|idx| poi_offset + *idx)
})
})
.collect();
let feature_name_for_idx = |idx: usize| -> &str {
if idx < num_features {
&feature_names[idx]
} else {
&poi_metrics.feature_names[idx - poi_offset]
}
};
// Build feature unit map (feat_idx → (prefix, suffix)) for number formatting
let feature_units: FxHashMap<usize, (&str, &str)> = state
.features_response
@ -309,16 +394,25 @@ pub async fn get_export(
suffix,
..
} => {
let idx = state.feature_name_to_index.get(name.as_str())?;
Some((*idx, (*prefix, *suffix)))
if let Some(&idx) = state.feature_name_to_index.get(name.as_str()) {
Some((idx, (*prefix, *suffix)))
} else {
state
.data
.poi_metrics
.name_to_index
.get(name.as_str())
.map(|idx| (poi_offset + *idx, (*prefix, *suffix)))
}
}
_ => None,
})
.collect();
let integer_feature_indices: FxHashSet<usize> = INTEGER_BIN_FEATURES
let integer_feature_indices: FxHashSet<usize> = all_feature_indices
.iter()
.filter_map(|name| state.feature_name_to_index.get(*name).copied())
.copied()
.filter(|&idx| features::has_integer_bins(feature_name_for_idx(idx)))
.collect();
// Build Excel number formats per feature index for unit display
@ -435,7 +529,7 @@ pub async fn get_export(
.write_string_with_format(
header_row,
col,
&feature_names[feat_idx],
feature_name_for_idx(feat_idx),
&header_fmt,
)
.map_err(|e| format!("Failed to write header: {e}"))?;
@ -453,7 +547,7 @@ pub async fn get_export(
for (col_offset, &feat_idx) in feat_indices.iter().enumerate() {
let col = (col_offset + 2) as u16;
let desc = feature_descriptions
.get(&feature_names[feat_idx])
.get(feature_name_for_idx(feat_idx))
.map(String::as_str)
.unwrap_or("");
sheet
@ -477,7 +571,7 @@ pub async fn get_export(
for (col_offset, &feat_idx) in feat_indices.iter().enumerate() {
let col = (col_offset + 2) as u16;
if enum_indices.contains_key(&feat_idx) {
if feat_idx < num_features && enum_indices.contains_key(&feat_idx) {
if let Some(freqs) = agg.enum_freqs.get(&feat_idx) {
if let Some((&mode_bits, _)) =
freqs.iter().max_by_key(|(_, &count)| count)
@ -543,7 +637,7 @@ pub async fn get_export(
.map_err(|e| format!("Failed to set column width: {e}"))?;
for col_offset in 0..feat_indices.len() {
let col = (col_offset + 2) as u16;
let feat_name = &feature_names[feat_indices[col_offset]];
let feat_name = feature_name_for_idx(feat_indices[col_offset]);
let width = (feat_name.len() as f64 * 1.1).clamp(10.0, 30.0);
sheet
.set_column_width(col, width)

View file

@ -7,7 +7,7 @@ use serde::Serialize;
use tracing::info;
use crate::data::{Histogram, PropertyData};
use crate::features::{Feature, FEATURE_GROUPS};
use crate::features::{self, Feature, FEATURE_GROUPS};
use crate::state::SharedState;
fn is_empty(val: &str) -> bool {
@ -28,9 +28,9 @@ pub enum FeatureInfo {
max: f32,
step: f32,
histogram: Histogram,
description: &'static str,
detail: &'static str,
source: &'static str,
description: String,
detail: String,
source: String,
#[serde(skip_serializing_if = "is_empty")]
prefix: &'static str,
#[serde(skip_serializing_if = "is_empty")]
@ -45,9 +45,9 @@ pub enum FeatureInfo {
name: String,
values: Vec<String>,
counts: HashMap<String, u64>,
description: &'static str,
detail: &'static str,
source: &'static str,
description: String,
detail: String,
source: String,
},
}
@ -85,9 +85,9 @@ pub fn build_features_response(data: &PropertyData) -> FeaturesResponse {
max: stats.slider_max,
step: config.step,
histogram: stats.histogram.clone(),
description: config.description,
detail: config.detail,
source: config.source,
description: config.description.to_string(),
detail: config.detail.to_string(),
source: config.source.to_string(),
prefix: config.prefix,
suffix: config.suffix,
raw: config.raw,
@ -118,9 +118,9 @@ pub fn build_features_response(data: &PropertyData) -> FeaturesResponse {
name: config.name.to_string(),
values: values.clone(),
counts,
description: config.description,
detail: config.detail,
source: config.source,
description: config.description.to_string(),
detail: config.detail.to_string(),
source: config.source.to_string(),
});
}
}
@ -136,6 +136,58 @@ pub fn build_features_response(data: &PropertyData) -> FeaturesResponse {
}
}
let mut dynamic_poi_features = Vec::new();
for (feat_idx, name) in data.poi_metrics.feature_names.iter().enumerate() {
if let Some(category) = features::dynamic_poi_distance_category(name) {
let stats = &data.poi_metrics.feature_stats[feat_idx];
dynamic_poi_features.push(FeatureInfo::Numeric {
name: name.clone(),
min: stats.slider_min,
max: stats.slider_max,
step: 0.1,
histogram: stats.histogram.clone(),
description: format!("Distance to the closest {category} POI"),
detail: format!(
"Straight-line distance in kilometres from the postcode to the nearest {category} point of interest in the POI dataset."
),
source: "osm-pois".to_string(),
prefix: "",
suffix: " km",
raw: false,
absolute: false,
});
} else if let Some(category) = features::dynamic_poi_count_category(name) {
let stats = &data.poi_metrics.feature_stats[feat_idx];
let radius = features::dynamic_poi_count_radius(name).unwrap_or(0);
dynamic_poi_features.push(FeatureInfo::Numeric {
name: name.clone(),
min: stats.slider_min,
max: stats.slider_max,
step: 1.0,
histogram: stats.histogram.clone(),
description: format!("Number of {category} POIs within {radius}km"),
detail: format!(
"Count of {category} points of interest within a {radius}km radius of the property's postcode centroid."
),
source: "osm-pois".to_string(),
prefix: "",
suffix: "",
raw: false,
absolute: false,
});
}
}
if !dynamic_poi_features.is_empty() {
dynamic_poi_features.sort_by_key(|feature| match feature {
FeatureInfo::Numeric { name, .. } => features::dynamic_poi_feature_sort_key(name),
FeatureInfo::Enum { name, .. } => features::dynamic_poi_feature_sort_key(name),
});
groups.push(FeatureGroupResponse {
name: "Nearby POIs".to_string(),
features: dynamic_poi_features,
});
}
FeaturesResponse { groups }
}

View file

@ -9,7 +9,7 @@ use tracing::info;
use crate::consts::NAN_U16;
use crate::data::travel_time::TravelData;
use crate::parsing::{parse_filters, require_bounds};
use crate::parsing::{parse_filters_with_poi, require_bounds};
use crate::routes::travel_time::parse_optional_travel;
use crate::state::SharedState;
@ -36,18 +36,21 @@ pub async fn get_filter_counts(
require_bounds(params.bounds).map_err(IntoResponse::into_response)?;
let quant = state.data.quant_ref();
let (parsed_filters, parsed_enum_filters) = parse_filters(
let poi_quant = state.data.poi_metrics.quant_ref();
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
params.filters.as_deref(),
&state.feature_name_to_index,
&state.data.enum_values,
&quant,
&state.data.poi_metrics.name_to_index,
&poi_quant,
)
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let travel_entries = parse_optional_travel(params.travel.as_deref())
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let num_regular = parsed_filters.len() + parsed_enum_filters.len();
let num_regular = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
// Only travel entries with a filter range count as filters for impact tracking
let travel_filter_indices: Vec<usize> = travel_entries
.iter()
@ -65,6 +68,7 @@ pub async fn get_filter_counts(
}
let filters_str = params.filters;
let has_poi_filters = !parsed_poi_filters.is_empty();
let response = tokio::task::spawn_blocking(move || -> Result<FilterCountsResponse, String> {
let t0 = std::time::Instant::now();
@ -124,6 +128,23 @@ pub async fn get_filter_counts(
}
}
// Test travel time filters
if fail_count <= 1 && has_poi_filters {
for (i, f) in parsed_poi_filters.iter().enumerate() {
let raw = state
.data
.poi_metrics
.raw_for_property_row(row, f.metric_idx);
if raw == NAN_U16 || raw < f.min_u16 || raw > f.max_u16 {
fail_count += 1;
fail_index = parsed_filters.len() + parsed_enum_filters.len() + i;
if fail_count > 1 {
break;
}
}
}
}
// Test travel time filters
if fail_count <= 1 && has_travel {
let postcode = pc_interner.resolve(&pc_keys[row]);
@ -169,8 +190,15 @@ pub async fn get_filter_counts(
let name = if i < parsed_filters.len() {
state.data.feature_names[parsed_filters[i].feat_idx].clone()
} else if i < num_regular {
let ei = i - parsed_filters.len();
state.data.feature_names[parsed_enum_filters[ei].feat_idx].clone()
let enum_start = parsed_filters.len();
let poi_start = enum_start + parsed_enum_filters.len();
if i < poi_start {
let ei = i - enum_start;
state.data.feature_names[parsed_enum_filters[ei].feat_idx].clone()
} else {
let pi = i - poi_start;
state.data.poi_metrics.feature_names[parsed_poi_filters[pi].metric_idx].clone()
}
} else {
let slot = i - num_regular;
let ti = travel_filter_indices[slot];

View file

@ -13,8 +13,8 @@ use tracing::{info, warn};
use crate::auth::OptionalUser;
use crate::licensing::{check_license_bounds, resolve_share_code};
use crate::parsing::{
cell_for_row_cached, h3_cell_bounds, needs_parent, parse_field_set, parse_filters,
row_passes_filters, validate_h3_resolution,
cell_for_row_cached, h3_cell_bounds, needs_parent, parse_field_set, parse_filters_with_poi,
row_passes_filters, row_passes_poi_filters, validate_h3_resolution,
};
use crate::state::SharedState;
@ -110,15 +110,19 @@ pub async fn get_hexagon_stats(
let h3_str = params.h3;
let quant = state.data.quant_ref();
let (parsed_filters, parsed_enum_filters) = parse_filters(
let poi_quant = state.data.poi_metrics.quant_ref();
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
params.filters.as_deref(),
&state.feature_name_to_index,
&state.data.enum_values,
&quant,
&state.data.poi_metrics.name_to_index,
&poi_quant,
)
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
let filters_str = params.filters;
let has_poi_filters = !parsed_poi_filters.is_empty();
let (fields_specified, field_set) = parse_field_set(params.fields.as_deref());
@ -161,6 +165,12 @@ pub async fn get_hexagon_stats(
feature_data,
num_features,
)
&& (!has_poi_filters
|| row_passes_poi_filters(
row,
&parsed_poi_filters,
&state.data.poi_metrics,
))
{
if has_travel {
let postcode = state.data.postcode(row);
@ -233,7 +243,7 @@ pub async fn get_hexagon_stats(
let price_history =
stats::extract_price_history(&matching_rows, &state.data, &state.feature_name_to_index);
let (numeric_features, enum_features_out) = stats::compute_feature_stats(
let (mut numeric_features, enum_features_out) = stats::compute_feature_stats(
&matching_rows,
&state.data,
&state.data.feature_names,
@ -242,6 +252,12 @@ pub async fn get_hexagon_stats(
fields_specified,
&field_set,
);
numeric_features.extend(stats::compute_poi_feature_stats(
&matching_rows,
&state.data.poi_metrics,
fields_specified,
&field_set,
));
let elapsed = start_time.elapsed();
info!(

View file

@ -11,14 +11,15 @@ use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use tracing::info;
use crate::aggregation::{Aggregator, EnumDistConfig};
use crate::aggregation::{Aggregator, EnumDistConfig, PoiAggregator};
use crate::auth::OptionalUser;
use crate::consts::MAX_CELLS_PER_REQUEST;
use crate::data::travel_time::TravelData;
use crate::licensing::{check_license_bounds, resolve_share_code};
use crate::parsing::{
cell_for_row_cached, needs_parent, parse_enum_dist, parse_field_indices, parse_filters,
require_bounds, row_passes_filters, validate_h3_resolution,
cell_for_row_cached, needs_parent, parse_enum_dist, parse_field_indices_with_poi,
parse_filters_with_poi, require_bounds, row_passes_filters, row_passes_poi_filters,
validate_h3_resolution,
};
use crate::routes::travel_time::{parse_optional_travel, TravelTimeAgg};
use crate::state::SharedState;
@ -29,6 +30,7 @@ const PARALLEL_THRESHOLD: usize = 50_000;
/// Per-thread aggregation result: feature accumulators + travel time accumulators.
type ChunkResult = (
FxHashMap<u64, Aggregator>,
FxHashMap<u64, PoiAggregator>,
Vec<FxHashMap<u64, TravelTimeAgg>>,
);
@ -79,11 +81,14 @@ pub struct HexagonParams {
#[allow(clippy::too_many_arguments)]
fn build_feature_maps(
groups: &FxHashMap<u64, Aggregator>,
poi_groups: &FxHashMap<u64, PoiAggregator>,
min_keys: &[String],
max_keys: &[String],
avg_keys: &[String],
num_features: usize,
indices: Option<&[usize]>,
poi_feature_names: &[String],
poi_indices: &[usize],
query_bounds: (f64, f64, f64, f64),
resolution: h3o::Resolution,
travel_aggs: &[FxHashMap<u64, TravelTimeAgg>],
@ -163,6 +168,25 @@ fn build_feature_maps(
}
}
if let Some(poi_aggregation) = poi_groups.get(&cell_id) {
for &metric_idx in poi_indices {
if poi_aggregation.counts[metric_idx] > 0 {
let avg = poi_aggregation.sums[metric_idx]
/ poi_aggregation.counts[metric_idx] as f64;
if let (Some(min_num), Some(max_num), Some(avg_num)) = (
serde_json::Number::from_f64(poi_aggregation.mins[metric_idx] as f64),
serde_json::Number::from_f64(poi_aggregation.maxs[metric_idx] as f64),
serde_json::Number::from_f64(avg),
) {
let name = &poi_feature_names[metric_idx];
map.insert(format!("min_{name}"), Value::Number(min_num));
map.insert(format!("max_{name}"), Value::Number(max_num));
map.insert(format!("avg_{name}"), Value::Number(avg_num));
}
}
}
}
// Add travel time aggregation fields (using pre-computed key strings)
for (ti, agg_map) in travel_aggs.iter().enumerate() {
if let Some(agg) = agg_map.get(&cell_id) {
@ -209,18 +233,25 @@ pub async fn get_hexagons(
check_license_bounds(&user.0, (south, west, north, east), share_bounds)?;
let quant = state.data.quant_ref();
let (parsed_filters, parsed_enum_filters) = parse_filters(
let poi_quant = state.data.poi_metrics.quant_ref();
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
params.filters.as_deref(),
&state.feature_name_to_index,
&state.data.enum_values,
&quant,
&state.data.poi_metrics.name_to_index,
&poi_quant,
)
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
let filters_str = params.filters;
let field_indices = parse_field_indices(params.fields.as_deref(), &state.feature_name_to_index)
.map_err(|err| (err.0, err.1).into_response())?;
let field_indices = parse_field_indices_with_poi(
params.fields.as_deref(),
&state.feature_name_to_index,
&state.data.poi_metrics.name_to_index,
)
.map_err(|err| (err.0, err.1).into_response())?;
let travel_entries = parse_optional_travel(params.travel.as_deref())
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
@ -269,6 +300,11 @@ pub async fn get_hexagons(
let min_keys = &state.min_keys;
let max_keys = &state.max_keys;
let avg_keys = &state.avg_keys;
let poi_metrics = &state.data.poi_metrics;
let poi_field_indices = field_indices.poi.as_slice();
let has_poi_fields = !poi_field_indices.is_empty();
let has_poi_filters = !parsed_poi_filters.is_empty();
let poi_num_features = poi_metrics.num_features();
let h3_res = h3o::Resolution::try_from(resolution)
.map_err(|error| format!("Invalid H3 resolution {}: {}", resolution, error))?;
@ -276,6 +312,7 @@ pub async fn get_hexagons(
let need_parent = needs_parent(resolution);
let mut groups: FxHashMap<u64, Aggregator> = FxHashMap::default();
let mut poi_groups: FxHashMap<u64, PoiAggregator> = FxHashMap::default();
let mut travel_aggs: Vec<FxHashMap<u64, TravelTimeAgg>> = (0..travel_entries.len())
.map(|_| FxHashMap::default())
.collect();
@ -296,6 +333,7 @@ pub async fn get_hexagons(
.par_chunks(chunk_size)
.map(|chunk| {
let mut local_groups: FxHashMap<u64, Aggregator> = FxHashMap::default();
let mut local_poi_groups: FxHashMap<u64, PoiAggregator> = FxHashMap::default();
let mut local_travel_aggs: Vec<FxHashMap<u64, TravelTimeAgg>> = (0
..travel_entries.len())
.map(|_| FxHashMap::default())
@ -315,6 +353,11 @@ pub async fn get_hexagons(
) {
continue;
}
if has_poi_filters
&& !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics)
{
continue;
}
if has_travel {
travel_minutes.clear();
@ -352,7 +395,7 @@ pub async fn get_hexagons(
let agg = local_groups
.entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features, enum_dist_config));
if let Some(sel_indices) = field_indices.as_deref() {
if let Some(sel_indices) = field_indices.normal.as_deref() {
agg.add_row_selective(
feature_data,
row,
@ -364,6 +407,13 @@ pub async fn get_hexagons(
agg.add_row(feature_data, row, num_features, &quant);
}
if has_poi_fields {
local_poi_groups
.entry(cell_id)
.or_insert_with(|| PoiAggregator::new(poi_num_features))
.add_row_selective(poi_metrics, row, poi_field_indices);
}
for (ti, minutes) in travel_minutes.iter().enumerate() {
if let Some(mins) = minutes {
let tagg = local_travel_aggs[ti]
@ -374,18 +424,24 @@ pub async fn get_hexagons(
}
}
(local_groups, local_travel_aggs)
(local_groups, local_poi_groups, local_travel_aggs)
})
.collect();
// Merge thread-local results into the main accumulators
for (local_groups, local_travel) in thread_results {
for (local_groups, local_poi_groups, local_travel) in thread_results {
for (cell_id, local_agg) in local_groups {
groups
.entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features, enum_dist_config))
.merge(&local_agg);
}
for (cell_id, local_agg) in local_poi_groups {
poi_groups
.entry(cell_id)
.or_insert_with(|| PoiAggregator::new(poi_num_features))
.merge(&local_agg);
}
for (ti, local_ta) in local_travel.into_iter().enumerate() {
for (cell_id, local_tt) in local_ta {
travel_aggs[ti]
@ -414,6 +470,11 @@ pub async fn get_hexagons(
) {
return;
}
if has_poi_filters
&& !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics)
{
return;
}
if has_travel {
travel_minutes.clear();
@ -444,7 +505,7 @@ pub async fn get_hexagons(
let aggregation = groups
.entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features, enum_dist_config));
if let Some(sel_indices) = field_indices.as_deref() {
if let Some(sel_indices) = field_indices.normal.as_deref() {
aggregation.add_row_selective(
feature_data,
row,
@ -456,6 +517,13 @@ pub async fn get_hexagons(
aggregation.add_row(feature_data, row, num_features, &quant);
}
if has_poi_fields {
poi_groups
.entry(cell_id)
.or_insert_with(|| PoiAggregator::new(poi_num_features))
.add_row_selective(poi_metrics, row, poi_field_indices);
}
for (ti, minutes) in travel_minutes.iter().enumerate() {
if let Some(mins) = minutes {
let agg = travel_aggs[ti]
@ -471,11 +539,14 @@ pub async fn get_hexagons(
let mut features = build_feature_maps(
&groups,
&poi_groups,
min_keys,
max_keys,
avg_keys,
num_features,
field_indices.as_deref(),
field_indices.normal.as_deref(),
&poi_metrics.feature_names,
poi_field_indices,
(south, west, north, east),
h3_res,
&travel_aggs,
@ -499,7 +570,11 @@ pub async fn get_hexagons(
bounds = format_args!("{:.4},{:.4},{:.4},{:.4}", south, west, north, east),
filters = num_filters,
filters_raw = filters_str.as_deref().unwrap_or("-"),
fields = field_indices.as_ref().map(|v| v.len() as i32).unwrap_or(-1),
fields = field_indices
.normal
.as_ref()
.map(|v| (v.len() + poi_field_indices.len()) as i32)
.unwrap_or(-1),
travel_entries = travel_entries.len(),
grid_ms = format_args!("{:.1}", t_grid.as_secs_f64() * 1000.0),
agg_ms = format_args!("{:.1}", (t_agg - t_grid).as_secs_f64() * 1000.0),

View file

@ -9,11 +9,16 @@ use serde::{Deserialize, Serialize};
use tracing::{info, warn};
use crate::auth::{OptionalUser, PocketBaseUser};
use crate::checkout_sessions::{
active_referral_checkout_user, start_license_checkout, CheckoutStart,
};
use crate::pocketbase::get_superuser_token;
use crate::pocketbase_locks::acquire_pocketbase_lock;
use crate::state::{AppState, SharedState};
static INVITE_REDEMPTIONS_IN_PROGRESS: LazyLock<Mutex<HashSet<String>>> =
LazyLock::new(|| Mutex::new(HashSet::new()));
const INVITE_REDEMPTION_LOCK_TTL_SECS: u64 = 5 * 60;
struct InviteRedemptionGuard {
code: String,
@ -103,7 +108,7 @@ fn validate_invite_code(code: &str) -> Result<(), &'static str> {
}
fn generate_invite_code() -> String {
use rand::Rng;
use rand::RngExt;
let mut rng = rand::rng();
let chars: Vec<char> = (0..12)
.map(|_| {
@ -246,74 +251,26 @@ async fn grant_license_for_invite(
async fn create_referral_checkout(
state: &AppState,
user: &PocketBaseUser,
invite_id: &str,
) -> Result<String, Response> {
let count = match super::pricing::count_licensed_users(state).await {
Ok(count) => count,
Err(err) => {
warn!("Failed to count licensed users for invite checkout: {err}");
return Err(StatusCode::SERVICE_UNAVAILABLE.into_response());
}
};
let price_pence = super::pricing::price_for_count(count);
let public_url = &state.public_url;
let success_url = format!("{public_url}/pricing?license_success=1");
let cancel_url = format!("{public_url}/pricing");
let form_params = vec![
("mode", "payment".to_string()),
(
"line_items[0][price_data][unit_amount]",
price_pence.to_string(),
),
("line_items[0][price_data][currency]", "gbp".to_string()),
(
"line_items[0][price_data][product_data][name]",
"Perfect Postcodes Lifetime License".to_string(),
),
("line_items[0][quantity]", "1".to_string()),
("success_url", success_url),
("cancel_url", cancel_url),
("client_reference_id", user.id.clone()),
("customer_email", user.email.clone()),
(
"discounts[0][coupon]",
state.stripe_referral_coupon_id.clone(),
),
];
let stripe_res = state
.http_client
.post("https://api.stripe.com/v1/checkout/sessions")
.basic_auth(&state.stripe_secret_key, None::<&str>)
.form(&form_params)
.send()
.await;
match stripe_res {
Ok(resp) if resp.status().is_success() => {
let stripe_body: serde_json::Value = match resp.json().await {
Ok(value) => value,
Err(err) => {
warn!("Failed to parse Stripe checkout response: {err}");
return Err(StatusCode::BAD_GATEWAY.into_response());
}
};
let checkout_url = stripe_body["url"].as_str().unwrap_or_default().to_string();
if checkout_url.is_empty() {
warn!("Stripe checkout response did not include a URL");
return Err(StatusCode::BAD_GATEWAY.into_response());
}
Ok(checkout_url)
}
Ok(resp) => {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
warn!("Failed to create Stripe checkout for referral invite ({status}): {text}");
Err(StatusCode::BAD_GATEWAY.into_response())
}
match start_license_checkout(
state,
user,
&success_url,
&cancel_url,
Some(&state.stripe_referral_coupon_id),
Some(invite_id),
)
.await
{
Ok(CheckoutStart::Free) => Ok(success_url),
Ok(CheckoutStart::Stripe { url }) => Ok(url),
Err(err) => {
warn!("Stripe request error for referral invite: {err}");
warn!("Failed to create reserved Stripe checkout for referral invite: {err:?}");
Err(StatusCode::BAD_GATEWAY.into_response())
}
}
@ -541,6 +498,10 @@ pub async fn post_redeem_invite(
.into_response();
}
if user.is_admin || user.subscription == "licensed" {
return (StatusCode::CONFLICT, "Account already has full access").into_response();
}
let pb_url = state.pocketbase_url.trim_end_matches('/');
let token = match get_superuser_token(&state).await {
@ -561,6 +522,19 @@ pub async fn post_redeem_invite(
.into_response()
}
};
let lock_name = format!("invite:{}", req.code);
let _distributed_redemption_guard =
match acquire_pocketbase_lock(&state, &lock_name, INVITE_REDEMPTION_LOCK_TTL_SECS).await {
Ok(guard) => guard,
Err(err) => {
warn!(code = %req.code, "Failed to acquire invite redemption lock: {err}");
return (
StatusCode::CONFLICT,
"Invite redemption is already in progress",
)
.into_response();
}
};
let invite = match lookup_unused_invite(&state, pb_url, &token, &req.code).await {
Ok(Some(invite)) => invite,
@ -591,11 +565,11 @@ pub async fn post_redeem_invite(
};
if invite_type == "admin" {
if let Err(response) = grant_license_for_invite(&state, pb_url, &token, &user.id).await {
if let Err(response) = mark_invite_used(&state, pb_url, &token, invite_id, &user.id).await {
return response;
}
if let Err(response) = mark_invite_used(&state, pb_url, &token, invite_id, &user.id).await {
if let Err(response) = grant_license_for_invite(&state, pb_url, &token, &user.id).await {
return response;
}
@ -607,15 +581,26 @@ pub async fn post_redeem_invite(
.into_response();
}
let checkout_url = match create_referral_checkout(&state, &user).await {
match active_referral_checkout_user(&state, invite_id).await {
Ok(Some(active_user_id)) if active_user_id != user.id => {
return (
StatusCode::CONFLICT,
"Invite checkout is already in progress",
)
.into_response()
}
Ok(_) => {}
Err(err) => {
warn!(code = %req.code, "Failed to check active referral checkout: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
}
let checkout_url = match create_referral_checkout(&state, &user, invite_id).await {
Ok(url) => url,
Err(response) => return response,
};
if let Err(response) = mark_invite_used(&state, pb_url, &token, invite_id, &user.id).await {
return response;
}
info!(user_id = %user.id, code = %req.code, "Referral invite redeemed; checkout created");
Json(RedeemResponse {
result: "checkout".to_string(),

View file

@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
use tracing::info;
use crate::consts::MAX_POIS_PER_REQUEST;
use crate::data::POICategoryGroup;
use crate::data::{resolve_poi_category_filter, POICategoryGroup};
use crate::parsing::require_bounds;
use crate::state::SharedState;
@ -47,20 +47,7 @@ pub async fn get_pois(
.categories
.as_deref()
.filter(|text| !text.is_empty())
.map(|text| {
text.split(',')
.filter_map(|part| {
let name = part.trim();
state
.poi_data
.category
.values
.iter()
.position(|v| v == name)
.map(|pos| pos as u16)
})
.collect()
});
.map(|text| resolve_poi_category_filter(&state.poi_data.category.values, text));
let categories_raw = params.categories;
let num_categories = category_filter.as_ref().map(|cats| cats.len()).unwrap_or(0);

View file

@ -10,7 +10,7 @@ use tracing::{info, warn};
use crate::auth::OptionalUser;
use crate::consts::{DEFAULT_PROPERTIES_LIMIT, MAX_PROPERTIES_LIMIT, POSTCODE_SEARCH_OFFSET};
use crate::licensing::{check_license_point, resolve_share_code};
use crate::parsing::{parse_filters, row_passes_filters};
use crate::parsing::{parse_filters_with_poi, row_passes_filters, row_passes_poi_filters};
use crate::state::SharedState;
use crate::utils::normalize_postcode;
@ -62,15 +62,19 @@ pub async fn get_postcode_properties(
)?;
let quant = state.data.quant_ref();
let (parsed_filters, parsed_enum_filters) = parse_filters(
let poi_quant = state.data.poi_metrics.quant_ref();
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
params.filters.as_deref(),
&state.feature_name_to_index,
&state.data.enum_values,
&quant,
&state.data.poi_metrics.name_to_index,
&poi_quant,
)
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
let filters_str = params.filters;
let has_poi_filters = !parsed_poi_filters.is_empty();
let travel_entries = parse_optional_travel(params.travel.as_deref())
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
@ -111,6 +115,12 @@ pub async fn get_postcode_properties(
feature_data,
num_features,
)
&& (!has_poi_filters
|| row_passes_poi_filters(
row,
&parsed_poi_filters,
&state.data.poi_metrics,
))
{
if has_travel
&& !row_passes_travel_filters(

View file

@ -10,7 +10,9 @@ use tracing::{info, warn};
use crate::auth::OptionalUser;
use crate::consts::POSTCODE_SEARCH_OFFSET;
use crate::licensing::{check_license_point, resolve_share_code};
use crate::parsing::{parse_field_set, parse_filters, row_passes_filters};
use crate::parsing::{
parse_field_set, parse_filters_with_poi, row_passes_filters, row_passes_poi_filters,
};
use crate::state::SharedState;
use crate::utils::normalize_postcode;
@ -64,15 +66,19 @@ pub async fn get_postcode_stats(
)?;
let quant = state.data.quant_ref();
let (parsed_filters, parsed_enum_filters) = parse_filters(
let poi_quant = state.data.poi_metrics.quant_ref();
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
params.filters.as_deref(),
&state.feature_name_to_index,
&state.data.enum_values,
&quant,
&state.data.poi_metrics.name_to_index,
&poi_quant,
)
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
let filters_str = params.filters;
let has_poi_filters = !parsed_poi_filters.is_empty();
let (fields_specified, field_set) = parse_field_set(params.fields.as_deref());
let travel_entries = parse_optional_travel(params.travel.as_deref())
@ -108,6 +114,12 @@ pub async fn get_postcode_stats(
feature_data,
num_features,
)
&& (!has_poi_filters
|| row_passes_poi_filters(
row,
&parsed_poi_filters,
&state.data.poi_metrics,
))
{
if has_travel
&& !row_passes_travel_filters(row_postcode, &travel_entries, &travel_data)
@ -123,7 +135,7 @@ pub async fn get_postcode_stats(
let price_history =
stats::extract_price_history(&matching_rows, &state.data, &state.feature_name_to_index);
let (numeric_features, enum_features_out) = stats::compute_feature_stats(
let (mut numeric_features, enum_features_out) = stats::compute_feature_stats(
&matching_rows,
&state.data,
&state.data.feature_names,
@ -132,6 +144,12 @@ pub async fn get_postcode_stats(
fields_specified,
&field_set,
);
numeric_features.extend(stats::compute_poi_feature_stats(
&matching_rows,
&state.data.poi_metrics,
fields_specified,
&field_set,
));
let elapsed = start_time.elapsed();
info!(

View file

@ -10,14 +10,14 @@ use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use tracing::info;
use crate::aggregation::{Aggregator, EnumDistConfig};
use crate::aggregation::{Aggregator, EnumDistConfig, PoiAggregator};
use crate::auth::OptionalUser;
use crate::consts::MAX_CELLS_PER_REQUEST;
use crate::data::travel_time::TravelData;
use crate::licensing::{check_license_bounds, resolve_share_code};
use crate::parsing::{
bounds_intersect, parse_enum_dist, parse_field_indices, parse_filters, require_bounds,
row_passes_filters,
bounds_intersect, parse_enum_dist, parse_field_indices_with_poi, parse_filters_with_poi,
require_bounds, row_passes_filters, row_passes_poi_filters,
};
use crate::pocketbase::log_user_location;
use crate::routes::travel_time::{parse_optional_travel, TravelTimeAgg};
@ -64,18 +64,25 @@ pub async fn get_postcodes(
check_license_bounds(&user.0, (south, west, north, east), share_bounds)?;
let quant = state.data.quant_ref();
let (parsed_filters, parsed_enum_filters) = parse_filters(
let poi_quant = state.data.poi_metrics.quant_ref();
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
params.filters.as_deref(),
&state.feature_name_to_index,
&state.data.enum_values,
&quant,
&state.data.poi_metrics.name_to_index,
&poi_quant,
)
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
let filters_str = params.filters;
let field_indices = parse_field_indices(params.fields.as_deref(), &state.feature_name_to_index)
.map_err(|err| (err.0, err.1).into_response())?;
let field_indices = parse_field_indices_with_poi(
params.fields.as_deref(),
&state.feature_name_to_index,
&state.data.poi_metrics.name_to_index,
)
.map_err(|err| (err.0, err.1).into_response())?;
let travel_entries = parse_optional_travel(params.travel.as_deref())
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
@ -123,12 +130,18 @@ pub async fn get_postcodes(
let min_keys = &state.min_keys;
let max_keys = &state.max_keys;
let avg_keys = &state.avg_keys;
let poi_metrics = &state.data.poi_metrics;
let poi_field_indices = field_indices.poi.as_slice();
let has_poi_fields = !poi_field_indices.is_empty();
let has_poi_filters = !parsed_poi_filters.is_empty();
let poi_num_features = poi_metrics.num_features();
let has_selective = field_indices.is_some();
let sel_indices = field_indices.as_deref().unwrap_or(&[]);
let has_selective = field_indices.normal.is_some();
let sel_indices = field_indices.normal.as_deref().unwrap_or(&[]);
// Single-pass: aggregate directly into postcode_aggs while iterating properties in bounds
let mut postcode_aggs: FxHashMap<usize, Aggregator> = FxHashMap::default();
let mut poi_aggs: FxHashMap<usize, PoiAggregator> = FxHashMap::default();
state
.grid
@ -143,6 +156,10 @@ pub async fn get_postcodes(
) {
return;
}
if has_poi_filters && !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics)
{
return;
}
let postcode = state.data.postcode(row);
if let Some(&pc_idx) = postcode_data.postcode_to_idx.get(postcode) {
@ -154,6 +171,12 @@ pub async fn get_postcodes(
} else {
agg.add_row(feature_data, row, num_features, &quant);
}
if has_poi_fields {
poi_aggs
.entry(pc_idx)
.or_insert_with(|| PoiAggregator::new(poi_num_features))
.add_row_selective(poi_metrics, row, poi_field_indices);
}
}
});
@ -250,11 +273,12 @@ pub async fn get_postcodes(
]),
);
let iter: Box<dyn Iterator<Item = usize>> = if let Some(idx) = field_indices.as_ref() {
Box::new(idx.iter().copied())
} else {
Box::new(0..num_features)
};
let iter: Box<dyn Iterator<Item = usize>> =
if let Some(idx) = field_indices.normal.as_ref() {
Box::new(idx.iter().copied())
} else {
Box::new(0..num_features)
};
for feat_index in iter {
if aggregation.feat_counts[feat_index] > 0 {
@ -272,6 +296,25 @@ pub async fn get_postcodes(
}
}
if let Some(poi_aggregation) = poi_aggs.get(&pc_idx) {
for &metric_idx in poi_field_indices {
if poi_aggregation.counts[metric_idx] > 0 {
let avg = poi_aggregation.sums[metric_idx]
/ poi_aggregation.counts[metric_idx] as f64;
if let (Some(min_num), Some(max_num), Some(avg_num)) = (
serde_json::Number::from_f64(poi_aggregation.mins[metric_idx] as f64),
serde_json::Number::from_f64(poi_aggregation.maxs[metric_idx] as f64),
serde_json::Number::from_f64(avg),
) {
let name = &poi_metrics.feature_names[metric_idx];
props.insert(format!("min_{name}"), Value::Number(min_num));
props.insert(format!("max_{name}"), Value::Number(max_num));
props.insert(format!("avg_{name}"), Value::Number(avg_num));
}
}
}
}
// Add travel time aggregation fields
if let Some(tt_aggs) = travel_aggs.get(&pc_idx) {
for (ti, agg) in tt_aggs.iter().enumerate() {
@ -322,7 +365,11 @@ pub async fn get_postcodes(
bounds = format_args!("{:.6},{:.6},{:.6},{:.6}", south, west, north, east),
filters = num_filters,
filters_raw = filters_str.as_deref().unwrap_or("-"),
fields = field_indices.as_ref().map(|v| v.len() as i32).unwrap_or(-1),
fields = field_indices
.normal
.as_ref()
.map(|v| (v.len() + poi_field_indices.len()) as i32)
.unwrap_or(-1),
travel_entries = travel_entries.len(),
agg_ms = format_args!("{:.1}", t_agg.as_secs_f64() * 1000.0),
json_ms = format_args!("{:.1}", (t_total - t_agg).as_secs_f64() * 1000.0),

View file

@ -14,8 +14,8 @@ use crate::consts::{DEFAULT_PROPERTIES_LIMIT, MAX_PROPERTIES_LIMIT};
use crate::data::RenovationEvent;
use crate::licensing::{check_license_bounds, resolve_share_code};
use crate::parsing::{
cell_for_row_cached, h3_cell_bounds, needs_parent, parse_filters, row_passes_filters,
validate_h3_resolution,
cell_for_row_cached, h3_cell_bounds, needs_parent, parse_filters_with_poi, row_passes_filters,
row_passes_poi_filters, validate_h3_resolution,
};
use crate::state::{AppState, SharedState};
@ -117,6 +117,12 @@ pub fn build_property(
features.insert(feat_name.clone(), value);
}
}
for (metric_idx, metric_name) in state.data.poi_metrics.feature_names.iter().enumerate() {
let value = state.data.poi_metrics.get_for_property_row(row, metric_idx);
if value.is_finite() {
features.insert(metric_name.clone(), value);
}
}
Property {
address: non_empty_string(state.data.address(row)),
@ -199,15 +205,19 @@ pub async fn get_hexagon_properties(
let h3_str = params.h3;
let quant = state.data.quant_ref();
let (parsed_filters, parsed_enum_filters) = parse_filters(
let poi_quant = state.data.poi_metrics.quant_ref();
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
params.filters.as_deref(),
&state.feature_name_to_index,
&state.data.enum_values,
&quant,
&state.data.poi_metrics.name_to_index,
&poi_quant,
)
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
let filters_str = params.filters;
let has_poi_filters = !parsed_poi_filters.is_empty();
let travel_entries = parse_optional_travel(params.travel.as_deref())
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
@ -242,6 +252,12 @@ pub async fn get_hexagon_properties(
feature_data,
num_features,
)
&& (!has_poi_filters
|| row_passes_poi_filters(
row,
&parsed_poi_filters,
&state.data.poi_metrics,
))
{
if has_travel {
let postcode = state.data.postcode(row);

View file

@ -4,7 +4,7 @@ use rustc_hash::FxHashMap;
use tracing::warn;
use crate::consts::MAX_PRICE_HISTORY_POINTS;
use crate::data::{FeatureStats, PropertyData};
use crate::data::{FeatureStats, PostcodePoiMetrics, PropertyData};
use super::hexagon_stats::{EnumFeatureStats, HistogramStats, NumericFeatureStats, PricePoint};
@ -243,3 +243,80 @@ pub fn compute_feature_stats(
(numeric_features, enum_features_out)
}
pub fn compute_poi_feature_stats(
matching_rows: &[usize],
poi_metrics: &PostcodePoiMetrics,
fields_specified: bool,
field_set: &HashSet<String>,
) -> Vec<NumericFeatureStats> {
let mut out = Vec::new();
for (metric_idx, name) in poi_metrics.feature_names.iter().enumerate() {
if fields_specified && !field_set.contains(name.as_str()) {
continue;
}
let global_hist = &poi_metrics.feature_stats[metric_idx].histogram;
let p1 = global_hist.p1;
let p99 = global_hist.p99;
let num_bins = global_hist.counts.len();
let middle_bins = num_bins.saturating_sub(2);
let middle_width = if middle_bins > 0 && p99 > p1 {
(p99 - p1) / middle_bins as f32
} else {
0.0
};
let mut count = 0usize;
let mut min_value = f32::INFINITY;
let mut max_value = f32::NEG_INFINITY;
let mut sum = 0.0f64;
let mut bins = vec![0u64; num_bins];
for &row in matching_rows {
let value = poi_metrics.get_for_property_row(row, metric_idx);
if !value.is_finite() {
continue;
}
count += 1;
if value < min_value {
min_value = value;
}
if value > max_value {
max_value = value;
}
sum += value as f64;
let bin = if value < p1 {
0
} else if value >= p99 {
num_bins - 1
} else if middle_width > 0.0 {
let middle_bin = ((value - p1) / middle_width) as usize;
(1 + middle_bin).min(num_bins - 2)
} else {
num_bins / 2
};
bins[bin] += 1;
}
if count > 0 {
out.push(NumericFeatureStats {
name: name.clone(),
count,
min: min_value as f64,
max: max_value as f64,
mean: sum / count as f64,
histogram: HistogramStats {
min: global_hist.min as f64,
max: global_hist.max as f64,
p1: p1 as f64,
p99: p99 as f64,
counts: bins,
},
});
}
}
out
}

View file

@ -1,78 +1,40 @@
use std::collections::VecDeque;
use std::sync::{Arc, LazyLock};
use std::sync::Arc;
use axum::body::Bytes;
use axum::extract::State;
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use hmac::{Hmac, Mac};
use parking_lot::Mutex;
use rustc_hash::FxHashSet;
use hmac::{Hmac, KeyInit, Mac};
use sha2::Sha256;
use tracing::{info, warn};
use crate::pocketbase::get_superuser_token;
use crate::checkout_sessions::{
grant_license, mark_checkout_completed, mark_referral_invite_used, verify_checkout_completion,
CheckoutCompletion,
};
use crate::state::SharedState;
type HmacSha256 = Hmac<Sha256>;
/// Process-local LRU of recently processed Stripe event IDs.
/// Stripe retries deliver the same event ID; we drop duplicates so we don't
/// re-run side effects (subscription writes, token cache invalidation, logs).
/// Capacity is intentionally generous: at typical webhook volumes this covers
/// far more than Stripe's retry window.
struct EventDedup {
seen: FxHashSet<String>,
queue: VecDeque<String>,
capacity: usize,
}
impl EventDedup {
fn new(capacity: usize) -> Self {
Self {
seen: FxHashSet::default(),
queue: VecDeque::with_capacity(capacity),
capacity,
}
}
/// Returns `true` if this event ID is new (and records it),
/// `false` if it was already seen recently.
fn check_and_insert(&mut self, id: &str) -> bool {
if self.seen.contains(id) {
return false;
}
self.seen.insert(id.to_string());
self.queue.push_back(id.to_string());
if self.queue.len() > self.capacity {
if let Some(old) = self.queue.pop_front() {
self.seen.remove(&old);
}
}
true
}
}
static EVENT_DEDUP: LazyLock<Mutex<EventDedup>> =
LazyLock::new(|| Mutex::new(EventDedup::new(1024)));
/// Verify Stripe webhook signature (v1 scheme).
fn verify_signature(payload: &[u8], sig_header: &str, secret: &str) -> bool {
// Parse timestamp and signature from header: "t=TIMESTAMP,v1=SIGNATURE"
let mut timestamp = None;
let mut signature = None;
let mut signatures = Vec::new();
for part in sig_header.split(',') {
if let Some(ts) = part.strip_prefix("t=") {
timestamp = Some(ts);
} else if let Some(sig) = part.strip_prefix("v1=") {
signature = Some(sig);
signatures.push(sig);
}
}
let (ts, sig_hex) = match (timestamp, signature) {
(Some(t), Some(s)) => (t, s),
_ => return false,
let Some(ts) = timestamp else {
return false;
};
if signatures.is_empty() {
return false;
}
// Reject webhooks older than 5 minutes to prevent replay attacks
if let Ok(ts_secs) = ts.parse::<i64>() {
@ -87,20 +49,21 @@ fn verify_signature(payload: &[u8], sig_header: &str, secret: &str) -> bool {
return false;
}
// Compute expected signature: HMAC-SHA256(secret, "TIMESTAMP.PAYLOAD")
let signed_payload = format!("{ts}.{}", String::from_utf8_lossy(payload));
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) {
Ok(m) => m,
Err(_) => return false,
};
mac.update(signed_payload.as_bytes());
let mut signed_payload = Vec::with_capacity(ts.len() + 1 + payload.len());
signed_payload.extend_from_slice(ts.as_bytes());
signed_payload.push(b'.');
signed_payload.extend_from_slice(payload);
// Decode the provided hex signature and verify with constant-time comparison
let sig_bytes = match hex::decode(sig_hex) {
Ok(bytes) => bytes,
Err(_) => return false,
};
mac.verify_slice(&sig_bytes).is_ok()
signatures.into_iter().any(|sig_hex| {
let Ok(sig_bytes) = hex::decode(sig_hex) else {
return false;
};
let Ok(mut mac) = HmacSha256::new_from_slice(secret.as_bytes()) else {
return false;
};
mac.update(&signed_payload);
mac.verify_slice(&sig_bytes).is_ok()
})
}
/// Handle Stripe webhook events.
@ -140,65 +103,64 @@ pub async fn post_stripe_webhook(
let event_type = event["type"].as_str().unwrap_or("");
let event_id = event["id"].as_str().unwrap_or("");
// Idempotency: drop replays/retries of an already-processed event.
// We always answer 200 so Stripe stops retrying.
if !event_id.is_empty() && !EVENT_DEDUP.lock().check_and_insert(event_id) {
info!(event_id, event_type, "Dropping duplicate Stripe webhook");
return StatusCode::OK.into_response();
}
info!(event_id, event_type, "Received Stripe webhook");
if event_type == "checkout.session.completed" {
let user_id = event["data"]["object"]["client_reference_id"]
.as_str()
.unwrap_or("");
if user_id.is_empty() {
warn!("checkout.session.completed missing client_reference_id");
return StatusCode::OK.into_response();
}
if !user_id.bytes().all(|b| b.is_ascii_alphanumeric()) || user_id.len() > 20 {
warn!(user_id, "Invalid client_reference_id format in webhook");
return StatusCode::BAD_REQUEST.into_response();
}
// Update user subscription to "licensed" via PocketBase superuser auth
let token = match get_superuser_token(&state).await {
Ok(t) => t,
Err(err) => {
warn!("Failed to auth as PocketBase superuser in webhook: {err}");
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
let res = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({ "subscription": "licensed" }))
.send()
.await;
match res {
Ok(resp) if resp.status().is_success() => {
state.token_cache.invalidate_by_user_id(user_id);
let session = &event["data"]["object"];
match verify_checkout_completion(&state, session).await {
Ok(CheckoutCompletion::Grant(checkout)) => {
if let Err(err) = mark_referral_invite_used(
&state,
&checkout.referral_invite_id,
&checkout.user_id,
)
.await
{
warn!(
user_id = %checkout.user_id,
reservation_id = %checkout.reservation_id,
referral_invite_id = %checkout.referral_invite_id,
"Failed to mark referral invite used after Stripe checkout: {err:?}"
);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
if let Err(err) = grant_license(&state, &checkout.user_id).await {
warn!(
user_id = %checkout.user_id,
reservation_id = %checkout.reservation_id,
"Failed to grant license after Stripe checkout: {err:?}"
);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
if let Err(err) = mark_checkout_completed(
&state,
&checkout.reservation_id,
checkout.paid_amount_pence,
)
.await
{
warn!(
user_id = %checkout.user_id,
reservation_id = %checkout.reservation_id,
"Failed to mark checkout completed after license grant: {err:?}"
);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
info!(
user_id,
"User subscription updated to licensed via Stripe webhook"
user_id = %checkout.user_id,
reservation_id = %checkout.reservation_id,
"User subscription updated to licensed via verified Stripe checkout"
);
}
Ok(resp) => {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
warn!(
user_id,
"Failed to update user subscription ({status}): {text}"
);
Ok(CheckoutCompletion::AlreadyHandled) => {
info!("Stripe checkout session was already handled");
}
Ok(CheckoutCompletion::Rejected(reason)) => {
warn!("Rejecting Stripe checkout completion: {reason}");
}
Err(err) => {
warn!(user_id, "PocketBase request error in webhook: {err}");
warn!("Failed to verify Stripe checkout completion: {err:?}");
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
}
}