diff --git a/frontend/public/assets/poi-icons/logos/the_food_warehouse.png b/frontend/public/assets/poi-icons/logos/the_food_warehouse.png index 0fb14a2..9d16347 100644 Binary files a/frontend/public/assets/poi-icons/logos/the_food_warehouse.png and b/frontend/public/assets/poi-icons/logos/the_food_warehouse.png differ diff --git a/frontend/public/video/poster.jpg b/frontend/public/video/poster.jpg index 88dc2b5..ea74d31 100644 Binary files a/frontend/public/video/poster.jpg and b/frontend/public/video/poster.jpg differ diff --git a/server-rs/src/aggregation.rs b/server-rs/src/aggregation.rs index ca4f82f..3d7bb36 100644 --- a/server-rs/src/aggregation.rs +++ b/server-rs/src/aggregation.rs @@ -1,5 +1,5 @@ use crate::consts::NAN_U16; -use crate::data::QuantRef; +use crate::data::{PostcodePoiMetrics, QuantRef}; /// Optional per-enum-value distribution tracking for a single feature. /// Counts how many rows have each enum value (by raw u16 index). @@ -21,6 +21,69 @@ pub struct Aggregator { pub enum_dist: Option, } +/// Accumulator for postcode-level POI metrics stored outside `feature_data`. +/// Only constructed when a request selects POI metric fields. +pub struct PoiAggregator { + pub mins: Box<[f32]>, + pub maxs: Box<[f32]>, + pub sums: Box<[f64]>, + pub counts: Box<[u32]>, +} + +impl PoiAggregator { + pub fn new(num_features: usize) -> Self { + Self { + mins: vec![f32::INFINITY; num_features].into_boxed_slice(), + maxs: vec![f32::NEG_INFINITY; num_features].into_boxed_slice(), + sums: vec![0.0f64; num_features].into_boxed_slice(), + counts: vec![0u32; num_features].into_boxed_slice(), + } + } + + #[inline] + pub fn add_row_selective( + &mut self, + poi_metrics: &PostcodePoiMetrics, + row: usize, + indices: &[usize], + ) { + let Some(metric_row) = poi_metrics.metric_row_for_property(row) else { + return; + }; + for &metric_idx in indices { + let raw = poi_metrics.raw_for_metric_row(metric_row, metric_idx); + if raw == NAN_U16 { + continue; + } + let value = poi_metrics.decode_raw(metric_idx, raw); + if value < self.mins[metric_idx] { + self.mins[metric_idx] = value; + } + if value > self.maxs[metric_idx] { + self.maxs[metric_idx] = value; + } + self.sums[metric_idx] += value as f64; + self.counts[metric_idx] += 1; + } + } + + pub fn merge(&mut self, other: &PoiAggregator) { + for i in 0..self.counts.len() { + if other.counts[i] == 0 { + continue; + } + if other.mins[i] < self.mins[i] { + self.mins[i] = other.mins[i]; + } + if other.maxs[i] > self.maxs[i] { + self.maxs[i] = other.maxs[i]; + } + self.sums[i] += other.sums[i]; + self.counts[i] += other.counts[i]; + } + } +} + /// Configuration for enum distribution tracking, passed to Aggregator::new. /// (feature_index, number_of_enum_values) pub type EnumDistConfig = Option<(usize, usize)>; diff --git a/server-rs/src/checkout_sessions.rs b/server-rs/src/checkout_sessions.rs new file mode 100644 index 0000000..081ded6 --- /dev/null +++ b/server-rs/src/checkout_sessions.rs @@ -0,0 +1,807 @@ +use std::sync::LazyLock; +use std::time::{SystemTime, UNIX_EPOCH}; + +use anyhow::{anyhow, Context}; +use serde_json::Value; +use tokio::sync::Mutex; +use tracing::warn; + +use crate::auth::PocketBaseUser; +use crate::pocketbase::get_superuser_token; +use crate::pocketbase_locks::acquire_pocketbase_lock; +use crate::routes::pricing::{count_licensed_users, price_for_count}; +use crate::state::AppState; + +pub const CHECKOUT_CURRENCY: &str = "gbp"; + +const CHECKOUT_SESSION_TTL_SECS: u64 = 31 * 60; +const CHECKOUT_PRODUCT_NAME: &str = "Perfect Postcodes Lifetime License"; +const CHECKOUT_COLLECTION: &str = "checkout_sessions"; +const CHECKOUT_PRICING_LOCK_NAME: &str = "checkout:pricing"; +const CHECKOUT_PRICING_LOCK_TTL_SECS: u64 = 5 * 60; +const REFERRAL_DISCOUNT_PERCENT: u64 = 30; + +static CHECKOUT_RESERVATION_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); + +pub enum CheckoutStart { + Free, + Stripe { url: String }, +} + +pub enum CheckoutCompletion { + Grant(VerifiedCheckout), + AlreadyHandled, + Rejected(String), +} + +pub struct VerifiedCheckout { + pub reservation_id: String, + pub user_id: String, + pub paid_amount_pence: u64, + pub referral_invite_id: String, +} + +#[derive(Debug)] +struct PendingCheckout { + id: String, + user_id: String, + stripe_session_id: String, + checkout_url: String, + amount_pence: u64, + expected_total_pence: u64, + currency: String, + referral_invite_id: String, + status: String, +} + +pub fn now_unix_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +pub async fn start_license_checkout( + state: &AppState, + user: &PocketBaseUser, + success_url: &str, + cancel_url: &str, + discount_coupon_id: Option<&str>, + referral_invite_id: Option<&str>, +) -> anyhow::Result { + let _guard = CHECKOUT_RESERVATION_LOCK.lock().await; + let pricing_lock = acquire_pocketbase_lock( + state, + CHECKOUT_PRICING_LOCK_NAME, + CHECKOUT_PRICING_LOCK_TTL_SECS, + ) + .await?; + let result = start_license_checkout_locked( + state, + user, + success_url, + cancel_url, + discount_coupon_id, + referral_invite_id, + ) + .await; + if let Err(err) = pricing_lock.release().await { + warn!("Failed to release checkout pricing lock: {err}"); + } + result +} + +async fn start_license_checkout_locked( + state: &AppState, + user: &PocketBaseUser, + success_url: &str, + cancel_url: &str, + discount_coupon_id: Option<&str>, + referral_invite_id: Option<&str>, +) -> anyhow::Result { + let now = now_unix_secs(); + expire_stale_pending_checkouts(state, now).await?; + + if let Some(existing) = find_active_checkout_for_user( + state, + &user.id, + discount_coupon_id.unwrap_or_default(), + referral_invite_id.unwrap_or_default(), + now, + ) + .await? + { + if !existing.checkout_url.is_empty() { + return Ok(CheckoutStart::Stripe { + url: existing.checkout_url, + }); + } + if let Err(err) = mark_checkout_status(state, &existing.id, "failed").await { + warn!( + reservation_id = %existing.id, + "Failed to fail incomplete checkout reservation: {err}" + ); + } + } + + let licensed_count = count_licensed_users(state).await?; + let pending_count = count_active_pending_checkouts(state, now).await?; + let price_pence = price_for_count(licensed_count + pending_count); + + if price_pence == 0 { + grant_license(state, &user.id).await?; + return Ok(CheckoutStart::Free); + } + + let expires_at_unix = now + CHECKOUT_SESSION_TTL_SECS; + let expected_total_pence = expected_total_for_checkout(price_pence, discount_coupon_id); + let reservation_id = create_pending_checkout( + state, + PendingCheckoutInput { + user_id: &user.id, + amount_pence: price_pence, + expected_total_pence, + currency: CHECKOUT_CURRENCY, + discount_coupon_id: discount_coupon_id.unwrap_or_default(), + referral_invite_id: referral_invite_id.unwrap_or_default(), + expires_at_unix, + }, + ) + .await?; + + let stripe_result = create_stripe_session( + state, + user, + &reservation_id, + price_pence, + success_url, + cancel_url, + expires_at_unix, + discount_coupon_id, + ) + .await; + + let (stripe_session_id, url) = match stripe_result { + Ok(session) => session, + Err(err) => { + if let Err(mark_err) = mark_checkout_status(state, &reservation_id, "failed").await { + warn!( + reservation_id, + "Failed to mark checkout reservation failed: {mark_err}" + ); + } + return Err(err); + } + }; + + if let Err(err) = attach_stripe_session(state, &reservation_id, &stripe_session_id, &url).await + { + if let Err(mark_err) = mark_checkout_status(state, &reservation_id, "failed").await { + warn!( + reservation_id, + "Failed to mark checkout reservation failed: {mark_err}" + ); + } + return Err(err); + } + + Ok(CheckoutStart::Stripe { url }) +} + +pub async fn verify_checkout_completion( + state: &AppState, + session: &Value, +) -> anyhow::Result { + let session_id = match session["id"].as_str() { + Some(id) if is_safe_stripe_session_id(id) => id, + _ => { + return Ok(CheckoutCompletion::Rejected( + "missing or invalid session id".into(), + )) + } + }; + + let checkout = match find_checkout_by_stripe_session(state, session_id).await? { + Some(checkout) => checkout, + None => { + return Ok(CheckoutCompletion::Rejected( + "checkout session has no reservation".into(), + )) + } + }; + + if checkout.status == "completed" { + return Ok(CheckoutCompletion::AlreadyHandled); + } + if checkout.status != "pending" && checkout.status != "expired" { + return Ok(CheckoutCompletion::Rejected(format!( + "checkout reservation is {}", + checkout.status + ))); + } + if checkout.stripe_session_id != session_id { + mark_checkout_status(state, &checkout.id, "invalid").await?; + return Ok(CheckoutCompletion::Rejected( + "checkout reservation session id mismatch".into(), + )); + } + + let client_reference_id = session["client_reference_id"].as_str().unwrap_or_default(); + if client_reference_id != checkout.user_id { + mark_checkout_status(state, &checkout.id, "invalid").await?; + return Ok(CheckoutCompletion::Rejected( + "checkout client_reference_id mismatch".into(), + )); + } + + let payment_status = session["payment_status"].as_str().unwrap_or_default(); + if payment_status != "paid" { + return Ok(CheckoutCompletion::Rejected(format!( + "checkout payment_status is {payment_status}" + ))); + } + + let currency = session["currency"] + .as_str() + .unwrap_or_default() + .to_ascii_lowercase(); + if currency != checkout.currency { + mark_checkout_status(state, &checkout.id, "invalid").await?; + return Ok(CheckoutCompletion::Rejected( + "checkout currency mismatch".into(), + )); + } + + let amount_subtotal = match number_field(session, "amount_subtotal") { + Some(amount) => amount, + None => { + mark_checkout_status(state, &checkout.id, "invalid").await?; + return Ok(CheckoutCompletion::Rejected( + "checkout amount_subtotal missing".into(), + )); + } + }; + if amount_subtotal != checkout.amount_pence { + mark_checkout_status(state, &checkout.id, "invalid").await?; + return Ok(CheckoutCompletion::Rejected( + "checkout amount_subtotal mismatch".into(), + )); + } + + let amount_total = match number_field(session, "amount_total") { + Some(amount) => amount, + None => { + mark_checkout_status(state, &checkout.id, "invalid").await?; + return Ok(CheckoutCompletion::Rejected( + "checkout amount_total missing".into(), + )); + } + }; + if amount_total != checkout.expected_total_pence { + mark_checkout_status(state, &checkout.id, "invalid").await?; + return Ok(CheckoutCompletion::Rejected( + "checkout amount_total mismatch".into(), + )); + } + + Ok(CheckoutCompletion::Grant(VerifiedCheckout { + reservation_id: checkout.id, + user_id: checkout.user_id, + paid_amount_pence: amount_total, + referral_invite_id: checkout.referral_invite_id, + })) +} + +pub async fn mark_checkout_completed( + state: &AppState, + reservation_id: &str, + paid_amount_pence: u64, +) -> anyhow::Result<()> { + let token = get_superuser_token(state).await?; + let pb_url = state.pocketbase_url.trim_end_matches('/'); + let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}"); + let resp = state + .http_client + .patch(&url) + .header("Authorization", format!("Bearer {token}")) + .json(&serde_json::json!({ + "status": "completed", + "paid_amount_pence": paid_amount_pence, + "completed_at_unix": now_unix_secs().to_string(), + })) + .send() + .await?; + + ensure_success(resp) + .await + .context("PocketBase checkout completion update failed") +} + +pub async fn grant_license(state: &AppState, user_id: &str) -> anyhow::Result<()> { + let token = get_superuser_token(state).await?; + + let pb_url = state.pocketbase_url.trim_end_matches('/'); + let url = format!("{pb_url}/api/collections/users/records/{user_id}"); + let resp = state + .http_client + .patch(&url) + .header("Authorization", format!("Bearer {token}")) + .json(&serde_json::json!({ "subscription": "licensed" })) + .send() + .await?; + + ensure_success(resp) + .await + .context("PocketBase license update failed")?; + + state.token_cache.invalidate_by_user_id(user_id); + Ok(()) +} + +pub async fn mark_referral_invite_used( + state: &AppState, + invite_id: &str, + user_id: &str, +) -> anyhow::Result<()> { + if invite_id.is_empty() { + return Ok(()); + } + if !is_safe_pocketbase_id(invite_id) || !is_safe_pocketbase_id(user_id) { + return Err(anyhow!("invalid PocketBase id")); + } + + let token = get_superuser_token(state).await?; + let pb_url = state.pocketbase_url.trim_end_matches('/'); + let existing_used_by = fetch_invite_used_by(state, pb_url, &token, invite_id).await?; + if existing_used_by == user_id { + return Ok(()); + } + if !existing_used_by.is_empty() { + return Err(anyhow!("referral invite already used by another account")); + } + + let url = format!("{pb_url}/api/collections/invites/records/{invite_id}"); + let resp = state + .http_client + .patch(&url) + .header("Authorization", format!("Bearer {token}")) + .json(&serde_json::json!({ + "used_by_id": user_id, + "used_at": now_unix_secs().to_string(), + })) + .send() + .await?; + + ensure_success(resp) + .await + .context("PocketBase invite usage update failed") +} + +async fn fetch_invite_used_by( + state: &AppState, + pb_url: &str, + token: &str, + invite_id: &str, +) -> anyhow::Result { + let url = format!("{pb_url}/api/collections/invites/records/{invite_id}"); + let resp = state + .http_client + .get(&url) + .header("Authorization", format!("Bearer {token}")) + .send() + .await?; + + ensure_success_ref(&resp).await?; + + let body: Value = resp.json().await?; + Ok(body["used_by_id"].as_str().unwrap_or_default().to_string()) +} + +pub async fn active_referral_checkout_user( + state: &AppState, + invite_id: &str, +) -> anyhow::Result> { + if !is_safe_pocketbase_id(invite_id) { + return Err(anyhow!("invalid PocketBase invite id")); + } + + let now = now_unix_secs(); + let token = get_superuser_token(state).await?; + let pb_url = state.pocketbase_url.trim_end_matches('/'); + let filter = format!( + "status=\"pending\" && expires_at_unix>={now} && referral_invite_id=\"{}\"", + invite_id + ); + let url = format!( + "{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1", + urlencoding::encode(&filter) + ); + let resp = state + .http_client + .get(&url) + .header("Authorization", format!("Bearer {token}")) + .send() + .await?; + + ensure_success_ref(&resp).await?; + + let body: Value = resp.json().await?; + Ok(body["items"] + .as_array() + .and_then(|items| items.first()) + .and_then(|item| item["user"].as_str()) + .map(str::to_string)) +} + +async fn count_active_pending_checkouts(state: &AppState, now: u64) -> anyhow::Result { + let token = get_superuser_token(state).await?; + let pb_url = state.pocketbase_url.trim_end_matches('/'); + let filter = format!("status=\"pending\" && expires_at_unix>={now}"); + let url = format!( + "{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1", + urlencoding::encode(&filter) + ); + let resp = state + .http_client + .get(&url) + .header("Authorization", format!("Bearer {token}")) + .send() + .await?; + + ensure_success_ref(&resp).await?; + + let body: Value = resp.json().await?; + Ok(body["totalItems"].as_u64().unwrap_or(0)) +} + +async fn find_active_checkout_for_user( + state: &AppState, + user_id: &str, + discount_coupon_id: &str, + referral_invite_id: &str, + now: u64, +) -> anyhow::Result> { + if !is_safe_pocketbase_id(user_id) { + return Err(anyhow!("invalid PocketBase user id")); + } + + let token = get_superuser_token(state).await?; + let pb_url = state.pocketbase_url.trim_end_matches('/'); + let filter = format!( + "status=\"pending\" && expires_at_unix>={now} && user=\"{}\" && discount_coupon_id=\"{}\" && referral_invite_id=\"{}\"", + user_id, discount_coupon_id, referral_invite_id + ); + let url = format!( + "{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1", + urlencoding::encode(&filter) + ); + let resp = state + .http_client + .get(&url) + .header("Authorization", format!("Bearer {token}")) + .send() + .await?; + + ensure_success_ref(&resp).await?; + + let body: Value = resp.json().await?; + let item = body["items"] + .as_array() + .and_then(|items| items.first()) + .cloned(); + + item.map(parse_pending_checkout).transpose() +} + +async fn expire_stale_pending_checkouts(state: &AppState, now: u64) -> anyhow::Result<()> { + let token = get_superuser_token(state).await?; + let pb_url = state.pocketbase_url.trim_end_matches('/'); + let filter = format!("status=\"pending\" && expires_at_unix<{now}"); + let url = format!( + "{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=50", + urlencoding::encode(&filter) + ); + let resp = state + .http_client + .get(&url) + .header("Authorization", format!("Bearer {token}")) + .send() + .await?; + + ensure_success_ref(&resp).await?; + + let body: Value = resp.json().await?; + let Some(items) = body["items"].as_array() else { + return Ok(()); + }; + + for id in items.iter().filter_map(|item| item["id"].as_str()) { + if let Err(err) = mark_checkout_status(state, id, "expired").await { + warn!( + reservation_id = id, + "Failed to expire checkout reservation: {err}" + ); + } + } + + Ok(()) +} + +struct PendingCheckoutInput<'a> { + user_id: &'a str, + amount_pence: u64, + expected_total_pence: u64, + currency: &'a str, + discount_coupon_id: &'a str, + referral_invite_id: &'a str, + expires_at_unix: u64, +} + +async fn create_pending_checkout( + state: &AppState, + input: PendingCheckoutInput<'_>, +) -> anyhow::Result { + let token = get_superuser_token(state).await?; + let pb_url = state.pocketbase_url.trim_end_matches('/'); + let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records"); + let resp = state + .http_client + .post(&url) + .header("Authorization", format!("Bearer {token}")) + .json(&serde_json::json!({ + "user": input.user_id, + "stripe_session_id": "", + "checkout_url": "", + "amount_pence": input.amount_pence, + "expected_total_pence": input.expected_total_pence, + "currency": input.currency, + "discount_coupon_id": input.discount_coupon_id, + "referral_invite_id": input.referral_invite_id, + "status": "pending", + "expires_at_unix": input.expires_at_unix, + "paid_amount_pence": 0, + "completed_at_unix": "", + })) + .send() + .await?; + + ensure_success_ref(&resp).await?; + + let body: Value = resp.json().await?; + body["id"] + .as_str() + .map(str::to_string) + .ok_or_else(|| anyhow!("PocketBase checkout reservation missing id")) +} + +#[allow(clippy::too_many_arguments)] +async fn create_stripe_session( + state: &AppState, + user: &PocketBaseUser, + reservation_id: &str, + price_pence: u64, + success_url: &str, + cancel_url: &str, + expires_at_unix: u64, + discount_coupon_id: Option<&str>, +) -> anyhow::Result<(String, String)> { + let mut form_params = vec![ + ("mode", "payment".to_string()), + ("payment_method_types[0]", "card".to_string()), + ( + "line_items[0][price_data][unit_amount]", + price_pence.to_string(), + ), + ( + "line_items[0][price_data][currency]", + CHECKOUT_CURRENCY.to_string(), + ), + ( + "line_items[0][price_data][product_data][name]", + CHECKOUT_PRODUCT_NAME.to_string(), + ), + ("line_items[0][quantity]", "1".to_string()), + ("success_url", success_url.to_string()), + ("cancel_url", cancel_url.to_string()), + ("expires_at", expires_at_unix.to_string()), + ("client_reference_id", user.id.clone()), + ("customer_email", user.email.clone()), + ("metadata[pending_checkout_id]", reservation_id.to_string()), + ("metadata[expected_amount_pence]", price_pence.to_string()), + ( + "metadata[expected_total_pence]", + expected_total_for_checkout(price_pence, discount_coupon_id).to_string(), + ), + ("metadata[expected_currency]", CHECKOUT_CURRENCY.to_string()), + ]; + + if let Some(coupon_id) = discount_coupon_id.filter(|id| !id.is_empty()) { + form_params.push(("discounts[0][coupon]", coupon_id.to_string())); + form_params.push(("metadata[discount_coupon_id]", coupon_id.to_string())); + } + + let resp = state + .http_client + .post("https://api.stripe.com/v1/checkout/sessions") + .basic_auth(&state.stripe_secret_key, None::<&str>) + .form(&form_params) + .send() + .await + .context("Stripe checkout request failed")?; + + ensure_success_ref(&resp) + .await + .context("Stripe checkout failed")?; + + let body: Value = resp + .json() + .await + .context("Failed to parse Stripe response")?; + let session_id = body["id"] + .as_str() + .filter(|id| is_safe_stripe_session_id(id)) + .map(str::to_string) + .ok_or_else(|| anyhow!("Stripe session missing valid id"))?; + let url = body["url"] + .as_str() + .map(str::to_string) + .filter(|url| !url.is_empty()) + .ok_or_else(|| anyhow!("Stripe session missing URL"))?; + + Ok((session_id, url)) +} + +async fn attach_stripe_session( + state: &AppState, + reservation_id: &str, + stripe_session_id: &str, + checkout_url: &str, +) -> anyhow::Result<()> { + let token = get_superuser_token(state).await?; + let pb_url = state.pocketbase_url.trim_end_matches('/'); + let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}"); + let resp = state + .http_client + .patch(&url) + .header("Authorization", format!("Bearer {token}")) + .json(&serde_json::json!({ + "stripe_session_id": stripe_session_id, + "checkout_url": checkout_url, + })) + .send() + .await?; + + ensure_success(resp) + .await + .context("PocketBase checkout session attach failed") +} + +async fn mark_checkout_status( + state: &AppState, + reservation_id: &str, + status: &str, +) -> anyhow::Result<()> { + let token = get_superuser_token(state).await?; + let pb_url = state.pocketbase_url.trim_end_matches('/'); + let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}"); + let resp = state + .http_client + .patch(&url) + .header("Authorization", format!("Bearer {token}")) + .json(&serde_json::json!({ "status": status })) + .send() + .await?; + + ensure_success(resp) + .await + .with_context(|| format!("PocketBase checkout status update failed for {reservation_id}")) +} + +async fn find_checkout_by_stripe_session( + state: &AppState, + stripe_session_id: &str, +) -> anyhow::Result> { + let token = get_superuser_token(state).await?; + let pb_url = state.pocketbase_url.trim_end_matches('/'); + let filter = format!("stripe_session_id=\"{}\"", stripe_session_id); + let url = format!( + "{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1", + urlencoding::encode(&filter) + ); + let resp = state + .http_client + .get(&url) + .header("Authorization", format!("Bearer {token}")) + .send() + .await?; + + ensure_success_ref(&resp).await?; + + let body: Value = resp.json().await?; + let item = body["items"] + .as_array() + .and_then(|items| items.first()) + .cloned(); + + item.map(parse_pending_checkout).transpose() +} + +fn parse_pending_checkout(item: Value) -> anyhow::Result { + Ok(PendingCheckout { + id: item["id"] + .as_str() + .ok_or_else(|| anyhow!("checkout reservation missing id"))? + .to_string(), + user_id: item["user"] + .as_str() + .ok_or_else(|| anyhow!("checkout reservation missing user"))? + .to_string(), + stripe_session_id: item["stripe_session_id"] + .as_str() + .unwrap_or_default() + .to_string(), + checkout_url: item["checkout_url"] + .as_str() + .unwrap_or_default() + .to_string(), + amount_pence: number_field(&item, "amount_pence") + .ok_or_else(|| anyhow!("checkout reservation missing amount_pence"))?, + expected_total_pence: number_field(&item, "expected_total_pence") + .ok_or_else(|| anyhow!("checkout reservation missing expected_total_pence"))?, + currency: item["currency"] + .as_str() + .unwrap_or_default() + .to_ascii_lowercase(), + referral_invite_id: item["referral_invite_id"] + .as_str() + .unwrap_or_default() + .to_string(), + status: item["status"].as_str().unwrap_or_default().to_string(), + }) +} + +fn expected_total_for_checkout(amount_pence: u64, discount_coupon_id: Option<&str>) -> u64 { + if discount_coupon_id.is_some_and(|id| !id.is_empty()) { + return ((amount_pence * (100 - REFERRAL_DISCOUNT_PERCENT)) / 100).max(1); + } + amount_pence +} + +fn number_field(value: &Value, field: &str) -> Option { + value[field].as_u64().or_else(|| { + value[field] + .as_f64() + .filter(|n| n.is_finite() && *n >= 0.0 && n.fract() == 0.0) + .map(|n| n as u64) + }) +} + +fn is_safe_stripe_session_id(id: &str) -> bool { + !id.is_empty() + && id.len() <= 128 + && id + .bytes() + .all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'-') +} + +fn is_safe_pocketbase_id(id: &str) -> bool { + !id.is_empty() && id.len() <= 32 && id.bytes().all(|b| b.is_ascii_alphanumeric()) +} + +async fn ensure_success(resp: reqwest::Response) -> anyhow::Result<()> { + if resp.status().is_success() { + return Ok(()); + } + + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + Err(anyhow!("upstream returned {status}: {text}")) +} + +async fn ensure_success_ref(resp: &reqwest::Response) -> anyhow::Result<()> { + if resp.status().is_success() { + return Ok(()); + } + + Err(anyhow!("upstream returned {}", resp.status())) +} diff --git a/server-rs/src/data/places.rs b/server-rs/src/data/places.rs index 108bc3c..251f707 100644 --- a/server-rs/src/data/places.rs +++ b/server-rs/src/data/places.rs @@ -97,7 +97,7 @@ fn build_search_text(name: &str, place_type: &str) -> String { } if place_type == "station" { - let suffix_aliases: [(&str, &[&str]); 5] = [ + let suffix_aliases: [(&str, &[&str]); 6] = [ ( " tube station", &[" underground station", " station", " tube", " underground"], @@ -118,6 +118,7 @@ fn build_search_text(name: &str, place_type: &str) -> String { " elizabeth line station", &[" station", " elizabeth line", " crossrail station"], ), + (" dlr station", &[" station", " dlr"]), ]; for (suffix, replacements) in suffix_aliases { @@ -139,10 +140,15 @@ fn extract_str_col(df: &DataFrame, name: &str) -> anyhow::Result> { let string_column = column .str() .with_context(|| format!("Column '{name}' is not a string column"))?; - Ok(string_column + string_column .into_iter() - .map(|value| value.unwrap_or("").to_string()) - .collect()) + .enumerate() + .map(|(row, value)| { + value + .map(ToString::to_string) + .with_context(|| format!("Column '{name}' has null at row {row}")) + }) + .collect() } fn extract_f32_col(df: &DataFrame, name: &str) -> anyhow::Result> { @@ -155,33 +161,37 @@ fn extract_f32_col(df: &DataFrame, name: &str) -> anyhow::Result> { let float_column = cast .f32() .with_context(|| format!("Column '{name}' is not a float32 column"))?; - Ok(float_column + float_column .into_iter() - .map(|value| value.unwrap_or(0.0)) - .collect()) + .enumerate() + .map(|(row, value)| value.with_context(|| format!("Column '{name}' has null at row {row}"))) + .collect() } -fn extract_bool_col_or_default( - df: &DataFrame, - name: &str, - default_value: bool, -) -> anyhow::Result> { - let Ok(column) = df.column(name) else { - return Ok(vec![default_value; df.height()]); - }; +fn extract_bool_col(df: &DataFrame, name: &str) -> anyhow::Result> { + let column = df + .column(name) + .with_context(|| format!("Missing column '{name}' in places data"))?; let bool_column = column .bool() .with_context(|| format!("Column '{name}' is not a boolean column"))?; - Ok(bool_column + bool_column .into_iter() - .map(|value| value.unwrap_or(default_value)) - .collect()) + .enumerate() + .map(|(row, value)| value.with_context(|| format!("Column '{name}' has null at row {row}"))) + .collect() } impl PlaceData { pub fn load(parquet_path: &Path) -> anyhow::Result { + super::run_polars_io(|| Self::load_inner(parquet_path)) + } + + fn load_inner(parquet_path: &Path) -> anyhow::Result { info!("Loading place data from {:?}...", parquet_path); + let parquet_path = PlRefPath::try_from_path(parquet_path) + .context("Failed to normalize places parquet path")?; let df = LazyFrame::scan_parquet(parquet_path, Default::default()) .context("Failed to scan places parquet")? .collect() @@ -210,7 +220,7 @@ impl PlaceData { let type_rank_vec: Vec = place_type_raw.iter().map(|pt| type_rank(pt)).collect(); let place_type = InternedColumn::build(&place_type_raw); let travel_destination = if df.column("travel_destination").is_ok() { - extract_bool_col_or_default(&df, "travel_destination", true)? + extract_bool_col(&df, "travel_destination")? } else { place_type_raw .iter() @@ -296,6 +306,7 @@ mod tests { assert!(build_search_text("King's Cross tube station", "station") .contains("kings cross underground")); assert!(build_search_text("St Albans", "city").contains("saint albans")); + assert!(build_search_text("Shadwell DLR station", "station").contains("shadwell station")); } #[test] diff --git a/server-rs/src/data/poi.rs b/server-rs/src/data/poi.rs index 503610b..9220c26 100644 --- a/server-rs/src/data/poi.rs +++ b/server-rs/src/data/poi.rs @@ -5,6 +5,7 @@ use anyhow::{bail, Context}; use polars::frame::DataFrame; use polars::lazy::frame::LazyFrame; use polars::prelude::*; +use rustc_hash::FxHashSet; use serde::Serialize; use tracing::info; @@ -17,6 +18,94 @@ pub struct POICategoryGroup { pub categories: Vec, } +const GROCERY_DASHBOARD_CATEGORIES: &[&str] = &[ + "Supermarket", + "Convenience Store", + "Bakery", + "Greengrocer", + "Aldi", + "Amazon", + "Asda", + "Booths", + "Budgens", + "Centra", + "Co-op", + "COOK", + "Costco", + "Dunnes Stores", + "Farmfoods", + "Heron Foods", + "Iceland", + "Lidl", + "Makro", + "M&S", + "Morrisons", + "Planet Organic", + "Sainsbury's", + "Spar", + "Tesco", + "The Food Warehouse", + "Waitrose", + "Whole Foods Market", +]; + +const DASHBOARD_POI_GROUPS: &[(&str, &[&str])] = &[ + ( + "Public Transport", + &[ + "Rail station", + "Tube station", + "Bus station", + "Bus stop", + "Airport", + ], + ), + ("Groceries", GROCERY_DASHBOARD_CATEGORIES), + ("Food & Drink", &["Café", "Restaurant", "Pub", "Fast Food"]), + ("Green Space", &["Park", "Playground"]), + ("Education", &["School"]), + ( + "Health", + &["GP Surgery", "Pharmacy", "Dentist", "Hospital & Clinic"], + ), + ( + "Leisure", + &[ + "Gym & Fitness", + "Sports Centre", + "Cinema", + "Theatre", + "Library", + ], + ), + ( + "Practical", + &["Post Office", "Bank", "EV Charging", "Fuel Station"], + ), +]; + +fn add_category_filter_index( + category_values: &[String], + category: &str, + selected: &mut FxHashSet, +) { + if let Some(pos) = category_values.iter().position(|value| value == category) { + selected.insert(pos as u16); + } +} + +pub fn resolve_poi_category_filter(category_values: &[String], categories: &str) -> FxHashSet { + let mut selected = FxHashSet::default(); + for part in categories.split(',') { + let category = part.trim(); + if category.is_empty() { + continue; + } + add_category_filter_index(category_values, category, &mut selected); + } + selected +} + pub struct POIData { /// Contiguous buffer holding all POI ID strings end-to-end. id_buffer: String, @@ -53,13 +142,18 @@ fn extract_str_col(df: &DataFrame, name: &str) -> anyhow::Result> { let string_column = column .str() .with_context(|| format!("Column '{name}' is not a string column"))?; - Ok(string_column + string_column .into_iter() - .map(|value| value.unwrap_or("").to_string()) - .collect()) + .enumerate() + .map(|(row, value)| { + value + .map(ToString::to_string) + .with_context(|| format!("Column '{name}' has null at row {row}")) + }) + .collect() } -fn extract_f32_col(df: &DataFrame, name: &str, default: f32) -> anyhow::Result> { +fn extract_f32_col(df: &DataFrame, name: &str) -> anyhow::Result> { let column = df .column(name) .with_context(|| format!("Missing column '{name}' in POI data"))?; @@ -69,16 +163,23 @@ fn extract_f32_col(df: &DataFrame, name: &str, default: f32) -> anyhow::Result anyhow::Result { + super::run_polars_io(|| Self::load_inner(parquet_path)) + } + + fn load_inner(parquet_path: &Path) -> anyhow::Result { info!("Loading POI data from {:?}...", parquet_path); + let parquet_path = PlRefPath::try_from_path(parquet_path) + .context("Failed to normalize POI parquet path")?; let df = LazyFrame::scan_parquet(parquet_path, Default::default()) .context("Failed to scan POI parquet")? .collect() @@ -91,18 +192,10 @@ impl POIData { let name = extract_str_col(&df, "name")?; let category_raw = extract_str_col(&df, "category")?; let group_raw = extract_str_col(&df, "group")?; - let lat = extract_f32_col(&df, "lat", 0.0)?; - let lng = extract_f32_col(&df, "lng", 0.0)?; + let lat = extract_f32_col(&df, "lat")?; + let lng = extract_f32_col(&df, "lng")?; let emoji_raw = extract_str_col(&df, "emoji")?; - let icon_category_raw = if df - .get_column_names() - .iter() - .any(|name| name.as_str() == "icon_category") - { - extract_str_col(&df, "icon_category")? - } else { - category_raw.clone() - }; + let icon_category_raw = extract_str_col(&df, "icon_category")?; // Pack POI IDs into a contiguous buffer let total_id_bytes: usize = id_raw.iter().map(|s| s.len()).sum(); @@ -152,7 +245,7 @@ impl POIData { }) } - /// Build category groups from the loaded POI data, validated against POI_GROUP_ORDER. + /// Build dashboard category groups from every category present in the loaded POI data. pub fn category_groups(&self) -> anyhow::Result> { let mut group_cats: HashMap> = HashMap::new(); let num_pois = self.category.indices.len(); @@ -174,18 +267,78 @@ impl POIData { ); } - POI_GROUP_ORDER + let preferred_order: HashMap<&str, HashMap<&str, usize>> = DASHBOARD_POI_GROUPS .iter() - .map(|group_name| { - let name = group_name.to_string(); - let mut categories: Vec = group_cats - .remove(&name) - .context("POI group validated but missing from map")? - .into_iter() - .collect(); - categories.sort(); - Ok(POICategoryGroup { name, categories }) + .map(|(group, categories)| { + ( + *group, + categories + .iter() + .enumerate() + .map(|(idx, category)| (*category, idx)) + .collect(), + ) }) - .collect() + .collect(); + + let groups: Vec = POI_GROUP_ORDER + .iter() + .filter_map(|group_name| { + let mut categories: Vec = group_cats + .get(*group_name) + .map(|categories| categories.iter().cloned().collect()) + .unwrap_or_default(); + if categories.is_empty() { + return None; + } + let group_order = preferred_order.get(*group_name); + categories.sort_by(|a, b| { + let a_order = group_order.and_then(|order| order.get(a.as_str())).copied(); + let b_order = group_order.and_then(|order| order.get(b.as_str())).copied(); + match (a_order, b_order) { + (Some(left), Some(right)) => left.cmp(&right), + (Some(_), None) => std::cmp::Ordering::Less, + (None, Some(_)) => std::cmp::Ordering::Greater, + (None, None) => a.cmp(b), + } + }); + Some(POICategoryGroup { + name: (*group_name).to_string(), + categories, + }) + }) + .collect(); + + Ok(groups) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn category_filter_matches_exact_present_categories() { + let values = vec![ + "Supermarket".to_string(), + "Tesco".to_string(), + "Aldi".to_string(), + "Rail station".to_string(), + ]; + + let selected = resolve_poi_category_filter(&values, "Supermarket,Rail station"); + + assert!(selected.contains(&0)); + assert!(selected.contains(&3)); + assert_eq!(selected.len(), 2); + } + + #[test] + fn unknown_category_filter_matches_nothing() { + let values = vec!["Supermarket".to_string()]; + + let selected = resolve_poi_category_filter(&values, "Unknown"); + + assert!(selected.is_empty()); } } diff --git a/server-rs/src/data/postcodes.rs b/server-rs/src/data/postcodes.rs index e9b9259..10387bc 100644 --- a/server-rs/src/data/postcodes.rs +++ b/server-rs/src/data/postcodes.rs @@ -195,33 +195,38 @@ impl PostcodeData { // Extract all outer rings from the geometry let rings: Vec> = match feature.geometry { - Geometry::Polygon { coordinates } => coordinates - .first() - .map(|ring| { - vec![ring - .iter() - .map(|[lon, lat]| [*lon as f32, *lat as f32]) - .collect()] - }) - .unwrap_or_default(), + Geometry::Polygon { coordinates } => { + let ring = coordinates.first().with_context(|| { + format!("Postcode '{postcode}' polygon has no outer ring") + })?; + vec![ring + .iter() + .map(|[lon, lat]| [*lon as f32, *lat as f32]) + .collect()] + } Geometry::MultiPolygon { coordinates } => coordinates .iter() - .filter_map(|poly| { - poly.first().map(|ring| { - ring.iter() - .map(|[lon, lat]| [*lon as f32, *lat as f32]) - .collect() - }) + .enumerate() + .map(|(idx, poly)| { + let ring = poly.first().with_context(|| { + format!( + "Postcode '{postcode}' multipolygon part {idx} has no outer ring" + ) + })?; + Ok(ring + .iter() + .map(|[lon, lat]| [*lon as f32, *lat as f32]) + .collect()) }) - .collect(), + .collect::>>()?, }; // Compute centroid across all vertices from all rings let total_vertices: usize = rings.iter().map(|ring| ring.len()).sum(); - let centroid = if total_vertices == 0 { - tracing::warn!(postcode = %postcode, "Postcode polygon has zero vertices, defaulting centroid to (0,0)"); - (0.0, 0.0) - } else { + if total_vertices == 0 { + anyhow::bail!("Postcode '{postcode}' polygon has zero vertices"); + } + let centroid = { let mut sum_lat: f32 = 0.0; let mut sum_lon: f32 = 0.0; for ring in &rings { diff --git a/server-rs/src/data/property.rs b/server-rs/src/data/property.rs index 612a021..279a484 100644 --- a/server-rs/src/data/property.rs +++ b/server-rs/src/data/property.rs @@ -14,6 +14,7 @@ const ADDRESS_SEARCH_CANDIDATE_LIMIT: usize = 50_000; const ADDRESS_SEARCH_MAX_POSTINGS_PER_TOKEN: usize = 250_000; const ADDRESS_SEARCH_PREFIX_MIN_LEN: usize = 4; const ADDRESS_SEARCH_PREFIX_MAX_LEN: usize = 8; +const NO_POI_METRIC_ROW: u32 = u32::MAX; fn is_numeric_dtype(dtype: &DataType) -> bool { matches!( @@ -495,6 +496,187 @@ impl QuantRef<'_> { } } +pub struct PostcodePoiMetrics { + pub feature_names: Vec, + pub name_to_index: FxHashMap, + /// Metric-major storage: columns[metric_idx][postcode_metric_idx]. + pub columns: Vec>, + pub feature_stats: Vec, + /// Per-property row lookup into the postcode metric table. + row_to_metric_idx: Vec, + dequant_a: Vec, + quant_min: Vec, + quant_range: Vec, +} + +impl PostcodePoiMetrics { + fn empty(row_count: usize) -> Self { + Self { + feature_names: Vec::new(), + name_to_index: FxHashMap::default(), + columns: Vec::new(), + feature_stats: Vec::new(), + row_to_metric_idx: vec![NO_POI_METRIC_ROW; row_count], + dequant_a: Vec::new(), + quant_min: Vec::new(), + quant_range: Vec::new(), + } + } + + fn from_postcode_df(df: &DataFrame, feature_names: Vec) -> anyhow::Result { + if feature_names.is_empty() { + return Ok(Self::empty(0)); + } + + tracing::info!( + metrics = feature_names.len(), + postcodes = df.height(), + "Building postcode POI metric side table" + ); + + let col_major: Vec> = feature_names + .par_iter() + .map(|name| { + let column = df + .column(name.as_str()) + .with_context(|| format!("Missing POI metric column '{name}'"))?; + column_to_f32_vec(column) + }) + .collect::>>()?; + + let feature_stats: Vec = col_major + .par_iter() + .enumerate() + .map(|(metric_idx, vals)| { + let name = feature_names[metric_idx].as_str(); + let bounds = features::bounds_for(name) + .with_context(|| format!("No bounds config for POI metric '{name}'"))?; + Ok(compute_feature_stats( + vals, + &bounds, + features::has_integer_bins(name), + )) + }) + .collect::>>()?; + + let mut quant_min = Vec::with_capacity(feature_names.len()); + let mut quant_range = Vec::with_capacity(feature_names.len()); + for (metric_idx, stats) in feature_stats.iter().enumerate() { + let (min, max) = match features::bounds_for(feature_names[metric_idx].as_str()) { + Some(Bounds::Fixed { min, max }) => (min, max), + _ => (stats.histogram.min, stats.histogram.max), + }; + quant_min.push(min); + quant_range.push(if max > min { max - min } else { 0.0 }); + } + let dequant_a: Vec = quant_range + .iter() + .map(|&range| { + if range > 0.0 { + range / QUANT_SCALE + } else { + 0.0 + } + }) + .collect(); + + let columns: Vec> = col_major + .par_iter() + .enumerate() + .map(|(metric_idx, vals)| { + let range = quant_range[metric_idx]; + let min = quant_min[metric_idx]; + vals.iter() + .map(|&value| { + if !value.is_finite() { + NAN_U16 + } else if range > 0.0 { + let normalized = (value - min) / range; + (normalized * QUANT_SCALE).round().clamp(0.0, QUANT_SCALE) as u16 + } else { + 0 + } + }) + .collect() + }) + .collect(); + + let name_to_index = feature_names + .iter() + .enumerate() + .map(|(idx, name)| (name.clone(), idx)) + .collect(); + + Ok(Self { + feature_names, + name_to_index, + columns, + feature_stats, + row_to_metric_idx: Vec::new(), + dequant_a, + quant_min, + quant_range, + }) + } + + fn set_row_mapping(&mut self, row_to_metric_idx: Vec) { + self.row_to_metric_idx = row_to_metric_idx; + } + + pub fn is_empty(&self) -> bool { + self.feature_names.is_empty() + } + + pub fn num_features(&self) -> usize { + self.feature_names.len() + } + + pub fn quant_ref(&self) -> QuantRef<'_> { + QuantRef { + dequant_a: &self.dequant_a, + quant_min: &self.quant_min, + quant_range: &self.quant_range, + num_numeric: self.feature_names.len(), + } + } + + #[inline] + pub fn metric_row_for_property(&self, row: usize) -> Option { + self.row_to_metric_idx + .get(row) + .copied() + .filter(|&idx| idx != NO_POI_METRIC_ROW) + .map(|idx| idx as usize) + } + + #[inline] + pub fn raw_for_metric_row(&self, metric_row: usize, metric_idx: usize) -> u16 { + self.columns[metric_idx][metric_row] + } + + #[inline] + pub fn raw_for_property_row(&self, row: usize, metric_idx: usize) -> u16 { + let Some(metric_row) = self.metric_row_for_property(row) else { + return NAN_U16; + }; + self.raw_for_metric_row(metric_row, metric_idx) + } + + #[inline] + pub fn decode_raw(&self, metric_idx: usize, raw: u16) -> f32 { + if raw == NAN_U16 { + f32::NAN + } else { + raw as f32 * self.dequant_a[metric_idx] + self.quant_min[metric_idx] + } + } + + #[inline] + pub fn get_for_property_row(&self, row: usize, metric_idx: usize) -> f32 { + self.decode_raw(metric_idx, self.raw_for_property_row(row, metric_idx)) + } +} + pub struct PropertyData { pub lat: Vec, pub lon: Vec, @@ -514,6 +696,7 @@ pub struct PropertyData { /// Per-feature: max - min (for encoding filter bounds). quant_range: Vec, pub feature_stats: Vec, + pub poi_metrics: PostcodePoiMetrics, /// Unquantized last sale price used by the price-history chart. last_known_price_raw: Vec, /// Contiguous buffer holding all address strings end-to-end. @@ -1055,19 +1238,54 @@ pub fn precompute_h3(lat: &[f32], lon: &[f32]) -> anyhow::Result> { impl PropertyData { pub fn load(properties_path: &Path, postcode_features_path: &Path) -> anyhow::Result { + super::run_polars_io(|| Self::load_inner(properties_path, postcode_features_path)) + } + + fn load_inner(properties_path: &Path, postcode_features_path: &Path) -> anyhow::Result { // Load postcode.parquet tracing::info!( "Loading postcode features from {:?}", postcode_features_path ); + let postcode_features_path = PlRefPath::try_from_path(postcode_features_path) + .context("Failed to normalize postcode parquet path")?; let postcode_df = LazyFrame::scan_parquet(postcode_features_path, Default::default()) .context("Failed to scan postcode parquet")? .collect() .context("Failed to read postcode parquet")?; tracing::info!(rows = postcode_df.height(), "Postcode features loaded"); + let mut poi_metric_names: Vec = postcode_df + .get_column_names() + .iter() + .map(|name| name.as_str()) + .filter(|&name| features::is_dynamic_poi_feature(name)) + .map(str::to_string) + .collect(); + poi_metric_names.sort_by_key(|name| features::dynamic_poi_feature_sort_key(name)); + + let poi_metric_by_postcode: FxHashMap = if poi_metric_names.is_empty() { + FxHashMap::default() + } else { + let postcode_column = postcode_df + .column("Postcode") + .context("Postcode feature parquet missing 'Postcode' column")? + .str() + .context("'Postcode' column in postcode feature parquet is not a string")?; + postcode_column + .into_iter() + .enumerate() + .filter_map(|(idx, postcode)| { + postcode.map(|postcode| (postcode.to_string(), idx as u32)) + }) + .collect() + }; + let mut poi_metrics = PostcodePoiMetrics::from_postcode_df(&postcode_df, poi_metric_names)?; + // Load properties.parquet and join with postcode data for lat/lon + area features tracing::info!("Loading properties from {:?}", properties_path); + let properties_path = PlRefPath::try_from_path(properties_path) + .context("Failed to normalize properties parquet path")?; let properties_lf = LazyFrame::scan_parquet(properties_path, Default::default()) .context("Failed to scan properties parquet")?; let combined = properties_lf @@ -1082,14 +1300,20 @@ impl PropertyData { let total_rows = combined.height(); tracing::info!(rows = total_rows, "Properties joined with postcodes"); - // Get configured feature/enum names in config order - let numeric_names = features::all_numeric_feature_names(); + // Get configured feature/enum names in config order. Dynamic POI + // metrics live in a postcode-level side table so they do not widen the + // hot row-major property feature matrix. + let configured_numeric_names = features::all_numeric_feature_names(); let enum_names = features::all_enum_feature_names(); let schema = combined.schema(); + let numeric_names: Vec = configured_numeric_names + .iter() + .map(|name| (*name).to_string()) + .collect(); for name in &numeric_names { - match schema.get(name) { + match schema.get(name.as_str()) { Some(dtype) if is_numeric_dtype(dtype) => {} Some(dtype) => bail!( "Configured numeric feature '{}' has non-numeric type {:?}", @@ -1120,8 +1344,8 @@ impl PropertyData { // Combine numeric and enum feature names (numeric first, then enum) let feature_names: Vec = numeric_names .iter() - .chain(enum_names.iter()) .map(|name| name.to_string()) + .chain(enum_names.iter().map(|name| name.to_string())) .collect(); let num_features = feature_names.len(); let num_numeric = numeric_names.len(); @@ -1138,16 +1362,16 @@ impl PropertyData { select_exprs.push(col("lon").cast(DataType::Float32)); // Select numeric features as Float32 (datetime columns → fractional year) - for &name in &numeric_names { - if is_datetime_dtype(schema.get(name).unwrap()) { + for name in &numeric_names { + if is_datetime_dtype(schema.get(name.as_str()).unwrap()) { select_exprs.push( - (col(name).dt().year().cast(DataType::Float32) - + (col(name).dt().month().cast(DataType::Float32) - lit(1.0f32)) + (col(name.as_str()).dt().year().cast(DataType::Float32) + + (col(name.as_str()).dt().month().cast(DataType::Float32) - lit(1.0f32)) / lit(12.0f32)) - .alias(name), + .alias(name.as_str()), ); } else { - select_exprs.push(col(name).cast(DataType::Float32)); + select_exprs.push(col(name.as_str()).cast(DataType::Float32)); } } @@ -1233,7 +1457,7 @@ impl PropertyData { .par_iter() .map(|name| { let column = df - .column(name) + .column(name.as_str()) .with_context(|| format!("Missing feature column '{name}'"))?; column_to_f32_vec(column) }) @@ -1244,10 +1468,10 @@ impl PropertyData { .par_iter() .enumerate() .map(|(feat_index, vals)| { - let name = numeric_names[feat_index]; + let name = numeric_names[feat_index].as_str(); let bounds = features::bounds_for(name) .with_context(|| format!("No bounds config for feature '{}'", name))?; - let stats = compute_feature_stats(vals, bounds, features::has_integer_bins(name)); + let stats = compute_feature_stats(vals, &bounds, features::has_integer_bins(name)); tracing::debug!( feature = %name, slider_min = format_args!("{:.2}", stats.slider_min), @@ -1268,8 +1492,8 @@ impl PropertyData { let mut quant_min = Vec::with_capacity(num_features); let mut quant_range = Vec::with_capacity(num_features); for (feat_idx, stats) in numeric_feature_stats.iter().enumerate() { - let (min, max) = match features::bounds_for(numeric_names[feat_idx]) { - Some(Bounds::Fixed { min, max }) => (*min, *max), + let (min, max) = match features::bounds_for(numeric_names[feat_idx].as_str()) { + Some(Bounds::Fixed { min, max }) => (min, max), _ => (stats.histogram.min, stats.histogram.max), }; quant_min.push(min); @@ -1284,10 +1508,15 @@ impl PropertyData { let string_column = column .str() .with_context(|| format!("Column '{name}' is not a string column"))?; - Ok(string_column + string_column .into_iter() - .map(|value| value.unwrap_or("").to_string()) - .collect()) + .enumerate() + .map(|(row, value)| { + value + .map(ToString::to_string) + .with_context(|| format!("Required column '{name}' has null at row {row}")) + }) + .collect() }; let address_raw = extract_string_col(&df, "Address per Property Register")?; @@ -1325,18 +1554,18 @@ impl PropertyData { // enum_col_major: Vec<(values_list, encoded_as_f32)> let enum_col_major: Vec<(Vec, Vec)> = enum_names .par_iter() - .filter_map(|&name| { - let column_data = df.column(name).ok()?; - let string_column = column_data.str().ok()?; + .map(|&name| -> anyhow::Result<(Vec, Vec)> { + let column_data = df + .column(name) + .with_context(|| format!("Required enum column '{name}' not found"))?; + let string_column = column_data + .str() + .with_context(|| format!("Enum column '{name}' is not a string column"))?; let unique_set: std::collections::HashSet = string_column .into_iter() .filter_map(|value| { - let text = value.unwrap_or(""); - if text.is_empty() { - None - } else { - Some(text.to_string()) - } + let text = value?.trim(); + (!text.is_empty()).then(|| text.to_string()) }) .collect(); @@ -1373,20 +1602,22 @@ impl PropertyData { let encoded: Vec = string_column .into_iter() - .map(|value| { - let text = value.unwrap_or(""); - if text.is_empty() { - f32::NAN - } else { - *value_to_idx.get(text).unwrap_or(&f32::NAN) - } + .enumerate() + .map(|(row, value)| { + let Some(text) = value.map(str::trim).filter(|text| !text.is_empty()) + else { + return Ok(f32::NAN); + }; + value_to_idx.get(text).copied().with_context(|| { + format!("Enum column '{name}' has unknown value '{text}' at row {row}") + }) }) - .collect(); + .collect::>>()?; tracing::debug!(column = %name, unique_values = unique.len(), "Enum feature encoded as f32"); - Some((unique, encoded)) + Ok((unique, encoded)) }) - .collect(); + .collect::>>()?; // Extract is_approx_build_date: 0.0 = exact, anything else (1.0/NaN) = approximate let is_approx_build_date_raw: Vec = if has_approx_col { @@ -1487,13 +1718,13 @@ impl PropertyData { .collect(); let last_known_price_raw: Vec = numeric_names .iter() - .position(|&name| name == "Last known price") + .position(|name| name == "Last known price") .map(|price_idx| { perm.iter() .map(|&perm_index| numeric_col_major[price_idx][perm_index as usize]) .collect() }) - .unwrap_or_else(|| vec![f32::NAN; row_count]); + .context("Required numeric column 'Last known price' not configured")?; // Build contiguous address buffer and address search index (permuted) tracing::info!("Building interned strings"); @@ -1561,6 +1792,20 @@ impl PropertyData { } let postcode_interner = postcode_rodeo.into_reader(); + let row_to_poi_metric_idx: Vec = if poi_metrics.is_empty() { + vec![NO_POI_METRIC_ROW; row_count] + } else { + perm.iter() + .map(|&old_row| { + poi_metric_by_postcode + .get(postcode_raw[old_row as usize].as_str()) + .copied() + .unwrap_or(NO_POI_METRIC_ROW) + }) + .collect() + }; + poi_metrics.set_row_mapping(row_to_poi_metric_idx); + // Pack is_approx_build_date into a bitvec (8 bools per byte) let num_bytes = row_count.div_ceil(8); let mut approx_build_date_bits = vec![0u8; num_bytes]; @@ -1697,6 +1942,7 @@ impl PropertyData { quant_min, quant_range, feature_stats, + poi_metrics, last_known_price_raw, address_buffer, address_offsets, diff --git a/server-rs/src/data/travel_time.rs b/server-rs/src/data/travel_time.rs index 0fc6d5f..f90dca8 100644 --- a/server-rs/src/data/travel_time.rs +++ b/server-rs/src/data/travel_time.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use anyhow::Context; use parking_lot::Mutex; use polars::lazy::frame::LazyFrame; +use polars::prelude::PlRefPath; use rustc_hash::{FxHashMap, FxHashSet}; use tracing::info; @@ -155,15 +156,23 @@ impl TravelTimeStore { /// Returns a cached or freshly-loaded postcode → travel_minutes mapping. pub fn get(&self, mode: &str, slug: &str) -> anyhow::Result { let key = (mode.to_string(), slug.to_string()); - - // Check cache first - { - let mut cache = self.cache.lock(); - if let Some(data) = cache.get(&key) { - return Ok(data); - } + if let Some(data) = self.get_cached(&key) { + return Ok(data); } + super::run_polars_io(|| self.load_uncached(key)) + } + + fn get_cached(&self, key: &(String, String)) -> Option { + let mut cache = self.cache.lock(); + cache.get(key) + } + + fn load_uncached(&self, key: (String, String)) -> anyhow::Result { + if let Some(data) = self.get_cached(&key) { + return Ok(data); + } + let (mode, slug) = &key; // Resolve slug to actual filename (may have numeric prefix). // Reject unknown slugs rather than falling back to raw input to prevent path traversal. let file_stem = self @@ -175,7 +184,9 @@ impl TravelTimeStore { .join(mode) .join(format!("{}.parquet", file_stem)); - let df = LazyFrame::scan_parquet(&path, Default::default()) + let parquet_path = PlRefPath::try_from_path(&path) + .with_context(|| format!("Failed to normalize path: {}", path.display()))?; + let df = LazyFrame::scan_parquet(parquet_path, Default::default()) .with_context(|| format!("Failed to scan: {}", path.display()))? .collect() .with_context(|| format!("Failed to read: {}", path.display()))?; diff --git a/server-rs/src/features.rs b/server-rs/src/features.rs index d4025b8..e779710 100644 --- a/server-rs/src/features.rs +++ b/server-rs/src/features.rs @@ -1,6 +1,7 @@ //! Static feature configuration. Every numeric and enum column in wide.parquet //! must be declared here. Unknown columns cause a startup panic. +#[derive(Clone, Copy)] pub enum Bounds { /// Fixed min/max values for the slider Fixed { min: f32, max: f32 }, @@ -61,6 +62,26 @@ pub struct FeatureGroup { } pub static FEATURE_GROUPS: &[FeatureGroup] = &[ + FeatureGroup { + name: "Transport", + features: &[ + Feature::Numeric(FeatureConfig { + name: "Distance to nearest train or tube station (km)", + bounds: Bounds::Percentile { + low: 2.0, + high: 98.0, + }, + step: 0.1, + description: "Distance to the closest train or tube station", + detail: "Straight-line distance in kilometres from the postcode to the nearest rail station or Tube/metro/tram stop.", + source: "naptan", + prefix: "", + suffix: " km", + raw: false, + absolute: false, + }), + ], + }, FeatureGroup { name: "Properties", features: &[ @@ -78,6 +99,21 @@ pub static FEATURE_GROUPS: &[FeatureGroup] = &[ detail: "From HM Land Registry Price Paid data. Freehold means you own the building and the land it stands on. Leasehold means you own the building but not the land: you have a lease from the freeholder for a set number of years.", source: "price-paid", }), + Feature::Numeric(FeatureConfig { + name: "Estimated current price", + bounds: Bounds::Fixed { + min: 0.0, + max: 2_500_000.0, + }, + step: 10000.0, + description: "Modelled estimate of the current property value", + detail: "Based on the last sale price, local repeat-sales price movement, and nearby recently sold properties. The repeat-sales index is tracked by postcode sector and property type, with smoothing and neighbour blending where data is sparse. Recent sales stay close to the recorded price; older sales depend more on the model.", + source: "price-paid", + prefix: "£", + suffix: "", + raw: false, + absolute: true, + }), Feature::Numeric(FeatureConfig { name: "Last known price", bounds: Bounds::Fixed { @@ -94,19 +130,19 @@ pub static FEATURE_GROUPS: &[FeatureGroup] = &[ absolute: true, }), Feature::Numeric(FeatureConfig { - name: "Estimated current price", - bounds: Bounds::Fixed { - min: 0.0, - max: 2_500_000.0, + name: "Est. price per sqm", + bounds: Bounds::Percentile { + low: 0.0, + high: 98.0, }, - step: 10000.0, - description: "Inflation-adjusted estimate of the current property value", - detail: "Based on the last sale price, adjusted for local price changes over time using a repeat-sales index (tracked per postcode sector and property type). If post-sale improvements are detected from EPC records, a renovation premium is added. Recent sales will be close to the original price; older sales are adjusted more.", + step: 100.0, + description: "Estimated current price divided by total floor area", + detail: "Calculated by dividing the modelled estimated current price by the total floor area from the EPC certificate. Provides a more up-to-date price-per-area comparison than the historical sale price per sqm.", source: "price-paid", prefix: "£", suffix: "", raw: false, - absolute: true, + absolute: false, }), Feature::Numeric(FeatureConfig { name: "Price per sqm", @@ -123,21 +159,6 @@ pub static FEATURE_GROUPS: &[FeatureGroup] = &[ raw: false, absolute: false, }), - Feature::Numeric(FeatureConfig { - name: "Est. price per sqm", - bounds: Bounds::Percentile { - low: 0.0, - high: 98.0, - }, - step: 100.0, - description: "Estimated current price divided by total floor area", - detail: "Calculated by dividing the inflation-adjusted estimated current price (including any renovation premium) by the total floor area from the EPC certificate. Provides a more up-to-date price-per-area comparison than the historical sale price per sqm.", - source: "price-paid", - prefix: "£", - suffix: "", - raw: false, - absolute: false, - }), Feature::Numeric(FeatureConfig { name: "Estimated monthly rent", bounds: Bounds::Percentile { low: 2.0, high: 98.0 }, @@ -248,26 +269,6 @@ pub static FEATURE_GROUPS: &[FeatureGroup] = &[ }), ], }, - FeatureGroup { - name: "Transport", - features: &[ - Feature::Numeric(FeatureConfig { - name: "Distance to nearest train or tube station (km)", - bounds: Bounds::Percentile { - low: 2.0, - high: 98.0, - }, - step: 0.1, - description: "Distance to the closest train or tube station", - detail: "Straight-line distance in kilometres from the postcode to the nearest rail station or Tube/metro/tram stop.", - source: "naptan", - prefix: "", - suffix: " km", - raw: false, - absolute: false, - }), - ], - }, FeatureGroup { name: "Education", features: &[ @@ -393,18 +394,18 @@ pub static FEATURE_GROUPS: &[FeatureGroup] = &[ }), Feature::Numeric(FeatureConfig { name: "Education, Skills and Training Score", - bounds: Bounds::Percentile { - low: 2.0, - high: 98.0, + bounds: Bounds::Fixed { + min: 0.0, + max: 100.0, }, - step: 0.1, - description: "Education quality score for the local area (higher = better)", - detail: "From the English Indices of Deprivation (inverted so higher = better). Covers school attainment, entry to higher education, adult qualifications, and English language proficiency. Higher scores indicate less deprivation.", + step: 1.0, + description: "Education and skills deprivation percentile (higher = less deprived)", + detail: "From the English Indices of Deprivation, converted to a national percentile where 0% is most deprived and 100% is least deprived. Covers school attainment, entry to higher education, adult qualifications, and English language proficiency.", source: "iod", prefix: "", - suffix: "", - raw: false, - absolute: false, + suffix: "%", + raw: true, + absolute: true, }), ], }, @@ -413,72 +414,78 @@ pub static FEATURE_GROUPS: &[FeatureGroup] = &[ features: &[ Feature::Numeric(FeatureConfig { name: "Income Score", - bounds: Bounds::Fixed { min: 0.0, max: 1.0 }, - step: 0.01, - description: "Income deprivation rate, inverted (higher = less deprived)", - detail: "From the English Indices of Deprivation (inverted so higher = better). Higher values indicate less income deprivation. Based on Income Support, income-based Jobseeker's Allowance, income-based Employment and Support Allowance, Pension Credit, Working Tax Credit and Child Tax Credit, Universal Credit, and asylum seekers.", + bounds: Bounds::Fixed { + min: 0.0, + max: 100.0, + }, + step: 1.0, + description: "Income deprivation percentile (higher = less deprived)", + detail: "From the English Indices of Deprivation, converted to a national percentile where 0% is most income deprived and 100% is least income deprived. Based on Income Support, income-based Jobseeker's Allowance, income-based Employment and Support Allowance, Pension Credit, Working Tax Credit and Child Tax Credit, Universal Credit, and asylum seekers.", source: "iod", prefix: "", - suffix: "", - raw: false, - absolute: false, + suffix: "%", + raw: true, + absolute: true, }), Feature::Numeric(FeatureConfig { name: "Employment Score", - bounds: Bounds::Fixed { min: 0.0, max: 1.0 }, - step: 0.01, - description: "Employment deprivation rate, inverted (higher = less deprived)", - detail: "From the English Indices of Deprivation (inverted so higher = better). Higher values indicate less employment deprivation. Based on claimants of Jobseeker's Allowance, Employment and Support Allowance, Incapacity Benefit, Severe Disablement Allowance, Carer's Allowance, and relevant Universal Credit claimants.", + bounds: Bounds::Fixed { + min: 0.0, + max: 100.0, + }, + step: 1.0, + description: "Employment deprivation percentile (higher = less deprived)", + detail: "From the English Indices of Deprivation, converted to a national percentile where 0% is most employment deprived and 100% is least employment deprived. Based on claimants of Jobseeker's Allowance, Employment and Support Allowance, Incapacity Benefit, Severe Disablement Allowance, Carer's Allowance, and relevant Universal Credit claimants.", source: "iod", prefix: "", - suffix: "", - raw: false, - absolute: false, + suffix: "%", + raw: true, + absolute: true, }), Feature::Numeric(FeatureConfig { name: "Health Deprivation and Disability Score", - bounds: Bounds::Percentile { - low: 2.0, - high: 98.0, + bounds: Bounds::Fixed { + min: 0.0, + max: 100.0, }, - step: 0.1, - description: "Health and disability score (higher = better health outcomes)", - detail: "From the English Indices of Deprivation (inverted so higher = better). Higher scores indicate lower risk of premature death and better quality of life. Derived from years of potential life lost, comparative illness and disability ratio, acute morbidity, and mood and anxiety disorders.", + step: 1.0, + description: "Health and disability deprivation percentile (higher = better outcomes)", + detail: "From the English Indices of Deprivation, converted to a national percentile where 0% is most health deprived and 100% is least health deprived. Derived from years of potential life lost, comparative illness and disability ratio, acute morbidity, and mood and anxiety disorders.", source: "iod", prefix: "", - suffix: "", - raw: false, - absolute: false, + suffix: "%", + raw: true, + absolute: true, }), Feature::Numeric(FeatureConfig { name: "Housing Conditions Score", - bounds: Bounds::Percentile { - low: 2.0, - high: 98.0, + bounds: Bounds::Fixed { + min: 0.0, + max: 100.0, }, - step: 0.1, - description: "Housing quality and conditions (higher = better)", - detail: "From the English Indices of Deprivation, Living Environment domain (inverted so higher = better). Measures the quality of housing stock: central heating availability, housing condition, and Decent Homes standards. Higher scores indicate better housing conditions.", + step: 1.0, + description: "Housing conditions percentile (higher = better conditions)", + detail: "From the English Indices of Deprivation, Living Environment domain, converted to a national percentile where 0% is most deprived and 100% is least deprived. Measures the quality of housing stock: central heating availability, housing condition, and Decent Homes standards.", source: "iod", prefix: "", - suffix: "", - raw: false, - absolute: false, + suffix: "%", + raw: true, + absolute: true, }), Feature::Numeric(FeatureConfig { name: "Air Quality and Road Safety Score", - bounds: Bounds::Percentile { - low: 2.0, - high: 98.0, + bounds: Bounds::Fixed { + min: 0.0, + max: 100.0, }, - step: 0.1, - description: "Air quality and road safety (higher = better)", - detail: "From the English Indices of Deprivation, Living Environment domain (inverted so higher = better). Measures the outdoor living environment quality through air quality indicators and road traffic accident casualties involving pedestrians and cyclists. Higher scores indicate better outdoor environments.", + step: 1.0, + description: "Air quality and road safety percentile (higher = better conditions)", + detail: "From the English Indices of Deprivation, Living Environment domain, converted to a national percentile where 0% is most deprived and 100% is least deprived. Measures the outdoor living environment through air quality indicators and road traffic accident casualties involving pedestrians and cyclists.", source: "iod", prefix: "", - suffix: "", - raw: false, - absolute: false, + suffix: "%", + raw: true, + absolute: true, }), ], }, @@ -996,6 +1003,126 @@ pub static FEATURE_GROUPS: &[FeatureGroup] = &[ raw: false, absolute: false, }), + Feature::Numeric(FeatureConfig { + name: "Distance to nearest grocery store (km)", + bounds: Bounds::Percentile { + low: 2.0, + high: 98.0, + }, + step: 0.1, + description: "Distance to the closest grocery shop or supermarket", + detail: "Straight-line distance in kilometres from the postcode to the nearest grocery shop, supermarket, or convenience store. Uses OpenStreetMap POIs, with Waitrose and Tesco coverage from GEOLYTIX retail points.", + source: "osm-pois", + prefix: "", + suffix: " km", + raw: false, + absolute: false, + }), + Feature::Numeric(FeatureConfig { + name: "Distance to nearest tube station (km)", + bounds: Bounds::Percentile { + low: 2.0, + high: 98.0, + }, + step: 0.1, + description: "Distance to the closest Tube, metro, tram, or DLR stop", + detail: "Straight-line distance in kilometres from the postcode to the nearest NaPTAN station classified as Tube, metro, tram, or DLR.", + source: "naptan", + prefix: "", + suffix: " km", + raw: false, + absolute: false, + }), + Feature::Numeric(FeatureConfig { + name: "Distance to nearest rail station (km)", + bounds: Bounds::Percentile { + low: 2.0, + high: 98.0, + }, + step: 0.1, + description: "Distance to the closest National Rail station", + detail: "Straight-line distance in kilometres from the postcode to the nearest NaPTAN railway station.", + source: "naptan", + prefix: "", + suffix: " km", + raw: false, + absolute: false, + }), + Feature::Numeric(FeatureConfig { + name: "Distance to nearest Waitrose (km)", + bounds: Bounds::Percentile { + low: 2.0, + high: 98.0, + }, + step: 0.1, + description: "Distance to the closest Waitrose store", + detail: "Straight-line distance in kilometres from the postcode to the nearest Waitrose or Little Waitrose store in the GEOLYTIX Grocery Retail Points dataset.", + source: "geolytix-retail-points", + prefix: "", + suffix: " km", + raw: false, + absolute: false, + }), + Feature::Numeric(FeatureConfig { + name: "Distance to nearest Tesco (km)", + bounds: Bounds::Percentile { + low: 2.0, + high: 98.0, + }, + step: 0.1, + description: "Distance to the closest Tesco store", + detail: "Straight-line distance in kilometres from the postcode to the nearest Tesco store in the GEOLYTIX Grocery Retail Points dataset.", + source: "geolytix-retail-points", + prefix: "", + suffix: " km", + raw: false, + absolute: false, + }), + Feature::Numeric(FeatureConfig { + name: "Distance to nearest cafe (km)", + bounds: Bounds::Percentile { + low: 2.0, + high: 98.0, + }, + step: 0.1, + description: "Distance to the closest cafe", + detail: "Straight-line distance in kilometres from the postcode to the nearest cafe, ice-cream shop, or internet cafe mapped in OpenStreetMap.", + source: "osm-pois", + prefix: "", + suffix: " km", + raw: false, + absolute: false, + }), + Feature::Numeric(FeatureConfig { + name: "Distance to nearest pub (km)", + bounds: Bounds::Percentile { + low: 2.0, + high: 98.0, + }, + step: 0.1, + description: "Distance to the closest pub", + detail: "Straight-line distance in kilometres from the postcode to the nearest pub, social club, brewery, distillery, or winery mapped in OpenStreetMap.", + source: "osm-pois", + prefix: "", + suffix: " km", + raw: false, + absolute: false, + }), + Feature::Numeric(FeatureConfig { + name: "Distance to nearest restaurant (km)", + bounds: Bounds::Percentile { + low: 2.0, + high: 98.0, + }, + step: 0.1, + description: "Distance to the closest restaurant", + detail: "Straight-line distance in kilometres from the postcode to the nearest restaurant or food court mapped in OpenStreetMap.", + source: "osm-pois", + prefix: "", + suffix: " km", + raw: false, + absolute: false, + }), Feature::Numeric(FeatureConfig { name: "Number of parks within 1km", bounds: Bounds::Percentile { @@ -1105,20 +1232,76 @@ pub fn order_for(name: &str) -> Option<&'static [&'static str]> { /// Whether this feature should use integer-width histogram bins. pub fn has_integer_bins(name: &str) -> bool { - INTEGER_BIN_FEATURES.contains(&name) + INTEGER_BIN_FEATURES.contains(&name) || dynamic_poi_count_radius(name).is_some() } /// Look up the Bounds config for a numeric feature by name. -pub fn bounds_for(name: &str) -> Option<&'static Bounds> { +pub fn bounds_for(name: &str) -> Option { + if dynamic_poi_distance_category(name).is_some() { + return Some(Bounds::Percentile { + low: 2.0, + high: 98.0, + }); + } + if dynamic_poi_count_radius(name).is_some() { + return Some(Bounds::Percentile { + low: 5.0, + high: 95.0, + }); + } + FEATURE_GROUPS .iter() .flat_map(|group| group.features.iter()) .find_map(|feature| match feature { - Feature::Numeric(c) if c.name == name => Some(&c.bounds), + Feature::Numeric(c) if c.name == name => Some(c.bounds), _ => None, }) } +pub fn dynamic_poi_distance_category(name: &str) -> Option<&str> { + name.strip_prefix("Distance to nearest ") + .and_then(|rest| rest.strip_suffix(" POI (km)")) + .filter(|category| !category.is_empty()) +} + +pub fn dynamic_poi_count_radius(name: &str) -> Option { + let rest = name.strip_prefix("Number of ")?; + let (_category, suffix) = rest.rsplit_once(" POIs within ")?; + match suffix { + "2km" => Some(2), + "5km" => Some(5), + _ => None, + } +} + +pub fn dynamic_poi_count_category(name: &str) -> Option<&str> { + let rest = name.strip_prefix("Number of ")?; + let (category, suffix) = rest.rsplit_once(" POIs within ")?; + matches!(suffix, "2km" | "5km") + .then_some(category) + .filter(|category| !category.is_empty()) +} + +pub fn is_dynamic_poi_feature(name: &str) -> bool { + dynamic_poi_distance_category(name).is_some() || dynamic_poi_count_category(name).is_some() +} + +pub fn dynamic_poi_feature_sort_key(name: &str) -> (u8, String) { + if let Some(category) = dynamic_poi_distance_category(name) { + return (0, category.to_ascii_lowercase()); + } + if let Some(category) = dynamic_poi_count_category(name) { + let metric_order = match dynamic_poi_count_radius(name) { + Some(2) => 1, + Some(5) => 2, + _ => 3, + }; + return (metric_order, category.to_ascii_lowercase()); + } + (9, name.to_ascii_lowercase()) +} + /// Canonical display order for POI category groups. /// The server will panic at startup if the data contains groups not in this list or vice versa. pub const POI_GROUP_ORDER: &[&str] = &[ diff --git a/server-rs/src/main.rs b/server-rs/src/main.rs index 611616a..3d6a324 100644 --- a/server-rs/src/main.rs +++ b/server-rs/src/main.rs @@ -2,6 +2,7 @@ mod aggregation; mod auth; +mod checkout_sessions; mod consts; mod data; mod features; @@ -10,6 +11,7 @@ mod metrics; mod og_middleware; pub mod parsing; mod pocketbase; +mod pocketbase_locks; mod routes; mod state; pub mod utils; diff --git a/server-rs/src/parsing.rs b/server-rs/src/parsing.rs index 4a825be..f0c92e0 100644 --- a/server-rs/src/parsing.rs +++ b/server-rs/src/parsing.rs @@ -4,8 +4,11 @@ mod filters; mod h3; pub use bounds::{bounds_intersect, h3_cell_bounds, parse_bounds, require_bounds}; -pub use fields::{parse_enum_dist, parse_field_indices, parse_field_set}; +pub use fields::{ + parse_enum_dist, parse_field_indices, parse_field_indices_with_poi, parse_field_set, +}; pub use filters::{ - count_filter_impacts, parse_filters, row_passes_filters, ParsedEnumFilter, ParsedFilter, + count_filter_impacts, parse_filters, parse_filters_with_poi, row_passes_filters, + row_passes_poi_filters, ParsedEnumFilter, ParsedFilter, ParsedPoiFilter, }; pub use h3::{cell_for_row, cell_for_row_cached, needs_parent, validate_h3_resolution}; diff --git a/server-rs/src/parsing/fields.rs b/server-rs/src/parsing/fields.rs index 38d7531..003d2f9 100644 --- a/server-rs/src/parsing/fields.rs +++ b/server-rs/src/parsing/fields.rs @@ -31,6 +31,55 @@ pub fn parse_field_indices( Ok(Some(indices)) } +pub struct ParsedFieldIndices { + /// None means no `fields` param was supplied, so normal aggregation keeps + /// its existing "all configured features" behavior. + pub normal: Option>, + pub poi: Vec, +} + +/// Parse `?fields=` against both the row-major feature matrix and the +/// postcode-level POI side table. +pub fn parse_field_indices_with_poi( + fields: Option<&str>, + name_to_index: &FxHashMap, + poi_name_to_index: &FxHashMap, +) -> Result { + let Some(fields_str) = fields else { + return Ok(ParsedFieldIndices { + normal: None, + poi: Vec::new(), + }); + }; + if fields_str.is_empty() { + return Ok(ParsedFieldIndices { + normal: Some(Vec::new()), + poi: Vec::new(), + }); + } + + let mut normal = Vec::new(); + let mut poi = Vec::new(); + for name in fields_str.split(";;") { + let name = name.trim(); + if name.is_empty() { + continue; + } + if let Some(&idx) = name_to_index.get(name) { + normal.push(idx); + } else if let Some(&idx) = poi_name_to_index.get(name) { + poi.push(idx); + } else { + return Err((StatusCode::BAD_REQUEST, format!("Unknown field: {}", name))); + } + } + + Ok(ParsedFieldIndices { + normal: Some(normal), + poi, + }) +} + /// Parse an optional `?enum_dist=` query param into (feature_index, num_values) for /// per-value distribution counting. Returns None if not requested. /// Returns 400 if the feature name is unknown or not an enum feature. @@ -73,3 +122,28 @@ pub fn parse_field_set(fields: Option<&str>) -> (bool, HashSet) { .unwrap_or_default(); (fields_specified, field_set) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_field_indices_with_poi_splits_normal_and_side_fields() { + let normal: FxHashMap = [("Price".to_string(), 0), ("Area".to_string(), 1)] + .into_iter() + .collect(); + let poi: FxHashMap = [("Distance to nearest cafe POI (km)".to_string(), 2)] + .into_iter() + .collect(); + + let parsed = parse_field_indices_with_poi( + Some("Price;;Distance to nearest cafe POI (km)"), + &normal, + &poi, + ) + .unwrap(); + + assert_eq!(parsed.normal, Some(vec![0])); + assert_eq!(parsed.poi, vec![2]); + } +} diff --git a/server-rs/src/parsing/filters.rs b/server-rs/src/parsing/filters.rs index e74898a..ee21deb 100644 --- a/server-rs/src/parsing/filters.rs +++ b/server-rs/src/parsing/filters.rs @@ -1,7 +1,7 @@ use rustc_hash::{FxHashMap, FxHashSet}; use crate::consts::NAN_U16; -use crate::data::QuantRef; +use crate::data::{PostcodePoiMetrics, QuantRef}; /// Filter for numeric features: value must be in [min_u16, max_u16] range (quantized). #[derive(Debug)] @@ -19,6 +19,20 @@ pub struct ParsedEnumFilter { pub allowed: FxHashSet, } +/// Filter for postcode-level POI metrics stored in the side table. +#[derive(Debug)] +pub struct ParsedPoiFilter { + pub metric_idx: usize, + pub min_u16: u16, + pub max_u16: u16, +} + +pub type ParsedFiltersWithPoi = ( + Vec, + Vec, + Vec, +); + /// Parse `;;`-separated filter string into numeric and enum filters. /// Numeric format: `name:min:max` /// Enum format: `name:val1|val2|val3` (pipe-separated string values) @@ -110,6 +124,101 @@ pub fn parse_filters( Ok((numeric, enums)) } +/// Parse filters while allowing dynamic POI metric names that live outside the +/// row-major property feature matrix. +pub fn parse_filters_with_poi( + filter_str: Option<&str>, + feature_name_to_index: &FxHashMap, + enum_values: &FxHashMap>, + quant: &QuantRef, + poi_name_to_index: &FxHashMap, + poi_quant: &QuantRef, +) -> Result { + let mut numeric = Vec::new(); + let mut enums = Vec::new(); + let mut poi = Vec::new(); + + let input = match filter_str.filter(|text| !text.is_empty()) { + Some(text) => text, + None => return Ok((numeric, enums, poi)), + }; + + for entry in input.split(";;") { + let parts: Vec<&str> = entry.splitn(2, ':').collect(); + if parts.len() != 2 { + return Err(format!("Malformed filter entry (missing ':'): '{entry}'")); + } + let name = parts[0].trim(); + let rest = parts[1].trim(); + + if let Some(&feat_idx) = feature_name_to_index.get(name) { + if let Some(values) = enum_values.get(&feat_idx) { + let mut allowed: FxHashSet = FxHashSet::default(); + for value in rest.split('|') { + let value = value.trim(); + match values.iter().position(|existing| existing == value) { + Some(position) => { + allowed.insert(position as u16); + } + None => { + return Err(format!( + "Unknown value '{}' for enum feature '{}'. Valid values: {:?}", + value, name, values + )); + } + } + } + enums.push(ParsedEnumFilter { feat_idx, allowed }); + } else { + let (min, max) = parse_numeric_filter_bounds(name, rest, entry)?; + numeric.push(ParsedFilter { + feat_idx, + min_u16: quant.encode_min(feat_idx, min), + max_u16: quant.encode_max(feat_idx, max), + }); + } + } else if let Some(&metric_idx) = poi_name_to_index.get(name) { + let (min, max) = parse_numeric_filter_bounds(name, rest, entry)?; + poi.push(ParsedPoiFilter { + metric_idx, + min_u16: poi_quant.encode_min(metric_idx, min), + max_u16: poi_quant.encode_max(metric_idx, max), + }); + } else { + return Err(format!("Unknown feature in filter: '{name}'")); + } + } + + numeric.sort_unstable_by_key(|f| f.max_u16.saturating_sub(f.min_u16)); + enums.sort_unstable_by_key(|f| f.allowed.len()); + poi.sort_unstable_by_key(|f| f.max_u16.saturating_sub(f.min_u16)); + + Ok((numeric, enums, poi)) +} + +fn parse_numeric_filter_bounds(name: &str, rest: &str, entry: &str) -> Result<(f32, f32), String> { + let num_parts: Vec<&str> = rest.splitn(2, ':').collect(); + if num_parts.len() != 2 { + return Err(format!( + "Numeric filter '{name}' must have format 'name:min:max', got '{entry}'" + )); + } + let min = num_parts[0] + .trim() + .parse::() + .map_err(|err| format!("Invalid min value in filter '{name}': {err}"))?; + let max = num_parts[1] + .trim() + .parse::() + .map_err(|err| format!("Invalid max value in filter '{name}': {err}"))?; + if min.is_finite() && max.is_finite() && min > max { + return Err(format!( + "Numeric filter '{name}' has inverted range: min ({min}) > max ({max})" + )); + } + Ok((min, max)) +} + /// Check if a row passes all filters. /// All features (numeric and enum) are stored in feature_data as quantized u16. pub fn row_passes_filters( @@ -130,6 +239,18 @@ pub fn row_passes_filters( }) } +#[inline] +pub fn row_passes_poi_filters( + row: usize, + filters: &[ParsedPoiFilter], + poi_metrics: &PostcodePoiMetrics, +) -> bool { + filters.iter().all(|filter| { + let raw = poi_metrics.raw_for_property_row(row, filter.metric_idx); + raw != NAN_U16 && raw >= filter.min_u16 && raw <= filter.max_u16 + }) +} + /// Single-pass marginal impact counting. /// /// Returns `(total_passing, impacts)` where `impacts[i]` is how many MORE rows @@ -330,6 +451,35 @@ mod tests { assert_eq!(enums[0].allowed.len(), 2); } + #[test] + fn parse_filters_with_poi_splits_side_table_filters() { + let tq = test_quant(3, 2); + let poi_tq = test_quant(2, 2); + let poi_map: FxHashMap = [ + ("Distance to nearest cafe POI (km)".into(), 0), + ("Number of cafe POIs within 2km".into(), 1), + ] + .into_iter() + .collect(); + + let (numeric, enums, poi) = parse_filters_with_poi( + Some("price:100:500;;rating:A;;Distance to nearest cafe POI (km):0:1.5"), + &feature_name_to_index(), + &enum_values(), + &tq.as_ref(), + &poi_map, + &poi_tq.as_ref(), + ) + .unwrap(); + + assert_eq!(numeric.len(), 1); + assert_eq!(enums.len(), 1); + assert_eq!(poi.len(), 1); + assert_eq!(poi[0].metric_idx, 0); + assert_eq!(poi[0].min_u16, 0); + assert_eq!(poi[0].max_u16, 99); + } + #[test] fn parse_filters_empty() { let tq = test_quant(3, 2); diff --git a/server-rs/src/pocketbase.rs b/server-rs/src/pocketbase.rs index 130dd36..a86d4d8 100644 --- a/server-rs/src/pocketbase.rs +++ b/server-rs/src/pocketbase.rs @@ -88,6 +88,8 @@ struct CreateCollection { update_rule: Option, #[serde(skip_serializing_if = "Option::is_none")] delete_rule: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + indexes: Vec, } #[derive(Serialize)] @@ -308,12 +310,13 @@ async fn ensure_user_fields(client: &Client, base_url: &str, token: &str) -> any let has_ai_tokens_used = fields.iter().any(|f| f["name"] == "ai_tokens_used"); let has_ai_tokens_week = fields.iter().any(|f| f["name"] == "ai_tokens_week"); - if has_is_admin + let has_all_required_fields = has_is_admin && has_subscription && has_newsletter && has_ai_tokens_used - && has_ai_tokens_week - { + && has_ai_tokens_week; + + if has_all_required_fields { info!("PocketBase users collection already has all required fields"); return Ok(()); } @@ -372,6 +375,52 @@ async fn ensure_user_fields(client: &Client, base_url: &str, token: &str) -> any Ok(()) } +/// Ensure clients can manage normal account data but cannot self-grant paid or +/// admin-only state. Superuser writes from the Rust API bypass these rules. +async fn ensure_user_auth_rules( + client: &Client, + base_url: &str, + token: &str, +) -> anyhow::Result<()> { + let url = format!("{base_url}/api/collections/users"); + let self_only = "id = @request.auth.id"; + let protected_fields_absent = concat!( + "@request.body.subscription:isset = false", + " && @request.body.is_admin:isset = false", + " && @request.body.ai_tokens_used:isset = false", + " && @request.body.ai_tokens_week:isset = false" + ); + let protected_fields_unchanged = concat!( + "@request.body.subscription:changed = false", + " && @request.body.is_admin:changed = false", + " && @request.body.ai_tokens_used:changed = false", + " && @request.body.ai_tokens_week:changed = false" + ); + let update_rule = format!("{self_only} && {protected_fields_unchanged}"); + + let resp = client + .patch(&url) + .header("Authorization", format!("Bearer {token}")) + .json(&serde_json::json!({ + "listRule": self_only, + "viewRule": self_only, + "createRule": protected_fields_absent, + "updateRule": update_rule, + "deleteRule": self_only, + })) + .send() + .await?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + anyhow::bail!("Failed to update users collection API rules ({status}): {text}"); + } + + info!("PocketBase users collection API rules hardened"); + Ok(()) +} + /// Ensure a collection has API rules allowing users to manage their own records. async fn ensure_user_owned_rules( client: &Client, @@ -404,6 +453,263 @@ async fn ensure_user_owned_rules( Ok(()) } +/// Ensure a collection is accessible only via server-side superuser calls. +async fn ensure_server_only_rules( + client: &Client, + base_url: &str, + token: &str, + collection_name: &str, +) -> anyhow::Result<()> { + let url = format!("{base_url}/api/collections/{collection_name}"); + let resp = client + .patch(&url) + .header("Authorization", format!("Bearer {token}")) + .json(&serde_json::json!({ + "listRule": serde_json::Value::Null, + "viewRule": serde_json::Value::Null, + "createRule": serde_json::Value::Null, + "updateRule": serde_json::Value::Null, + "deleteRule": serde_json::Value::Null, + })) + .send() + .await?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + anyhow::bail!("Failed to lock {collection_name} API rules ({status}): {text}"); + } + + info!("PocketBase collection '{collection_name}' locked to superuser access"); + Ok(()) +} + +async fn ensure_checkout_sessions_fields( + client: &Client, + base_url: &str, + token: &str, +) -> anyhow::Result<()> { + let url = format!("{base_url}/api/collections/checkout_sessions"); + let resp = client + .get(&url) + .header("Authorization", format!("Bearer {token}")) + .send() + .await?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + anyhow::bail!("Failed to fetch checkout_sessions collection ({status}): {text}"); + } + + let body: serde_json::Value = resp.json().await?; + let fields = body["fields"] + .as_array() + .ok_or_else(|| anyhow::anyhow!("checkout_sessions collection has no fields array"))?; + let users_id = find_users_collection_id(client, base_url, token).await?; + + let mut new_fields = fields.clone(); + let mut add_field = |name: &str, field: serde_json::Value| { + if !fields.iter().any(|f| f["name"] == name) { + new_fields.push(field); + } + }; + + add_field( + "user", + serde_json::json!({ + "name": "user", + "type": "relation", + "required": true, + "maxSelect": 1, + "collectionId": users_id, + }), + ); + add_field( + "stripe_session_id", + serde_json::json!({ "name": "stripe_session_id", "type": "text", "required": false }), + ); + add_field( + "checkout_url", + serde_json::json!({ "name": "checkout_url", "type": "text", "required": false }), + ); + add_field( + "amount_pence", + serde_json::json!({ "name": "amount_pence", "type": "number" }), + ); + add_field( + "expected_total_pence", + serde_json::json!({ "name": "expected_total_pence", "type": "number" }), + ); + add_field( + "currency", + serde_json::json!({ "name": "currency", "type": "text", "required": true }), + ); + add_field( + "discount_coupon_id", + serde_json::json!({ "name": "discount_coupon_id", "type": "text", "required": false }), + ); + add_field( + "referral_invite_id", + serde_json::json!({ "name": "referral_invite_id", "type": "text", "required": false }), + ); + add_field( + "status", + serde_json::json!({ "name": "status", "type": "text", "required": true }), + ); + add_field( + "expires_at_unix", + serde_json::json!({ "name": "expires_at_unix", "type": "number" }), + ); + add_field( + "paid_amount_pence", + serde_json::json!({ "name": "paid_amount_pence", "type": "number" }), + ); + add_field( + "completed_at_unix", + serde_json::json!({ "name": "completed_at_unix", "type": "text", "required": false }), + ); + + if new_fields.len() == fields.len() { + return Ok(()); + } + + let patch_resp = client + .patch(&url) + .header("Authorization", format!("Bearer {token}")) + .json(&serde_json::json!({ "fields": new_fields })) + .send() + .await?; + + if !patch_resp.status().is_success() { + let status = patch_resp.status(); + let text = patch_resp.text().await.unwrap_or_default(); + anyhow::bail!("Failed to patch checkout_sessions fields ({status}): {text}"); + } + + info!("PocketBase checkout_sessions collection fields updated"); + Ok(()) +} + +async fn ensure_checkout_locks_fields( + client: &Client, + base_url: &str, + token: &str, +) -> anyhow::Result<()> { + let url = format!("{base_url}/api/collections/checkout_locks"); + let resp = client + .get(&url) + .header("Authorization", format!("Bearer {token}")) + .send() + .await?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + anyhow::bail!("Failed to fetch checkout_locks collection ({status}): {text}"); + } + + let body: serde_json::Value = resp.json().await?; + let fields = body["fields"] + .as_array() + .ok_or_else(|| anyhow::anyhow!("checkout_locks collection has no fields array"))?; + + let mut new_fields = fields.clone(); + let mut add_field = |name: &str, field: serde_json::Value| { + if !fields.iter().any(|f| f["name"] == name) { + new_fields.push(field); + } + }; + + add_field( + "name", + serde_json::json!({ "name": "name", "type": "text", "required": true }), + ); + add_field( + "owner", + serde_json::json!({ "name": "owner", "type": "text", "required": true }), + ); + add_field( + "expires_at_unix", + serde_json::json!({ "name": "expires_at_unix", "type": "number" }), + ); + + if new_fields.len() == fields.len() { + return Ok(()); + } + + let patch_resp = client + .patch(&url) + .header("Authorization", format!("Bearer {token}")) + .json(&serde_json::json!({ "fields": new_fields })) + .send() + .await?; + + if !patch_resp.status().is_success() { + let status = patch_resp.status(); + let text = patch_resp.text().await.unwrap_or_default(); + anyhow::bail!("Failed to patch checkout_locks fields ({status}): {text}"); + } + + info!("PocketBase checkout_locks collection fields updated"); + Ok(()) +} + +async fn ensure_collection_indexes( + client: &Client, + base_url: &str, + token: &str, + collection_name: &str, + required_indexes: &[(&str, &str)], +) -> anyhow::Result<()> { + let url = format!("{base_url}/api/collections/{collection_name}"); + let resp = client + .get(&url) + .header("Authorization", format!("Bearer {token}")) + .send() + .await?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + anyhow::bail!("Failed to fetch {collection_name} collection ({status}): {text}"); + } + + let body: serde_json::Value = resp.json().await?; + let indexes = body["indexes"].as_array().cloned().unwrap_or_default(); + let mut new_indexes = indexes.clone(); + + for (index_name, create_sql) in required_indexes { + let exists = indexes + .iter() + .filter_map(|idx| idx.as_str()) + .any(|idx| idx.contains(index_name)); + if !exists { + new_indexes.push(serde_json::Value::String((*create_sql).to_string())); + } + } + + if new_indexes.len() == indexes.len() { + return Ok(()); + } + + let patch_resp = client + .patch(&url) + .header("Authorization", format!("Bearer {token}")) + .json(&serde_json::json!({ "indexes": new_indexes })) + .send() + .await?; + + if !patch_resp.status().is_success() { + let status = patch_resp.status(); + let text = patch_resp.text().await.unwrap_or_default(); + anyhow::bail!("Failed to patch {collection_name} indexes ({status}): {text}"); + } + + info!("PocketBase collection '{collection_name}' indexes updated"); + Ok(()) +} + /// Ensure the `saved_searches` collection has API rules allowing users to manage their own records. async fn ensure_saved_searches_rules( client: &Client, @@ -608,6 +914,7 @@ pub async fn ensure_collections( let existing = list_collections(client, base_url, &token).await?; ensure_user_fields(client, base_url, &token).await?; + ensure_user_auth_rules(client, base_url, &token).await?; if !existing.iter().any(|n| n == "saved_searches") { let users_id = find_users_collection_id(client, base_url, &token).await?; @@ -633,6 +940,7 @@ pub async fn ensure_collections( create_rule: user_only.clone(), update_rule: user_only.clone(), delete_rule: user_only, + indexes: Vec::new(), }, ) .await?; @@ -667,6 +975,7 @@ pub async fn ensure_collections( create_rule: user_only.clone(), update_rule: user_only.clone(), delete_rule: user_only, + indexes: Vec::new(), }, ) .await?; @@ -698,6 +1007,7 @@ pub async fn ensure_collections( create_rule: None, update_rule: None, delete_rule: None, + indexes: Vec::new(), }, ) .await?; @@ -705,6 +1015,86 @@ pub async fn ensure_collections( ensure_autodate_fields(client, base_url, &token, "invites").await?; } + if !existing.iter().any(|n| n == "checkout_sessions") { + let users_id = find_users_collection_id(client, base_url, &token).await?; + create_collection( + client, + base_url, + &token, + CreateCollection { + name: "checkout_sessions".to_string(), + r#type: "base".to_string(), + fields: vec![ + Field::relation("user", &users_id), + Field::text("stripe_session_id", false), + Field::text("checkout_url", false), + Field::number("amount_pence"), + Field::number("expected_total_pence"), + Field::text("currency", true), + Field::text("discount_coupon_id", false), + Field::text("referral_invite_id", false), + Field::text("status", true), + Field::number("expires_at_unix"), + Field::number("paid_amount_pence"), + Field::text("completed_at_unix", false), + Field::autodate("created", true, false), + Field::autodate("updated", true, true), + ], + list_rule: None, + view_rule: None, + create_rule: None, + update_rule: None, + delete_rule: None, + indexes: Vec::new(), + }, + ) + .await?; + } else { + ensure_server_only_rules(client, base_url, &token, "checkout_sessions").await?; + ensure_checkout_sessions_fields(client, base_url, &token).await?; + ensure_autodate_fields(client, base_url, &token, "checkout_sessions").await?; + } + + let checkout_locks_name_index = + "CREATE UNIQUE INDEX idx_checkout_locks_name ON checkout_locks (name)"; + if !existing.iter().any(|n| n == "checkout_locks") { + create_collection( + client, + base_url, + &token, + CreateCollection { + name: "checkout_locks".to_string(), + r#type: "base".to_string(), + fields: vec![ + Field::text("name", true), + Field::text("owner", true), + Field::number("expires_at_unix"), + Field::autodate("created", true, false), + Field::autodate("updated", true, true), + ], + list_rule: None, + view_rule: None, + create_rule: None, + update_rule: None, + delete_rule: None, + indexes: vec![checkout_locks_name_index.to_string()], + }, + ) + .await?; + } else { + ensure_server_only_rules(client, base_url, &token, "checkout_locks").await?; + ensure_checkout_locks_fields(client, base_url, &token).await?; + ensure_autodate_fields(client, base_url, &token, "checkout_locks").await?; + ensure_collection_indexes( + client, + base_url, + &token, + "checkout_locks", + &[("idx_checkout_locks_name", checkout_locks_name_index)], + ) + .await?; + } + if !existing.iter().any(|n| n == "short_urls") { create_collection( client, @@ -724,6 +1114,7 @@ pub async fn ensure_collections( create_rule: None, update_rule: None, delete_rule: None, + indexes: Vec::new(), }, ) .await?; @@ -753,6 +1144,7 @@ pub async fn ensure_collections( create_rule: None, update_rule: None, delete_rule: None, + indexes: Vec::new(), }, ) .await?; @@ -785,6 +1177,7 @@ pub async fn ensure_collections( create_rule: None, update_rule: None, delete_rule: None, + indexes: Vec::new(), }, ) .await?; diff --git a/server-rs/src/pocketbase_locks.rs b/server-rs/src/pocketbase_locks.rs new file mode 100644 index 0000000..e7a843f --- /dev/null +++ b/server-rs/src/pocketbase_locks.rs @@ -0,0 +1,264 @@ +use std::time::{Duration, Instant}; + +use anyhow::{anyhow, bail, Context}; +use rand::RngExt; +use serde_json::Value; +use tokio::time::sleep; +use tracing::warn; + +use crate::pocketbase::get_superuser_token; +use crate::state::AppState; + +const LOCK_COLLECTION: &str = "checkout_locks"; +const LOCK_ACQUIRE_TIMEOUT_SECS: u64 = 10; +const LOCK_RETRY_DELAY_MS: u64 = 100; + +pub struct PocketBaseLock { + client: reqwest::Client, + pb_url: String, + token: String, + record_id: Option, + name: String, +} + +struct ExistingLock { + id: String, + expires_at_unix: u64, +} + +pub async fn acquire_pocketbase_lock( + state: &AppState, + name: &str, + ttl_secs: u64, +) -> anyhow::Result { + validate_lock_name(name)?; + + let token = get_superuser_token(state).await?; + let pb_url = state.pocketbase_url.trim_end_matches('/').to_string(); + let owner = random_owner(); + let deadline = Instant::now() + Duration::from_secs(LOCK_ACQUIRE_TIMEOUT_SECS); + + loop { + let now = now_unix_secs(); + if let Some(record_id) = + try_create_lock(state, &pb_url, &token, name, &owner, now + ttl_secs).await? + { + return Ok(PocketBaseLock { + client: state.http_client.clone(), + pb_url, + token, + record_id: Some(record_id), + name: name.to_string(), + }); + } + + if let Some(existing) = find_lock(state, &pb_url, &token, name).await? { + if existing.expires_at_unix <= now { + if let Err(err) = delete_lock_record(state, &pb_url, &token, &existing.id).await { + warn!( + lock_name = name, + lock_id = %existing.id, + "Failed to delete stale PocketBase lock: {err}" + ); + } + continue; + } + } + + if Instant::now() >= deadline { + bail!("Timed out acquiring PocketBase lock '{name}'"); + } + + sleep(Duration::from_millis(LOCK_RETRY_DELAY_MS)).await; + } +} + +impl PocketBaseLock { + pub async fn release(mut self) -> anyhow::Result<()> { + let Some(record_id) = self.record_id.take() else { + return Ok(()); + }; + release_lock_record(&self.client, &self.pb_url, &self.token, &record_id) + .await + .with_context(|| format!("Failed to release PocketBase lock '{}'", self.name)) + } +} + +impl Drop for PocketBaseLock { + fn drop(&mut self) { + let Some(record_id) = self.record_id.take() else { + return; + }; + + let client = self.client.clone(); + let pb_url = self.pb_url.clone(); + let token = self.token.clone(); + let name = self.name.clone(); + tokio::spawn(async move { + if let Err(err) = release_lock_record(&client, &pb_url, &token, &record_id).await { + warn!( + lock_name = %name, + lock_id = %record_id, + "Failed to release PocketBase lock on drop: {err}" + ); + } + }); + } +} + +async fn try_create_lock( + state: &AppState, + pb_url: &str, + token: &str, + name: &str, + owner: &str, + expires_at_unix: u64, +) -> anyhow::Result> { + let url = format!("{pb_url}/api/collections/{LOCK_COLLECTION}/records"); + let resp = state + .http_client + .post(&url) + .header("Authorization", format!("Bearer {token}")) + .json(&serde_json::json!({ + "name": name, + "owner": owner, + "expires_at_unix": expires_at_unix, + })) + .send() + .await?; + + if resp.status().is_success() { + let body: Value = resp.json().await?; + return body["id"] + .as_str() + .map(str::to_string) + .map(Some) + .ok_or_else(|| anyhow!("PocketBase lock record missing id")); + } + + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + if status.is_client_error() { + return Ok(None); + } + + Err(anyhow!("PocketBase lock create failed ({status}): {text}")) +} + +async fn find_lock( + state: &AppState, + pb_url: &str, + token: &str, + name: &str, +) -> anyhow::Result> { + let filter = format!("name=\"{}\"", name); + let url = format!( + "{pb_url}/api/collections/{LOCK_COLLECTION}/records?filter={}&perPage=1", + urlencoding::encode(&filter) + ); + let resp = state + .http_client + .get(&url) + .header("Authorization", format!("Bearer {token}")) + .send() + .await?; + + ensure_success_ref(&resp).await?; + + let body: Value = resp.json().await?; + let Some(item) = body["items"].as_array().and_then(|items| items.first()) else { + return Ok(None); + }; + let id = item["id"] + .as_str() + .ok_or_else(|| anyhow!("PocketBase lock missing id"))? + .to_string(); + let expires_at_unix = number_field(item, "expires_at_unix").unwrap_or(0); + + Ok(Some(ExistingLock { + id, + expires_at_unix, + })) +} + +async fn delete_lock_record( + state: &AppState, + pb_url: &str, + token: &str, + record_id: &str, +) -> anyhow::Result<()> { + release_lock_record(&state.http_client, pb_url, token, record_id).await +} + +async fn release_lock_record( + client: &reqwest::Client, + pb_url: &str, + token: &str, + record_id: &str, +) -> anyhow::Result<()> { + let url = format!("{pb_url}/api/collections/{LOCK_COLLECTION}/records/{record_id}"); + let resp = client + .delete(&url) + .header("Authorization", format!("Bearer {token}")) + .send() + .await?; + + if resp.status().is_success() || resp.status() == reqwest::StatusCode::NOT_FOUND { + return Ok(()); + } + + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + Err(anyhow!("PocketBase lock delete failed ({status}): {text}")) +} + +fn validate_lock_name(name: &str) -> anyhow::Result<()> { + if name.is_empty() || name.len() > 80 { + bail!("invalid PocketBase lock name length"); + } + if !name + .bytes() + .all(|b| b.is_ascii_alphanumeric() || b == b':' || b == b'_' || b == b'-') + { + bail!("invalid PocketBase lock name characters"); + } + Ok(()) +} + +fn random_owner() -> String { + let mut rng = rand::rng(); + (0..24) + .map(|_| { + let idx: u8 = rng.random_range(0..36); + if idx < 10 { + (b'0' + idx) as char + } else { + (b'a' + idx - 10) as char + } + }) + .collect() +} + +fn now_unix_secs() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +fn number_field(value: &Value, field: &str) -> Option { + value[field].as_u64().or_else(|| { + value[field] + .as_f64() + .filter(|n| n.is_finite() && *n >= 0.0 && n.fract() == 0.0) + .map(|n| n as u64) + }) +} + +async fn ensure_success_ref(resp: &reqwest::Response) -> anyhow::Result<()> { + if resp.status().is_success() { + return Ok(()); + } + + Err(anyhow!("upstream returned {}", resp.status())) +} diff --git a/server-rs/src/routes/checkout.rs b/server-rs/src/routes/checkout.rs index 34bf338..cf3db11 100644 --- a/server-rs/src/routes/checkout.rs +++ b/server-rs/src/routes/checkout.rs @@ -8,10 +8,8 @@ use serde::{Deserialize, Serialize}; use tracing::{info, warn}; use crate::auth::OptionalUser; -use crate::pocketbase::get_superuser_token; -use crate::state::{AppState, SharedState}; - -use super::pricing::{count_licensed_users, price_for_count}; +use crate::checkout_sessions::{start_license_checkout, CheckoutStart}; +use crate::state::SharedState; #[derive(Deserialize)] pub struct CheckoutRequest { @@ -23,8 +21,8 @@ struct CheckoutResponse { url: String, } -/// Create a Stripe Checkout session for the lifetime license (or grant for free if in free tier). -/// Requires authentication. Optionally accepts a referral code to apply a coupon. +/// Create a reserved Stripe Checkout session for the lifetime license. +/// Requires authentication. Referral discounts are issued via invite redemption. pub async fn post_checkout( State(shared): State>, Extension(user): Extension, @@ -36,147 +34,27 @@ pub async fn post_checkout( None => return StatusCode::UNAUTHORIZED.into_response(), }; - let count = match count_licensed_users(&state).await { - Ok(c) => c, - Err(err) => { - warn!("Failed to count licensed users at checkout: {err}"); - return StatusCode::SERVICE_UNAVAILABLE.into_response(); - } - }; - - let price_pence = price_for_count(count); let public_url = &state.public_url; let success_url = format!("{public_url}/pricing?license_success=1"); - - // Free tier — grant license directly without Stripe - if price_pence == 0 { - if let Err(err) = grant_license(&state, &user.id).await { - warn!(user_id = %user.id, "Failed to grant free license: {err}"); - return StatusCode::BAD_GATEWAY.into_response(); - } - info!(user_id = %user.id, "Granted free early-bird license"); - return Json(CheckoutResponse { url: success_url }).into_response(); - } - - // Paid tier — create Stripe checkout with dynamic price - let secret_key = &state.stripe_secret_key; let cancel_url = format!("{public_url}/pricing"); - let mut form_params = vec![ - ("mode", "payment".to_string()), - ( - "line_items[0][price_data][unit_amount]", - price_pence.to_string(), - ), - ("line_items[0][price_data][currency]", "gbp".to_string()), - ( - "line_items[0][price_data][product_data][name]", - "Perfect Postcodes Lifetime License".to_string(), - ), - ("line_items[0][quantity]", "1".to_string()), - ("success_url", success_url), - ("cancel_url", cancel_url), - ("client_reference_id", user.id.clone()), - ("customer_email", user.email.clone()), - ]; - - // If a referral code is provided and valid, look it up and apply the coupon - if let Some(ref code) = req.referral_code { - if validate_referral_invite(&state, code).await { - form_params.push(( - "discounts[0][coupon]", - state.stripe_referral_coupon_id.clone(), - )); - info!(code = %code, "Applying referral coupon to checkout"); - } else { - warn!(code = %code, "Referral code validation failed, proceeding without discount"); - } + if req.referral_code.is_some() { + return ( + StatusCode::BAD_REQUEST, + "Referral codes must be redeemed from the invite link", + ) + .into_response(); } - let res = state - .http_client - .post("https://api.stripe.com/v1/checkout/sessions") - .basic_auth(secret_key, None::<&str>) - .form(&form_params) - .send() - .await; - - match res { - Ok(resp) if resp.status().is_success() => { - let body: serde_json::Value = match resp.json().await { - Ok(v) => v, - Err(err) => { - warn!("Failed to parse Stripe response: {err}"); - return StatusCode::BAD_GATEWAY.into_response(); - } - }; - let url = body["url"].as_str().unwrap_or_default().to_string(); - if url.is_empty() { - warn!("Stripe session missing URL"); - return StatusCode::BAD_GATEWAY.into_response(); - } - info!(user_id = %user.id, price_pence, "Created Stripe checkout session"); - Json(CheckoutResponse { url }).into_response() - } - Ok(resp) => { - let status = resp.status(); - let text = resp.text().await.unwrap_or_default(); - warn!("Stripe checkout failed ({status}): {text}"); - StatusCode::BAD_GATEWAY.into_response() + match start_license_checkout(&state, &user, &success_url, &cancel_url, None, None).await { + Ok(CheckoutStart::Free) => { + info!(user_id = %user.id, "Granted free early-bird license"); + Json(CheckoutResponse { url: success_url }).into_response() } + Ok(CheckoutStart::Stripe { url }) => Json(CheckoutResponse { url }).into_response(), Err(err) => { - warn!("Stripe request error: {err}"); + warn!(user_id = %user.id, "Failed to start checkout: {err:?}"); StatusCode::BAD_GATEWAY.into_response() } } } - -/// Grant a license by updating the user's subscription to "licensed" in PocketBase. -async fn grant_license(state: &AppState, user_id: &str) -> anyhow::Result<()> { - let token = get_superuser_token(state).await?; - - let pb_url = state.pocketbase_url.trim_end_matches('/'); - let url = format!("{pb_url}/api/collections/users/records/{user_id}"); - let resp = state - .http_client - .patch(&url) - .header("Authorization", format!("Bearer {token}")) - .json(&serde_json::json!({ "subscription": "licensed" })) - .send() - .await?; - - if !resp.status().is_success() { - let status = resp.status(); - let text = resp.text().await.unwrap_or_default(); - anyhow::bail!("PocketBase update failed ({status}): {text}"); - } - - state.token_cache.invalidate_by_user_id(user_id); - Ok(()) -} - -/// Check if a referral invite code exists and is unused. -async fn validate_referral_invite(state: &AppState, code: &str) -> bool { - // Only allow alphanumeric codes to prevent PocketBase filter injection - if code.is_empty() || code.len() > 20 || !code.bytes().all(|b| b.is_ascii_alphanumeric()) { - return false; - } - - let pb_url = state.pocketbase_url.trim_end_matches('/'); - let filter = format!( - "code=\"{}\" && invite_type=\"referral\" && used_by_id=\"\"", - code - ); - let url = format!( - "{pb_url}/api/collections/invites/records?filter={}&perPage=1", - urlencoding::encode(&filter) - ); - - match state.http_client.get(&url).send().await { - Ok(resp) if resp.status().is_success() => { - let body: serde_json::Value = resp.json().await.unwrap_or_default(); - body["totalItems"].as_u64().unwrap_or(0) > 0 - } - _ => false, - } -} diff --git a/server-rs/src/routes/export.rs b/server-rs/src/routes/export.rs index 56087f5..f60ae09 100644 --- a/server-rs/src/routes/export.rs +++ b/server-rs/src/routes/export.rs @@ -1,6 +1,7 @@ use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; use std::sync::Arc; +use std::time::Duration; use axum::extract::{Query, State}; use axum::http::{header, HeaderMap, StatusCode}; @@ -13,14 +14,18 @@ use tracing::{info, warn}; use crate::auth::OptionalUser; use crate::consts::NAN_U16; -use crate::data::QuantRef; -use crate::features::INTEGER_BIN_FEATURES; +use crate::data::{PostcodePoiMetrics, QuantRef}; +use crate::features; use crate::licensing::check_license_bounds; -use crate::parsing::{parse_field_indices, parse_filters, require_bounds, row_passes_filters}; +use crate::parsing::{ + parse_field_indices_with_poi, parse_filters_with_poi, require_bounds, row_passes_filters, + row_passes_poi_filters, +}; use crate::routes::{fetch_screenshot_bytes, FeatureInfo}; use crate::state::SharedState; const MAX_EXPORT_POSTCODES: usize = 250; +const EXPORT_SCREENSHOT_TIMEOUT_SECS: u64 = 12; /// Height (in pixels) reserved for the screenshot row const IMAGE_ROW_HEIGHT: f64 = 225.0; @@ -41,11 +46,11 @@ struct PostcodeExportAgg { } impl PostcodeExportAgg { - fn new(num_features: usize) -> Self { + fn new(total_features: usize) -> Self { Self { count: 0, - sums: vec![0.0; num_features], - finite_counts: vec![0; num_features], + sums: vec![0.0; total_features], + finite_counts: vec![0; total_features], enum_freqs: FxHashMap::default(), } } @@ -58,6 +63,7 @@ impl PostcodeExportAgg { num_features: usize, enum_indices: &FxHashMap, quant: &QuantRef, + poi_metrics: &PostcodePoiMetrics, ) { self.count += 1; let base = row * num_features; @@ -79,6 +85,18 @@ impl PostcodeExportAgg { self.finite_counts[feat_idx] += 1; } } + + let poi_offset = num_features; + for metric_idx in 0..poi_metrics.num_features() { + let raw = poi_metrics.raw_for_property_row(row, metric_idx); + if raw == NAN_U16 { + continue; + } + let value = poi_metrics.decode_raw(metric_idx, raw); + let out_idx = poi_offset + metric_idx; + self.sums[out_idx] += value as f64; + self.finite_counts[out_idx] += 1; + } } } @@ -138,13 +156,17 @@ pub async fn get_export( check_license_bounds(&user.0, (south, west, north, east), None)?; let quant = state.data.quant_ref(); - let (parsed_filters, parsed_enum_filters) = parse_filters( + let poi_quant = state.data.poi_metrics.quant_ref(); + let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi( params.filters.as_deref(), &state.feature_name_to_index, &state.data.enum_values, &quant, + &state.data.poi_metrics.name_to_index, + &poi_quant, ) .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; + let has_poi_filters = !parsed_poi_filters.is_empty(); let filters_str = params.filters; let fields_str = params.fields; @@ -164,16 +186,28 @@ pub async fn get_export( // Fetch screenshot (async, before spawn_blocking) let auth_header = headers.get(header::AUTHORIZATION); - let screenshot_bytes = match fetch_screenshot_bytes(&state, &frontend_params, auth_header).await + let screenshot_fetch = fetch_screenshot_bytes(&state, &frontend_params, auth_header); + let screenshot_bytes = match tokio::time::timeout( + Duration::from_secs(EXPORT_SCREENSHOT_TIMEOUT_SECS), + screenshot_fetch, + ) + .await { - Ok(bytes) => { + Ok(Ok(bytes)) => { info!(bytes = bytes.len(), "Fetched screenshot for export"); Some(bytes) } - Err(err) => { + Ok(Err(err)) => { warn!("Screenshot failed for export: {err}"); None } + Err(_) => { + warn!( + timeout_secs = EXPORT_SCREENSHOT_TIMEOUT_SECS, + "Screenshot timed out for export" + ); + None + } }; // Build feature name → description map from the precomputed features response @@ -200,6 +234,9 @@ pub async fn get_export( let feature_names = &state.data.feature_names; let enum_values = &state.data.enum_values; let postcode_data = &state.postcode_data; + let poi_metrics = &state.data.poi_metrics; + let poi_offset = num_features; + let total_export_features = num_features + poi_metrics.num_features(); // Build set of enum feature indices for quick lookup let enum_indices: FxHashMap = enum_values.keys().map(|&idx| (idx, ())).collect(); @@ -219,6 +256,10 @@ pub async fn get_export( ) { return; } + if has_poi_filters && !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics) + { + return; + } let postcode = state.data.postcode(row); if let Some(&pc_idx) = postcode_data.postcode_to_idx.get(postcode) { postcode_rows.entry(pc_idx).or_default().push(row); @@ -229,9 +270,16 @@ pub async fn get_export( let mut postcode_aggs: Vec<(usize, PostcodeExportAgg)> = Vec::with_capacity(postcode_rows.len()); for (pc_idx, rows) in postcode_rows { - let mut agg = PostcodeExportAgg::new(num_features); + let mut agg = PostcodeExportAgg::new(total_export_features); for &row in &rows { - agg.add_row(feature_data, row, num_features, &enum_indices, &quant); + agg.add_row( + feature_data, + row, + num_features, + &enum_indices, + &quant, + poi_metrics, + ); } if agg.count > 0 { postcode_aggs.push((pc_idx, agg)); @@ -265,14 +313,19 @@ pub async fn get_export( // Determine column order: filter features first, then remaining let filter_feature_names = extract_filter_feature_names(filters_str.as_deref()); - let field_indices = - parse_field_indices(fields_str.as_deref(), &state.feature_name_to_index) - .map_err(|err| err.1)?; + let field_indices = parse_field_indices_with_poi( + fields_str.as_deref(), + &state.feature_name_to_index, + &state.data.poi_metrics.name_to_index, + ) + .map_err(|err| err.1)?; - let all_feature_indices: Vec = if let Some(ref indices) = field_indices { - indices.clone() + let all_feature_indices: Vec = if let Some(ref indices) = field_indices.normal { + let mut selected = indices.clone(); + selected.extend(field_indices.poi.iter().map(|idx| poi_offset + *idx)); + selected } else { - let mut ordered = Vec::with_capacity(num_features); + let mut ordered = Vec::with_capacity(total_export_features); let mut used = FxHashSet::default(); for name in &filter_feature_names { @@ -280,6 +333,11 @@ pub async fn get_export( if used.insert(idx) { ordered.push(idx); } + } else if let Some(&idx) = state.data.poi_metrics.name_to_index.get(name.as_str()) { + let virtual_idx = poi_offset + idx; + if used.insert(virtual_idx) { + ordered.push(virtual_idx); + } } } for idx in 0..num_features { @@ -287,15 +345,42 @@ pub async fn get_export( ordered.push(idx); } } + for idx in 0..poi_metrics.num_features() { + let virtual_idx = poi_offset + idx; + if used.insert(virtual_idx) { + ordered.push(virtual_idx); + } + } ordered }; // Filter-only feature indices for the Selected sheet let filter_feature_indices: Vec = filter_feature_names .iter() - .filter_map(|name| state.feature_name_to_index.get(name.as_str()).copied()) + .filter_map(|name| { + state + .feature_name_to_index + .get(name.as_str()) + .copied() + .or_else(|| { + state + .data + .poi_metrics + .name_to_index + .get(name.as_str()) + .map(|idx| poi_offset + *idx) + }) + }) .collect(); + let feature_name_for_idx = |idx: usize| -> &str { + if idx < num_features { + &feature_names[idx] + } else { + &poi_metrics.feature_names[idx - poi_offset] + } + }; + // Build feature unit map (feat_idx → (prefix, suffix)) for number formatting let feature_units: FxHashMap = state .features_response @@ -309,16 +394,25 @@ pub async fn get_export( suffix, .. } => { - let idx = state.feature_name_to_index.get(name.as_str())?; - Some((*idx, (*prefix, *suffix))) + if let Some(&idx) = state.feature_name_to_index.get(name.as_str()) { + Some((idx, (*prefix, *suffix))) + } else { + state + .data + .poi_metrics + .name_to_index + .get(name.as_str()) + .map(|idx| (poi_offset + *idx, (*prefix, *suffix))) + } } _ => None, }) .collect(); - let integer_feature_indices: FxHashSet = INTEGER_BIN_FEATURES + let integer_feature_indices: FxHashSet = all_feature_indices .iter() - .filter_map(|name| state.feature_name_to_index.get(*name).copied()) + .copied() + .filter(|&idx| features::has_integer_bins(feature_name_for_idx(idx))) .collect(); // Build Excel number formats per feature index for unit display @@ -435,7 +529,7 @@ pub async fn get_export( .write_string_with_format( header_row, col, - &feature_names[feat_idx], + feature_name_for_idx(feat_idx), &header_fmt, ) .map_err(|e| format!("Failed to write header: {e}"))?; @@ -453,7 +547,7 @@ pub async fn get_export( for (col_offset, &feat_idx) in feat_indices.iter().enumerate() { let col = (col_offset + 2) as u16; let desc = feature_descriptions - .get(&feature_names[feat_idx]) + .get(feature_name_for_idx(feat_idx)) .map(String::as_str) .unwrap_or(""); sheet @@ -477,7 +571,7 @@ pub async fn get_export( for (col_offset, &feat_idx) in feat_indices.iter().enumerate() { let col = (col_offset + 2) as u16; - if enum_indices.contains_key(&feat_idx) { + if feat_idx < num_features && enum_indices.contains_key(&feat_idx) { if let Some(freqs) = agg.enum_freqs.get(&feat_idx) { if let Some((&mode_bits, _)) = freqs.iter().max_by_key(|(_, &count)| count) @@ -543,7 +637,7 @@ pub async fn get_export( .map_err(|e| format!("Failed to set column width: {e}"))?; for col_offset in 0..feat_indices.len() { let col = (col_offset + 2) as u16; - let feat_name = &feature_names[feat_indices[col_offset]]; + let feat_name = feature_name_for_idx(feat_indices[col_offset]); let width = (feat_name.len() as f64 * 1.1).clamp(10.0, 30.0); sheet .set_column_width(col, width) diff --git a/server-rs/src/routes/features.rs b/server-rs/src/routes/features.rs index 91292b5..5d34d3c 100644 --- a/server-rs/src/routes/features.rs +++ b/server-rs/src/routes/features.rs @@ -7,7 +7,7 @@ use serde::Serialize; use tracing::info; use crate::data::{Histogram, PropertyData}; -use crate::features::{Feature, FEATURE_GROUPS}; +use crate::features::{self, Feature, FEATURE_GROUPS}; use crate::state::SharedState; fn is_empty(val: &str) -> bool { @@ -28,9 +28,9 @@ pub enum FeatureInfo { max: f32, step: f32, histogram: Histogram, - description: &'static str, - detail: &'static str, - source: &'static str, + description: String, + detail: String, + source: String, #[serde(skip_serializing_if = "is_empty")] prefix: &'static str, #[serde(skip_serializing_if = "is_empty")] @@ -45,9 +45,9 @@ pub enum FeatureInfo { name: String, values: Vec, counts: HashMap, - description: &'static str, - detail: &'static str, - source: &'static str, + description: String, + detail: String, + source: String, }, } @@ -85,9 +85,9 @@ pub fn build_features_response(data: &PropertyData) -> FeaturesResponse { max: stats.slider_max, step: config.step, histogram: stats.histogram.clone(), - description: config.description, - detail: config.detail, - source: config.source, + description: config.description.to_string(), + detail: config.detail.to_string(), + source: config.source.to_string(), prefix: config.prefix, suffix: config.suffix, raw: config.raw, @@ -118,9 +118,9 @@ pub fn build_features_response(data: &PropertyData) -> FeaturesResponse { name: config.name.to_string(), values: values.clone(), counts, - description: config.description, - detail: config.detail, - source: config.source, + description: config.description.to_string(), + detail: config.detail.to_string(), + source: config.source.to_string(), }); } } @@ -136,6 +136,58 @@ pub fn build_features_response(data: &PropertyData) -> FeaturesResponse { } } + let mut dynamic_poi_features = Vec::new(); + for (feat_idx, name) in data.poi_metrics.feature_names.iter().enumerate() { + if let Some(category) = features::dynamic_poi_distance_category(name) { + let stats = &data.poi_metrics.feature_stats[feat_idx]; + dynamic_poi_features.push(FeatureInfo::Numeric { + name: name.clone(), + min: stats.slider_min, + max: stats.slider_max, + step: 0.1, + histogram: stats.histogram.clone(), + description: format!("Distance to the closest {category} POI"), + detail: format!( + "Straight-line distance in kilometres from the postcode to the nearest {category} point of interest in the POI dataset." + ), + source: "osm-pois".to_string(), + prefix: "", + suffix: " km", + raw: false, + absolute: false, + }); + } else if let Some(category) = features::dynamic_poi_count_category(name) { + let stats = &data.poi_metrics.feature_stats[feat_idx]; + let radius = features::dynamic_poi_count_radius(name).unwrap_or(0); + dynamic_poi_features.push(FeatureInfo::Numeric { + name: name.clone(), + min: stats.slider_min, + max: stats.slider_max, + step: 1.0, + histogram: stats.histogram.clone(), + description: format!("Number of {category} POIs within {radius}km"), + detail: format!( + "Count of {category} points of interest within a {radius}km radius of the property's postcode centroid." + ), + source: "osm-pois".to_string(), + prefix: "", + suffix: "", + raw: false, + absolute: false, + }); + } + } + if !dynamic_poi_features.is_empty() { + dynamic_poi_features.sort_by_key(|feature| match feature { + FeatureInfo::Numeric { name, .. } => features::dynamic_poi_feature_sort_key(name), + FeatureInfo::Enum { name, .. } => features::dynamic_poi_feature_sort_key(name), + }); + groups.push(FeatureGroupResponse { + name: "Nearby POIs".to_string(), + features: dynamic_poi_features, + }); + } + FeaturesResponse { groups } } diff --git a/server-rs/src/routes/filter_counts.rs b/server-rs/src/routes/filter_counts.rs index c6f12a1..08d20ef 100644 --- a/server-rs/src/routes/filter_counts.rs +++ b/server-rs/src/routes/filter_counts.rs @@ -9,7 +9,7 @@ use tracing::info; use crate::consts::NAN_U16; use crate::data::travel_time::TravelData; -use crate::parsing::{parse_filters, require_bounds}; +use crate::parsing::{parse_filters_with_poi, require_bounds}; use crate::routes::travel_time::parse_optional_travel; use crate::state::SharedState; @@ -36,18 +36,21 @@ pub async fn get_filter_counts( require_bounds(params.bounds).map_err(IntoResponse::into_response)?; let quant = state.data.quant_ref(); - let (parsed_filters, parsed_enum_filters) = parse_filters( + let poi_quant = state.data.poi_metrics.quant_ref(); + let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi( params.filters.as_deref(), &state.feature_name_to_index, &state.data.enum_values, &quant, + &state.data.poi_metrics.name_to_index, + &poi_quant, ) .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; let travel_entries = parse_optional_travel(params.travel.as_deref()) .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; - let num_regular = parsed_filters.len() + parsed_enum_filters.len(); + let num_regular = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len(); // Only travel entries with a filter range count as filters for impact tracking let travel_filter_indices: Vec = travel_entries .iter() @@ -65,6 +68,7 @@ pub async fn get_filter_counts( } let filters_str = params.filters; + let has_poi_filters = !parsed_poi_filters.is_empty(); let response = tokio::task::spawn_blocking(move || -> Result { let t0 = std::time::Instant::now(); @@ -124,6 +128,23 @@ pub async fn get_filter_counts( } } + // Test travel time filters + if fail_count <= 1 && has_poi_filters { + for (i, f) in parsed_poi_filters.iter().enumerate() { + let raw = state + .data + .poi_metrics + .raw_for_property_row(row, f.metric_idx); + if raw == NAN_U16 || raw < f.min_u16 || raw > f.max_u16 { + fail_count += 1; + fail_index = parsed_filters.len() + parsed_enum_filters.len() + i; + if fail_count > 1 { + break; + } + } + } + } + // Test travel time filters if fail_count <= 1 && has_travel { let postcode = pc_interner.resolve(&pc_keys[row]); @@ -169,8 +190,15 @@ pub async fn get_filter_counts( let name = if i < parsed_filters.len() { state.data.feature_names[parsed_filters[i].feat_idx].clone() } else if i < num_regular { - let ei = i - parsed_filters.len(); - state.data.feature_names[parsed_enum_filters[ei].feat_idx].clone() + let enum_start = parsed_filters.len(); + let poi_start = enum_start + parsed_enum_filters.len(); + if i < poi_start { + let ei = i - enum_start; + state.data.feature_names[parsed_enum_filters[ei].feat_idx].clone() + } else { + let pi = i - poi_start; + state.data.poi_metrics.feature_names[parsed_poi_filters[pi].metric_idx].clone() + } } else { let slot = i - num_regular; let ti = travel_filter_indices[slot]; diff --git a/server-rs/src/routes/hexagon_stats.rs b/server-rs/src/routes/hexagon_stats.rs index 9bbb4cf..01226a1 100644 --- a/server-rs/src/routes/hexagon_stats.rs +++ b/server-rs/src/routes/hexagon_stats.rs @@ -13,8 +13,8 @@ use tracing::{info, warn}; use crate::auth::OptionalUser; use crate::licensing::{check_license_bounds, resolve_share_code}; use crate::parsing::{ - cell_for_row_cached, h3_cell_bounds, needs_parent, parse_field_set, parse_filters, - row_passes_filters, validate_h3_resolution, + cell_for_row_cached, h3_cell_bounds, needs_parent, parse_field_set, parse_filters_with_poi, + row_passes_filters, row_passes_poi_filters, validate_h3_resolution, }; use crate::state::SharedState; @@ -110,15 +110,19 @@ pub async fn get_hexagon_stats( let h3_str = params.h3; let quant = state.data.quant_ref(); - let (parsed_filters, parsed_enum_filters) = parse_filters( + let poi_quant = state.data.poi_metrics.quant_ref(); + let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi( params.filters.as_deref(), &state.feature_name_to_index, &state.data.enum_values, &quant, + &state.data.poi_metrics.name_to_index, + &poi_quant, ) .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; - let num_filters = parsed_filters.len() + parsed_enum_filters.len(); + let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len(); let filters_str = params.filters; + let has_poi_filters = !parsed_poi_filters.is_empty(); let (fields_specified, field_set) = parse_field_set(params.fields.as_deref()); @@ -161,6 +165,12 @@ pub async fn get_hexagon_stats( feature_data, num_features, ) + && (!has_poi_filters + || row_passes_poi_filters( + row, + &parsed_poi_filters, + &state.data.poi_metrics, + )) { if has_travel { let postcode = state.data.postcode(row); @@ -233,7 +243,7 @@ pub async fn get_hexagon_stats( let price_history = stats::extract_price_history(&matching_rows, &state.data, &state.feature_name_to_index); - let (numeric_features, enum_features_out) = stats::compute_feature_stats( + let (mut numeric_features, enum_features_out) = stats::compute_feature_stats( &matching_rows, &state.data, &state.data.feature_names, @@ -242,6 +252,12 @@ pub async fn get_hexagon_stats( fields_specified, &field_set, ); + numeric_features.extend(stats::compute_poi_feature_stats( + &matching_rows, + &state.data.poi_metrics, + fields_specified, + &field_set, + )); let elapsed = start_time.elapsed(); info!( diff --git a/server-rs/src/routes/hexagons.rs b/server-rs/src/routes/hexagons.rs index fe557e0..434deaa 100644 --- a/server-rs/src/routes/hexagons.rs +++ b/server-rs/src/routes/hexagons.rs @@ -11,14 +11,15 @@ use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; use tracing::info; -use crate::aggregation::{Aggregator, EnumDistConfig}; +use crate::aggregation::{Aggregator, EnumDistConfig, PoiAggregator}; use crate::auth::OptionalUser; use crate::consts::MAX_CELLS_PER_REQUEST; use crate::data::travel_time::TravelData; use crate::licensing::{check_license_bounds, resolve_share_code}; use crate::parsing::{ - cell_for_row_cached, needs_parent, parse_enum_dist, parse_field_indices, parse_filters, - require_bounds, row_passes_filters, validate_h3_resolution, + cell_for_row_cached, needs_parent, parse_enum_dist, parse_field_indices_with_poi, + parse_filters_with_poi, require_bounds, row_passes_filters, row_passes_poi_filters, + validate_h3_resolution, }; use crate::routes::travel_time::{parse_optional_travel, TravelTimeAgg}; use crate::state::SharedState; @@ -29,6 +30,7 @@ const PARALLEL_THRESHOLD: usize = 50_000; /// Per-thread aggregation result: feature accumulators + travel time accumulators. type ChunkResult = ( FxHashMap, + FxHashMap, Vec>, ); @@ -79,11 +81,14 @@ pub struct HexagonParams { #[allow(clippy::too_many_arguments)] fn build_feature_maps( groups: &FxHashMap, + poi_groups: &FxHashMap, min_keys: &[String], max_keys: &[String], avg_keys: &[String], num_features: usize, indices: Option<&[usize]>, + poi_feature_names: &[String], + poi_indices: &[usize], query_bounds: (f64, f64, f64, f64), resolution: h3o::Resolution, travel_aggs: &[FxHashMap], @@ -163,6 +168,25 @@ fn build_feature_maps( } } + if let Some(poi_aggregation) = poi_groups.get(&cell_id) { + for &metric_idx in poi_indices { + if poi_aggregation.counts[metric_idx] > 0 { + let avg = poi_aggregation.sums[metric_idx] + / poi_aggregation.counts[metric_idx] as f64; + if let (Some(min_num), Some(max_num), Some(avg_num)) = ( + serde_json::Number::from_f64(poi_aggregation.mins[metric_idx] as f64), + serde_json::Number::from_f64(poi_aggregation.maxs[metric_idx] as f64), + serde_json::Number::from_f64(avg), + ) { + let name = &poi_feature_names[metric_idx]; + map.insert(format!("min_{name}"), Value::Number(min_num)); + map.insert(format!("max_{name}"), Value::Number(max_num)); + map.insert(format!("avg_{name}"), Value::Number(avg_num)); + } + } + } + } + // Add travel time aggregation fields (using pre-computed key strings) for (ti, agg_map) in travel_aggs.iter().enumerate() { if let Some(agg) = agg_map.get(&cell_id) { @@ -209,18 +233,25 @@ pub async fn get_hexagons( check_license_bounds(&user.0, (south, west, north, east), share_bounds)?; let quant = state.data.quant_ref(); - let (parsed_filters, parsed_enum_filters) = parse_filters( + let poi_quant = state.data.poi_metrics.quant_ref(); + let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi( params.filters.as_deref(), &state.feature_name_to_index, &state.data.enum_values, &quant, + &state.data.poi_metrics.name_to_index, + &poi_quant, ) .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; - let num_filters = parsed_filters.len() + parsed_enum_filters.len(); + let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len(); let filters_str = params.filters; - let field_indices = parse_field_indices(params.fields.as_deref(), &state.feature_name_to_index) - .map_err(|err| (err.0, err.1).into_response())?; + let field_indices = parse_field_indices_with_poi( + params.fields.as_deref(), + &state.feature_name_to_index, + &state.data.poi_metrics.name_to_index, + ) + .map_err(|err| (err.0, err.1).into_response())?; let travel_entries = parse_optional_travel(params.travel.as_deref()) .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; @@ -269,6 +300,11 @@ pub async fn get_hexagons( let min_keys = &state.min_keys; let max_keys = &state.max_keys; let avg_keys = &state.avg_keys; + let poi_metrics = &state.data.poi_metrics; + let poi_field_indices = field_indices.poi.as_slice(); + let has_poi_fields = !poi_field_indices.is_empty(); + let has_poi_filters = !parsed_poi_filters.is_empty(); + let poi_num_features = poi_metrics.num_features(); let h3_res = h3o::Resolution::try_from(resolution) .map_err(|error| format!("Invalid H3 resolution {}: {}", resolution, error))?; @@ -276,6 +312,7 @@ pub async fn get_hexagons( let need_parent = needs_parent(resolution); let mut groups: FxHashMap = FxHashMap::default(); + let mut poi_groups: FxHashMap = FxHashMap::default(); let mut travel_aggs: Vec> = (0..travel_entries.len()) .map(|_| FxHashMap::default()) .collect(); @@ -296,6 +333,7 @@ pub async fn get_hexagons( .par_chunks(chunk_size) .map(|chunk| { let mut local_groups: FxHashMap = FxHashMap::default(); + let mut local_poi_groups: FxHashMap = FxHashMap::default(); let mut local_travel_aggs: Vec> = (0 ..travel_entries.len()) .map(|_| FxHashMap::default()) @@ -315,6 +353,11 @@ pub async fn get_hexagons( ) { continue; } + if has_poi_filters + && !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics) + { + continue; + } if has_travel { travel_minutes.clear(); @@ -352,7 +395,7 @@ pub async fn get_hexagons( let agg = local_groups .entry(cell_id) .or_insert_with(|| Aggregator::new(num_features, enum_dist_config)); - if let Some(sel_indices) = field_indices.as_deref() { + if let Some(sel_indices) = field_indices.normal.as_deref() { agg.add_row_selective( feature_data, row, @@ -364,6 +407,13 @@ pub async fn get_hexagons( agg.add_row(feature_data, row, num_features, &quant); } + if has_poi_fields { + local_poi_groups + .entry(cell_id) + .or_insert_with(|| PoiAggregator::new(poi_num_features)) + .add_row_selective(poi_metrics, row, poi_field_indices); + } + for (ti, minutes) in travel_minutes.iter().enumerate() { if let Some(mins) = minutes { let tagg = local_travel_aggs[ti] @@ -374,18 +424,24 @@ pub async fn get_hexagons( } } - (local_groups, local_travel_aggs) + (local_groups, local_poi_groups, local_travel_aggs) }) .collect(); // Merge thread-local results into the main accumulators - for (local_groups, local_travel) in thread_results { + for (local_groups, local_poi_groups, local_travel) in thread_results { for (cell_id, local_agg) in local_groups { groups .entry(cell_id) .or_insert_with(|| Aggregator::new(num_features, enum_dist_config)) .merge(&local_agg); } + for (cell_id, local_agg) in local_poi_groups { + poi_groups + .entry(cell_id) + .or_insert_with(|| PoiAggregator::new(poi_num_features)) + .merge(&local_agg); + } for (ti, local_ta) in local_travel.into_iter().enumerate() { for (cell_id, local_tt) in local_ta { travel_aggs[ti] @@ -414,6 +470,11 @@ pub async fn get_hexagons( ) { return; } + if has_poi_filters + && !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics) + { + return; + } if has_travel { travel_minutes.clear(); @@ -444,7 +505,7 @@ pub async fn get_hexagons( let aggregation = groups .entry(cell_id) .or_insert_with(|| Aggregator::new(num_features, enum_dist_config)); - if let Some(sel_indices) = field_indices.as_deref() { + if let Some(sel_indices) = field_indices.normal.as_deref() { aggregation.add_row_selective( feature_data, row, @@ -456,6 +517,13 @@ pub async fn get_hexagons( aggregation.add_row(feature_data, row, num_features, &quant); } + if has_poi_fields { + poi_groups + .entry(cell_id) + .or_insert_with(|| PoiAggregator::new(poi_num_features)) + .add_row_selective(poi_metrics, row, poi_field_indices); + } + for (ti, minutes) in travel_minutes.iter().enumerate() { if let Some(mins) = minutes { let agg = travel_aggs[ti] @@ -471,11 +539,14 @@ pub async fn get_hexagons( let mut features = build_feature_maps( &groups, + &poi_groups, min_keys, max_keys, avg_keys, num_features, - field_indices.as_deref(), + field_indices.normal.as_deref(), + &poi_metrics.feature_names, + poi_field_indices, (south, west, north, east), h3_res, &travel_aggs, @@ -499,7 +570,11 @@ pub async fn get_hexagons( bounds = format_args!("{:.4},{:.4},{:.4},{:.4}", south, west, north, east), filters = num_filters, filters_raw = filters_str.as_deref().unwrap_or("-"), - fields = field_indices.as_ref().map(|v| v.len() as i32).unwrap_or(-1), + fields = field_indices + .normal + .as_ref() + .map(|v| (v.len() + poi_field_indices.len()) as i32) + .unwrap_or(-1), travel_entries = travel_entries.len(), grid_ms = format_args!("{:.1}", t_grid.as_secs_f64() * 1000.0), agg_ms = format_args!("{:.1}", (t_agg - t_grid).as_secs_f64() * 1000.0), diff --git a/server-rs/src/routes/invites.rs b/server-rs/src/routes/invites.rs index fa7c633..b110858 100644 --- a/server-rs/src/routes/invites.rs +++ b/server-rs/src/routes/invites.rs @@ -9,11 +9,16 @@ use serde::{Deserialize, Serialize}; use tracing::{info, warn}; use crate::auth::{OptionalUser, PocketBaseUser}; +use crate::checkout_sessions::{ + active_referral_checkout_user, start_license_checkout, CheckoutStart, +}; use crate::pocketbase::get_superuser_token; +use crate::pocketbase_locks::acquire_pocketbase_lock; use crate::state::{AppState, SharedState}; static INVITE_REDEMPTIONS_IN_PROGRESS: LazyLock>> = LazyLock::new(|| Mutex::new(HashSet::new())); +const INVITE_REDEMPTION_LOCK_TTL_SECS: u64 = 5 * 60; struct InviteRedemptionGuard { code: String, @@ -103,7 +108,7 @@ fn validate_invite_code(code: &str) -> Result<(), &'static str> { } fn generate_invite_code() -> String { - use rand::Rng; + use rand::RngExt; let mut rng = rand::rng(); let chars: Vec = (0..12) .map(|_| { @@ -246,74 +251,26 @@ async fn grant_license_for_invite( async fn create_referral_checkout( state: &AppState, user: &PocketBaseUser, + invite_id: &str, ) -> Result { - let count = match super::pricing::count_licensed_users(state).await { - Ok(count) => count, - Err(err) => { - warn!("Failed to count licensed users for invite checkout: {err}"); - return Err(StatusCode::SERVICE_UNAVAILABLE.into_response()); - } - }; - let price_pence = super::pricing::price_for_count(count); - let public_url = &state.public_url; let success_url = format!("{public_url}/pricing?license_success=1"); let cancel_url = format!("{public_url}/pricing"); - let form_params = vec![ - ("mode", "payment".to_string()), - ( - "line_items[0][price_data][unit_amount]", - price_pence.to_string(), - ), - ("line_items[0][price_data][currency]", "gbp".to_string()), - ( - "line_items[0][price_data][product_data][name]", - "Perfect Postcodes Lifetime License".to_string(), - ), - ("line_items[0][quantity]", "1".to_string()), - ("success_url", success_url), - ("cancel_url", cancel_url), - ("client_reference_id", user.id.clone()), - ("customer_email", user.email.clone()), - ( - "discounts[0][coupon]", - state.stripe_referral_coupon_id.clone(), - ), - ]; - - let stripe_res = state - .http_client - .post("https://api.stripe.com/v1/checkout/sessions") - .basic_auth(&state.stripe_secret_key, None::<&str>) - .form(&form_params) - .send() - .await; - - match stripe_res { - Ok(resp) if resp.status().is_success() => { - let stripe_body: serde_json::Value = match resp.json().await { - Ok(value) => value, - Err(err) => { - warn!("Failed to parse Stripe checkout response: {err}"); - return Err(StatusCode::BAD_GATEWAY.into_response()); - } - }; - let checkout_url = stripe_body["url"].as_str().unwrap_or_default().to_string(); - if checkout_url.is_empty() { - warn!("Stripe checkout response did not include a URL"); - return Err(StatusCode::BAD_GATEWAY.into_response()); - } - Ok(checkout_url) - } - Ok(resp) => { - let status = resp.status(); - let text = resp.text().await.unwrap_or_default(); - warn!("Failed to create Stripe checkout for referral invite ({status}): {text}"); - Err(StatusCode::BAD_GATEWAY.into_response()) - } + match start_license_checkout( + state, + user, + &success_url, + &cancel_url, + Some(&state.stripe_referral_coupon_id), + Some(invite_id), + ) + .await + { + Ok(CheckoutStart::Free) => Ok(success_url), + Ok(CheckoutStart::Stripe { url }) => Ok(url), Err(err) => { - warn!("Stripe request error for referral invite: {err}"); + warn!("Failed to create reserved Stripe checkout for referral invite: {err:?}"); Err(StatusCode::BAD_GATEWAY.into_response()) } } @@ -541,6 +498,10 @@ pub async fn post_redeem_invite( .into_response(); } + if user.is_admin || user.subscription == "licensed" { + return (StatusCode::CONFLICT, "Account already has full access").into_response(); + } + let pb_url = state.pocketbase_url.trim_end_matches('/'); let token = match get_superuser_token(&state).await { @@ -561,6 +522,19 @@ pub async fn post_redeem_invite( .into_response() } }; + let lock_name = format!("invite:{}", req.code); + let _distributed_redemption_guard = + match acquire_pocketbase_lock(&state, &lock_name, INVITE_REDEMPTION_LOCK_TTL_SECS).await { + Ok(guard) => guard, + Err(err) => { + warn!(code = %req.code, "Failed to acquire invite redemption lock: {err}"); + return ( + StatusCode::CONFLICT, + "Invite redemption is already in progress", + ) + .into_response(); + } + }; let invite = match lookup_unused_invite(&state, pb_url, &token, &req.code).await { Ok(Some(invite)) => invite, @@ -591,11 +565,11 @@ pub async fn post_redeem_invite( }; if invite_type == "admin" { - if let Err(response) = grant_license_for_invite(&state, pb_url, &token, &user.id).await { + if let Err(response) = mark_invite_used(&state, pb_url, &token, invite_id, &user.id).await { return response; } - if let Err(response) = mark_invite_used(&state, pb_url, &token, invite_id, &user.id).await { + if let Err(response) = grant_license_for_invite(&state, pb_url, &token, &user.id).await { return response; } @@ -607,15 +581,26 @@ pub async fn post_redeem_invite( .into_response(); } - let checkout_url = match create_referral_checkout(&state, &user).await { + match active_referral_checkout_user(&state, invite_id).await { + Ok(Some(active_user_id)) if active_user_id != user.id => { + return ( + StatusCode::CONFLICT, + "Invite checkout is already in progress", + ) + .into_response() + } + Ok(_) => {} + Err(err) => { + warn!(code = %req.code, "Failed to check active referral checkout: {err}"); + return StatusCode::BAD_GATEWAY.into_response(); + } + } + + let checkout_url = match create_referral_checkout(&state, &user, invite_id).await { Ok(url) => url, Err(response) => return response, }; - if let Err(response) = mark_invite_used(&state, pb_url, &token, invite_id, &user.id).await { - return response; - } - info!(user_id = %user.id, code = %req.code, "Referral invite redeemed; checkout created"); Json(RedeemResponse { result: "checkout".to_string(), diff --git a/server-rs/src/routes/pois.rs b/server-rs/src/routes/pois.rs index e6f4cde..88946c9 100644 --- a/server-rs/src/routes/pois.rs +++ b/server-rs/src/routes/pois.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; use tracing::info; use crate::consts::MAX_POIS_PER_REQUEST; -use crate::data::POICategoryGroup; +use crate::data::{resolve_poi_category_filter, POICategoryGroup}; use crate::parsing::require_bounds; use crate::state::SharedState; @@ -47,20 +47,7 @@ pub async fn get_pois( .categories .as_deref() .filter(|text| !text.is_empty()) - .map(|text| { - text.split(',') - .filter_map(|part| { - let name = part.trim(); - state - .poi_data - .category - .values - .iter() - .position(|v| v == name) - .map(|pos| pos as u16) - }) - .collect() - }); + .map(|text| resolve_poi_category_filter(&state.poi_data.category.values, text)); let categories_raw = params.categories; let num_categories = category_filter.as_ref().map(|cats| cats.len()).unwrap_or(0); diff --git a/server-rs/src/routes/postcode_properties.rs b/server-rs/src/routes/postcode_properties.rs index c954c26..e0aa684 100644 --- a/server-rs/src/routes/postcode_properties.rs +++ b/server-rs/src/routes/postcode_properties.rs @@ -10,7 +10,7 @@ use tracing::{info, warn}; use crate::auth::OptionalUser; use crate::consts::{DEFAULT_PROPERTIES_LIMIT, MAX_PROPERTIES_LIMIT, POSTCODE_SEARCH_OFFSET}; use crate::licensing::{check_license_point, resolve_share_code}; -use crate::parsing::{parse_filters, row_passes_filters}; +use crate::parsing::{parse_filters_with_poi, row_passes_filters, row_passes_poi_filters}; use crate::state::SharedState; use crate::utils::normalize_postcode; @@ -62,15 +62,19 @@ pub async fn get_postcode_properties( )?; let quant = state.data.quant_ref(); - let (parsed_filters, parsed_enum_filters) = parse_filters( + let poi_quant = state.data.poi_metrics.quant_ref(); + let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi( params.filters.as_deref(), &state.feature_name_to_index, &state.data.enum_values, &quant, + &state.data.poi_metrics.name_to_index, + &poi_quant, ) .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; - let num_filters = parsed_filters.len() + parsed_enum_filters.len(); + let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len(); let filters_str = params.filters; + let has_poi_filters = !parsed_poi_filters.is_empty(); let travel_entries = parse_optional_travel(params.travel.as_deref()) .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; @@ -111,6 +115,12 @@ pub async fn get_postcode_properties( feature_data, num_features, ) + && (!has_poi_filters + || row_passes_poi_filters( + row, + &parsed_poi_filters, + &state.data.poi_metrics, + )) { if has_travel && !row_passes_travel_filters( diff --git a/server-rs/src/routes/postcode_stats.rs b/server-rs/src/routes/postcode_stats.rs index c2f9ac4..a261409 100644 --- a/server-rs/src/routes/postcode_stats.rs +++ b/server-rs/src/routes/postcode_stats.rs @@ -10,7 +10,9 @@ use tracing::{info, warn}; use crate::auth::OptionalUser; use crate::consts::POSTCODE_SEARCH_OFFSET; use crate::licensing::{check_license_point, resolve_share_code}; -use crate::parsing::{parse_field_set, parse_filters, row_passes_filters}; +use crate::parsing::{ + parse_field_set, parse_filters_with_poi, row_passes_filters, row_passes_poi_filters, +}; use crate::state::SharedState; use crate::utils::normalize_postcode; @@ -64,15 +66,19 @@ pub async fn get_postcode_stats( )?; let quant = state.data.quant_ref(); - let (parsed_filters, parsed_enum_filters) = parse_filters( + let poi_quant = state.data.poi_metrics.quant_ref(); + let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi( params.filters.as_deref(), &state.feature_name_to_index, &state.data.enum_values, &quant, + &state.data.poi_metrics.name_to_index, + &poi_quant, ) .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; - let num_filters = parsed_filters.len() + parsed_enum_filters.len(); + let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len(); let filters_str = params.filters; + let has_poi_filters = !parsed_poi_filters.is_empty(); let (fields_specified, field_set) = parse_field_set(params.fields.as_deref()); let travel_entries = parse_optional_travel(params.travel.as_deref()) @@ -108,6 +114,12 @@ pub async fn get_postcode_stats( feature_data, num_features, ) + && (!has_poi_filters + || row_passes_poi_filters( + row, + &parsed_poi_filters, + &state.data.poi_metrics, + )) { if has_travel && !row_passes_travel_filters(row_postcode, &travel_entries, &travel_data) @@ -123,7 +135,7 @@ pub async fn get_postcode_stats( let price_history = stats::extract_price_history(&matching_rows, &state.data, &state.feature_name_to_index); - let (numeric_features, enum_features_out) = stats::compute_feature_stats( + let (mut numeric_features, enum_features_out) = stats::compute_feature_stats( &matching_rows, &state.data, &state.data.feature_names, @@ -132,6 +144,12 @@ pub async fn get_postcode_stats( fields_specified, &field_set, ); + numeric_features.extend(stats::compute_poi_feature_stats( + &matching_rows, + &state.data.poi_metrics, + fields_specified, + &field_set, + )); let elapsed = start_time.elapsed(); info!( diff --git a/server-rs/src/routes/postcodes.rs b/server-rs/src/routes/postcodes.rs index d78d92a..541087f 100644 --- a/server-rs/src/routes/postcodes.rs +++ b/server-rs/src/routes/postcodes.rs @@ -10,14 +10,14 @@ use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; use tracing::info; -use crate::aggregation::{Aggregator, EnumDistConfig}; +use crate::aggregation::{Aggregator, EnumDistConfig, PoiAggregator}; use crate::auth::OptionalUser; use crate::consts::MAX_CELLS_PER_REQUEST; use crate::data::travel_time::TravelData; use crate::licensing::{check_license_bounds, resolve_share_code}; use crate::parsing::{ - bounds_intersect, parse_enum_dist, parse_field_indices, parse_filters, require_bounds, - row_passes_filters, + bounds_intersect, parse_enum_dist, parse_field_indices_with_poi, parse_filters_with_poi, + require_bounds, row_passes_filters, row_passes_poi_filters, }; use crate::pocketbase::log_user_location; use crate::routes::travel_time::{parse_optional_travel, TravelTimeAgg}; @@ -64,18 +64,25 @@ pub async fn get_postcodes( check_license_bounds(&user.0, (south, west, north, east), share_bounds)?; let quant = state.data.quant_ref(); - let (parsed_filters, parsed_enum_filters) = parse_filters( + let poi_quant = state.data.poi_metrics.quant_ref(); + let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi( params.filters.as_deref(), &state.feature_name_to_index, &state.data.enum_values, &quant, + &state.data.poi_metrics.name_to_index, + &poi_quant, ) .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; - let num_filters = parsed_filters.len() + parsed_enum_filters.len(); + let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len(); let filters_str = params.filters; - let field_indices = parse_field_indices(params.fields.as_deref(), &state.feature_name_to_index) - .map_err(|err| (err.0, err.1).into_response())?; + let field_indices = parse_field_indices_with_poi( + params.fields.as_deref(), + &state.feature_name_to_index, + &state.data.poi_metrics.name_to_index, + ) + .map_err(|err| (err.0, err.1).into_response())?; let travel_entries = parse_optional_travel(params.travel.as_deref()) .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; @@ -123,12 +130,18 @@ pub async fn get_postcodes( let min_keys = &state.min_keys; let max_keys = &state.max_keys; let avg_keys = &state.avg_keys; + let poi_metrics = &state.data.poi_metrics; + let poi_field_indices = field_indices.poi.as_slice(); + let has_poi_fields = !poi_field_indices.is_empty(); + let has_poi_filters = !parsed_poi_filters.is_empty(); + let poi_num_features = poi_metrics.num_features(); - let has_selective = field_indices.is_some(); - let sel_indices = field_indices.as_deref().unwrap_or(&[]); + let has_selective = field_indices.normal.is_some(); + let sel_indices = field_indices.normal.as_deref().unwrap_or(&[]); // Single-pass: aggregate directly into postcode_aggs while iterating properties in bounds let mut postcode_aggs: FxHashMap = FxHashMap::default(); + let mut poi_aggs: FxHashMap = FxHashMap::default(); state .grid @@ -143,6 +156,10 @@ pub async fn get_postcodes( ) { return; } + if has_poi_filters && !row_passes_poi_filters(row, &parsed_poi_filters, poi_metrics) + { + return; + } let postcode = state.data.postcode(row); if let Some(&pc_idx) = postcode_data.postcode_to_idx.get(postcode) { @@ -154,6 +171,12 @@ pub async fn get_postcodes( } else { agg.add_row(feature_data, row, num_features, &quant); } + if has_poi_fields { + poi_aggs + .entry(pc_idx) + .or_insert_with(|| PoiAggregator::new(poi_num_features)) + .add_row_selective(poi_metrics, row, poi_field_indices); + } } }); @@ -250,11 +273,12 @@ pub async fn get_postcodes( ]), ); - let iter: Box> = if let Some(idx) = field_indices.as_ref() { - Box::new(idx.iter().copied()) - } else { - Box::new(0..num_features) - }; + let iter: Box> = + if let Some(idx) = field_indices.normal.as_ref() { + Box::new(idx.iter().copied()) + } else { + Box::new(0..num_features) + }; for feat_index in iter { if aggregation.feat_counts[feat_index] > 0 { @@ -272,6 +296,25 @@ pub async fn get_postcodes( } } + if let Some(poi_aggregation) = poi_aggs.get(&pc_idx) { + for &metric_idx in poi_field_indices { + if poi_aggregation.counts[metric_idx] > 0 { + let avg = poi_aggregation.sums[metric_idx] + / poi_aggregation.counts[metric_idx] as f64; + if let (Some(min_num), Some(max_num), Some(avg_num)) = ( + serde_json::Number::from_f64(poi_aggregation.mins[metric_idx] as f64), + serde_json::Number::from_f64(poi_aggregation.maxs[metric_idx] as f64), + serde_json::Number::from_f64(avg), + ) { + let name = &poi_metrics.feature_names[metric_idx]; + props.insert(format!("min_{name}"), Value::Number(min_num)); + props.insert(format!("max_{name}"), Value::Number(max_num)); + props.insert(format!("avg_{name}"), Value::Number(avg_num)); + } + } + } + } + // Add travel time aggregation fields if let Some(tt_aggs) = travel_aggs.get(&pc_idx) { for (ti, agg) in tt_aggs.iter().enumerate() { @@ -322,7 +365,11 @@ pub async fn get_postcodes( bounds = format_args!("{:.6},{:.6},{:.6},{:.6}", south, west, north, east), filters = num_filters, filters_raw = filters_str.as_deref().unwrap_or("-"), - fields = field_indices.as_ref().map(|v| v.len() as i32).unwrap_or(-1), + fields = field_indices + .normal + .as_ref() + .map(|v| (v.len() + poi_field_indices.len()) as i32) + .unwrap_or(-1), travel_entries = travel_entries.len(), agg_ms = format_args!("{:.1}", t_agg.as_secs_f64() * 1000.0), json_ms = format_args!("{:.1}", (t_total - t_agg).as_secs_f64() * 1000.0), diff --git a/server-rs/src/routes/properties.rs b/server-rs/src/routes/properties.rs index f935c5a..6b08328 100644 --- a/server-rs/src/routes/properties.rs +++ b/server-rs/src/routes/properties.rs @@ -14,8 +14,8 @@ use crate::consts::{DEFAULT_PROPERTIES_LIMIT, MAX_PROPERTIES_LIMIT}; use crate::data::RenovationEvent; use crate::licensing::{check_license_bounds, resolve_share_code}; use crate::parsing::{ - cell_for_row_cached, h3_cell_bounds, needs_parent, parse_filters, row_passes_filters, - validate_h3_resolution, + cell_for_row_cached, h3_cell_bounds, needs_parent, parse_filters_with_poi, row_passes_filters, + row_passes_poi_filters, validate_h3_resolution, }; use crate::state::{AppState, SharedState}; @@ -117,6 +117,12 @@ pub fn build_property( features.insert(feat_name.clone(), value); } } + for (metric_idx, metric_name) in state.data.poi_metrics.feature_names.iter().enumerate() { + let value = state.data.poi_metrics.get_for_property_row(row, metric_idx); + if value.is_finite() { + features.insert(metric_name.clone(), value); + } + } Property { address: non_empty_string(state.data.address(row)), @@ -199,15 +205,19 @@ pub async fn get_hexagon_properties( let h3_str = params.h3; let quant = state.data.quant_ref(); - let (parsed_filters, parsed_enum_filters) = parse_filters( + let poi_quant = state.data.poi_metrics.quant_ref(); + let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = parse_filters_with_poi( params.filters.as_deref(), &state.feature_name_to_index, &state.data.enum_values, &quant, + &state.data.poi_metrics.name_to_index, + &poi_quant, ) .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; - let num_filters = parsed_filters.len() + parsed_enum_filters.len(); + let num_filters = parsed_filters.len() + parsed_enum_filters.len() + parsed_poi_filters.len(); let filters_str = params.filters; + let has_poi_filters = !parsed_poi_filters.is_empty(); let travel_entries = parse_optional_travel(params.travel.as_deref()) .map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?; @@ -242,6 +252,12 @@ pub async fn get_hexagon_properties( feature_data, num_features, ) + && (!has_poi_filters + || row_passes_poi_filters( + row, + &parsed_poi_filters, + &state.data.poi_metrics, + )) { if has_travel { let postcode = state.data.postcode(row); diff --git a/server-rs/src/routes/stats.rs b/server-rs/src/routes/stats.rs index 46def2d..94ea754 100644 --- a/server-rs/src/routes/stats.rs +++ b/server-rs/src/routes/stats.rs @@ -4,7 +4,7 @@ use rustc_hash::FxHashMap; use tracing::warn; use crate::consts::MAX_PRICE_HISTORY_POINTS; -use crate::data::{FeatureStats, PropertyData}; +use crate::data::{FeatureStats, PostcodePoiMetrics, PropertyData}; use super::hexagon_stats::{EnumFeatureStats, HistogramStats, NumericFeatureStats, PricePoint}; @@ -243,3 +243,80 @@ pub fn compute_feature_stats( (numeric_features, enum_features_out) } + +pub fn compute_poi_feature_stats( + matching_rows: &[usize], + poi_metrics: &PostcodePoiMetrics, + fields_specified: bool, + field_set: &HashSet, +) -> Vec { + let mut out = Vec::new(); + for (metric_idx, name) in poi_metrics.feature_names.iter().enumerate() { + if fields_specified && !field_set.contains(name.as_str()) { + continue; + } + + let global_hist = &poi_metrics.feature_stats[metric_idx].histogram; + let p1 = global_hist.p1; + let p99 = global_hist.p99; + let num_bins = global_hist.counts.len(); + let middle_bins = num_bins.saturating_sub(2); + let middle_width = if middle_bins > 0 && p99 > p1 { + (p99 - p1) / middle_bins as f32 + } else { + 0.0 + }; + + let mut count = 0usize; + let mut min_value = f32::INFINITY; + let mut max_value = f32::NEG_INFINITY; + let mut sum = 0.0f64; + let mut bins = vec![0u64; num_bins]; + + for &row in matching_rows { + let value = poi_metrics.get_for_property_row(row, metric_idx); + if !value.is_finite() { + continue; + } + count += 1; + if value < min_value { + min_value = value; + } + if value > max_value { + max_value = value; + } + sum += value as f64; + + let bin = if value < p1 { + 0 + } else if value >= p99 { + num_bins - 1 + } else if middle_width > 0.0 { + let middle_bin = ((value - p1) / middle_width) as usize; + (1 + middle_bin).min(num_bins - 2) + } else { + num_bins / 2 + }; + bins[bin] += 1; + } + + if count > 0 { + out.push(NumericFeatureStats { + name: name.clone(), + count, + min: min_value as f64, + max: max_value as f64, + mean: sum / count as f64, + histogram: HistogramStats { + min: global_hist.min as f64, + max: global_hist.max as f64, + p1: p1 as f64, + p99: p99 as f64, + counts: bins, + }, + }); + } + } + + out +} diff --git a/server-rs/src/routes/stripe_webhook.rs b/server-rs/src/routes/stripe_webhook.rs index b894ad1..76d7f80 100644 --- a/server-rs/src/routes/stripe_webhook.rs +++ b/server-rs/src/routes/stripe_webhook.rs @@ -1,78 +1,40 @@ -use std::collections::VecDeque; -use std::sync::{Arc, LazyLock}; +use std::sync::Arc; use axum::body::Bytes; use axum::extract::State; use axum::http::{HeaderMap, StatusCode}; use axum::response::{IntoResponse, Response}; -use hmac::{Hmac, Mac}; -use parking_lot::Mutex; -use rustc_hash::FxHashSet; +use hmac::{Hmac, KeyInit, Mac}; use sha2::Sha256; use tracing::{info, warn}; -use crate::pocketbase::get_superuser_token; +use crate::checkout_sessions::{ + grant_license, mark_checkout_completed, mark_referral_invite_used, verify_checkout_completion, + CheckoutCompletion, +}; use crate::state::SharedState; type HmacSha256 = Hmac; -/// Process-local LRU of recently processed Stripe event IDs. -/// Stripe retries deliver the same event ID; we drop duplicates so we don't -/// re-run side effects (subscription writes, token cache invalidation, logs). -/// Capacity is intentionally generous: at typical webhook volumes this covers -/// far more than Stripe's retry window. -struct EventDedup { - seen: FxHashSet, - queue: VecDeque, - capacity: usize, -} - -impl EventDedup { - fn new(capacity: usize) -> Self { - Self { - seen: FxHashSet::default(), - queue: VecDeque::with_capacity(capacity), - capacity, - } - } - - /// Returns `true` if this event ID is new (and records it), - /// `false` if it was already seen recently. - fn check_and_insert(&mut self, id: &str) -> bool { - if self.seen.contains(id) { - return false; - } - self.seen.insert(id.to_string()); - self.queue.push_back(id.to_string()); - if self.queue.len() > self.capacity { - if let Some(old) = self.queue.pop_front() { - self.seen.remove(&old); - } - } - true - } -} - -static EVENT_DEDUP: LazyLock> = - LazyLock::new(|| Mutex::new(EventDedup::new(1024))); - /// Verify Stripe webhook signature (v1 scheme). fn verify_signature(payload: &[u8], sig_header: &str, secret: &str) -> bool { // Parse timestamp and signature from header: "t=TIMESTAMP,v1=SIGNATURE" let mut timestamp = None; - let mut signature = None; + let mut signatures = Vec::new(); for part in sig_header.split(',') { if let Some(ts) = part.strip_prefix("t=") { timestamp = Some(ts); } else if let Some(sig) = part.strip_prefix("v1=") { - signature = Some(sig); + signatures.push(sig); } } - let (ts, sig_hex) = match (timestamp, signature) { - (Some(t), Some(s)) => (t, s), - _ => return false, + let Some(ts) = timestamp else { + return false; }; + if signatures.is_empty() { + return false; + } // Reject webhooks older than 5 minutes to prevent replay attacks if let Ok(ts_secs) = ts.parse::() { @@ -87,20 +49,21 @@ fn verify_signature(payload: &[u8], sig_header: &str, secret: &str) -> bool { return false; } - // Compute expected signature: HMAC-SHA256(secret, "TIMESTAMP.PAYLOAD") - let signed_payload = format!("{ts}.{}", String::from_utf8_lossy(payload)); - let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) { - Ok(m) => m, - Err(_) => return false, - }; - mac.update(signed_payload.as_bytes()); + let mut signed_payload = Vec::with_capacity(ts.len() + 1 + payload.len()); + signed_payload.extend_from_slice(ts.as_bytes()); + signed_payload.push(b'.'); + signed_payload.extend_from_slice(payload); - // Decode the provided hex signature and verify with constant-time comparison - let sig_bytes = match hex::decode(sig_hex) { - Ok(bytes) => bytes, - Err(_) => return false, - }; - mac.verify_slice(&sig_bytes).is_ok() + signatures.into_iter().any(|sig_hex| { + let Ok(sig_bytes) = hex::decode(sig_hex) else { + return false; + }; + let Ok(mut mac) = HmacSha256::new_from_slice(secret.as_bytes()) else { + return false; + }; + mac.update(&signed_payload); + mac.verify_slice(&sig_bytes).is_ok() + }) } /// Handle Stripe webhook events. @@ -140,65 +103,64 @@ pub async fn post_stripe_webhook( let event_type = event["type"].as_str().unwrap_or(""); let event_id = event["id"].as_str().unwrap_or(""); - // Idempotency: drop replays/retries of an already-processed event. - // We always answer 200 so Stripe stops retrying. - if !event_id.is_empty() && !EVENT_DEDUP.lock().check_and_insert(event_id) { - info!(event_id, event_type, "Dropping duplicate Stripe webhook"); - return StatusCode::OK.into_response(); - } - info!(event_id, event_type, "Received Stripe webhook"); if event_type == "checkout.session.completed" { - let user_id = event["data"]["object"]["client_reference_id"] - .as_str() - .unwrap_or(""); - if user_id.is_empty() { - warn!("checkout.session.completed missing client_reference_id"); - return StatusCode::OK.into_response(); - } - if !user_id.bytes().all(|b| b.is_ascii_alphanumeric()) || user_id.len() > 20 { - warn!(user_id, "Invalid client_reference_id format in webhook"); - return StatusCode::BAD_REQUEST.into_response(); - } - - // Update user subscription to "licensed" via PocketBase superuser auth - let token = match get_superuser_token(&state).await { - Ok(t) => t, - Err(err) => { - warn!("Failed to auth as PocketBase superuser in webhook: {err}"); - return StatusCode::INTERNAL_SERVER_ERROR.into_response(); - } - }; - - let pb_url = state.pocketbase_url.trim_end_matches('/'); - let url = format!("{pb_url}/api/collections/users/records/{user_id}"); - let res = state - .http_client - .patch(&url) - .header("Authorization", format!("Bearer {token}")) - .json(&serde_json::json!({ "subscription": "licensed" })) - .send() - .await; - - match res { - Ok(resp) if resp.status().is_success() => { - state.token_cache.invalidate_by_user_id(user_id); + let session = &event["data"]["object"]; + match verify_checkout_completion(&state, session).await { + Ok(CheckoutCompletion::Grant(checkout)) => { + if let Err(err) = mark_referral_invite_used( + &state, + &checkout.referral_invite_id, + &checkout.user_id, + ) + .await + { + warn!( + user_id = %checkout.user_id, + reservation_id = %checkout.reservation_id, + referral_invite_id = %checkout.referral_invite_id, + "Failed to mark referral invite used after Stripe checkout: {err:?}" + ); + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } + if let Err(err) = grant_license(&state, &checkout.user_id).await { + warn!( + user_id = %checkout.user_id, + reservation_id = %checkout.reservation_id, + "Failed to grant license after Stripe checkout: {err:?}" + ); + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } + if let Err(err) = mark_checkout_completed( + &state, + &checkout.reservation_id, + checkout.paid_amount_pence, + ) + .await + { + warn!( + user_id = %checkout.user_id, + reservation_id = %checkout.reservation_id, + "Failed to mark checkout completed after license grant: {err:?}" + ); + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } info!( - user_id, - "User subscription updated to licensed via Stripe webhook" + user_id = %checkout.user_id, + reservation_id = %checkout.reservation_id, + "User subscription updated to licensed via verified Stripe checkout" ); } - Ok(resp) => { - let status = resp.status(); - let text = resp.text().await.unwrap_or_default(); - warn!( - user_id, - "Failed to update user subscription ({status}): {text}" - ); + Ok(CheckoutCompletion::AlreadyHandled) => { + info!("Stripe checkout session was already handled"); + } + Ok(CheckoutCompletion::Rejected(reason)) => { + warn!("Rejecting Stripe checkout completion: {reason}"); } Err(err) => { - warn!(user_id, "PocketBase request error in webhook: {err}"); + warn!("Failed to verify Stripe checkout completion: {err:?}"); + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } } }