Rust things
This commit is contained in:
parent
fc10381692
commit
3debacab4f
30 changed files with 3257 additions and 647 deletions
|
|
@ -8,10 +8,8 @@ use serde::{Deserialize, Serialize};
|
|||
use tracing::{info, warn};
|
||||
|
||||
use crate::auth::OptionalUser;
|
||||
use crate::pocketbase::get_superuser_token;
|
||||
use crate::state::{AppState, SharedState};
|
||||
|
||||
use super::pricing::{count_licensed_users, price_for_count};
|
||||
use crate::checkout_sessions::{start_license_checkout, CheckoutStart};
|
||||
use crate::state::SharedState;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct CheckoutRequest {
|
||||
|
|
@ -23,8 +21,8 @@ struct CheckoutResponse {
|
|||
url: String,
|
||||
}
|
||||
|
||||
/// Create a Stripe Checkout session for the lifetime license (or grant for free if in free tier).
|
||||
/// Requires authentication. Optionally accepts a referral code to apply a coupon.
|
||||
/// Create a reserved Stripe Checkout session for the lifetime license.
|
||||
/// Requires authentication. Referral discounts are issued via invite redemption.
|
||||
pub async fn post_checkout(
|
||||
State(shared): State<Arc<SharedState>>,
|
||||
Extension(user): Extension<OptionalUser>,
|
||||
|
|
@ -36,147 +34,27 @@ pub async fn post_checkout(
|
|||
None => return StatusCode::UNAUTHORIZED.into_response(),
|
||||
};
|
||||
|
||||
let count = match count_licensed_users(&state).await {
|
||||
Ok(c) => c,
|
||||
Err(err) => {
|
||||
warn!("Failed to count licensed users at checkout: {err}");
|
||||
return StatusCode::SERVICE_UNAVAILABLE.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let price_pence = price_for_count(count);
|
||||
let public_url = &state.public_url;
|
||||
let success_url = format!("{public_url}/pricing?license_success=1");
|
||||
|
||||
// Free tier — grant license directly without Stripe
|
||||
if price_pence == 0 {
|
||||
if let Err(err) = grant_license(&state, &user.id).await {
|
||||
warn!(user_id = %user.id, "Failed to grant free license: {err}");
|
||||
return StatusCode::BAD_GATEWAY.into_response();
|
||||
}
|
||||
info!(user_id = %user.id, "Granted free early-bird license");
|
||||
return Json(CheckoutResponse { url: success_url }).into_response();
|
||||
}
|
||||
|
||||
// Paid tier — create Stripe checkout with dynamic price
|
||||
let secret_key = &state.stripe_secret_key;
|
||||
let cancel_url = format!("{public_url}/pricing");
|
||||
|
||||
let mut form_params = vec![
|
||||
("mode", "payment".to_string()),
|
||||
(
|
||||
"line_items[0][price_data][unit_amount]",
|
||||
price_pence.to_string(),
|
||||
),
|
||||
("line_items[0][price_data][currency]", "gbp".to_string()),
|
||||
(
|
||||
"line_items[0][price_data][product_data][name]",
|
||||
"Perfect Postcodes Lifetime License".to_string(),
|
||||
),
|
||||
("line_items[0][quantity]", "1".to_string()),
|
||||
("success_url", success_url),
|
||||
("cancel_url", cancel_url),
|
||||
("client_reference_id", user.id.clone()),
|
||||
("customer_email", user.email.clone()),
|
||||
];
|
||||
|
||||
// If a referral code is provided and valid, look it up and apply the coupon
|
||||
if let Some(ref code) = req.referral_code {
|
||||
if validate_referral_invite(&state, code).await {
|
||||
form_params.push((
|
||||
"discounts[0][coupon]",
|
||||
state.stripe_referral_coupon_id.clone(),
|
||||
));
|
||||
info!(code = %code, "Applying referral coupon to checkout");
|
||||
} else {
|
||||
warn!(code = %code, "Referral code validation failed, proceeding without discount");
|
||||
}
|
||||
if req.referral_code.is_some() {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Referral codes must be redeemed from the invite link",
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let res = state
|
||||
.http_client
|
||||
.post("https://api.stripe.com/v1/checkout/sessions")
|
||||
.basic_auth(secret_key, None::<&str>)
|
||||
.form(&form_params)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let body: serde_json::Value = match resp.json().await {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
warn!("Failed to parse Stripe response: {err}");
|
||||
return StatusCode::BAD_GATEWAY.into_response();
|
||||
}
|
||||
};
|
||||
let url = body["url"].as_str().unwrap_or_default().to_string();
|
||||
if url.is_empty() {
|
||||
warn!("Stripe session missing URL");
|
||||
return StatusCode::BAD_GATEWAY.into_response();
|
||||
}
|
||||
info!(user_id = %user.id, price_pence, "Created Stripe checkout session");
|
||||
Json(CheckoutResponse { url }).into_response()
|
||||
}
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
warn!("Stripe checkout failed ({status}): {text}");
|
||||
StatusCode::BAD_GATEWAY.into_response()
|
||||
match start_license_checkout(&state, &user, &success_url, &cancel_url, None, None).await {
|
||||
Ok(CheckoutStart::Free) => {
|
||||
info!(user_id = %user.id, "Granted free early-bird license");
|
||||
Json(CheckoutResponse { url: success_url }).into_response()
|
||||
}
|
||||
Ok(CheckoutStart::Stripe { url }) => Json(CheckoutResponse { url }).into_response(),
|
||||
Err(err) => {
|
||||
warn!("Stripe request error: {err}");
|
||||
warn!(user_id = %user.id, "Failed to start checkout: {err:?}");
|
||||
StatusCode::BAD_GATEWAY.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Grant a license by updating the user's subscription to "licensed" in PocketBase.
|
||||
async fn grant_license(state: &AppState, user_id: &str) -> anyhow::Result<()> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({ "subscription": "licensed" }))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("PocketBase update failed ({status}): {text}");
|
||||
}
|
||||
|
||||
state.token_cache.invalidate_by_user_id(user_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a referral invite code exists and is unused.
|
||||
async fn validate_referral_invite(state: &AppState, code: &str) -> bool {
|
||||
// Only allow alphanumeric codes to prevent PocketBase filter injection
|
||||
if code.is_empty() || code.len() > 20 || !code.bytes().all(|b| b.is_ascii_alphanumeric()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let filter = format!(
|
||||
"code=\"{}\" && invite_type=\"referral\" && used_by_id=\"\"",
|
||||
code
|
||||
);
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/invites/records?filter={}&perPage=1",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
|
||||
match state.http_client.get(&url).send().await {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let body: serde_json::Value = resp.json().await.unwrap_or_default();
|
||||
body["totalItems"].as_u64().unwrap_or(0) > 0
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::extract::{Query, State};
|
||||
use axum::http::{header, HeaderMap, StatusCode};
|
||||
|
|
@ -13,14 +14,18 @@ use tracing::{info, warn};
|
|||
|
||||
use crate::auth::OptionalUser;
|
||||
use crate::consts::NAN_U16;
|
||||
use crate::data::QuantRef;
|
||||
use crate::features::INTEGER_BIN_FEATURES;
|
||||
use crate::data::{PostcodePoiMetrics, QuantRef};
|
||||
use crate::features;
|
||||
use crate::licensing::check_license_bounds;
|
||||
use crate::parsing::{parse_field_indices, parse_filters, require_bounds, row_passes_filters};
|
||||
use crate::parsing::{
|
||||
parse_field_indices_with_poi, parse_filters_with_poi, require_bounds, row_passes_filters,
|
||||
row_passes_poi_filters,
|
||||
};
|
||||
use crate::routes::{fetch_screenshot_bytes, FeatureInfo};
|
||||
use crate::state::SharedState;
|
||||
|
||||
const MAX_EXPORT_POSTCODES: usize = 250;
|
||||
const EXPORT_SCREENSHOT_TIMEOUT_SECS: u64 = 12;
|
||||
/// Height (in pixels) reserved for the screenshot row
|
||||
const IMAGE_ROW_HEIGHT: f64 = 225.0;
|
||||
|
||||
|
|
@ -41,11 +46,11 @@ struct PostcodeExportAgg {
|
|||
}
|
||||
|
||||
impl PostcodeExportAgg {
|
||||
fn new(num_features: usize) -> Self {
|
||||
fn new(total_features: usize) -> Self {
|
||||
Self {
|
||||
count: 0,
|
||||
sums: vec![0.0; num_features],
|
||||
finite_counts: vec![0; num_features],
|
||||
sums: vec![0.0; total_features],
|
||||
finite_counts: vec![0; total_features],
|
||||
enum_freqs: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
|
|
@ -58,6 +63,7 @@ impl PostcodeExportAgg {
|
|||
num_features: usize,
|
||||
enum_indices: &FxHashMap<usize, ()>,
|
||||
quant: &QuantRef,
|
||||
poi_metrics: &PostcodePoiMetrics,
|
||||
) {
|
||||
self.count += 1;
|
||||
let base = row * num_features;
|
||||
|
|
@ -79,6 +85,18 @@ impl PostcodeExportAgg {
|
|||
self.finite_counts[feat_idx] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let poi_offset = num_features;
|
||||
for metric_idx in 0..poi_metrics.num_features() {
|
||||
let raw = poi_metrics.raw_for_property_row(row, metric_idx);
|
||||
if raw == NAN_U16 {
|
||||
continue;
|
||||
}
|
||||
let value = poi_metrics.decode_raw(metric_idx, raw);
|
||||
let out_idx = poi_offset + metric_idx;
|
||||
self.sums[out_idx] += value as f64;
|
||||
self.finite_counts[out_idx] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -138,13 +156,17 @@ pub async fn get_export(
|
|||
check_license_bounds(&user.0, (south, west, north, east), None)?;
|
||||
|
||||
let quant = state.data.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters) = parse_filters(
|
||||
let poi_quant = state.data.poi_metrics.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
|
||||
params.filters.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
&quant,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
&poi_quant,
|
||||
)
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
let has_poi_filters = !parsed_poi_filters.is_empty();
|
||||
let filters_str = params.filters;
|
||||
let fields_str = params.fields;
|
||||
|
||||
|
|
@ -164,16 +186,28 @@ pub async fn get_export(
|
|||
|
||||
// Fetch screenshot (async, before spawn_blocking)
|
||||
let auth_header = headers.get(header::AUTHORIZATION);
|
||||
let screenshot_bytes = match fetch_screenshot_bytes(&state, &frontend_params, auth_header).await
|
||||
let screenshot_fetch = fetch_screenshot_bytes(&state, &frontend_params, auth_header);
|
||||
let screenshot_bytes = match tokio::time::timeout(
|
||||
Duration::from_secs(EXPORT_SCREENSHOT_TIMEOUT_SECS),
|
||||
screenshot_fetch,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(bytes) => {
|
||||
Ok(Ok(bytes)) => {
|
||||
info!(bytes = bytes.len(), "Fetched screenshot for export");
|
||||
Some(bytes)
|
||||
}
|
||||
Err(err) => {
|
||||
Ok(Err(err)) => {
|
||||
warn!("Screenshot failed for export: {err}");
|
||||
None
|
||||
}
|
||||
Err(_) => {
|
||||
warn!(
|
||||
timeout_secs = EXPORT_SCREENSHOT_TIMEOUT_SECS,
|
||||
"Screenshot timed out for export"
|
||||
);
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
// Build feature name → description map from the precomputed features response
|
||||
|
|
@ -200,6 +234,9 @@ pub async fn get_export(
|
|||
let feature_names = &state.data.feature_names;
|
||||
let enum_values = &state.data.enum_values;
|
||||
let postcode_data = &state.postcode_data;
|
||||
let poi_metrics = &state.data.poi_metrics;
|
||||
let poi_offset = num_features;
|
||||
let total_export_features = num_features + poi_metrics.num_features();
|
||||
|
||||
// Build set of enum feature indices for quick lookup
|
||||
let enum_indices: FxHashMap<usize, ()> = enum_values.keys().map(|&idx| (idx, ())).collect();
|
||||
|
|
@ -219,6 +256,10 @@ pub async fn get_export(
|
|||
) {
|
||||
return;
|
||||
}
|
||||
if has_poi_filters && !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics)
|
||||
{
|
||||
return;
|
||||
}
|
||||
let postcode = state.data.postcode(row);
|
||||
if let Some(&pc_idx) = postcode_data.postcode_to_idx.get(postcode) {
|
||||
postcode_rows.entry(pc_idx).or_default().push(row);
|
||||
|
|
@ -229,9 +270,16 @@ pub async fn get_export(
|
|||
let mut postcode_aggs: Vec<(usize, PostcodeExportAgg)> =
|
||||
Vec::with_capacity(postcode_rows.len());
|
||||
for (pc_idx, rows) in postcode_rows {
|
||||
let mut agg = PostcodeExportAgg::new(num_features);
|
||||
let mut agg = PostcodeExportAgg::new(total_export_features);
|
||||
for &row in &rows {
|
||||
agg.add_row(feature_data, row, num_features, &enum_indices, &quant);
|
||||
agg.add_row(
|
||||
feature_data,
|
||||
row,
|
||||
num_features,
|
||||
&enum_indices,
|
||||
&quant,
|
||||
poi_metrics,
|
||||
);
|
||||
}
|
||||
if agg.count > 0 {
|
||||
postcode_aggs.push((pc_idx, agg));
|
||||
|
|
@ -265,14 +313,19 @@ pub async fn get_export(
|
|||
// Determine column order: filter features first, then remaining
|
||||
let filter_feature_names = extract_filter_feature_names(filters_str.as_deref());
|
||||
|
||||
let field_indices =
|
||||
parse_field_indices(fields_str.as_deref(), &state.feature_name_to_index)
|
||||
.map_err(|err| err.1)?;
|
||||
let field_indices = parse_field_indices_with_poi(
|
||||
fields_str.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
)
|
||||
.map_err(|err| err.1)?;
|
||||
|
||||
let all_feature_indices: Vec<usize> = if let Some(ref indices) = field_indices {
|
||||
indices.clone()
|
||||
let all_feature_indices: Vec<usize> = if let Some(ref indices) = field_indices.normal {
|
||||
let mut selected = indices.clone();
|
||||
selected.extend(field_indices.poi.iter().map(|idx| poi_offset + *idx));
|
||||
selected
|
||||
} else {
|
||||
let mut ordered = Vec::with_capacity(num_features);
|
||||
let mut ordered = Vec::with_capacity(total_export_features);
|
||||
let mut used = FxHashSet::default();
|
||||
|
||||
for name in &filter_feature_names {
|
||||
|
|
@ -280,6 +333,11 @@ pub async fn get_export(
|
|||
if used.insert(idx) {
|
||||
ordered.push(idx);
|
||||
}
|
||||
} else if let Some(&idx) = state.data.poi_metrics.name_to_index.get(name.as_str()) {
|
||||
let virtual_idx = poi_offset + idx;
|
||||
if used.insert(virtual_idx) {
|
||||
ordered.push(virtual_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
for idx in 0..num_features {
|
||||
|
|
@ -287,15 +345,42 @@ pub async fn get_export(
|
|||
ordered.push(idx);
|
||||
}
|
||||
}
|
||||
for idx in 0..poi_metrics.num_features() {
|
||||
let virtual_idx = poi_offset + idx;
|
||||
if used.insert(virtual_idx) {
|
||||
ordered.push(virtual_idx);
|
||||
}
|
||||
}
|
||||
ordered
|
||||
};
|
||||
|
||||
// Filter-only feature indices for the Selected sheet
|
||||
let filter_feature_indices: Vec<usize> = filter_feature_names
|
||||
.iter()
|
||||
.filter_map(|name| state.feature_name_to_index.get(name.as_str()).copied())
|
||||
.filter_map(|name| {
|
||||
state
|
||||
.feature_name_to_index
|
||||
.get(name.as_str())
|
||||
.copied()
|
||||
.or_else(|| {
|
||||
state
|
||||
.data
|
||||
.poi_metrics
|
||||
.name_to_index
|
||||
.get(name.as_str())
|
||||
.map(|idx| poi_offset + *idx)
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let feature_name_for_idx = |idx: usize| -> &str {
|
||||
if idx < num_features {
|
||||
&feature_names[idx]
|
||||
} else {
|
||||
&poi_metrics.feature_names[idx - poi_offset]
|
||||
}
|
||||
};
|
||||
|
||||
// Build feature unit map (feat_idx → (prefix, suffix)) for number formatting
|
||||
let feature_units: FxHashMap<usize, (&str, &str)> = state
|
||||
.features_response
|
||||
|
|
@ -309,16 +394,25 @@ pub async fn get_export(
|
|||
suffix,
|
||||
..
|
||||
} => {
|
||||
let idx = state.feature_name_to_index.get(name.as_str())?;
|
||||
Some((*idx, (*prefix, *suffix)))
|
||||
if let Some(&idx) = state.feature_name_to_index.get(name.as_str()) {
|
||||
Some((idx, (*prefix, *suffix)))
|
||||
} else {
|
||||
state
|
||||
.data
|
||||
.poi_metrics
|
||||
.name_to_index
|
||||
.get(name.as_str())
|
||||
.map(|idx| (poi_offset + *idx, (*prefix, *suffix)))
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let integer_feature_indices: FxHashSet<usize> = INTEGER_BIN_FEATURES
|
||||
let integer_feature_indices: FxHashSet<usize> = all_feature_indices
|
||||
.iter()
|
||||
.filter_map(|name| state.feature_name_to_index.get(*name).copied())
|
||||
.copied()
|
||||
.filter(|&idx| features::has_integer_bins(feature_name_for_idx(idx)))
|
||||
.collect();
|
||||
|
||||
// Build Excel number formats per feature index for unit display
|
||||
|
|
@ -435,7 +529,7 @@ pub async fn get_export(
|
|||
.write_string_with_format(
|
||||
header_row,
|
||||
col,
|
||||
&feature_names[feat_idx],
|
||||
feature_name_for_idx(feat_idx),
|
||||
&header_fmt,
|
||||
)
|
||||
.map_err(|e| format!("Failed to write header: {e}"))?;
|
||||
|
|
@ -453,7 +547,7 @@ pub async fn get_export(
|
|||
for (col_offset, &feat_idx) in feat_indices.iter().enumerate() {
|
||||
let col = (col_offset + 2) as u16;
|
||||
let desc = feature_descriptions
|
||||
.get(&feature_names[feat_idx])
|
||||
.get(feature_name_for_idx(feat_idx))
|
||||
.map(String::as_str)
|
||||
.unwrap_or("");
|
||||
sheet
|
||||
|
|
@ -477,7 +571,7 @@ pub async fn get_export(
|
|||
for (col_offset, &feat_idx) in feat_indices.iter().enumerate() {
|
||||
let col = (col_offset + 2) as u16;
|
||||
|
||||
if enum_indices.contains_key(&feat_idx) {
|
||||
if feat_idx < num_features && enum_indices.contains_key(&feat_idx) {
|
||||
if let Some(freqs) = agg.enum_freqs.get(&feat_idx) {
|
||||
if let Some((&mode_bits, _)) =
|
||||
freqs.iter().max_by_key(|(_, &count)| count)
|
||||
|
|
@ -543,7 +637,7 @@ pub async fn get_export(
|
|||
.map_err(|e| format!("Failed to set column width: {e}"))?;
|
||||
for col_offset in 0..feat_indices.len() {
|
||||
let col = (col_offset + 2) as u16;
|
||||
let feat_name = &feature_names[feat_indices[col_offset]];
|
||||
let feat_name = feature_name_for_idx(feat_indices[col_offset]);
|
||||
let width = (feat_name.len() as f64 * 1.1).clamp(10.0, 30.0);
|
||||
sheet
|
||||
.set_column_width(col, width)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ use serde::Serialize;
|
|||
use tracing::info;
|
||||
|
||||
use crate::data::{Histogram, PropertyData};
|
||||
use crate::features::{Feature, FEATURE_GROUPS};
|
||||
use crate::features::{self, Feature, FEATURE_GROUPS};
|
||||
use crate::state::SharedState;
|
||||
|
||||
fn is_empty(val: &str) -> bool {
|
||||
|
|
@ -28,9 +28,9 @@ pub enum FeatureInfo {
|
|||
max: f32,
|
||||
step: f32,
|
||||
histogram: Histogram,
|
||||
description: &'static str,
|
||||
detail: &'static str,
|
||||
source: &'static str,
|
||||
description: String,
|
||||
detail: String,
|
||||
source: String,
|
||||
#[serde(skip_serializing_if = "is_empty")]
|
||||
prefix: &'static str,
|
||||
#[serde(skip_serializing_if = "is_empty")]
|
||||
|
|
@ -45,9 +45,9 @@ pub enum FeatureInfo {
|
|||
name: String,
|
||||
values: Vec<String>,
|
||||
counts: HashMap<String, u64>,
|
||||
description: &'static str,
|
||||
detail: &'static str,
|
||||
source: &'static str,
|
||||
description: String,
|
||||
detail: String,
|
||||
source: String,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -85,9 +85,9 @@ pub fn build_features_response(data: &PropertyData) -> FeaturesResponse {
|
|||
max: stats.slider_max,
|
||||
step: config.step,
|
||||
histogram: stats.histogram.clone(),
|
||||
description: config.description,
|
||||
detail: config.detail,
|
||||
source: config.source,
|
||||
description: config.description.to_string(),
|
||||
detail: config.detail.to_string(),
|
||||
source: config.source.to_string(),
|
||||
prefix: config.prefix,
|
||||
suffix: config.suffix,
|
||||
raw: config.raw,
|
||||
|
|
@ -118,9 +118,9 @@ pub fn build_features_response(data: &PropertyData) -> FeaturesResponse {
|
|||
name: config.name.to_string(),
|
||||
values: values.clone(),
|
||||
counts,
|
||||
description: config.description,
|
||||
detail: config.detail,
|
||||
source: config.source,
|
||||
description: config.description.to_string(),
|
||||
detail: config.detail.to_string(),
|
||||
source: config.source.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -136,6 +136,58 @@ pub fn build_features_response(data: &PropertyData) -> FeaturesResponse {
|
|||
}
|
||||
}
|
||||
|
||||
let mut dynamic_poi_features = Vec::new();
|
||||
for (feat_idx, name) in data.poi_metrics.feature_names.iter().enumerate() {
|
||||
if let Some(category) = features::dynamic_poi_distance_category(name) {
|
||||
let stats = &data.poi_metrics.feature_stats[feat_idx];
|
||||
dynamic_poi_features.push(FeatureInfo::Numeric {
|
||||
name: name.clone(),
|
||||
min: stats.slider_min,
|
||||
max: stats.slider_max,
|
||||
step: 0.1,
|
||||
histogram: stats.histogram.clone(),
|
||||
description: format!("Distance to the closest {category} POI"),
|
||||
detail: format!(
|
||||
"Straight-line distance in kilometres from the postcode to the nearest {category} point of interest in the POI dataset."
|
||||
),
|
||||
source: "osm-pois".to_string(),
|
||||
prefix: "",
|
||||
suffix: " km",
|
||||
raw: false,
|
||||
absolute: false,
|
||||
});
|
||||
} else if let Some(category) = features::dynamic_poi_count_category(name) {
|
||||
let stats = &data.poi_metrics.feature_stats[feat_idx];
|
||||
let radius = features::dynamic_poi_count_radius(name).unwrap_or(0);
|
||||
dynamic_poi_features.push(FeatureInfo::Numeric {
|
||||
name: name.clone(),
|
||||
min: stats.slider_min,
|
||||
max: stats.slider_max,
|
||||
step: 1.0,
|
||||
histogram: stats.histogram.clone(),
|
||||
description: format!("Number of {category} POIs within {radius}km"),
|
||||
detail: format!(
|
||||
"Count of {category} points of interest within a {radius}km radius of the property's postcode centroid."
|
||||
),
|
||||
source: "osm-pois".to_string(),
|
||||
prefix: "",
|
||||
suffix: "",
|
||||
raw: false,
|
||||
absolute: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
if !dynamic_poi_features.is_empty() {
|
||||
dynamic_poi_features.sort_by_key(|feature| match feature {
|
||||
FeatureInfo::Numeric { name, .. } => features::dynamic_poi_feature_sort_key(name),
|
||||
FeatureInfo::Enum { name, .. } => features::dynamic_poi_feature_sort_key(name),
|
||||
});
|
||||
groups.push(FeatureGroupResponse {
|
||||
name: "Nearby POIs".to_string(),
|
||||
features: dynamic_poi_features,
|
||||
});
|
||||
}
|
||||
|
||||
FeaturesResponse { groups }
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ use tracing::info;
|
|||
|
||||
use crate::consts::NAN_U16;
|
||||
use crate::data::travel_time::TravelData;
|
||||
use crate::parsing::{parse_filters, require_bounds};
|
||||
use crate::parsing::{parse_filters_with_poi, require_bounds};
|
||||
use crate::routes::travel_time::parse_optional_travel;
|
||||
use crate::state::SharedState;
|
||||
|
||||
|
|
@ -36,18 +36,21 @@ pub async fn get_filter_counts(
|
|||
require_bounds(params.bounds).map_err(IntoResponse::into_response)?;
|
||||
|
||||
let quant = state.data.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters) = parse_filters(
|
||||
let poi_quant = state.data.poi_metrics.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
|
||||
params.filters.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
&quant,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
&poi_quant,
|
||||
)
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
|
||||
let travel_entries = parse_optional_travel(params.travel.as_deref())
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
|
||||
let num_regular = parsed_filters.len() + parsed_enum_filters.len();
|
||||
let num_regular = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
|
||||
// Only travel entries with a filter range count as filters for impact tracking
|
||||
let travel_filter_indices: Vec<usize> = travel_entries
|
||||
.iter()
|
||||
|
|
@ -65,6 +68,7 @@ pub async fn get_filter_counts(
|
|||
}
|
||||
|
||||
let filters_str = params.filters;
|
||||
let has_poi_filters = !parsed_poi_filters.is_empty();
|
||||
|
||||
let response = tokio::task::spawn_blocking(move || -> Result<FilterCountsResponse, String> {
|
||||
let t0 = std::time::Instant::now();
|
||||
|
|
@ -124,6 +128,23 @@ pub async fn get_filter_counts(
|
|||
}
|
||||
}
|
||||
|
||||
// Test travel time filters
|
||||
if fail_count <= 1 && has_poi_filters {
|
||||
for (i, f) in parsed_poi_filters.iter().enumerate() {
|
||||
let raw = state
|
||||
.data
|
||||
.poi_metrics
|
||||
.raw_for_property_row(row, f.metric_idx);
|
||||
if raw == NAN_U16 || raw < f.min_u16 || raw > f.max_u16 {
|
||||
fail_count += 1;
|
||||
fail_index = parsed_filters.len() + parsed_enum_filters.len() + i;
|
||||
if fail_count > 1 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test travel time filters
|
||||
if fail_count <= 1 && has_travel {
|
||||
let postcode = pc_interner.resolve(&pc_keys[row]);
|
||||
|
|
@ -169,8 +190,15 @@ pub async fn get_filter_counts(
|
|||
let name = if i < parsed_filters.len() {
|
||||
state.data.feature_names[parsed_filters[i].feat_idx].clone()
|
||||
} else if i < num_regular {
|
||||
let ei = i - parsed_filters.len();
|
||||
state.data.feature_names[parsed_enum_filters[ei].feat_idx].clone()
|
||||
let enum_start = parsed_filters.len();
|
||||
let poi_start = enum_start + parsed_enum_filters.len();
|
||||
if i < poi_start {
|
||||
let ei = i - enum_start;
|
||||
state.data.feature_names[parsed_enum_filters[ei].feat_idx].clone()
|
||||
} else {
|
||||
let pi = i - poi_start;
|
||||
state.data.poi_metrics.feature_names[parsed_poi_filters[pi].metric_idx].clone()
|
||||
}
|
||||
} else {
|
||||
let slot = i - num_regular;
|
||||
let ti = travel_filter_indices[slot];
|
||||
|
|
|
|||
|
|
@ -13,8 +13,8 @@ use tracing::{info, warn};
|
|||
use crate::auth::OptionalUser;
|
||||
use crate::licensing::{check_license_bounds, resolve_share_code};
|
||||
use crate::parsing::{
|
||||
cell_for_row_cached, h3_cell_bounds, needs_parent, parse_field_set, parse_filters,
|
||||
row_passes_filters, validate_h3_resolution,
|
||||
cell_for_row_cached, h3_cell_bounds, needs_parent, parse_field_set, parse_filters_with_poi,
|
||||
row_passes_filters, row_passes_poi_filters, validate_h3_resolution,
|
||||
};
|
||||
use crate::state::SharedState;
|
||||
|
||||
|
|
@ -110,15 +110,19 @@ pub async fn get_hexagon_stats(
|
|||
|
||||
let h3_str = params.h3;
|
||||
let quant = state.data.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters) = parse_filters(
|
||||
let poi_quant = state.data.poi_metrics.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
|
||||
params.filters.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
&quant,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
&poi_quant,
|
||||
)
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
|
||||
let filters_str = params.filters;
|
||||
let has_poi_filters = !parsed_poi_filters.is_empty();
|
||||
|
||||
let (fields_specified, field_set) = parse_field_set(params.fields.as_deref());
|
||||
|
||||
|
|
@ -161,6 +165,12 @@ pub async fn get_hexagon_stats(
|
|||
feature_data,
|
||||
num_features,
|
||||
)
|
||||
&& (!has_poi_filters
|
||||
|| row_passes_poi_filters(
|
||||
row,
|
||||
&parsed_poi_filters,
|
||||
&state.data.poi_metrics,
|
||||
))
|
||||
{
|
||||
if has_travel {
|
||||
let postcode = state.data.postcode(row);
|
||||
|
|
@ -233,7 +243,7 @@ pub async fn get_hexagon_stats(
|
|||
let price_history =
|
||||
stats::extract_price_history(&matching_rows, &state.data, &state.feature_name_to_index);
|
||||
|
||||
let (numeric_features, enum_features_out) = stats::compute_feature_stats(
|
||||
let (mut numeric_features, enum_features_out) = stats::compute_feature_stats(
|
||||
&matching_rows,
|
||||
&state.data,
|
||||
&state.data.feature_names,
|
||||
|
|
@ -242,6 +252,12 @@ pub async fn get_hexagon_stats(
|
|||
fields_specified,
|
||||
&field_set,
|
||||
);
|
||||
numeric_features.extend(stats::compute_poi_feature_stats(
|
||||
&matching_rows,
|
||||
&state.data.poi_metrics,
|
||||
fields_specified,
|
||||
&field_set,
|
||||
));
|
||||
|
||||
let elapsed = start_time.elapsed();
|
||||
info!(
|
||||
|
|
|
|||
|
|
@ -11,14 +11,15 @@ use serde::{Deserialize, Serialize};
|
|||
use serde_json::{Map, Value};
|
||||
use tracing::info;
|
||||
|
||||
use crate::aggregation::{Aggregator, EnumDistConfig};
|
||||
use crate::aggregation::{Aggregator, EnumDistConfig, PoiAggregator};
|
||||
use crate::auth::OptionalUser;
|
||||
use crate::consts::MAX_CELLS_PER_REQUEST;
|
||||
use crate::data::travel_time::TravelData;
|
||||
use crate::licensing::{check_license_bounds, resolve_share_code};
|
||||
use crate::parsing::{
|
||||
cell_for_row_cached, needs_parent, parse_enum_dist, parse_field_indices, parse_filters,
|
||||
require_bounds, row_passes_filters, validate_h3_resolution,
|
||||
cell_for_row_cached, needs_parent, parse_enum_dist, parse_field_indices_with_poi,
|
||||
parse_filters_with_poi, require_bounds, row_passes_filters, row_passes_poi_filters,
|
||||
validate_h3_resolution,
|
||||
};
|
||||
use crate::routes::travel_time::{parse_optional_travel, TravelTimeAgg};
|
||||
use crate::state::SharedState;
|
||||
|
|
@ -29,6 +30,7 @@ const PARALLEL_THRESHOLD: usize = 50_000;
|
|||
/// Per-thread aggregation result: feature accumulators + travel time accumulators.
|
||||
type ChunkResult = (
|
||||
FxHashMap<u64, Aggregator>,
|
||||
FxHashMap<u64, PoiAggregator>,
|
||||
Vec<FxHashMap<u64, TravelTimeAgg>>,
|
||||
);
|
||||
|
||||
|
|
@ -79,11 +81,14 @@ pub struct HexagonParams {
|
|||
#[allow(clippy::too_many_arguments)]
|
||||
fn build_feature_maps(
|
||||
groups: &FxHashMap<u64, Aggregator>,
|
||||
poi_groups: &FxHashMap<u64, PoiAggregator>,
|
||||
min_keys: &[String],
|
||||
max_keys: &[String],
|
||||
avg_keys: &[String],
|
||||
num_features: usize,
|
||||
indices: Option<&[usize]>,
|
||||
poi_feature_names: &[String],
|
||||
poi_indices: &[usize],
|
||||
query_bounds: (f64, f64, f64, f64),
|
||||
resolution: h3o::Resolution,
|
||||
travel_aggs: &[FxHashMap<u64, TravelTimeAgg>],
|
||||
|
|
@ -163,6 +168,25 @@ fn build_feature_maps(
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(poi_aggregation) = poi_groups.get(&cell_id) {
|
||||
for &metric_idx in poi_indices {
|
||||
if poi_aggregation.counts[metric_idx] > 0 {
|
||||
let avg = poi_aggregation.sums[metric_idx]
|
||||
/ poi_aggregation.counts[metric_idx] as f64;
|
||||
if let (Some(min_num), Some(max_num), Some(avg_num)) = (
|
||||
serde_json::Number::from_f64(poi_aggregation.mins[metric_idx] as f64),
|
||||
serde_json::Number::from_f64(poi_aggregation.maxs[metric_idx] as f64),
|
||||
serde_json::Number::from_f64(avg),
|
||||
) {
|
||||
let name = &poi_feature_names[metric_idx];
|
||||
map.insert(format!("min_{name}"), Value::Number(min_num));
|
||||
map.insert(format!("max_{name}"), Value::Number(max_num));
|
||||
map.insert(format!("avg_{name}"), Value::Number(avg_num));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add travel time aggregation fields (using pre-computed key strings)
|
||||
for (ti, agg_map) in travel_aggs.iter().enumerate() {
|
||||
if let Some(agg) = agg_map.get(&cell_id) {
|
||||
|
|
@ -209,18 +233,25 @@ pub async fn get_hexagons(
|
|||
check_license_bounds(&user.0, (south, west, north, east), share_bounds)?;
|
||||
|
||||
let quant = state.data.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters) = parse_filters(
|
||||
let poi_quant = state.data.poi_metrics.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
|
||||
params.filters.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
&quant,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
&poi_quant,
|
||||
)
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
|
||||
let filters_str = params.filters;
|
||||
|
||||
let field_indices = parse_field_indices(params.fields.as_deref(), &state.feature_name_to_index)
|
||||
.map_err(|err| (err.0, err.1).into_response())?;
|
||||
let field_indices = parse_field_indices_with_poi(
|
||||
params.fields.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
)
|
||||
.map_err(|err| (err.0, err.1).into_response())?;
|
||||
|
||||
let travel_entries = parse_optional_travel(params.travel.as_deref())
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
|
|
@ -269,6 +300,11 @@ pub async fn get_hexagons(
|
|||
let min_keys = &state.min_keys;
|
||||
let max_keys = &state.max_keys;
|
||||
let avg_keys = &state.avg_keys;
|
||||
let poi_metrics = &state.data.poi_metrics;
|
||||
let poi_field_indices = field_indices.poi.as_slice();
|
||||
let has_poi_fields = !poi_field_indices.is_empty();
|
||||
let has_poi_filters = !parsed_poi_filters.is_empty();
|
||||
let poi_num_features = poi_metrics.num_features();
|
||||
|
||||
let h3_res = h3o::Resolution::try_from(resolution)
|
||||
.map_err(|error| format!("Invalid H3 resolution {}: {}", resolution, error))?;
|
||||
|
|
@ -276,6 +312,7 @@ pub async fn get_hexagons(
|
|||
let need_parent = needs_parent(resolution);
|
||||
|
||||
let mut groups: FxHashMap<u64, Aggregator> = FxHashMap::default();
|
||||
let mut poi_groups: FxHashMap<u64, PoiAggregator> = FxHashMap::default();
|
||||
let mut travel_aggs: Vec<FxHashMap<u64, TravelTimeAgg>> = (0..travel_entries.len())
|
||||
.map(|_| FxHashMap::default())
|
||||
.collect();
|
||||
|
|
@ -296,6 +333,7 @@ pub async fn get_hexagons(
|
|||
.par_chunks(chunk_size)
|
||||
.map(|chunk| {
|
||||
let mut local_groups: FxHashMap<u64, Aggregator> = FxHashMap::default();
|
||||
let mut local_poi_groups: FxHashMap<u64, PoiAggregator> = FxHashMap::default();
|
||||
let mut local_travel_aggs: Vec<FxHashMap<u64, TravelTimeAgg>> = (0
|
||||
..travel_entries.len())
|
||||
.map(|_| FxHashMap::default())
|
||||
|
|
@ -315,6 +353,11 @@ pub async fn get_hexagons(
|
|||
) {
|
||||
continue;
|
||||
}
|
||||
if has_poi_filters
|
||||
&& !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if has_travel {
|
||||
travel_minutes.clear();
|
||||
|
|
@ -352,7 +395,7 @@ pub async fn get_hexagons(
|
|||
let agg = local_groups
|
||||
.entry(cell_id)
|
||||
.or_insert_with(|| Aggregator::new(num_features, enum_dist_config));
|
||||
if let Some(sel_indices) = field_indices.as_deref() {
|
||||
if let Some(sel_indices) = field_indices.normal.as_deref() {
|
||||
agg.add_row_selective(
|
||||
feature_data,
|
||||
row,
|
||||
|
|
@ -364,6 +407,13 @@ pub async fn get_hexagons(
|
|||
agg.add_row(feature_data, row, num_features, &quant);
|
||||
}
|
||||
|
||||
if has_poi_fields {
|
||||
local_poi_groups
|
||||
.entry(cell_id)
|
||||
.or_insert_with(|| PoiAggregator::new(poi_num_features))
|
||||
.add_row_selective(poi_metrics, row, poi_field_indices);
|
||||
}
|
||||
|
||||
for (ti, minutes) in travel_minutes.iter().enumerate() {
|
||||
if let Some(mins) = minutes {
|
||||
let tagg = local_travel_aggs[ti]
|
||||
|
|
@ -374,18 +424,24 @@ pub async fn get_hexagons(
|
|||
}
|
||||
}
|
||||
|
||||
(local_groups, local_travel_aggs)
|
||||
(local_groups, local_poi_groups, local_travel_aggs)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Merge thread-local results into the main accumulators
|
||||
for (local_groups, local_travel) in thread_results {
|
||||
for (local_groups, local_poi_groups, local_travel) in thread_results {
|
||||
for (cell_id, local_agg) in local_groups {
|
||||
groups
|
||||
.entry(cell_id)
|
||||
.or_insert_with(|| Aggregator::new(num_features, enum_dist_config))
|
||||
.merge(&local_agg);
|
||||
}
|
||||
for (cell_id, local_agg) in local_poi_groups {
|
||||
poi_groups
|
||||
.entry(cell_id)
|
||||
.or_insert_with(|| PoiAggregator::new(poi_num_features))
|
||||
.merge(&local_agg);
|
||||
}
|
||||
for (ti, local_ta) in local_travel.into_iter().enumerate() {
|
||||
for (cell_id, local_tt) in local_ta {
|
||||
travel_aggs[ti]
|
||||
|
|
@ -414,6 +470,11 @@ pub async fn get_hexagons(
|
|||
) {
|
||||
return;
|
||||
}
|
||||
if has_poi_filters
|
||||
&& !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if has_travel {
|
||||
travel_minutes.clear();
|
||||
|
|
@ -444,7 +505,7 @@ pub async fn get_hexagons(
|
|||
let aggregation = groups
|
||||
.entry(cell_id)
|
||||
.or_insert_with(|| Aggregator::new(num_features, enum_dist_config));
|
||||
if let Some(sel_indices) = field_indices.as_deref() {
|
||||
if let Some(sel_indices) = field_indices.normal.as_deref() {
|
||||
aggregation.add_row_selective(
|
||||
feature_data,
|
||||
row,
|
||||
|
|
@ -456,6 +517,13 @@ pub async fn get_hexagons(
|
|||
aggregation.add_row(feature_data, row, num_features, &quant);
|
||||
}
|
||||
|
||||
if has_poi_fields {
|
||||
poi_groups
|
||||
.entry(cell_id)
|
||||
.or_insert_with(|| PoiAggregator::new(poi_num_features))
|
||||
.add_row_selective(poi_metrics, row, poi_field_indices);
|
||||
}
|
||||
|
||||
for (ti, minutes) in travel_minutes.iter().enumerate() {
|
||||
if let Some(mins) = minutes {
|
||||
let agg = travel_aggs[ti]
|
||||
|
|
@ -471,11 +539,14 @@ pub async fn get_hexagons(
|
|||
|
||||
let mut features = build_feature_maps(
|
||||
&groups,
|
||||
&poi_groups,
|
||||
min_keys,
|
||||
max_keys,
|
||||
avg_keys,
|
||||
num_features,
|
||||
field_indices.as_deref(),
|
||||
field_indices.normal.as_deref(),
|
||||
&poi_metrics.feature_names,
|
||||
poi_field_indices,
|
||||
(south, west, north, east),
|
||||
h3_res,
|
||||
&travel_aggs,
|
||||
|
|
@ -499,7 +570,11 @@ pub async fn get_hexagons(
|
|||
bounds = format_args!("{:.4},{:.4},{:.4},{:.4}", south, west, north, east),
|
||||
filters = num_filters,
|
||||
filters_raw = filters_str.as_deref().unwrap_or("-"),
|
||||
fields = field_indices.as_ref().map(|v| v.len() as i32).unwrap_or(-1),
|
||||
fields = field_indices
|
||||
.normal
|
||||
.as_ref()
|
||||
.map(|v| (v.len() + poi_field_indices.len()) as i32)
|
||||
.unwrap_or(-1),
|
||||
travel_entries = travel_entries.len(),
|
||||
grid_ms = format_args!("{:.1}", t_grid.as_secs_f64() * 1000.0),
|
||||
agg_ms = format_args!("{:.1}", (t_agg - t_grid).as_secs_f64() * 1000.0),
|
||||
|
|
|
|||
|
|
@ -9,11 +9,16 @@ use serde::{Deserialize, Serialize};
|
|||
use tracing::{info, warn};
|
||||
|
||||
use crate::auth::{OptionalUser, PocketBaseUser};
|
||||
use crate::checkout_sessions::{
|
||||
active_referral_checkout_user, start_license_checkout, CheckoutStart,
|
||||
};
|
||||
use crate::pocketbase::get_superuser_token;
|
||||
use crate::pocketbase_locks::acquire_pocketbase_lock;
|
||||
use crate::state::{AppState, SharedState};
|
||||
|
||||
static INVITE_REDEMPTIONS_IN_PROGRESS: LazyLock<Mutex<HashSet<String>>> =
|
||||
LazyLock::new(|| Mutex::new(HashSet::new()));
|
||||
const INVITE_REDEMPTION_LOCK_TTL_SECS: u64 = 5 * 60;
|
||||
|
||||
struct InviteRedemptionGuard {
|
||||
code: String,
|
||||
|
|
@ -103,7 +108,7 @@ fn validate_invite_code(code: &str) -> Result<(), &'static str> {
|
|||
}
|
||||
|
||||
fn generate_invite_code() -> String {
|
||||
use rand::Rng;
|
||||
use rand::RngExt;
|
||||
let mut rng = rand::rng();
|
||||
let chars: Vec<char> = (0..12)
|
||||
.map(|_| {
|
||||
|
|
@ -246,74 +251,26 @@ async fn grant_license_for_invite(
|
|||
async fn create_referral_checkout(
|
||||
state: &AppState,
|
||||
user: &PocketBaseUser,
|
||||
invite_id: &str,
|
||||
) -> Result<String, Response> {
|
||||
let count = match super::pricing::count_licensed_users(state).await {
|
||||
Ok(count) => count,
|
||||
Err(err) => {
|
||||
warn!("Failed to count licensed users for invite checkout: {err}");
|
||||
return Err(StatusCode::SERVICE_UNAVAILABLE.into_response());
|
||||
}
|
||||
};
|
||||
let price_pence = super::pricing::price_for_count(count);
|
||||
|
||||
let public_url = &state.public_url;
|
||||
let success_url = format!("{public_url}/pricing?license_success=1");
|
||||
let cancel_url = format!("{public_url}/pricing");
|
||||
|
||||
let form_params = vec![
|
||||
("mode", "payment".to_string()),
|
||||
(
|
||||
"line_items[0][price_data][unit_amount]",
|
||||
price_pence.to_string(),
|
||||
),
|
||||
("line_items[0][price_data][currency]", "gbp".to_string()),
|
||||
(
|
||||
"line_items[0][price_data][product_data][name]",
|
||||
"Perfect Postcodes Lifetime License".to_string(),
|
||||
),
|
||||
("line_items[0][quantity]", "1".to_string()),
|
||||
("success_url", success_url),
|
||||
("cancel_url", cancel_url),
|
||||
("client_reference_id", user.id.clone()),
|
||||
("customer_email", user.email.clone()),
|
||||
(
|
||||
"discounts[0][coupon]",
|
||||
state.stripe_referral_coupon_id.clone(),
|
||||
),
|
||||
];
|
||||
|
||||
let stripe_res = state
|
||||
.http_client
|
||||
.post("https://api.stripe.com/v1/checkout/sessions")
|
||||
.basic_auth(&state.stripe_secret_key, None::<&str>)
|
||||
.form(&form_params)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match stripe_res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let stripe_body: serde_json::Value = match resp.json().await {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
warn!("Failed to parse Stripe checkout response: {err}");
|
||||
return Err(StatusCode::BAD_GATEWAY.into_response());
|
||||
}
|
||||
};
|
||||
let checkout_url = stripe_body["url"].as_str().unwrap_or_default().to_string();
|
||||
if checkout_url.is_empty() {
|
||||
warn!("Stripe checkout response did not include a URL");
|
||||
return Err(StatusCode::BAD_GATEWAY.into_response());
|
||||
}
|
||||
Ok(checkout_url)
|
||||
}
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
warn!("Failed to create Stripe checkout for referral invite ({status}): {text}");
|
||||
Err(StatusCode::BAD_GATEWAY.into_response())
|
||||
}
|
||||
match start_license_checkout(
|
||||
state,
|
||||
user,
|
||||
&success_url,
|
||||
&cancel_url,
|
||||
Some(&state.stripe_referral_coupon_id),
|
||||
Some(invite_id),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(CheckoutStart::Free) => Ok(success_url),
|
||||
Ok(CheckoutStart::Stripe { url }) => Ok(url),
|
||||
Err(err) => {
|
||||
warn!("Stripe request error for referral invite: {err}");
|
||||
warn!("Failed to create reserved Stripe checkout for referral invite: {err:?}");
|
||||
Err(StatusCode::BAD_GATEWAY.into_response())
|
||||
}
|
||||
}
|
||||
|
|
@ -541,6 +498,10 @@ pub async fn post_redeem_invite(
|
|||
.into_response();
|
||||
}
|
||||
|
||||
if user.is_admin || user.subscription == "licensed" {
|
||||
return (StatusCode::CONFLICT, "Account already has full access").into_response();
|
||||
}
|
||||
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
|
||||
let token = match get_superuser_token(&state).await {
|
||||
|
|
@ -561,6 +522,19 @@ pub async fn post_redeem_invite(
|
|||
.into_response()
|
||||
}
|
||||
};
|
||||
let lock_name = format!("invite:{}", req.code);
|
||||
let _distributed_redemption_guard =
|
||||
match acquire_pocketbase_lock(&state, &lock_name, INVITE_REDEMPTION_LOCK_TTL_SECS).await {
|
||||
Ok(guard) => guard,
|
||||
Err(err) => {
|
||||
warn!(code = %req.code, "Failed to acquire invite redemption lock: {err}");
|
||||
return (
|
||||
StatusCode::CONFLICT,
|
||||
"Invite redemption is already in progress",
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let invite = match lookup_unused_invite(&state, pb_url, &token, &req.code).await {
|
||||
Ok(Some(invite)) => invite,
|
||||
|
|
@ -591,11 +565,11 @@ pub async fn post_redeem_invite(
|
|||
};
|
||||
|
||||
if invite_type == "admin" {
|
||||
if let Err(response) = grant_license_for_invite(&state, pb_url, &token, &user.id).await {
|
||||
if let Err(response) = mark_invite_used(&state, pb_url, &token, invite_id, &user.id).await {
|
||||
return response;
|
||||
}
|
||||
|
||||
if let Err(response) = mark_invite_used(&state, pb_url, &token, invite_id, &user.id).await {
|
||||
if let Err(response) = grant_license_for_invite(&state, pb_url, &token, &user.id).await {
|
||||
return response;
|
||||
}
|
||||
|
||||
|
|
@ -607,15 +581,26 @@ pub async fn post_redeem_invite(
|
|||
.into_response();
|
||||
}
|
||||
|
||||
let checkout_url = match create_referral_checkout(&state, &user).await {
|
||||
match active_referral_checkout_user(&state, invite_id).await {
|
||||
Ok(Some(active_user_id)) if active_user_id != user.id => {
|
||||
return (
|
||||
StatusCode::CONFLICT,
|
||||
"Invite checkout is already in progress",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
Ok(_) => {}
|
||||
Err(err) => {
|
||||
warn!(code = %req.code, "Failed to check active referral checkout: {err}");
|
||||
return StatusCode::BAD_GATEWAY.into_response();
|
||||
}
|
||||
}
|
||||
|
||||
let checkout_url = match create_referral_checkout(&state, &user, invite_id).await {
|
||||
Ok(url) => url,
|
||||
Err(response) => return response,
|
||||
};
|
||||
|
||||
if let Err(response) = mark_invite_used(&state, pb_url, &token, invite_id, &user.id).await {
|
||||
return response;
|
||||
}
|
||||
|
||||
info!(user_id = %user.id, code = %req.code, "Referral invite redeemed; checkout created");
|
||||
Json(RedeemResponse {
|
||||
result: "checkout".to_string(),
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
|
|||
use tracing::info;
|
||||
|
||||
use crate::consts::MAX_POIS_PER_REQUEST;
|
||||
use crate::data::POICategoryGroup;
|
||||
use crate::data::{resolve_poi_category_filter, POICategoryGroup};
|
||||
use crate::parsing::require_bounds;
|
||||
use crate::state::SharedState;
|
||||
|
||||
|
|
@ -47,20 +47,7 @@ pub async fn get_pois(
|
|||
.categories
|
||||
.as_deref()
|
||||
.filter(|text| !text.is_empty())
|
||||
.map(|text| {
|
||||
text.split(',')
|
||||
.filter_map(|part| {
|
||||
let name = part.trim();
|
||||
state
|
||||
.poi_data
|
||||
.category
|
||||
.values
|
||||
.iter()
|
||||
.position(|v| v == name)
|
||||
.map(|pos| pos as u16)
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
.map(|text| resolve_poi_category_filter(&state.poi_data.category.values, text));
|
||||
let categories_raw = params.categories;
|
||||
|
||||
let num_categories = category_filter.as_ref().map(|cats| cats.len()).unwrap_or(0);
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ use tracing::{info, warn};
|
|||
use crate::auth::OptionalUser;
|
||||
use crate::consts::{DEFAULT_PROPERTIES_LIMIT, MAX_PROPERTIES_LIMIT, POSTCODE_SEARCH_OFFSET};
|
||||
use crate::licensing::{check_license_point, resolve_share_code};
|
||||
use crate::parsing::{parse_filters, row_passes_filters};
|
||||
use crate::parsing::{parse_filters_with_poi, row_passes_filters, row_passes_poi_filters};
|
||||
use crate::state::SharedState;
|
||||
use crate::utils::normalize_postcode;
|
||||
|
||||
|
|
@ -62,15 +62,19 @@ pub async fn get_postcode_properties(
|
|||
)?;
|
||||
|
||||
let quant = state.data.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters) = parse_filters(
|
||||
let poi_quant = state.data.poi_metrics.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
|
||||
params.filters.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
&quant,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
&poi_quant,
|
||||
)
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
|
||||
let filters_str = params.filters;
|
||||
let has_poi_filters = !parsed_poi_filters.is_empty();
|
||||
let travel_entries = parse_optional_travel(params.travel.as_deref())
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
|
||||
|
|
@ -111,6 +115,12 @@ pub async fn get_postcode_properties(
|
|||
feature_data,
|
||||
num_features,
|
||||
)
|
||||
&& (!has_poi_filters
|
||||
|| row_passes_poi_filters(
|
||||
row,
|
||||
&parsed_poi_filters,
|
||||
&state.data.poi_metrics,
|
||||
))
|
||||
{
|
||||
if has_travel
|
||||
&& !row_passes_travel_filters(
|
||||
|
|
|
|||
|
|
@ -10,7 +10,9 @@ use tracing::{info, warn};
|
|||
use crate::auth::OptionalUser;
|
||||
use crate::consts::POSTCODE_SEARCH_OFFSET;
|
||||
use crate::licensing::{check_license_point, resolve_share_code};
|
||||
use crate::parsing::{parse_field_set, parse_filters, row_passes_filters};
|
||||
use crate::parsing::{
|
||||
parse_field_set, parse_filters_with_poi, row_passes_filters, row_passes_poi_filters,
|
||||
};
|
||||
use crate::state::SharedState;
|
||||
use crate::utils::normalize_postcode;
|
||||
|
||||
|
|
@ -64,15 +66,19 @@ pub async fn get_postcode_stats(
|
|||
)?;
|
||||
|
||||
let quant = state.data.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters) = parse_filters(
|
||||
let poi_quant = state.data.poi_metrics.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
|
||||
params.filters.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
&quant,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
&poi_quant,
|
||||
)
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
|
||||
let filters_str = params.filters;
|
||||
let has_poi_filters = !parsed_poi_filters.is_empty();
|
||||
|
||||
let (fields_specified, field_set) = parse_field_set(params.fields.as_deref());
|
||||
let travel_entries = parse_optional_travel(params.travel.as_deref())
|
||||
|
|
@ -108,6 +114,12 @@ pub async fn get_postcode_stats(
|
|||
feature_data,
|
||||
num_features,
|
||||
)
|
||||
&& (!has_poi_filters
|
||||
|| row_passes_poi_filters(
|
||||
row,
|
||||
&parsed_poi_filters,
|
||||
&state.data.poi_metrics,
|
||||
))
|
||||
{
|
||||
if has_travel
|
||||
&& !row_passes_travel_filters(row_postcode, &travel_entries, &travel_data)
|
||||
|
|
@ -123,7 +135,7 @@ pub async fn get_postcode_stats(
|
|||
let price_history =
|
||||
stats::extract_price_history(&matching_rows, &state.data, &state.feature_name_to_index);
|
||||
|
||||
let (numeric_features, enum_features_out) = stats::compute_feature_stats(
|
||||
let (mut numeric_features, enum_features_out) = stats::compute_feature_stats(
|
||||
&matching_rows,
|
||||
&state.data,
|
||||
&state.data.feature_names,
|
||||
|
|
@ -132,6 +144,12 @@ pub async fn get_postcode_stats(
|
|||
fields_specified,
|
||||
&field_set,
|
||||
);
|
||||
numeric_features.extend(stats::compute_poi_feature_stats(
|
||||
&matching_rows,
|
||||
&state.data.poi_metrics,
|
||||
fields_specified,
|
||||
&field_set,
|
||||
));
|
||||
|
||||
let elapsed = start_time.elapsed();
|
||||
info!(
|
||||
|
|
|
|||
|
|
@ -10,14 +10,14 @@ use serde::{Deserialize, Serialize};
|
|||
use serde_json::{Map, Value};
|
||||
use tracing::info;
|
||||
|
||||
use crate::aggregation::{Aggregator, EnumDistConfig};
|
||||
use crate::aggregation::{Aggregator, EnumDistConfig, PoiAggregator};
|
||||
use crate::auth::OptionalUser;
|
||||
use crate::consts::MAX_CELLS_PER_REQUEST;
|
||||
use crate::data::travel_time::TravelData;
|
||||
use crate::licensing::{check_license_bounds, resolve_share_code};
|
||||
use crate::parsing::{
|
||||
bounds_intersect, parse_enum_dist, parse_field_indices, parse_filters, require_bounds,
|
||||
row_passes_filters,
|
||||
bounds_intersect, parse_enum_dist, parse_field_indices_with_poi, parse_filters_with_poi,
|
||||
require_bounds, row_passes_filters, row_passes_poi_filters,
|
||||
};
|
||||
use crate::pocketbase::log_user_location;
|
||||
use crate::routes::travel_time::{parse_optional_travel, TravelTimeAgg};
|
||||
|
|
@ -64,18 +64,25 @@ pub async fn get_postcodes(
|
|||
check_license_bounds(&user.0, (south, west, north, east), share_bounds)?;
|
||||
|
||||
let quant = state.data.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters) = parse_filters(
|
||||
let poi_quant = state.data.poi_metrics.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
|
||||
params.filters.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
&quant,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
&poi_quant,
|
||||
)
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
|
||||
let filters_str = params.filters;
|
||||
|
||||
let field_indices = parse_field_indices(params.fields.as_deref(), &state.feature_name_to_index)
|
||||
.map_err(|err| (err.0, err.1).into_response())?;
|
||||
let field_indices = parse_field_indices_with_poi(
|
||||
params.fields.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
)
|
||||
.map_err(|err| (err.0, err.1).into_response())?;
|
||||
|
||||
let travel_entries = parse_optional_travel(params.travel.as_deref())
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
|
|
@ -123,12 +130,18 @@ pub async fn get_postcodes(
|
|||
let min_keys = &state.min_keys;
|
||||
let max_keys = &state.max_keys;
|
||||
let avg_keys = &state.avg_keys;
|
||||
let poi_metrics = &state.data.poi_metrics;
|
||||
let poi_field_indices = field_indices.poi.as_slice();
|
||||
let has_poi_fields = !poi_field_indices.is_empty();
|
||||
let has_poi_filters = !parsed_poi_filters.is_empty();
|
||||
let poi_num_features = poi_metrics.num_features();
|
||||
|
||||
let has_selective = field_indices.is_some();
|
||||
let sel_indices = field_indices.as_deref().unwrap_or(&[]);
|
||||
let has_selective = field_indices.normal.is_some();
|
||||
let sel_indices = field_indices.normal.as_deref().unwrap_or(&[]);
|
||||
|
||||
// Single-pass: aggregate directly into postcode_aggs while iterating properties in bounds
|
||||
let mut postcode_aggs: FxHashMap<usize, Aggregator> = FxHashMap::default();
|
||||
let mut poi_aggs: FxHashMap<usize, PoiAggregator> = FxHashMap::default();
|
||||
|
||||
state
|
||||
.grid
|
||||
|
|
@ -143,6 +156,10 @@ pub async fn get_postcodes(
|
|||
) {
|
||||
return;
|
||||
}
|
||||
if has_poi_filters && !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
let postcode = state.data.postcode(row);
|
||||
if let Some(&pc_idx) = postcode_data.postcode_to_idx.get(postcode) {
|
||||
|
|
@ -154,6 +171,12 @@ pub async fn get_postcodes(
|
|||
} else {
|
||||
agg.add_row(feature_data, row, num_features, &quant);
|
||||
}
|
||||
if has_poi_fields {
|
||||
poi_aggs
|
||||
.entry(pc_idx)
|
||||
.or_insert_with(|| PoiAggregator::new(poi_num_features))
|
||||
.add_row_selective(poi_metrics, row, poi_field_indices);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -250,11 +273,12 @@ pub async fn get_postcodes(
|
|||
]),
|
||||
);
|
||||
|
||||
let iter: Box<dyn Iterator<Item = usize>> = if let Some(idx) = field_indices.as_ref() {
|
||||
Box::new(idx.iter().copied())
|
||||
} else {
|
||||
Box::new(0..num_features)
|
||||
};
|
||||
let iter: Box<dyn Iterator<Item = usize>> =
|
||||
if let Some(idx) = field_indices.normal.as_ref() {
|
||||
Box::new(idx.iter().copied())
|
||||
} else {
|
||||
Box::new(0..num_features)
|
||||
};
|
||||
|
||||
for feat_index in iter {
|
||||
if aggregation.feat_counts[feat_index] > 0 {
|
||||
|
|
@ -272,6 +296,25 @@ pub async fn get_postcodes(
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(poi_aggregation) = poi_aggs.get(&pc_idx) {
|
||||
for &metric_idx in poi_field_indices {
|
||||
if poi_aggregation.counts[metric_idx] > 0 {
|
||||
let avg = poi_aggregation.sums[metric_idx]
|
||||
/ poi_aggregation.counts[metric_idx] as f64;
|
||||
if let (Some(min_num), Some(max_num), Some(avg_num)) = (
|
||||
serde_json::Number::from_f64(poi_aggregation.mins[metric_idx] as f64),
|
||||
serde_json::Number::from_f64(poi_aggregation.maxs[metric_idx] as f64),
|
||||
serde_json::Number::from_f64(avg),
|
||||
) {
|
||||
let name = &poi_metrics.feature_names[metric_idx];
|
||||
props.insert(format!("min_{name}"), Value::Number(min_num));
|
||||
props.insert(format!("max_{name}"), Value::Number(max_num));
|
||||
props.insert(format!("avg_{name}"), Value::Number(avg_num));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add travel time aggregation fields
|
||||
if let Some(tt_aggs) = travel_aggs.get(&pc_idx) {
|
||||
for (ti, agg) in tt_aggs.iter().enumerate() {
|
||||
|
|
@ -322,7 +365,11 @@ pub async fn get_postcodes(
|
|||
bounds = format_args!("{:.6},{:.6},{:.6},{:.6}", south, west, north, east),
|
||||
filters = num_filters,
|
||||
filters_raw = filters_str.as_deref().unwrap_or("-"),
|
||||
fields = field_indices.as_ref().map(|v| v.len() as i32).unwrap_or(-1),
|
||||
fields = field_indices
|
||||
.normal
|
||||
.as_ref()
|
||||
.map(|v| (v.len() + poi_field_indices.len()) as i32)
|
||||
.unwrap_or(-1),
|
||||
travel_entries = travel_entries.len(),
|
||||
agg_ms = format_args!("{:.1}", t_agg.as_secs_f64() * 1000.0),
|
||||
json_ms = format_args!("{:.1}", (t_total - t_agg).as_secs_f64() * 1000.0),
|
||||
|
|
|
|||
|
|
@ -14,8 +14,8 @@ use crate::consts::{DEFAULT_PROPERTIES_LIMIT, MAX_PROPERTIES_LIMIT};
|
|||
use crate::data::RenovationEvent;
|
||||
use crate::licensing::{check_license_bounds, resolve_share_code};
|
||||
use crate::parsing::{
|
||||
cell_for_row_cached, h3_cell_bounds, needs_parent, parse_filters, row_passes_filters,
|
||||
validate_h3_resolution,
|
||||
cell_for_row_cached, h3_cell_bounds, needs_parent, parse_filters_with_poi, row_passes_filters,
|
||||
row_passes_poi_filters, validate_h3_resolution,
|
||||
};
|
||||
use crate::state::{AppState, SharedState};
|
||||
|
||||
|
|
@ -117,6 +117,12 @@ pub fn build_property(
|
|||
features.insert(feat_name.clone(), value);
|
||||
}
|
||||
}
|
||||
for (metric_idx, metric_name) in state.data.poi_metrics.feature_names.iter().enumerate() {
|
||||
let value = state.data.poi_metrics.get_for_property_row(row, metric_idx);
|
||||
if value.is_finite() {
|
||||
features.insert(metric_name.clone(), value);
|
||||
}
|
||||
}
|
||||
|
||||
Property {
|
||||
address: non_empty_string(state.data.address(row)),
|
||||
|
|
@ -199,15 +205,19 @@ pub async fn get_hexagon_properties(
|
|||
|
||||
let h3_str = params.h3;
|
||||
let quant = state.data.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters) = parse_filters(
|
||||
let poi_quant = state.data.poi_metrics.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi(
|
||||
params.filters.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
&quant,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
&poi_quant,
|
||||
)
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len();
|
||||
let filters_str = params.filters;
|
||||
let has_poi_filters = !parsed_poi_filters.is_empty();
|
||||
let travel_entries = parse_optional_travel(params.travel.as_deref())
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
|
||||
|
|
@ -242,6 +252,12 @@ pub async fn get_hexagon_properties(
|
|||
feature_data,
|
||||
num_features,
|
||||
)
|
||||
&& (!has_poi_filters
|
||||
|| row_passes_poi_filters(
|
||||
row,
|
||||
&parsed_poi_filters,
|
||||
&state.data.poi_metrics,
|
||||
))
|
||||
{
|
||||
if has_travel {
|
||||
let postcode = state.data.postcode(row);
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use rustc_hash::FxHashMap;
|
|||
use tracing::warn;
|
||||
|
||||
use crate::consts::MAX_PRICE_HISTORY_POINTS;
|
||||
use crate::data::{FeatureStats, PropertyData};
|
||||
use crate::data::{FeatureStats, PostcodePoiMetrics, PropertyData};
|
||||
|
||||
use super::hexagon_stats::{EnumFeatureStats, HistogramStats, NumericFeatureStats, PricePoint};
|
||||
|
||||
|
|
@ -243,3 +243,80 @@ pub fn compute_feature_stats(
|
|||
|
||||
(numeric_features, enum_features_out)
|
||||
}
|
||||
|
||||
pub fn compute_poi_feature_stats(
|
||||
matching_rows: &[usize],
|
||||
poi_metrics: &PostcodePoiMetrics,
|
||||
fields_specified: bool,
|
||||
field_set: &HashSet<String>,
|
||||
) -> Vec<NumericFeatureStats> {
|
||||
let mut out = Vec::new();
|
||||
for (metric_idx, name) in poi_metrics.feature_names.iter().enumerate() {
|
||||
if fields_specified && !field_set.contains(name.as_str()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let global_hist = &poi_metrics.feature_stats[metric_idx].histogram;
|
||||
let p1 = global_hist.p1;
|
||||
let p99 = global_hist.p99;
|
||||
let num_bins = global_hist.counts.len();
|
||||
let middle_bins = num_bins.saturating_sub(2);
|
||||
let middle_width = if middle_bins > 0 && p99 > p1 {
|
||||
(p99 - p1) / middle_bins as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let mut count = 0usize;
|
||||
let mut min_value = f32::INFINITY;
|
||||
let mut max_value = f32::NEG_INFINITY;
|
||||
let mut sum = 0.0f64;
|
||||
let mut bins = vec![0u64; num_bins];
|
||||
|
||||
for &row in matching_rows {
|
||||
let value = poi_metrics.get_for_property_row(row, metric_idx);
|
||||
if !value.is_finite() {
|
||||
continue;
|
||||
}
|
||||
count += 1;
|
||||
if value < min_value {
|
||||
min_value = value;
|
||||
}
|
||||
if value > max_value {
|
||||
max_value = value;
|
||||
}
|
||||
sum += value as f64;
|
||||
|
||||
let bin = if value < p1 {
|
||||
0
|
||||
} else if value >= p99 {
|
||||
num_bins - 1
|
||||
} else if middle_width > 0.0 {
|
||||
let middle_bin = ((value - p1) / middle_width) as usize;
|
||||
(1 + middle_bin).min(num_bins - 2)
|
||||
} else {
|
||||
num_bins / 2
|
||||
};
|
||||
bins[bin] += 1;
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
out.push(NumericFeatureStats {
|
||||
name: name.clone(),
|
||||
count,
|
||||
min: min_value as f64,
|
||||
max: max_value as f64,
|
||||
mean: sum / count as f64,
|
||||
histogram: HistogramStats {
|
||||
min: global_hist.min as f64,
|
||||
max: global_hist.max as f64,
|
||||
p1: p1 as f64,
|
||||
p99: p99 as f64,
|
||||
counts: bins,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
out
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,78 +1,40 @@
|
|||
use std::collections::VecDeque;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::body::Bytes;
|
||||
use axum::extract::State;
|
||||
use axum::http::{HeaderMap, StatusCode};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use hmac::{Hmac, Mac};
|
||||
use parking_lot::Mutex;
|
||||
use rustc_hash::FxHashSet;
|
||||
use hmac::{Hmac, KeyInit, Mac};
|
||||
use sha2::Sha256;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::pocketbase::get_superuser_token;
|
||||
use crate::checkout_sessions::{
|
||||
grant_license, mark_checkout_completed, mark_referral_invite_used, verify_checkout_completion,
|
||||
CheckoutCompletion,
|
||||
};
|
||||
use crate::state::SharedState;
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
/// Process-local LRU of recently processed Stripe event IDs.
|
||||
/// Stripe retries deliver the same event ID; we drop duplicates so we don't
|
||||
/// re-run side effects (subscription writes, token cache invalidation, logs).
|
||||
/// Capacity is intentionally generous: at typical webhook volumes this covers
|
||||
/// far more than Stripe's retry window.
|
||||
struct EventDedup {
|
||||
seen: FxHashSet<String>,
|
||||
queue: VecDeque<String>,
|
||||
capacity: usize,
|
||||
}
|
||||
|
||||
impl EventDedup {
|
||||
fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
seen: FxHashSet::default(),
|
||||
queue: VecDeque::with_capacity(capacity),
|
||||
capacity,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` if this event ID is new (and records it),
|
||||
/// `false` if it was already seen recently.
|
||||
fn check_and_insert(&mut self, id: &str) -> bool {
|
||||
if self.seen.contains(id) {
|
||||
return false;
|
||||
}
|
||||
self.seen.insert(id.to_string());
|
||||
self.queue.push_back(id.to_string());
|
||||
if self.queue.len() > self.capacity {
|
||||
if let Some(old) = self.queue.pop_front() {
|
||||
self.seen.remove(&old);
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
static EVENT_DEDUP: LazyLock<Mutex<EventDedup>> =
|
||||
LazyLock::new(|| Mutex::new(EventDedup::new(1024)));
|
||||
|
||||
/// Verify Stripe webhook signature (v1 scheme).
|
||||
fn verify_signature(payload: &[u8], sig_header: &str, secret: &str) -> bool {
|
||||
// Parse timestamp and signature from header: "t=TIMESTAMP,v1=SIGNATURE"
|
||||
let mut timestamp = None;
|
||||
let mut signature = None;
|
||||
let mut signatures = Vec::new();
|
||||
for part in sig_header.split(',') {
|
||||
if let Some(ts) = part.strip_prefix("t=") {
|
||||
timestamp = Some(ts);
|
||||
} else if let Some(sig) = part.strip_prefix("v1=") {
|
||||
signature = Some(sig);
|
||||
signatures.push(sig);
|
||||
}
|
||||
}
|
||||
|
||||
let (ts, sig_hex) = match (timestamp, signature) {
|
||||
(Some(t), Some(s)) => (t, s),
|
||||
_ => return false,
|
||||
let Some(ts) = timestamp else {
|
||||
return false;
|
||||
};
|
||||
if signatures.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Reject webhooks older than 5 minutes to prevent replay attacks
|
||||
if let Ok(ts_secs) = ts.parse::<i64>() {
|
||||
|
|
@ -87,20 +49,21 @@ fn verify_signature(payload: &[u8], sig_header: &str, secret: &str) -> bool {
|
|||
return false;
|
||||
}
|
||||
|
||||
// Compute expected signature: HMAC-SHA256(secret, "TIMESTAMP.PAYLOAD")
|
||||
let signed_payload = format!("{ts}.{}", String::from_utf8_lossy(payload));
|
||||
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) {
|
||||
Ok(m) => m,
|
||||
Err(_) => return false,
|
||||
};
|
||||
mac.update(signed_payload.as_bytes());
|
||||
let mut signed_payload = Vec::with_capacity(ts.len() + 1 + payload.len());
|
||||
signed_payload.extend_from_slice(ts.as_bytes());
|
||||
signed_payload.push(b'.');
|
||||
signed_payload.extend_from_slice(payload);
|
||||
|
||||
// Decode the provided hex signature and verify with constant-time comparison
|
||||
let sig_bytes = match hex::decode(sig_hex) {
|
||||
Ok(bytes) => bytes,
|
||||
Err(_) => return false,
|
||||
};
|
||||
mac.verify_slice(&sig_bytes).is_ok()
|
||||
signatures.into_iter().any(|sig_hex| {
|
||||
let Ok(sig_bytes) = hex::decode(sig_hex) else {
|
||||
return false;
|
||||
};
|
||||
let Ok(mut mac) = HmacSha256::new_from_slice(secret.as_bytes()) else {
|
||||
return false;
|
||||
};
|
||||
mac.update(&signed_payload);
|
||||
mac.verify_slice(&sig_bytes).is_ok()
|
||||
})
|
||||
}
|
||||
|
||||
/// Handle Stripe webhook events.
|
||||
|
|
@ -140,65 +103,64 @@ pub async fn post_stripe_webhook(
|
|||
let event_type = event["type"].as_str().unwrap_or("");
|
||||
let event_id = event["id"].as_str().unwrap_or("");
|
||||
|
||||
// Idempotency: drop replays/retries of an already-processed event.
|
||||
// We always answer 200 so Stripe stops retrying.
|
||||
if !event_id.is_empty() && !EVENT_DEDUP.lock().check_and_insert(event_id) {
|
||||
info!(event_id, event_type, "Dropping duplicate Stripe webhook");
|
||||
return StatusCode::OK.into_response();
|
||||
}
|
||||
|
||||
info!(event_id, event_type, "Received Stripe webhook");
|
||||
|
||||
if event_type == "checkout.session.completed" {
|
||||
let user_id = event["data"]["object"]["client_reference_id"]
|
||||
.as_str()
|
||||
.unwrap_or("");
|
||||
if user_id.is_empty() {
|
||||
warn!("checkout.session.completed missing client_reference_id");
|
||||
return StatusCode::OK.into_response();
|
||||
}
|
||||
if !user_id.bytes().all(|b| b.is_ascii_alphanumeric()) || user_id.len() > 20 {
|
||||
warn!(user_id, "Invalid client_reference_id format in webhook");
|
||||
return StatusCode::BAD_REQUEST.into_response();
|
||||
}
|
||||
|
||||
// Update user subscription to "licensed" via PocketBase superuser auth
|
||||
let token = match get_superuser_token(&state).await {
|
||||
Ok(t) => t,
|
||||
Err(err) => {
|
||||
warn!("Failed to auth as PocketBase superuser in webhook: {err}");
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
|
||||
let res = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({ "subscription": "licensed" }))
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
state.token_cache.invalidate_by_user_id(user_id);
|
||||
let session = &event["data"]["object"];
|
||||
match verify_checkout_completion(&state, session).await {
|
||||
Ok(CheckoutCompletion::Grant(checkout)) => {
|
||||
if let Err(err) = mark_referral_invite_used(
|
||||
&state,
|
||||
&checkout.referral_invite_id,
|
||||
&checkout.user_id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!(
|
||||
user_id = %checkout.user_id,
|
||||
reservation_id = %checkout.reservation_id,
|
||||
referral_invite_id = %checkout.referral_invite_id,
|
||||
"Failed to mark referral invite used after Stripe checkout: {err:?}"
|
||||
);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
if let Err(err) = grant_license(&state, &checkout.user_id).await {
|
||||
warn!(
|
||||
user_id = %checkout.user_id,
|
||||
reservation_id = %checkout.reservation_id,
|
||||
"Failed to grant license after Stripe checkout: {err:?}"
|
||||
);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
if let Err(err) = mark_checkout_completed(
|
||||
&state,
|
||||
&checkout.reservation_id,
|
||||
checkout.paid_amount_pence,
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!(
|
||||
user_id = %checkout.user_id,
|
||||
reservation_id = %checkout.reservation_id,
|
||||
"Failed to mark checkout completed after license grant: {err:?}"
|
||||
);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
info!(
|
||||
user_id,
|
||||
"User subscription updated to licensed via Stripe webhook"
|
||||
user_id = %checkout.user_id,
|
||||
reservation_id = %checkout.reservation_id,
|
||||
"User subscription updated to licensed via verified Stripe checkout"
|
||||
);
|
||||
}
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
warn!(
|
||||
user_id,
|
||||
"Failed to update user subscription ({status}): {text}"
|
||||
);
|
||||
Ok(CheckoutCompletion::AlreadyHandled) => {
|
||||
info!("Stripe checkout session was already handled");
|
||||
}
|
||||
Ok(CheckoutCompletion::Rejected(reason)) => {
|
||||
warn!("Rejecting Stripe checkout completion: {reason}");
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(user_id, "PocketBase request error in webhook: {err}");
|
||||
warn!("Failed to verify Stripe checkout completion: {err:?}");
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue