SPlit up
Some checks failed
Build and publish Docker image / build-and-push (push) Failing after 15s
CI / Check (push) Failing after 1m58s

This commit is contained in:
Andras Schmelczer 2026-06-12 21:51:37 +01:00
parent cf39ad754e
commit f59d01227b
91 changed files with 10370 additions and 7562 deletions

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,589 @@
//! The checkout session state machine: starting a checkout (with pricing and
//! reservation under a cross-instance lock), verifying Stripe's completion
//! payload, completing/granting, and reversing/reinstating after refunds or
//! disputes.
use std::sync::LazyLock;
use anyhow::{anyhow, Context};
use serde_json::Value;
use tokio::sync::Mutex;
use tracing::warn;
use crate::auth::PocketBaseUser;
use crate::pocketbase::get_superuser_token;
use crate::pocketbase_locks::acquire_pocketbase_lock;
use crate::routes::pricing::{count_licensed_users, price_for_count};
use crate::state::AppState;
use super::records::{
attach_stripe_session, count_active_pending_checkouts, create_pending_checkout,
expire_stale_pending_checkouts, find_active_checkout_for_user,
find_checkout_by_payment_intent_or_checkout_session, find_checkout_by_stripe_session,
has_other_completed_checkout_for_user, mark_checkout_completed, mark_checkout_reinstated,
mark_checkout_reversed, mark_checkout_status, PendingCheckoutInput,
};
use super::referral::{
mark_referral_invite_used, release_referral_invite_reservation, reserve_referral_invite,
};
use super::stripe::create_stripe_session;
use super::{
ensure_success, is_safe_reversal_reason, is_safe_stripe_session_id, now_unix_secs,
number_field, CheckoutCompletion, CheckoutStart, PaymentReinstatementOutcome,
PaymentReversalOutcome, VerifiedCheckout, CHECKOUT_CURRENCY, REFERRAL_DISCOUNT_PERCENT,
};
const CHECKOUT_SESSION_TTL_SECS: u64 = 31 * 60;
const CHECKOUT_PRICING_LOCK_NAME: &str = "checkout:pricing";
const CHECKOUT_PRICING_LOCK_TTL_SECS: u64 = 5 * 60;
static CHECKOUT_RESERVATION_LOCK: LazyLock<Mutex<()>> = LazyLock::new(|| Mutex::new(()));
pub async fn start_license_checkout(
state: &AppState,
user: &PocketBaseUser,
success_url: &str,
cancel_url: &str,
discount_coupon_id: Option<&str>,
referral_invite_id: Option<&str>,
) -> anyhow::Result<CheckoutStart> {
let _guard = CHECKOUT_RESERVATION_LOCK.lock().await;
let pricing_lock = acquire_pocketbase_lock(
state,
CHECKOUT_PRICING_LOCK_NAME,
CHECKOUT_PRICING_LOCK_TTL_SECS,
)
.await?;
let result = start_license_checkout_locked(
state,
user,
success_url,
cancel_url,
discount_coupon_id,
referral_invite_id,
)
.await;
if let Err(err) = pricing_lock.release().await {
warn!("Failed to release checkout pricing lock: {err}");
}
result
}
async fn start_license_checkout_locked(
state: &AppState,
user: &PocketBaseUser,
success_url: &str,
cancel_url: &str,
discount_coupon_id: Option<&str>,
referral_invite_id: Option<&str>,
) -> anyhow::Result<CheckoutStart> {
let now = now_unix_secs();
expire_stale_pending_checkouts(state, now).await?;
if let Some(existing) = find_active_checkout_for_user(
state,
&user.id,
discount_coupon_id.unwrap_or_default(),
referral_invite_id.unwrap_or_default(),
now,
)
.await?
{
if !existing.checkout_url.is_empty() {
return Ok(CheckoutStart::Stripe {
url: existing.checkout_url,
});
}
if let Err(err) = mark_checkout_status(state, &existing.id, "failed").await {
warn!(
reservation_id = %existing.id,
"Failed to fail incomplete checkout reservation: {err}"
);
}
}
let licensed_count = count_licensed_users(state).await?;
let pending_count = count_active_pending_checkouts(state, now).await?;
let price_pence = price_for_count(licensed_count + pending_count);
if price_pence == 0 {
grant_license(state, &user.id).await?;
return Ok(CheckoutStart::Free);
}
let expires_at_unix = now + CHECKOUT_SESSION_TTL_SECS;
let expected_total_pence = expected_total_for_checkout(price_pence, discount_coupon_id);
let reservation_id = create_pending_checkout(
state,
PendingCheckoutInput {
user_id: &user.id,
amount_pence: price_pence,
expected_total_pence,
currency: CHECKOUT_CURRENCY,
discount_coupon_id: discount_coupon_id.unwrap_or_default(),
referral_invite_id: referral_invite_id.unwrap_or_default(),
expires_at_unix,
},
)
.await?;
if let Some(invite_id) = referral_invite_id.filter(|id| !id.is_empty()) {
if let Err(err) =
reserve_referral_invite(state, invite_id, &user.id, &reservation_id, expires_at_unix)
.await
{
if let Err(mark_err) = mark_checkout_status(state, &reservation_id, "failed").await {
warn!(
reservation_id,
"Failed to mark checkout reservation failed: {mark_err}"
);
}
return Err(err);
}
}
let stripe_result = create_stripe_session(
state,
user,
&reservation_id,
price_pence,
success_url,
cancel_url,
expires_at_unix,
discount_coupon_id,
)
.await;
let (stripe_session_id, url) = match stripe_result {
Ok(session) => session,
Err(err) => {
if let Err(mark_err) = mark_checkout_status(state, &reservation_id, "failed").await {
warn!(
reservation_id,
"Failed to mark checkout reservation failed: {mark_err}"
);
}
if let Some(invite_id) = referral_invite_id.filter(|id| !id.is_empty()) {
if let Err(release_err) =
release_referral_invite_reservation(state, invite_id, &reservation_id).await
{
warn!(
reservation_id,
referral_invite_id = invite_id,
"Failed to release referral invite reservation: {release_err}"
);
}
}
return Err(err);
}
};
if let Err(err) = attach_stripe_session(state, &reservation_id, &stripe_session_id, &url).await
{
if let Err(mark_err) = mark_checkout_status(state, &reservation_id, "failed").await {
warn!(
reservation_id,
"Failed to mark checkout reservation failed: {mark_err}"
);
}
if let Some(invite_id) = referral_invite_id.filter(|id| !id.is_empty()) {
if let Err(release_err) =
release_referral_invite_reservation(state, invite_id, &reservation_id).await
{
warn!(
reservation_id,
referral_invite_id = invite_id,
"Failed to release referral invite reservation: {release_err}"
);
}
}
return Err(err);
}
Ok(CheckoutStart::Stripe { url })
}
pub async fn verify_checkout_completion(
state: &AppState,
session: &Value,
) -> anyhow::Result<CheckoutCompletion> {
let session_id = match session["id"].as_str() {
Some(id) if is_safe_stripe_session_id(id) => id,
_ => {
return Ok(CheckoutCompletion::Rejected(
"missing or invalid session id".into(),
))
}
};
let payment_intent_id = match session["payment_intent"].as_str() {
Some(id) if is_safe_stripe_session_id(id) => id,
_ => {
return Ok(CheckoutCompletion::Rejected(
"missing or invalid payment intent id".into(),
))
}
};
let checkout = match find_checkout_by_stripe_session(state, session_id).await? {
Some(checkout) => checkout,
None => {
return Ok(CheckoutCompletion::Rejected(
"checkout session has no reservation".into(),
))
}
};
let already_completed = checkout.status == "completed";
if !already_completed && checkout.status != "pending" && checkout.status != "expired" {
return Ok(CheckoutCompletion::Rejected(format!(
"checkout reservation is {}",
checkout.status
)));
}
if checkout.stripe_session_id != session_id {
mark_checkout_status(state, &checkout.id, "invalid").await?;
return Ok(CheckoutCompletion::Rejected(
"checkout reservation session id mismatch".into(),
));
}
let client_reference_id = session["client_reference_id"].as_str().unwrap_or_default();
if client_reference_id != checkout.user_id {
mark_checkout_status(state, &checkout.id, "invalid").await?;
return Ok(CheckoutCompletion::Rejected(
"checkout client_reference_id mismatch".into(),
));
}
let payment_status = session["payment_status"].as_str().unwrap_or_default();
if payment_status != "paid" {
return Ok(CheckoutCompletion::Rejected(format!(
"checkout payment_status is {payment_status}"
)));
}
let currency = session["currency"]
.as_str()
.unwrap_or_default()
.to_ascii_lowercase();
if currency != checkout.currency {
mark_checkout_status(state, &checkout.id, "invalid").await?;
return Ok(CheckoutCompletion::Rejected(
"checkout currency mismatch".into(),
));
}
let amount_subtotal = match number_field(session, "amount_subtotal") {
Some(amount) => amount,
None => {
mark_checkout_status(state, &checkout.id, "invalid").await?;
return Ok(CheckoutCompletion::Rejected(
"checkout amount_subtotal missing".into(),
));
}
};
if amount_subtotal != checkout.amount_pence {
mark_checkout_status(state, &checkout.id, "invalid").await?;
return Ok(CheckoutCompletion::Rejected(
"checkout amount_subtotal mismatch".into(),
));
}
let amount_total = match number_field(session, "amount_total") {
Some(amount) => amount,
None => {
mark_checkout_status(state, &checkout.id, "invalid").await?;
return Ok(CheckoutCompletion::Rejected(
"checkout amount_total missing".into(),
));
}
};
if amount_total != checkout.expected_total_pence {
mark_checkout_status(state, &checkout.id, "invalid").await?;
return Ok(CheckoutCompletion::Rejected(
"checkout amount_total mismatch".into(),
));
}
let verified = VerifiedCheckout {
reservation_id: checkout.id,
user_id: checkout.user_id,
stripe_session_id: session_id.to_string(),
payment_intent_id: payment_intent_id.to_string(),
paid_amount_pence: amount_total,
referral_invite_id: checkout.referral_invite_id,
};
if already_completed {
Ok(CheckoutCompletion::AlreadyHandled(verified))
} else {
Ok(CheckoutCompletion::Grant(verified))
}
}
pub async fn complete_verified_checkout(
state: &AppState,
checkout: &VerifiedCheckout,
) -> anyhow::Result<()> {
let _guard = CHECKOUT_RESERVATION_LOCK.lock().await;
let pricing_lock = acquire_pocketbase_lock(
state,
CHECKOUT_PRICING_LOCK_NAME,
CHECKOUT_PRICING_LOCK_TTL_SECS,
)
.await?;
let result = complete_verified_checkout_locked(state, checkout).await;
if let Err(err) = pricing_lock.release().await {
warn!("Failed to release checkout pricing lock: {err}");
}
result
}
async fn complete_verified_checkout_locked(
state: &AppState,
checkout: &VerifiedCheckout,
) -> anyhow::Result<()> {
let live_checkout = find_checkout_by_stripe_session(state, &checkout.stripe_session_id)
.await?
.ok_or_else(|| anyhow!("checkout reservation disappeared before completion"))?;
if live_checkout.status == "completed" {
if !checkout.referral_invite_id.is_empty() {
mark_referral_invite_used(
state,
&checkout.referral_invite_id,
&checkout.user_id,
&checkout.reservation_id,
)
.await?;
}
return Ok(());
}
if live_checkout.id != checkout.reservation_id
|| live_checkout.user_id != checkout.user_id
|| live_checkout.referral_invite_id != checkout.referral_invite_id
{
mark_checkout_status(state, &checkout.reservation_id, "invalid").await?;
return Err(anyhow!("checkout reservation changed before completion"));
}
if live_checkout.status != "pending" && live_checkout.status != "expired" {
return Err(anyhow!("checkout reservation is {}", live_checkout.status));
}
grant_license(state, &checkout.user_id).await?;
mark_checkout_completed(
state,
&checkout.reservation_id,
checkout.paid_amount_pence,
&checkout.payment_intent_id,
)
.await?;
if !checkout.referral_invite_id.is_empty() {
mark_referral_invite_used(
state,
&checkout.referral_invite_id,
&checkout.user_id,
&checkout.reservation_id,
)
.await?;
}
Ok(())
}
pub async fn grant_license_with_pricing_lock(
state: &AppState,
user_id: &str,
) -> anyhow::Result<()> {
let _guard = CHECKOUT_RESERVATION_LOCK.lock().await;
let pricing_lock = acquire_pocketbase_lock(
state,
CHECKOUT_PRICING_LOCK_NAME,
CHECKOUT_PRICING_LOCK_TTL_SECS,
)
.await?;
let result = grant_license(state, user_id).await;
if let Err(err) = pricing_lock.release().await {
warn!("Failed to release checkout pricing lock: {err}");
}
result
}
pub async fn grant_license(state: &AppState, user_id: &str) -> anyhow::Result<()> {
set_user_subscription(state, user_id, "licensed").await
}
pub async fn reverse_license_for_payment_intent(
state: &AppState,
payment_intent_id: &str,
reason: &str,
refunded_amount_pence: Option<u64>,
) -> anyhow::Result<PaymentReversalOutcome> {
if !is_safe_stripe_session_id(payment_intent_id) {
return Err(anyhow!("invalid Stripe payment intent id"));
}
if !is_safe_reversal_reason(reason) {
return Err(anyhow!("invalid Stripe reversal reason"));
}
let _guard = CHECKOUT_RESERVATION_LOCK.lock().await;
let checkout = match find_checkout_by_payment_intent_or_checkout_session(
state,
payment_intent_id,
)
.await?
{
Some(checkout) => checkout,
None => return Ok(PaymentReversalOutcome::NoMatchingCheckout),
};
let paid_amount_pence = checkout
.paid_amount_pence
.max(checkout.expected_total_pence);
if let Some(refunded_amount_pence) = refunded_amount_pence {
if refunded_amount_pence < paid_amount_pence {
return Ok(PaymentReversalOutcome::IgnoredPartialRefund {
user_id: checkout.user_id,
refunded_amount_pence,
paid_amount_pence,
});
}
}
if checkout.status == "reversed" {
return Ok(PaymentReversalOutcome::AlreadyHandled {
user_id: checkout.user_id,
});
}
if matches!(checkout.status.as_str(), "pending" | "expired" | "failed") {
mark_checkout_reversed(state, &checkout.id, reason, payment_intent_id).await?;
return Ok(PaymentReversalOutcome::Applied {
user_id: checkout.user_id,
});
}
if checkout.status != "completed" {
return Ok(PaymentReversalOutcome::NotReversible {
user_id: checkout.user_id,
status: checkout.status,
});
}
let has_other_license = has_other_completed_checkout_for_user(
state,
&checkout.user_id,
&checkout.id,
payment_intent_id,
)
.await?;
if !has_other_license {
revoke_license(state, &checkout.user_id).await?;
}
mark_checkout_reversed(state, &checkout.id, reason, payment_intent_id).await?;
Ok(PaymentReversalOutcome::Applied {
user_id: checkout.user_id,
})
}
pub async fn reinstate_license_for_payment_intent(
state: &AppState,
payment_intent_id: &str,
reason: &str,
) -> anyhow::Result<PaymentReinstatementOutcome> {
if !is_safe_stripe_session_id(payment_intent_id) {
return Err(anyhow!("invalid Stripe payment intent id"));
}
if !is_safe_reversal_reason(reason) {
return Err(anyhow!("invalid Stripe reinstatement reason"));
}
let _guard = CHECKOUT_RESERVATION_LOCK.lock().await;
let checkout = match find_checkout_by_payment_intent_or_checkout_session(
state,
payment_intent_id,
)
.await?
{
Some(checkout) => checkout,
None => return Ok(PaymentReinstatementOutcome::NoMatchingCheckout),
};
if checkout.status == "completed" {
return Ok(PaymentReinstatementOutcome::AlreadyHandled {
user_id: checkout.user_id,
});
}
if checkout.status != "reversed" {
return Ok(PaymentReinstatementOutcome::Ignored {
user_id: checkout.user_id,
reason: format!("checkout status is {}", checkout.status),
});
}
if !checkout.reversal_reason.starts_with("charge.dispute.") {
return Ok(PaymentReinstatementOutcome::Ignored {
user_id: checkout.user_id,
reason: format!("checkout was reversed by {}", checkout.reversal_reason),
});
}
grant_license(state, &checkout.user_id).await?;
mark_checkout_reinstated(state, &checkout.id, reason).await?;
Ok(PaymentReinstatementOutcome::Applied {
user_id: checkout.user_id,
})
}
async fn revoke_license(state: &AppState, user_id: &str) -> anyhow::Result<()> {
set_user_subscription(state, user_id, "free").await
}
async fn set_user_subscription(
state: &AppState,
user_id: &str,
subscription: &str,
) -> anyhow::Result<()> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({ "subscription": subscription }))
.send()
.await?;
ensure_success(resp)
.await
.context("PocketBase license update failed")?;
state.token_cache.invalidate_by_user_id(user_id);
Ok(())
}
pub(super) fn expected_total_for_checkout(
amount_pence: u64,
discount_coupon_id: Option<&str>,
) -> u64 {
if discount_coupon_id.is_some_and(|id| !id.is_empty()) {
return ((amount_pence * (100 - REFERRAL_DISCOUNT_PERCENT)) / 100).max(1);
}
amount_pence
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn expected_total_for_referral_discount_rounds_down_like_stripe_amount_math() {
assert_eq!(expected_total_for_checkout(999, Some("coupon_30")), 699);
assert_eq!(expected_total_for_checkout(1, Some("coupon_30")), 1);
assert_eq!(expected_total_for_checkout(999, None), 999);
}
}

View file

@ -0,0 +1,133 @@
//! Checkout sessions: Stripe-backed lifetime-license purchases reserved and
//! recorded in the PocketBase `checkout_sessions` collection.
//!
//! Split by concern:
//! - [`lifecycle`]: the session state machine (start, verify, complete,
//! reverse, reinstate) and license granting
//! - [`records`]: PocketBase `checkout_sessions` record handling
//! - [`referral`]: referral invite reservation/consumption bookkeeping
//! - [`stripe`]: Stripe API interaction (sessions, coupons, lookups)
mod lifecycle;
mod records;
mod referral;
mod stripe;
#[cfg(test)]
mod tests;
pub use lifecycle::{
complete_verified_checkout, grant_license_with_pricing_lock,
reinstate_license_for_payment_intent, reverse_license_for_payment_intent,
start_license_checkout, verify_checkout_completion,
};
pub use referral::active_referral_checkout_user;
use std::time::{SystemTime, UNIX_EPOCH};
use anyhow::anyhow;
use serde_json::Value;
pub const CHECKOUT_CURRENCY: &str = "gbp";
const CHECKOUT_COLLECTION: &str = "checkout_sessions";
const REFERRAL_DISCOUNT_PERCENT: u64 = 30;
pub enum CheckoutStart {
Free,
Stripe { url: String },
}
pub enum CheckoutCompletion {
Grant(VerifiedCheckout),
AlreadyHandled(VerifiedCheckout),
Rejected(String),
}
pub enum PaymentReversalOutcome {
Applied {
user_id: String,
},
AlreadyHandled {
user_id: String,
},
IgnoredPartialRefund {
user_id: String,
refunded_amount_pence: u64,
paid_amount_pence: u64,
},
NoMatchingCheckout,
NotReversible {
user_id: String,
status: String,
},
}
pub enum PaymentReinstatementOutcome {
Applied { user_id: String },
AlreadyHandled { user_id: String },
Ignored { user_id: String, reason: String },
NoMatchingCheckout,
}
pub struct VerifiedCheckout {
pub reservation_id: String,
pub user_id: String,
pub stripe_session_id: String,
pub payment_intent_id: String,
pub paid_amount_pence: u64,
pub referral_invite_id: String,
}
pub fn now_unix_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
fn number_field(value: &Value, field: &str) -> Option<u64> {
value[field].as_u64().or_else(|| {
value[field]
.as_f64()
.filter(|n| n.is_finite() && *n >= 0.0 && n.fract() == 0.0)
.map(|n| n as u64)
})
}
fn is_safe_stripe_session_id(id: &str) -> bool {
!id.is_empty()
&& id.len() <= 128
&& id
.bytes()
.all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'-')
}
fn is_safe_pocketbase_id(id: &str) -> bool {
!id.is_empty() && id.len() <= 32 && id.bytes().all(|b| b.is_ascii_alphanumeric())
}
fn is_safe_reversal_reason(reason: &str) -> bool {
!reason.is_empty()
&& reason.len() <= 128
&& reason
.bytes()
.all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'-' || b == b'.')
}
async fn ensure_success(resp: reqwest::Response) -> anyhow::Result<()> {
if resp.status().is_success() {
return Ok(());
}
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
Err(anyhow!("upstream returned {status}: {text}"))
}
async fn ensure_success_ref(resp: &reqwest::Response) -> anyhow::Result<()> {
if resp.status().is_success() {
return Ok(());
}
Err(anyhow!("upstream returned {}", resp.status()))
}

View file

@ -0,0 +1,564 @@
//! PocketBase `checkout_sessions` record handling: creating reservations,
//! status transitions, and lookups by Stripe session / payment intent.
use anyhow::{anyhow, Context};
use serde_json::Value;
use tracing::warn;
use crate::pocketbase::get_superuser_token;
use crate::state::AppState;
use super::referral::release_referral_invite_reservation;
use super::stripe::fetch_stripe_checkout_session_id_for_payment_intent;
use super::{
ensure_success, ensure_success_ref, is_safe_pocketbase_id, is_safe_stripe_session_id,
now_unix_secs, number_field, CHECKOUT_COLLECTION,
};
#[derive(Debug)]
pub(super) struct PendingCheckout {
pub(super) id: String,
pub(super) user_id: String,
pub(super) stripe_session_id: String,
pub(super) checkout_url: String,
pub(super) amount_pence: u64,
pub(super) expected_total_pence: u64,
pub(super) currency: String,
pub(super) referral_invite_id: String,
pub(super) status: String,
pub(super) payment_intent_id: String,
pub(super) paid_amount_pence: u64,
pub(super) reversal_reason: String,
}
pub async fn mark_checkout_completed(
state: &AppState,
reservation_id: &str,
paid_amount_pence: u64,
payment_intent_id: &str,
) -> anyhow::Result<()> {
if !is_safe_stripe_session_id(payment_intent_id) {
return Err(anyhow!("invalid Stripe payment intent id"));
}
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"status": "completed",
"paid_amount_pence": paid_amount_pence,
"completed_at_unix": now_unix_secs().to_string(),
"stripe_payment_intent_id": payment_intent_id,
}))
.send()
.await?;
ensure_success(resp)
.await
.context("PocketBase checkout completion update failed")
}
pub(super) async fn count_active_pending_checkouts(
state: &AppState,
now: u64,
) -> anyhow::Result<u64> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let filter = format!("status=\"pending\" && expires_at_unix>={now}");
let url = format!(
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
urlencoding::encode(&filter)
);
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
ensure_success_ref(&resp).await?;
let body: Value = resp.json().await?;
Ok(body["totalItems"].as_u64().unwrap_or(0))
}
pub(super) async fn find_active_checkout_for_user(
state: &AppState,
user_id: &str,
discount_coupon_id: &str,
referral_invite_id: &str,
now: u64,
) -> anyhow::Result<Option<PendingCheckout>> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let filter = active_checkout_filter(user_id, discount_coupon_id, referral_invite_id, now)?;
let url = format!(
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
urlencoding::encode(&filter)
);
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
ensure_success_ref(&resp).await?;
let body: Value = resp.json().await?;
let item = body["items"]
.as_array()
.and_then(|items| items.first())
.cloned();
item.map(parse_pending_checkout).transpose()
}
fn active_checkout_filter(
user_id: &str,
discount_coupon_id: &str,
referral_invite_id: &str,
now: u64,
) -> anyhow::Result<String> {
if !is_safe_pocketbase_id(user_id) {
return Err(anyhow!("invalid PocketBase user id"));
}
if !discount_coupon_id.is_empty() && !is_safe_stripe_session_id(discount_coupon_id) {
return Err(anyhow!("invalid Stripe coupon id"));
}
if !referral_invite_id.is_empty() && !is_safe_pocketbase_id(referral_invite_id) {
return Err(anyhow!("invalid PocketBase referral invite id"));
}
Ok(format!(
"status=\"pending\" && expires_at_unix>={now} && user=\"{user_id}\" && discount_coupon_id=\"{discount_coupon_id}\" && referral_invite_id=\"{referral_invite_id}\""
))
}
pub(super) async fn expire_stale_pending_checkouts(
state: &AppState,
now: u64,
) -> anyhow::Result<()> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let filter = format!("status=\"pending\" && expires_at_unix<{now}");
let url = format!(
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=50",
urlencoding::encode(&filter)
);
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
ensure_success_ref(&resp).await?;
let body: Value = resp.json().await?;
let Some(items) = body["items"].as_array() else {
return Ok(());
};
for item in items {
let Some(id) = item["id"].as_str() else {
continue;
};
if let Err(err) = mark_checkout_status(state, id, "expired").await {
warn!(
reservation_id = id,
"Failed to expire checkout reservation: {err}"
);
}
if let Some(invite_id) = item["referral_invite_id"]
.as_str()
.filter(|invite_id| !invite_id.is_empty())
{
if let Err(err) = release_referral_invite_reservation(state, invite_id, id).await {
warn!(
reservation_id = id,
referral_invite_id = invite_id,
"Failed to release expired referral invite reservation: {err}"
);
}
}
}
Ok(())
}
pub(super) struct PendingCheckoutInput<'a> {
pub(super) user_id: &'a str,
pub(super) amount_pence: u64,
pub(super) expected_total_pence: u64,
pub(super) currency: &'a str,
pub(super) discount_coupon_id: &'a str,
pub(super) referral_invite_id: &'a str,
pub(super) expires_at_unix: u64,
}
pub(super) async fn create_pending_checkout(
state: &AppState,
input: PendingCheckoutInput<'_>,
) -> anyhow::Result<String> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records");
let resp = state
.http_client
.post(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"user": input.user_id,
"stripe_session_id": "",
"stripe_payment_intent_id": "",
"checkout_url": "",
"amount_pence": input.amount_pence,
"expected_total_pence": input.expected_total_pence,
"currency": input.currency,
"discount_coupon_id": input.discount_coupon_id,
"referral_invite_id": input.referral_invite_id,
"status": "pending",
"expires_at_unix": input.expires_at_unix,
"paid_amount_pence": 0,
"completed_at_unix": "",
"reversal_reason": "",
}))
.send()
.await?;
ensure_success_ref(&resp).await?;
let body: Value = resp.json().await?;
body["id"]
.as_str()
.map(str::to_string)
.ok_or_else(|| anyhow!("PocketBase checkout reservation missing id"))
}
pub(super) async fn attach_stripe_session(
state: &AppState,
reservation_id: &str,
stripe_session_id: &str,
checkout_url: &str,
) -> anyhow::Result<()> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"stripe_session_id": stripe_session_id,
"checkout_url": checkout_url,
}))
.send()
.await?;
ensure_success(resp)
.await
.context("PocketBase checkout session attach failed")
}
pub(super) async fn mark_checkout_status(
state: &AppState,
reservation_id: &str,
status: &str,
) -> anyhow::Result<()> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({ "status": status }))
.send()
.await?;
ensure_success(resp)
.await
.with_context(|| format!("PocketBase checkout status update failed for {reservation_id}"))
}
pub(super) async fn mark_checkout_reversed(
state: &AppState,
reservation_id: &str,
reason: &str,
payment_intent_id: &str,
) -> anyhow::Result<()> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"status": "reversed",
"reversal_reason": reason,
"stripe_payment_intent_id": payment_intent_id,
}))
.send()
.await?;
ensure_success(resp)
.await
.with_context(|| format!("PocketBase checkout reversal update failed for {reservation_id}"))
}
pub(super) async fn mark_checkout_reinstated(
state: &AppState,
reservation_id: &str,
_reason: &str,
) -> anyhow::Result<()> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"status": "completed",
"reversal_reason": "",
}))
.send()
.await?;
ensure_success(resp).await.with_context(|| {
format!("PocketBase checkout reinstatement update failed for {reservation_id}")
})
}
pub(super) async fn find_checkout_by_stripe_session(
state: &AppState,
stripe_session_id: &str,
) -> anyhow::Result<Option<PendingCheckout>> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let filter = format!("stripe_session_id=\"{}\"", stripe_session_id);
let url = format!(
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
urlencoding::encode(&filter)
);
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
ensure_success_ref(&resp).await?;
let body: Value = resp.json().await?;
let item = body["items"]
.as_array()
.and_then(|items| items.first())
.cloned();
item.map(parse_pending_checkout).transpose()
}
async fn find_checkout_by_payment_intent(
state: &AppState,
payment_intent_id: &str,
) -> anyhow::Result<Option<PendingCheckout>> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let filter = format!("stripe_payment_intent_id=\"{}\"", payment_intent_id);
let url = format!(
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
urlencoding::encode(&filter)
);
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
ensure_success_ref(&resp).await?;
let body: Value = resp.json().await?;
let item = body["items"]
.as_array()
.and_then(|items| items.first())
.cloned();
item.map(parse_pending_checkout).transpose()
}
pub(super) async fn find_checkout_by_payment_intent_or_checkout_session(
state: &AppState,
payment_intent_id: &str,
) -> anyhow::Result<Option<PendingCheckout>> {
if let Some(checkout) = find_checkout_by_payment_intent(state, payment_intent_id).await? {
return Ok(Some(checkout));
}
let Some(session_id) =
fetch_stripe_checkout_session_id_for_payment_intent(state, payment_intent_id).await?
else {
return Ok(None);
};
let Some(mut checkout) = find_checkout_by_stripe_session(state, &session_id).await? else {
return Ok(None);
};
if checkout.payment_intent_id.is_empty() {
attach_payment_intent_to_checkout(state, &checkout.id, payment_intent_id).await?;
checkout.payment_intent_id = payment_intent_id.to_string();
} else if checkout.payment_intent_id != payment_intent_id {
mark_checkout_status(state, &checkout.id, "invalid").await?;
return Err(anyhow!(
"checkout reservation payment intent changed before reversal"
));
}
Ok(Some(checkout))
}
async fn attach_payment_intent_to_checkout(
state: &AppState,
reservation_id: &str,
payment_intent_id: &str,
) -> anyhow::Result<()> {
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"stripe_payment_intent_id": payment_intent_id,
}))
.send()
.await?;
ensure_success(resp)
.await
.context("PocketBase checkout payment intent attach failed")
}
pub(super) async fn has_other_completed_checkout_for_user(
state: &AppState,
user_id: &str,
reservation_id: &str,
payment_intent_id: &str,
) -> anyhow::Result<bool> {
if !is_safe_pocketbase_id(user_id) || !is_safe_pocketbase_id(reservation_id) {
return Err(anyhow!("invalid PocketBase id"));
}
if !is_safe_stripe_session_id(payment_intent_id) {
return Err(anyhow!("invalid Stripe payment intent id"));
}
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let filter = format!("user=\"{user_id}\" && status=\"completed\"");
let url = format!(
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=50",
urlencoding::encode(&filter)
);
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
ensure_success_ref(&resp).await?;
let body: Value = resp.json().await?;
let Some(items) = body["items"].as_array() else {
return Ok(false);
};
Ok(items.iter().any(|item| {
let other_id = item["id"].as_str().unwrap_or_default();
let other_payment_intent = item["stripe_payment_intent_id"]
.as_str()
.unwrap_or_default();
other_id != reservation_id && other_payment_intent != payment_intent_id
}))
}
fn parse_pending_checkout(item: Value) -> anyhow::Result<PendingCheckout> {
Ok(PendingCheckout {
id: item["id"]
.as_str()
.ok_or_else(|| anyhow!("checkout reservation missing id"))?
.to_string(),
user_id: item["user"]
.as_str()
.ok_or_else(|| anyhow!("checkout reservation missing user"))?
.to_string(),
stripe_session_id: item["stripe_session_id"]
.as_str()
.unwrap_or_default()
.to_string(),
checkout_url: item["checkout_url"]
.as_str()
.unwrap_or_default()
.to_string(),
amount_pence: number_field(&item, "amount_pence")
.ok_or_else(|| anyhow!("checkout reservation missing amount_pence"))?,
expected_total_pence: number_field(&item, "expected_total_pence")
.ok_or_else(|| anyhow!("checkout reservation missing expected_total_pence"))?,
currency: item["currency"]
.as_str()
.unwrap_or_default()
.to_ascii_lowercase(),
referral_invite_id: item["referral_invite_id"]
.as_str()
.unwrap_or_default()
.to_string(),
status: item["status"].as_str().unwrap_or_default().to_string(),
payment_intent_id: item["stripe_payment_intent_id"]
.as_str()
.unwrap_or_default()
.to_string(),
paid_amount_pence: number_field(&item, "paid_amount_pence").unwrap_or(0),
reversal_reason: item["reversal_reason"]
.as_str()
.unwrap_or_default()
.to_string(),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn active_checkout_filter_includes_empty_context_for_standard_checkout() {
let filter = active_checkout_filter("abc123", "", "", 42).unwrap();
assert_eq!(
filter,
"status=\"pending\" && expires_at_unix>=42 && user=\"abc123\" && discount_coupon_id=\"\" && referral_invite_id=\"\""
);
}
#[test]
fn active_checkout_filter_includes_referral_context() {
let filter = active_checkout_filter("user123", "coupon_30", "invite123", 99).unwrap();
assert_eq!(
filter,
"status=\"pending\" && expires_at_unix>=99 && user=\"user123\" && discount_coupon_id=\"coupon_30\" && referral_invite_id=\"invite123\""
);
}
#[test]
fn active_checkout_filter_rejects_unsafe_context_values() {
assert!(active_checkout_filter("user123", "bad\"coupon", "", 1).is_err());
assert!(active_checkout_filter("user123", "", "bad-invite", 1).is_err());
assert!(active_checkout_filter("bad-user", "", "", 1).is_err());
}
}

View file

@ -0,0 +1,312 @@
//! Referral invite bookkeeping: reserving an invite for an in-flight checkout,
//! releasing the reservation on failure/expiry, and recording final usage when
//! a verified payment completes.
use anyhow::{anyhow, Context};
use serde_json::Value;
use tracing::warn;
use crate::pocketbase::get_superuser_token;
use crate::state::AppState;
use super::{
ensure_success, ensure_success_ref, is_safe_pocketbase_id, now_unix_secs, number_field,
CHECKOUT_COLLECTION,
};
pub async fn mark_referral_invite_used(
state: &AppState,
invite_id: &str,
user_id: &str,
reservation_id: &str,
) -> anyhow::Result<()> {
if invite_id.is_empty() {
return Ok(());
}
if !is_safe_pocketbase_id(invite_id)
|| !is_safe_pocketbase_id(user_id)
|| !is_safe_pocketbase_id(reservation_id)
{
return Err(anyhow!("invalid PocketBase id"));
}
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let invite = fetch_invite_record(state, pb_url, &token, invite_id).await?;
// A verified Stripe payment must not lose entitlement just because local
// invite reservation bookkeeping expired or moved before webhook delivery.
match referral_invite_completion_action(&invite, user_id, reservation_id) {
ReferralInviteCompletionAction::AlreadyRecorded => return Ok(()),
ReferralInviteCompletionAction::AlreadyUsedByAnother => {
warn!(
invite_id,
user_id,
existing_used_by = invite["used_by_id"].as_str().unwrap_or_default(),
"Referral invite was already used by another account; preserving verified checkout entitlement"
);
return Ok(());
}
ReferralInviteCompletionAction::Record {
reservation_reassigned,
} => {
if reservation_reassigned {
warn!(
invite_id,
user_id,
reservation_id,
reserved_by_id = invite["reserved_by_id"].as_str().unwrap_or_default(),
reserved_checkout_id = invite["reserved_checkout_id"].as_str().unwrap_or_default(),
"Referral invite reservation moved before webhook completion; verified checkout will consume it"
);
}
}
}
let url = format!("{pb_url}/api/collections/invites/records/{invite_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"used_by_id": user_id,
"used_at": now_unix_secs().to_string(),
"reserved_by_id": "",
"reserved_checkout_id": "",
"reserved_until_unix": 0,
}))
.send()
.await?;
ensure_success(resp)
.await
.context("PocketBase invite usage update failed")
}
#[derive(Debug, PartialEq, Eq)]
enum ReferralInviteCompletionAction {
AlreadyRecorded,
AlreadyUsedByAnother,
Record { reservation_reassigned: bool },
}
fn referral_invite_completion_action(
invite: &Value,
user_id: &str,
reservation_id: &str,
) -> ReferralInviteCompletionAction {
let existing_used_by = invite["used_by_id"].as_str().unwrap_or_default();
if existing_used_by == user_id {
return ReferralInviteCompletionAction::AlreadyRecorded;
}
if !existing_used_by.is_empty() {
return ReferralInviteCompletionAction::AlreadyUsedByAnother;
}
let reserved_by_id = invite["reserved_by_id"].as_str().unwrap_or_default();
let reserved_checkout_id = invite["reserved_checkout_id"].as_str().unwrap_or_default();
let reservation_reassigned = (!reserved_by_id.is_empty() && reserved_by_id != user_id)
|| (!reserved_checkout_id.is_empty() && reserved_checkout_id != reservation_id);
ReferralInviteCompletionAction::Record {
reservation_reassigned,
}
}
async fn fetch_invite_record(
state: &AppState,
pb_url: &str,
token: &str,
invite_id: &str,
) -> anyhow::Result<Value> {
let url = format!("{pb_url}/api/collections/invites/records/{invite_id}");
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
ensure_success_ref(&resp).await?;
resp.json().await.map_err(Into::into)
}
pub(super) async fn reserve_referral_invite(
state: &AppState,
invite_id: &str,
user_id: &str,
reservation_id: &str,
reserved_until_unix: u64,
) -> anyhow::Result<()> {
if !is_safe_pocketbase_id(invite_id)
|| !is_safe_pocketbase_id(user_id)
|| !is_safe_pocketbase_id(reservation_id)
{
return Err(anyhow!("invalid PocketBase id"));
}
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let invite = fetch_invite_record(state, pb_url, &token, invite_id).await?;
let used_by = invite["used_by_id"].as_str().unwrap_or_default();
if !used_by.is_empty() {
return Err(anyhow!("referral invite already used"));
}
let now = now_unix_secs();
let reserved_by_id = invite["reserved_by_id"].as_str().unwrap_or_default();
let reserved_checkout_id = invite["reserved_checkout_id"].as_str().unwrap_or_default();
let existing_reserved_until = number_field(&invite, "reserved_until_unix").unwrap_or(0);
let reservation_is_live = existing_reserved_until >= now;
if reservation_is_live
&& !reserved_checkout_id.is_empty()
&& reserved_checkout_id != reservation_id
{
return Err(anyhow!("referral invite already has an active checkout"));
}
if reservation_is_live && !reserved_by_id.is_empty() && reserved_by_id != user_id {
return Err(anyhow!("referral invite reserved by another account"));
}
let url = format!("{pb_url}/api/collections/invites/records/{invite_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"reserved_by_id": user_id,
"reserved_checkout_id": reservation_id,
"reserved_until_unix": reserved_until_unix,
}))
.send()
.await?;
ensure_success(resp)
.await
.context("PocketBase invite reservation update failed")
}
pub(super) async fn release_referral_invite_reservation(
state: &AppState,
invite_id: &str,
reservation_id: &str,
) -> anyhow::Result<()> {
if !is_safe_pocketbase_id(invite_id) || !is_safe_pocketbase_id(reservation_id) {
return Err(anyhow!("invalid PocketBase id"));
}
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let invite = fetch_invite_record(state, pb_url, &token, invite_id).await?;
let used_by = invite["used_by_id"].as_str().unwrap_or_default();
let reserved_checkout_id = invite["reserved_checkout_id"].as_str().unwrap_or_default();
if !used_by.is_empty() || reserved_checkout_id != reservation_id {
return Ok(());
}
let url = format!("{pb_url}/api/collections/invites/records/{invite_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"reserved_by_id": "",
"reserved_checkout_id": "",
"reserved_until_unix": 0,
}))
.send()
.await?;
ensure_success(resp)
.await
.context("PocketBase invite reservation release failed")
}
pub async fn active_referral_checkout_user(
state: &AppState,
invite_id: &str,
) -> anyhow::Result<Option<String>> {
if !is_safe_pocketbase_id(invite_id) {
return Err(anyhow!("invalid PocketBase invite id"));
}
let now = now_unix_secs();
let token = get_superuser_token(state).await?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let filter = format!(
"status=\"pending\" && expires_at_unix>={now} && referral_invite_id=\"{}\"",
invite_id
);
let url = format!(
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
urlencoding::encode(&filter)
);
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
ensure_success_ref(&resp).await?;
let body: Value = resp.json().await?;
Ok(body["items"]
.as_array()
.and_then(|items| items.first())
.and_then(|item| item["user"].as_str())
.map(str::to_string))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn referral_invite_completion_records_available_invite() {
let invite = serde_json::json!({
"used_by_id": "",
"reserved_by_id": "",
"reserved_checkout_id": "",
});
assert_eq!(
referral_invite_completion_action(&invite, "user123", "checkout123"),
ReferralInviteCompletionAction::Record {
reservation_reassigned: false
}
);
}
#[test]
fn referral_invite_completion_records_reassigned_reservation() {
let invite = serde_json::json!({
"used_by_id": "",
"reserved_by_id": "otheruser",
"reserved_checkout_id": "othercheckout",
});
assert_eq!(
referral_invite_completion_action(&invite, "user123", "checkout123"),
ReferralInviteCompletionAction::Record {
reservation_reassigned: true
}
);
}
#[test]
fn referral_invite_completion_detects_existing_usage() {
let used_by_same_user = serde_json::json!({ "used_by_id": "user123" });
let used_by_another_user = serde_json::json!({ "used_by_id": "otheruser" });
assert_eq!(
referral_invite_completion_action(&used_by_same_user, "user123", "checkout123"),
ReferralInviteCompletionAction::AlreadyRecorded
);
assert_eq!(
referral_invite_completion_action(&used_by_another_user, "user123", "checkout123"),
ReferralInviteCompletionAction::AlreadyUsedByAnother
);
}
}

View file

@ -0,0 +1,175 @@
//! Stripe API interaction: creating checkout sessions, verifying coupon
//! configuration, and looking up sessions by payment intent.
use anyhow::{anyhow, Context};
use serde_json::Value;
use crate::auth::PocketBaseUser;
use crate::state::AppState;
use super::lifecycle::expected_total_for_checkout;
use super::{
ensure_success_ref, is_safe_stripe_session_id, CHECKOUT_CURRENCY, REFERRAL_DISCOUNT_PERCENT,
};
const CHECKOUT_PRODUCT_NAME: &str = "Perfect Postcodes Lifetime License";
/// Fetch a Stripe coupon and ensure its `percent_off` matches the expected
/// referral discount AND that it has no `amount_off` override. This blocks a
/// misconfigured (or maliciously swapped) coupon ID from quietly granting a
/// larger discount than the server's pricing math assumed.
async fn verify_stripe_coupon_discount(state: &AppState, coupon_id: &str) -> anyhow::Result<()> {
if !is_safe_stripe_session_id(coupon_id) {
return Err(anyhow!("unsafe stripe coupon id"));
}
let url = format!(
"https://api.stripe.com/v1/coupons/{}",
urlencoding::encode(coupon_id)
);
let resp = state
.http_client
.get(&url)
.basic_auth(&state.stripe_secret_key, None::<&str>)
.send()
.await
.context("Stripe coupon fetch failed")?;
ensure_success_ref(&resp)
.await
.context("Stripe coupon fetch returned error")?;
let body: Value = resp
.json()
.await
.context("Failed to parse Stripe coupon response")?;
let valid = body["valid"].as_bool().unwrap_or(false);
if !valid {
return Err(anyhow!("stripe coupon is not valid"));
}
if body["amount_off"].is_number() {
return Err(anyhow!(
"stripe coupon uses amount_off; only percent_off is permitted"
));
}
let percent_off = body["percent_off"]
.as_f64()
.ok_or_else(|| anyhow!("stripe coupon missing percent_off"))?;
if percent_off.is_nan() || (percent_off - REFERRAL_DISCOUNT_PERCENT as f64).abs() > 0.001 {
return Err(anyhow!(
"stripe coupon percent_off ({percent_off}) does not match expected {REFERRAL_DISCOUNT_PERCENT}"
));
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub(super) async fn create_stripe_session(
state: &AppState,
user: &PocketBaseUser,
reservation_id: &str,
price_pence: u64,
success_url: &str,
cancel_url: &str,
expires_at_unix: u64,
discount_coupon_id: Option<&str>,
) -> anyhow::Result<(String, String)> {
if let Some(coupon_id) = discount_coupon_id.filter(|id| !id.is_empty()) {
verify_stripe_coupon_discount(state, coupon_id).await?;
}
let mut form_params = vec![
("mode", "payment".to_string()),
("payment_method_types[0]", "card".to_string()),
(
"line_items[0][price_data][unit_amount]",
price_pence.to_string(),
),
(
"line_items[0][price_data][currency]",
CHECKOUT_CURRENCY.to_string(),
),
(
"line_items[0][price_data][product_data][name]",
CHECKOUT_PRODUCT_NAME.to_string(),
),
("line_items[0][quantity]", "1".to_string()),
("success_url", success_url.to_string()),
("cancel_url", cancel_url.to_string()),
("expires_at", expires_at_unix.to_string()),
("client_reference_id", user.id.clone()),
("customer_email", user.email.clone()),
("metadata[pending_checkout_id]", reservation_id.to_string()),
("metadata[expected_amount_pence]", price_pence.to_string()),
(
"metadata[expected_total_pence]",
expected_total_for_checkout(price_pence, discount_coupon_id).to_string(),
),
("metadata[expected_currency]", CHECKOUT_CURRENCY.to_string()),
];
if let Some(coupon_id) = discount_coupon_id.filter(|id| !id.is_empty()) {
form_params.push(("discounts[0][coupon]", coupon_id.to_string()));
form_params.push(("metadata[discount_coupon_id]", coupon_id.to_string()));
}
let resp = state
.http_client
.post("https://api.stripe.com/v1/checkout/sessions")
.basic_auth(&state.stripe_secret_key, None::<&str>)
.form(&form_params)
.send()
.await
.context("Stripe checkout request failed")?;
ensure_success_ref(&resp)
.await
.context("Stripe checkout failed")?;
let body: Value = resp
.json()
.await
.context("Failed to parse Stripe response")?;
let session_id = body["id"]
.as_str()
.filter(|id| is_safe_stripe_session_id(id))
.map(str::to_string)
.ok_or_else(|| anyhow!("Stripe session missing valid id"))?;
let url = body["url"]
.as_str()
.map(str::to_string)
.filter(|url| !url.is_empty())
.ok_or_else(|| anyhow!("Stripe session missing URL"))?;
Ok((session_id, url))
}
pub(super) async fn fetch_stripe_checkout_session_id_for_payment_intent(
state: &AppState,
payment_intent_id: &str,
) -> anyhow::Result<Option<String>> {
let url = format!(
"https://api.stripe.com/v1/checkout/sessions?payment_intent={}&limit=1",
urlencoding::encode(payment_intent_id)
);
let resp = state
.http_client
.get(&url)
.basic_auth(&state.stripe_secret_key, None::<&str>)
.send()
.await
.context("Stripe checkout session lookup failed")?;
ensure_success_ref(&resp)
.await
.context("Stripe checkout session lookup returned error")?;
let body: Value = resp
.json()
.await
.context("Failed to parse Stripe checkout session lookup")?;
Ok(body["data"]
.as_array()
.and_then(|items| items.first())
.and_then(|item| item["id"].as_str())
.filter(|id| is_safe_stripe_session_id(id))
.map(str::to_string))
}

View file

@ -0,0 +1,688 @@
//! Integration-style tests for the money paths: Stripe webhook verification →
//! license granting, checkout reservation bookkeeping, and invite redemption.
//!
//! PocketBase (an external HTTP service in production) is replaced by an
//! in-process axum mock listening on an ephemeral local port. The mock keeps
//! records in memory, evaluates the small PocketBase filter subset the server
//! actually uses, and records every mutating request so tests can assert that
//! e.g. a replayed webhook does not grant a license twice.
//!
//! Stripe itself is NOT mocked (its API URL is hardcoded to
//! `https://api.stripe.com`), so tests that reach the Stripe call assert the
//! failure-cleanup behaviour instead: the reservation is marked `failed` and
//! referral invite reservations are released.
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{SystemTime, UNIX_EPOCH};
use axum::body::Bytes;
use axum::extract::{Path, Query, State};
use axum::http::{HeaderMap, HeaderValue, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use axum::{Extension, Json, Router};
use hmac::{Hmac, KeyInit, Mac};
use serde_json::{json, Value};
use sha2::Sha256;
use crate::auth::{OptionalUser, PocketBaseUser};
use crate::routes::{post_redeem_invite, post_stripe_webhook};
use crate::state::{AppState, SharedState};
use super::start_license_checkout;
// ---------------------------------------------------------------------------
// Mock PocketBase
// ---------------------------------------------------------------------------
#[derive(Default)]
struct MockPocketBase {
/// collection name → records (each a JSON object with an "id").
records: Mutex<HashMap<String, Vec<Value>>>,
next_id: AtomicUsize,
/// Every mutating request: (method, path, body).
log: Mutex<Vec<(String, String, Value)>>,
}
impl MockPocketBase {
fn seed(&self, collection: &str, mut record: Value) -> String {
let id = match record["id"].as_str() {
Some(id) if !id.is_empty() => id.to_string(),
_ => format!("mockid{:06}", self.next_id.fetch_add(1, Ordering::SeqCst)),
};
record["id"] = json!(id);
self.records
.lock()
.unwrap()
.entry(collection.to_string())
.or_default()
.push(record);
id
}
fn record(&self, collection: &str, id: &str) -> Option<Value> {
self.records
.lock()
.unwrap()
.get(collection)?
.iter()
.find(|record| record["id"].as_str() == Some(id))
.cloned()
}
fn records_in(&self, collection: &str) -> Vec<Value> {
self.records
.lock()
.unwrap()
.get(collection)
.cloned()
.unwrap_or_default()
}
/// Number of PATCHes that set the user's subscription to "licensed" —
/// i.e. how many times a license was granted.
fn license_grant_count(&self, user_id: &str) -> usize {
let path = format!("/api/collections/users/records/{user_id}");
self.log
.lock()
.unwrap()
.iter()
.filter(|(method, request_path, body)| {
method == "PATCH"
&& request_path == &path
&& body["subscription"].as_str() == Some("licensed")
})
.count()
}
}
/// Evaluate the PocketBase filter subset used by the server:
/// `a="x" && b>=N && c<N && (d="" || d="y")`.
fn record_matches(record: &Value, filter: &str) -> bool {
if filter.is_empty() {
return true;
}
filter
.split(" && ")
.all(|clause| clause_matches(record, clause.trim()))
}
fn clause_matches(record: &Value, clause: &str) -> bool {
if let Some(inner) = clause.strip_prefix('(').and_then(|c| c.strip_suffix(')')) {
return inner
.split(" || ")
.any(|alternative| clause_matches(record, alternative.trim()));
}
if let Some((field, value)) = clause.split_once(">=") {
return record_number(record, field) >= value.trim().parse::<i64>().unwrap_or(i64::MAX);
}
if let Some((field, value)) = clause.split_once('<') {
return record_number(record, field) < value.trim().parse::<i64>().unwrap_or(i64::MIN);
}
if let Some((field, value)) = clause.split_once('=') {
let expected = value.trim().trim_matches('"');
return record_string(record, field) == expected;
}
panic!("mock PocketBase cannot evaluate filter clause: {clause}");
}
fn record_number(record: &Value, field: &str) -> i64 {
let value = &record[field];
value
.as_i64()
.or_else(|| value.as_f64().map(|float| float as i64))
.or_else(|| value.as_str().and_then(|text| text.parse().ok()))
.unwrap_or(0)
}
fn record_string(record: &Value, field: &str) -> String {
let value = &record[field];
value.as_str().map(str::to_string).unwrap_or_else(|| {
if value.is_null() {
String::new()
} else {
value.to_string()
}
})
}
async fn auth_handler() -> Json<Value> {
Json(json!({ "token": "testsuperusertoken" }))
}
async fn list_records(
State(pb): State<Arc<MockPocketBase>>,
Path(collection): Path<String>,
Query(params): Query<HashMap<String, String>>,
) -> Json<Value> {
let filter = params.get("filter").map(String::as_str).unwrap_or("");
let matching: Vec<Value> = pb
.records
.lock()
.unwrap()
.get(&collection)
.map(|records| {
records
.iter()
.filter(|record| record_matches(record, filter))
.cloned()
.collect()
})
.unwrap_or_default();
let total = matching.len();
let per_page = params
.get("perPage")
.and_then(|raw| raw.parse::<usize>().ok())
.unwrap_or(30);
let items: Vec<Value> = matching.into_iter().take(per_page).collect();
Json(json!({ "items": items, "totalItems": total }))
}
async fn create_record(
State(pb): State<Arc<MockPocketBase>>,
Path(collection): Path<String>,
Json(body): Json<Value>,
) -> Response {
pb.log.lock().unwrap().push((
"POST".to_string(),
format!("/api/collections/{collection}/records"),
body.clone(),
));
// Emulate the unique `name` constraint on the distributed-lock collection
// so concurrent acquisitions conflict like they do against real PocketBase.
if collection == "checkout_locks" {
let exists = pb
.records
.lock()
.unwrap()
.get(&collection)
.is_some_and(|records| records.iter().any(|record| record["name"] == body["name"]));
if exists {
return (
StatusCode::BAD_REQUEST,
Json(json!({ "message": "name must be unique" })),
)
.into_response();
}
}
let id = pb.seed(&collection, body);
Json(pb.record(&collection, &id).expect("record just created")).into_response()
}
async fn get_record(
State(pb): State<Arc<MockPocketBase>>,
Path((collection, id)): Path<(String, String)>,
) -> Response {
match pb.record(&collection, &id) {
Some(record) => Json(record).into_response(),
None => StatusCode::NOT_FOUND.into_response(),
}
}
async fn patch_record(
State(pb): State<Arc<MockPocketBase>>,
Path((collection, id)): Path<(String, String)>,
Json(body): Json<Value>,
) -> Response {
pb.log.lock().unwrap().push((
"PATCH".to_string(),
format!("/api/collections/{collection}/records/{id}"),
body.clone(),
));
let mut records = pb.records.lock().unwrap();
let Some(record) = records.get_mut(&collection).and_then(|list| {
list.iter_mut()
.find(|record| record["id"].as_str() == Some(&id))
}) else {
return StatusCode::NOT_FOUND.into_response();
};
if let (Some(target), Some(updates)) = (record.as_object_mut(), body.as_object()) {
for (key, value) in updates {
target.insert(key.clone(), value.clone());
}
}
Json(record.clone()).into_response()
}
async fn delete_record(
State(pb): State<Arc<MockPocketBase>>,
Path((collection, id)): Path<(String, String)>,
) -> Response {
pb.log.lock().unwrap().push((
"DELETE".to_string(),
format!("/api/collections/{collection}/records/{id}"),
Value::Null,
));
let mut records = pb.records.lock().unwrap();
let Some(list) = records.get_mut(&collection) else {
return StatusCode::NOT_FOUND.into_response();
};
let before = list.len();
list.retain(|record| record["id"].as_str() != Some(&id));
if list.len() == before {
StatusCode::NOT_FOUND.into_response()
} else {
StatusCode::NO_CONTENT.into_response()
}
}
fn mock_pb_router(pb: Arc<MockPocketBase>) -> Router {
Router::new()
.route(
"/api/collections/_superusers/auth-with-password",
post(auth_handler),
)
.route(
"/api/collections/{collection}/records",
get(list_records).post(create_record),
)
.route(
"/api/collections/{collection}/records/{id}",
get(get_record).patch(patch_record).delete(delete_record),
)
.with_state(pb)
}
// ---------------------------------------------------------------------------
// Test harness
// ---------------------------------------------------------------------------
struct TestEnv {
shared: Arc<SharedState>,
pb: Arc<MockPocketBase>,
}
async fn setup() -> TestEnv {
let pb = Arc::new(MockPocketBase::default());
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind mock PocketBase listener");
let addr = listener.local_addr().expect("mock PocketBase address");
let router = mock_pb_router(pb.clone());
tokio::spawn(async move {
axum::serve(listener, router)
.await
.expect("mock PocketBase serve");
});
let state = AppState::for_tests(format!("http://{addr}"));
TestEnv {
shared: Arc::new(SharedState::new(state)),
pb,
}
}
fn now_unix() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("clock after epoch")
.as_secs()
}
fn test_user(id: &str) -> PocketBaseUser {
PocketBaseUser {
id: id.to_string(),
email: format!("{id}@test.example"),
is_admin: false,
subscription: "free".to_string(),
newsletter: false,
can_see_listings: false,
}
}
fn seed_user(pb: &MockPocketBase, id: &str) {
pb.seed(
"users",
json!({
"id": id,
"email": format!("{id}@test.example"),
"subscription": "free",
"is_admin": false,
}),
);
}
fn seed_pending_checkout(pb: &MockPocketBase, user_id: &str, session_id: &str) -> String {
pb.seed(
"checkout_sessions",
json!({
"user": user_id,
"stripe_session_id": session_id,
"stripe_payment_intent_id": "",
"checkout_url": "https://checkout.stripe.test/session",
"amount_pence": 999,
"expected_total_pence": 999,
"currency": "gbp",
"discount_coupon_id": "",
"referral_invite_id": "",
"status": "pending",
"expires_at_unix": now_unix() + 1800,
"paid_amount_pence": 0,
"completed_at_unix": "",
"reversal_reason": "",
}),
)
}
fn checkout_completed_event(session_id: &str, user_id: &str, amount_total: u64) -> Vec<u8> {
serde_json::to_vec(&json!({
"id": "evt_test_1",
"type": "checkout.session.completed",
"data": { "object": {
"id": session_id,
"payment_intent": "pi_test_1",
"client_reference_id": user_id,
"payment_status": "paid",
"currency": "gbp",
"amount_subtotal": 999,
"amount_total": amount_total,
}}
}))
.expect("event serializes")
}
/// Sign a payload the way Stripe does: HMAC-SHA256 over `{timestamp}.{payload}`
/// with the webhook secret, presented as `t=...,v1=...`.
fn stripe_signature_header(payload: &[u8], secret: &str) -> String {
let timestamp = now_unix();
let mut mac =
Hmac::<Sha256>::new_from_slice(secret.as_bytes()).expect("HMAC accepts any key length");
mac.update(format!("{timestamp}.").as_bytes());
mac.update(payload);
let signature = hex::encode(mac.finalize().into_bytes());
format!("t={timestamp},v1={signature}")
}
async fn deliver_webhook(env: &TestEnv, payload: Vec<u8>, signature: Option<&str>) -> Response {
let mut headers = HeaderMap::new();
if let Some(signature) = signature {
headers.insert(
"stripe-signature",
HeaderValue::from_str(signature).expect("signature header value"),
);
}
post_stripe_webhook(State(env.shared.clone()), headers, Bytes::from(payload)).await
}
async fn redeem_invite(env: &TestEnv, user: PocketBaseUser, code: &str) -> Response {
post_redeem_invite(
State(env.shared.clone()),
Extension(OptionalUser(Some(user))),
Json(serde_json::from_value(json!({ "code": code })).expect("redeem request deserializes")),
)
.await
}
async fn response_json(response: Response) -> Value {
let bytes = axum::body::to_bytes(response.into_body(), 1 << 20)
.await
.expect("response body");
serde_json::from_slice(&bytes).expect("response body is JSON")
}
// ---------------------------------------------------------------------------
// Stripe webhook → license granting
// ---------------------------------------------------------------------------
#[tokio::test]
async fn webhook_with_valid_signature_grants_license() {
let env = setup().await;
seed_user(&env.pb, "user1");
let reservation_id = seed_pending_checkout(&env.pb, "user1", "cs_test_abc");
let payload = checkout_completed_event("cs_test_abc", "user1", 999);
let signature = stripe_signature_header(&payload, "whsec_test_secret");
let response = deliver_webhook(&env, payload, Some(&signature)).await;
assert_eq!(response.status(), StatusCode::OK);
let user = env.pb.record("users", "user1").expect("user exists");
assert_eq!(user["subscription"], json!("licensed"));
let checkout = env
.pb
.record("checkout_sessions", &reservation_id)
.expect("checkout exists");
assert_eq!(checkout["status"], json!("completed"));
assert_eq!(checkout["paid_amount_pence"], json!(999));
assert_eq!(checkout["stripe_payment_intent_id"], json!("pi_test_1"));
assert_eq!(env.pb.license_grant_count("user1"), 1);
}
#[tokio::test]
async fn webhook_with_invalid_signature_is_rejected() {
let env = setup().await;
seed_user(&env.pb, "user1");
let reservation_id = seed_pending_checkout(&env.pb, "user1", "cs_test_bad");
let payload = checkout_completed_event("cs_test_bad", "user1", 999);
// Signed with the wrong secret.
let wrong_signature = stripe_signature_header(&payload, "whsec_wrong_secret");
let response = deliver_webhook(&env, payload.clone(), Some(&wrong_signature)).await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
// Missing signature header entirely.
let response = deliver_webhook(&env, payload, None).await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
let user = env.pb.record("users", "user1").expect("user exists");
assert_eq!(user["subscription"], json!("free"));
let checkout = env
.pb
.record("checkout_sessions", &reservation_id)
.expect("checkout exists");
assert_eq!(checkout["status"], json!("pending"));
assert_eq!(env.pb.license_grant_count("user1"), 0);
}
#[tokio::test]
async fn replayed_webhook_does_not_double_grant() {
let env = setup().await;
seed_user(&env.pb, "user1");
let reservation_id = seed_pending_checkout(&env.pb, "user1", "cs_test_replay");
let payload = checkout_completed_event("cs_test_replay", "user1", 999);
let signature = stripe_signature_header(&payload, "whsec_test_secret");
let first = deliver_webhook(&env, payload.clone(), Some(&signature)).await;
assert_eq!(first.status(), StatusCode::OK);
let replay = deliver_webhook(&env, payload, Some(&signature)).await;
assert_eq!(replay.status(), StatusCode::OK);
let user = env.pb.record("users", "user1").expect("user exists");
assert_eq!(user["subscription"], json!("licensed"));
let checkout = env
.pb
.record("checkout_sessions", &reservation_id)
.expect("checkout exists");
assert_eq!(checkout["status"], json!("completed"));
// The replay must be acknowledged without granting a second time.
assert_eq!(env.pb.license_grant_count("user1"), 1);
}
#[tokio::test]
async fn webhook_with_tampered_amount_is_rejected_and_marks_reservation_invalid() {
let env = setup().await;
seed_user(&env.pb, "user1");
let reservation_id = seed_pending_checkout(&env.pb, "user1", "cs_test_amount");
// Validly signed event whose amount_total does not match the reservation.
let payload = checkout_completed_event("cs_test_amount", "user1", 500);
let signature = stripe_signature_header(&payload, "whsec_test_secret");
let response = deliver_webhook(&env, payload, Some(&signature)).await;
// Rejections are acknowledged with 200 so Stripe stops retrying.
assert_eq!(response.status(), StatusCode::OK);
let user = env.pb.record("users", "user1").expect("user exists");
assert_eq!(user["subscription"], json!("free"));
let checkout = env
.pb
.record("checkout_sessions", &reservation_id)
.expect("checkout exists");
assert_eq!(checkout["status"], json!("invalid"));
assert_eq!(env.pb.license_grant_count("user1"), 0);
}
// ---------------------------------------------------------------------------
// Checkout session creation (up to the hardcoded Stripe API call)
// ---------------------------------------------------------------------------
#[tokio::test]
async fn checkout_start_reserves_then_marks_failed_when_stripe_is_unreachable() {
let env = setup().await;
seed_user(&env.pb, "user9");
let state = env.shared.load_state();
let user = test_user("user9");
// The Stripe API URL is hardcoded to https://api.stripe.com, so the
// session-creation call fails in tests (no network / dummy key). The
// reservation bookkeeping before and after that call is what we assert.
let result = start_license_checkout(
&state,
&user,
"https://x/success",
"https://x/cancel",
None,
None,
)
.await;
assert!(result.is_err(), "Stripe call must fail in tests");
let checkouts = env.pb.records_in("checkout_sessions");
assert_eq!(checkouts.len(), 1, "exactly one reservation created");
let checkout = &checkouts[0];
assert_eq!(checkout["user"], json!("user9"));
// 0 licensed users → public count 120 → second tier price (999p).
assert_eq!(checkout["amount_pence"], json!(999));
assert_eq!(checkout["expected_total_pence"], json!(999));
assert_eq!(checkout["currency"], json!("gbp"));
// The failed Stripe call must not leave a live pending reservation.
assert_eq!(checkout["status"], json!("failed"));
// The cross-instance pricing lock was released.
assert!(env.pb.records_in("checkout_locks").is_empty());
assert_eq!(env.pb.license_grant_count("user9"), 0);
}
// ---------------------------------------------------------------------------
// Invite redemption
// ---------------------------------------------------------------------------
#[tokio::test]
async fn admin_invite_redemption_grants_license() {
let env = setup().await;
seed_user(&env.pb, "user2");
let invite_id = env.pb.seed(
"invites",
json!({
"code": "admininvite1",
"invite_type": "admin",
"created_by": "adminuser1",
"used_by_id": "",
"used_at": "",
}),
);
let response = redeem_invite(&env, test_user("user2"), "admininvite1").await;
assert_eq!(response.status(), StatusCode::OK);
let body = response_json(response).await;
assert_eq!(body["result"], json!("licensed"));
let invite = env.pb.record("invites", &invite_id).expect("invite exists");
assert_eq!(invite["used_by_id"], json!("user2"));
let user = env.pb.record("users", "user2").expect("user exists");
assert_eq!(user["subscription"], json!("licensed"));
assert_eq!(env.pb.license_grant_count("user2"), 1);
}
#[tokio::test]
async fn invalid_and_oversized_invite_codes_are_rejected() {
let env = setup().await;
// Non-alphanumeric characters.
let response = redeem_invite(&env, test_user("user2"), "bad-code!").await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
// Longer than the 20-character limit.
let oversized = "a".repeat(21);
let response = redeem_invite(&env, test_user("user2"), &oversized).await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
// Empty code.
let response = redeem_invite(&env, test_user("user2"), "").await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
assert_eq!(env.pb.license_grant_count("user2"), 0);
}
#[tokio::test]
async fn already_used_invite_is_rejected() {
let env = setup().await;
seed_user(&env.pb, "user2");
env.pb.seed(
"invites",
json!({
"code": "usedinvite12",
"invite_type": "admin",
"created_by": "adminuser1",
"used_by_id": "otheruser9",
"used_at": "1700000000",
}),
);
let response = redeem_invite(&env, test_user("user2"), "usedinvite12").await;
assert_eq!(response.status(), StatusCode::NOT_FOUND);
let user = env.pb.record("users", "user2").expect("user exists");
assert_eq!(user["subscription"], json!("free"));
assert_eq!(env.pb.license_grant_count("user2"), 0);
}
#[tokio::test]
async fn referral_invite_redemption_releases_reservation_when_stripe_is_unreachable() {
let env = setup().await;
seed_user(&env.pb, "user3");
let invite_id = env.pb.seed(
"invites",
json!({
"code": "refcode12345",
"invite_type": "referral",
"created_by": "licenseduser1",
"used_by_id": "",
"used_at": "",
"reserved_by_id": "",
"reserved_checkout_id": "",
"reserved_until_unix": 0,
}),
);
// The redemption itself is valid; it fails only at the hardcoded Stripe
// call (coupon verification / session creation), which must roll back the
// reservation cleanly.
let response = redeem_invite(&env, test_user("user3"), "refcode12345").await;
assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
let checkouts = env.pb.records_in("checkout_sessions");
assert_eq!(checkouts.len(), 1, "referral reservation was created");
assert_eq!(checkouts[0]["referral_invite_id"], json!(invite_id));
assert_eq!(checkouts[0]["status"], json!("failed"));
// The invite reservation was released and the invite is still unused.
let invite = env.pb.record("invites", &invite_id).expect("invite exists");
assert_eq!(invite["used_by_id"], json!(""));
assert_eq!(invite["reserved_checkout_id"], json!(""));
assert_eq!(invite["reserved_by_id"], json!(""));
assert_eq!(env.pb.license_grant_count("user3"), 0);
}

View file

@ -182,8 +182,7 @@ impl CrimeByYearData {
// Force-coverage calendar (optional column: legacy parquets predate it;
// their postcodes are treated as fully covered). A row with an empty
// list is meaningful — zero covered years — so it IS inserted.
let mut covered_years_by_postcode: FxHashMap<String, Vec<i32>> =
FxHashMap::default();
let mut covered_years_by_postcode: FxHashMap<String, Vec<i32>> = FxHashMap::default();
if let Ok(col) = df.column(COVERAGE_COLUMN) {
let list_ca = col
.list()
@ -195,12 +194,12 @@ impl CrimeByYearData {
};
let mut years: Vec<i32> = Vec::with_capacity(inner.len());
if !inner.is_empty() {
let structs = inner.struct_().with_context(|| {
format!("Inner of '{COVERAGE_COLUMN}' is not a struct")
})?;
let year_field = structs.field_by_name("year").with_context(|| {
format!("Missing 'year' field in '{COVERAGE_COLUMN}'")
})?;
let structs = inner
.struct_()
.with_context(|| format!("Inner of '{COVERAGE_COLUMN}' is not a struct"))?;
let year_field = structs
.field_by_name("year")
.with_context(|| format!("Missing 'year' field in '{COVERAGE_COLUMN}'"))?;
for idx in 0..inner.len() {
match year_field.get(idx).ok() {
Some(AnyValue::Int32(y)) => years.push(y),

View file

@ -742,6 +742,29 @@ impl PlaceData {
}
}
#[cfg(test)]
impl PlaceData {
/// Minimal empty instance for integration tests that need an `AppState`
/// but never touch place data.
pub(crate) fn empty_for_tests() -> Self {
PlaceData {
name: Vec::new(),
name_lower: Vec::new(),
name_search: Vec::new(),
place_type: InternedColumn::build(&[]),
type_rank: Vec::new(),
population: Vec::new(),
lat: Vec::new(),
lon: Vec::new(),
city: Vec::new(),
travel_destination: Vec::new(),
token_index: FxHashMap::default(),
token_prefix_index: FxHashMap::default(),
fuzzy_trigram_index: FxHashMap::default(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;

View file

@ -588,6 +588,29 @@ impl POIData {
}
}
#[cfg(test)]
impl POIData {
/// Minimal empty instance for integration tests that need an `AppState`
/// but never touch POI data.
pub(crate) fn empty_for_tests() -> Self {
POIData {
id_buffer: String::new(),
id_offsets: Vec::new(),
id_lengths: Vec::new(),
group: InternedColumn::build(&[]),
category: InternedColumn::build(&[]),
icon_category: InternedColumn::build(&[]),
name: Vec::new(),
lat: Vec::new(),
lng: Vec::new(),
emoji: InternedColumn::build(&[]),
priority: Vec::new(),
school_meta_idx: Vec::new(),
school_meta: Vec::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,973 @@
//! Address search: tokenization, query parsing, inverted/prefix indexes and the
//! ranked per-row search over property addresses.
use rustc_hash::{FxHashMap, FxHashSet};
use super::PropertyData;
/// Upper bound on rows scored per query. Intersection keeps most candidate sets far below
/// this; only a single very common road word (e.g. "high") approaches it, and the in-area
/// priority sort keeps a refined query's matches ahead of the cut.
const ADDRESS_SEARCH_CANDIDATE_LIMIT: usize = 150_000;
const ADDRESS_SEARCH_PREFIX_MIN_LEN: usize = 4;
const ADDRESS_SEARCH_PREFIX_MAX_LEN: usize = 8;
#[derive(Clone, Debug)]
pub(super) struct AddressTermGroup {
alternatives: Vec<String>,
}
#[derive(Debug)]
pub(super) struct AddressQuery {
full_postcode: Option<String>,
/// Compact uppercase outward code (optionally with a sector digit) recovered when the
/// user appended a partial postcode like "NW1" or "NW1 6". Used as an additive ranking
/// bias, never as a hard filter — so the disambiguating hint is honoured without
/// excluding the same road in other areas.
postcode_area: Option<String>,
text_groups: Vec<AddressTermGroup>,
numeric_terms: Vec<String>,
candidate_terms: Vec<String>,
}
fn tokenize_address_text(text: &str) -> Vec<String> {
let mut tokens = Vec::new();
let mut current = String::new();
for ch in text.chars() {
if ch.is_ascii_alphanumeric() {
current.push(ch.to_ascii_lowercase());
} else if matches!(ch, '\'' | '' | '`') {
continue;
} else if !current.is_empty() {
tokens.push(std::mem::take(&mut current));
}
}
if !current.is_empty() {
tokens.push(current);
}
tokens
}
fn is_full_postcode_compact(compact: &str) -> bool {
let bytes = compact.as_bytes();
let len = bytes.len();
if !(5..=7).contains(&len) {
return false;
}
let inward = &bytes[len - 3..];
if !inward[0].is_ascii_digit()
|| !inward[1].is_ascii_alphabetic()
|| !inward[2].is_ascii_alphabetic()
{
return false;
}
let outward = &bytes[..len - 3];
if !(2..=4).contains(&outward.len()) {
return false;
}
outward[0].is_ascii_alphabetic()
&& outward.iter().all(u8::is_ascii_alphanumeric)
&& outward.iter().any(u8::is_ascii_digit)
}
fn canonical_postcode_from_compact(compact: &str) -> String {
let upper = compact.to_ascii_uppercase();
let split = upper.len() - 3;
format!("{} {}", &upper[..split], &upper[split..])
}
fn extract_full_postcode(tokens: &[String]) -> Option<(String, Vec<usize>)> {
for (idx, token) in tokens.iter().enumerate() {
let compact = token.to_ascii_uppercase();
if is_full_postcode_compact(&compact) {
return Some((canonical_postcode_from_compact(&compact), vec![idx]));
}
}
for idx in 0..tokens.len().saturating_sub(1) {
let compact = format!(
"{}{}",
tokens[idx].to_ascii_uppercase(),
tokens[idx + 1].to_ascii_uppercase()
);
if is_full_postcode_compact(&compact) {
return Some((
canonical_postcode_from_compact(&compact),
vec![idx, idx + 1],
));
}
}
None
}
fn looks_like_postcode_fragment(token: &str) -> bool {
(2..=4).contains(&token.len())
&& token
.chars()
.next()
.is_some_and(|ch| ch.is_ascii_alphabetic())
&& token.chars().any(|ch| ch.is_ascii_digit())
&& token.chars().all(|ch| ch.is_ascii_alphanumeric())
}
fn is_numeric_address_token(token: &str) -> bool {
token.chars().all(|ch| ch.is_ascii_digit())
}
fn address_token_aliases(token: &str) -> Vec<&'static str> {
match token {
"apt" => vec!["apt", "apartment"],
"apartment" => vec!["apartment", "apt"],
"ave" => vec!["ave", "avenue"],
"avenue" => vec!["avenue", "ave"],
"blvd" => vec!["blvd", "boulevard"],
"boulevard" => vec!["boulevard", "blvd"],
"cl" => vec!["cl", "close"],
"close" => vec!["close", "cl"],
"ct" => vec!["ct", "court"],
"court" => vec!["court", "ct"],
"cres" => vec!["cres", "crescent"],
"crescent" => vec!["crescent", "cres"],
"dr" => vec!["dr", "drive"],
"drive" => vec!["drive", "dr"],
"fl" => vec!["fl", "flat"],
"flat" => vec!["flat", "fl"],
"gdns" => vec!["gdns", "gardens", "garden"],
"garden" => vec!["garden", "gardens", "gdns"],
"gardens" => vec!["gardens", "garden", "gdns"],
"hse" => vec!["hse", "house"],
"house" => vec!["house", "hse"],
"ln" => vec!["ln", "lane"],
"lane" => vec!["lane", "ln"],
"rd" => vec!["rd", "road"],
"road" => vec!["road", "rd"],
"sq" => vec!["sq", "square"],
"square" => vec!["square", "sq"],
"st" => vec!["st", "street", "saint"],
"street" => vec!["street", "st"],
"saint" => vec!["saint", "st"],
"terr" => vec!["terr", "terrace"],
"terrace" => vec!["terrace", "terr"],
_ => Vec::new(),
}
}
fn is_address_stop_token(token: &str) -> bool {
matches!(
token,
"a" | "an"
| "and"
| "apartment"
| "apt"
| "avenue"
| "ave"
| "block"
| "building"
| "bungalow"
| "close"
| "cl"
| "court"
| "ct"
| "cres"
| "crescent"
| "drive"
| "dr"
| "estate"
| "flat"
| "fl"
| "floor"
| "garden"
| "gardens"
| "gdns"
| "grove"
| "house"
| "hse"
| "lane"
| "ln"
| "lodge"
| "mansions"
| "mews"
| "of"
| "park"
| "place"
| "road"
| "rd"
| "room"
| "row"
| "saint"
| "sq"
| "square"
| "st"
| "street"
| "terr"
| "terrace"
| "the"
| "unit"
| "view"
| "villas"
| "walk"
| "way"
| "yard"
)
}
fn address_term_group(token: &str) -> Option<AddressTermGroup> {
if token.len() < 3 || is_numeric_address_token(token) || looks_like_postcode_fragment(token) {
return None;
}
let mut alternatives = Vec::new();
alternatives.push(token.to_string());
for alias in address_token_aliases(token) {
if !alternatives.iter().any(|existing| existing == alias) {
alternatives.push(alias.to_string());
}
}
if alternatives
.iter()
.all(|alternative| is_address_stop_token(alternative))
{
return None;
}
Some(AddressTermGroup { alternatives })
}
pub(super) fn address_search_tokens(text: &str) -> Vec<String> {
let mut tokens: Vec<String> = tokenize_address_text(text)
.into_iter()
.filter(|token| is_address_search_token(token))
.collect();
tokens.sort_unstable();
tokens.dedup();
tokens
}
fn is_address_search_token(token: &str) -> bool {
if looks_like_postcode_fragment(token) {
return false;
}
if is_numeric_address_token(token) {
return true;
}
if token.chars().any(|ch| ch.is_ascii_digit()) {
return token.len() >= 2;
}
token.len() >= 3
}
pub(super) fn is_address_candidate_token(token: &str) -> bool {
!is_numeric_address_token(token)
&& !looks_like_postcode_fragment(token)
&& (token.chars().any(|ch| ch.is_ascii_digit())
|| (token.len() >= 3 && !is_address_stop_token(token)))
}
fn address_prefix_key(term: &str) -> &str {
if term.len() > ADDRESS_SEARCH_PREFIX_MAX_LEN {
&term[..ADDRESS_SEARCH_PREFIX_MAX_LEN]
} else {
term
}
}
pub(super) fn build_address_prefix_index(
address_token_index: &FxHashMap<String, Vec<u32>>,
) -> FxHashMap<String, Vec<String>> {
let mut prefix_index: FxHashMap<String, Vec<String>> = FxHashMap::default();
for token in address_token_index.keys() {
let max_prefix_len = token.len().min(ADDRESS_SEARCH_PREFIX_MAX_LEN);
for prefix_len in ADDRESS_SEARCH_PREFIX_MIN_LEN..=max_prefix_len {
prefix_index
.entry(token[..prefix_len].to_string())
.or_default()
.push(token.clone());
}
}
for tokens in prefix_index.values_mut() {
tokens.sort_unstable();
tokens.dedup();
}
prefix_index
}
/// Intersect two ascending-sorted row-id slices.
fn intersect_sorted(left: &[u32], right: &[u32]) -> Vec<u32> {
let mut out = Vec::new();
let (mut i, mut j) = (0, 0);
while i < left.len() && j < right.len() {
match left[i].cmp(&right[j]) {
std::cmp::Ordering::Less => i += 1,
std::cmp::Ordering::Greater => j += 1,
std::cmp::Ordering::Equal => {
out.push(left[i]);
i += 1;
j += 1;
}
}
}
out
}
/// Union two ascending-sorted row-id slices (deduplicated, stays sorted).
fn union_sorted(left: &[u32], right: &[u32]) -> Vec<u32> {
let mut out = Vec::with_capacity(left.len() + right.len());
let (mut i, mut j) = (0, 0);
while i < left.len() && j < right.len() {
match left[i].cmp(&right[j]) {
std::cmp::Ordering::Less => {
out.push(left[i]);
i += 1;
}
std::cmp::Ordering::Greater => {
out.push(right[j]);
j += 1;
}
std::cmp::Ordering::Equal => {
out.push(left[i]);
i += 1;
j += 1;
}
}
}
out.extend_from_slice(&left[i..]);
out.extend_from_slice(&right[j..]);
out
}
/// An ordinal like "1st", "2nd", "3rd", "21st" — part of the street name ("2nd Avenue"), not a
/// house-number prefix.
fn is_ordinal_token(token: &str) -> bool {
let split = token.len().saturating_sub(2);
let (digits, suffix) = token.split_at(split);
!digits.is_empty()
&& digits.chars().all(|ch| ch.is_ascii_digit())
&& matches!(suffix, "st" | "nd" | "rd" | "th")
}
/// Leading address tokens that denote a unit/house number rather than the street itself.
fn is_house_prefix_token(token: &str) -> bool {
if is_ordinal_token(token) {
return false;
}
matches!(
token,
"flat" | "fl" | "apartment" | "apt" | "unit" | "no" | "block" | "floor" | "room"
) || token.len() == 1
|| token.chars().all(|ch| ch.is_ascii_digit())
|| (token.chars().next().is_some_and(|ch| ch.is_ascii_digit())
&& token.chars().any(|ch| ch.is_ascii_alphabetic()))
}
/// Street-level key for an address: drops the leading house-number / flat prefix so that
/// "12 Baker Street" and "5 Baker Street" collapse to a single street entry.
fn street_key(address: &str) -> String {
let tokens = tokenize_address_text(address);
let mut start = 0;
while start < tokens.len() && is_house_prefix_token(&tokens[start]) {
start += 1;
}
if start >= tokens.len() {
return tokens.join(" ");
}
tokens[start..].join(" ")
}
/// Road-type words. Their presence (with no house number) marks a road browse, which we
/// collapse to one result per street.
const ROAD_TYPE_TOKENS: &[&str] = &[
"street",
"st",
"road",
"rd",
"lane",
"ln",
"avenue",
"ave",
"close",
"cl",
"drive",
"dr",
"way",
"court",
"ct",
"crescent",
"cres",
"place",
"terrace",
"terr",
"grove",
"gardens",
"gdns",
"walk",
"row",
"square",
"sq",
"hill",
"parade",
"mews",
"embankment",
"broadway",
"boulevard",
"blvd",
];
fn query_has_road_type(query: &str) -> bool {
tokenize_address_text(query)
.iter()
.any(|token| ROAD_TYPE_TOKENS.contains(&token.as_str()))
}
/// The outward code (everything before the space) of a canonical postcode.
fn outcode_of(postcode: &str) -> &str {
postcode.split(' ').next().unwrap_or(postcode)
}
fn parse_address_query(query: &str) -> AddressQuery {
let tokens = tokenize_address_text(query);
let (full_postcode, postcode_token_indices) = extract_full_postcode(&tokens)
.map(|(postcode, indices)| (Some(postcode), indices))
.unwrap_or((None, Vec::new()));
let skip_postcode_tokens: FxHashSet<usize> = postcode_token_indices.into_iter().collect();
// Recover an appended partial postcode (outcode, or outcode + sector digit) as a ranking
// bias rather than discarding it — but only from the TRAILING position, so a leading road
// designation like "A4 Great West Road" is not mistaken for an area refinement.
let mut postcode_area: Option<String> = None;
let mut consumed_partial_tokens: FxHashSet<usize> = FxHashSet::default();
if full_postcode.is_none() && !tokens.is_empty() {
let last = tokens.len() - 1;
if !skip_postcode_tokens.contains(&last) {
let sector_digit =
tokens[last].len() == 1 && tokens[last].chars().all(|ch| ch.is_ascii_digit());
if last >= 1
&& sector_digit
&& !skip_postcode_tokens.contains(&(last - 1))
&& looks_like_postcode_fragment(&tokens[last - 1])
{
postcode_area = Some(format!(
"{}{}",
tokens[last - 1].to_ascii_uppercase(),
tokens[last]
));
consumed_partial_tokens.insert(last);
consumed_partial_tokens.insert(last - 1);
} else if looks_like_postcode_fragment(&tokens[last]) {
postcode_area = Some(tokens[last].to_ascii_uppercase());
consumed_partial_tokens.insert(last);
}
}
}
let mut text_groups = Vec::new();
let mut numeric_terms = Vec::new();
let mut candidate_terms = Vec::new();
for (idx, token) in tokens.iter().enumerate() {
if skip_postcode_tokens.contains(&idx)
|| consumed_partial_tokens.contains(&idx)
|| looks_like_postcode_fragment(token)
{
continue;
}
if is_numeric_address_token(token) {
numeric_terms.push(token.clone());
continue;
}
if let Some(group) = address_term_group(token) {
for alternative in &group.alternatives {
if !is_address_stop_token(alternative)
&& !candidate_terms.iter().any(|term| term == alternative)
{
candidate_terms.push(alternative.clone());
}
}
text_groups.push(group);
} else if token.chars().any(|ch| ch.is_ascii_digit()) && token.len() >= 2 {
numeric_terms.push(token.clone());
if !candidate_terms.iter().any(|term| term == token) {
candidate_terms.push(token.clone());
}
}
}
text_groups.dedup_by(|left, right| left.alternatives == right.alternatives);
numeric_terms.sort_unstable();
numeric_terms.dedup();
AddressQuery {
full_postcode,
postcode_area,
text_groups,
numeric_terms,
candidate_terms,
}
}
fn token_matches_query_term(token: &str, query_term: &str) -> bool {
token == query_term || (query_term.len() >= 3 && token.starts_with(query_term))
}
fn token_matches_numeric_term(token: &str, query_term: &str) -> bool {
token == query_term || token.starts_with(query_term)
}
#[cfg(test)]
fn address_tokens_match_group(tokens: &[String], group: &AddressTermGroup) -> bool {
group.alternatives.iter().any(|alternative| {
tokens
.iter()
.any(|token| token_matches_query_term(token, alternative))
})
}
impl PropertyData {
fn row_address_search_tokens(&self, row: usize) -> &[lasso::Spur] {
let offset = self.address_search_token_offsets[row] as usize;
let length = self.address_search_token_lengths[row] as usize;
&self.address_search_token_keys[offset..offset + length]
}
/// Search individual property addresses, returning `(row, score)` ranked best-first.
///
/// Candidate rows come from intersecting the posting lists of the distinctive words the
/// user typed in full (so "Cherry Hinton Road" narrows to rows containing both), unioned
/// with the exact-postcode rows when a complete postcode is present (so a postcode is a
/// boost, not an all-or-nothing gate). An appended partial postcode keeps in-area rows
/// ahead of the candidate cut and adds a scoring bias. With a road-type word and no house
/// number, results collapse to one row per street.
pub fn search_addresses(&self, query: &str, limit: usize) -> Vec<(usize, i32)> {
if limit == 0 {
return Vec::new();
}
let parsed = parse_address_query(query);
if parsed.full_postcode.is_none()
&& parsed.text_groups.is_empty()
&& parsed.numeric_terms.is_empty()
{
return Vec::new();
}
let mut candidate_rows = self.address_candidate_rows(&parsed.candidate_terms);
// A complete postcode contributes its rows too, instead of replacing the road match.
if let Some(postcode) = parsed.full_postcode.as_deref() {
if let Some(rows) = self
.postcode_interner
.get(postcode)
.and_then(|key| self.postcode_row_index.get(&key))
{
candidate_rows = if candidate_rows.is_empty() {
rows.clone()
} else {
union_sorted(&candidate_rows, rows)
};
}
}
if candidate_rows.is_empty() {
return Vec::new();
}
// When the user appended a partial postcode, keep in-area rows ahead of the cut so the
// refinement still surfaces even for very common roads. Single pass (stable partition) so
// the postcode check — which allocates — runs exactly once per candidate.
if let Some(area) = parsed.postcode_area.as_deref() {
let mut in_area = Vec::new();
let mut others = Vec::new();
for &row in &candidate_rows {
if self.row_postcode_in_area(row as usize, area) {
in_area.push(row);
} else {
others.push(row);
}
}
in_area.extend(others);
candidate_rows = in_area;
}
candidate_rows.truncate(ADDRESS_SEARCH_CANDIDATE_LIMIT);
let mut scored: Vec<(i32, usize, usize)> = candidate_rows
.into_iter()
.filter_map(|row| {
let row = row as usize;
self.address_match_score(row, &parsed)
.map(|score| (score, self.address(row).len(), row))
})
.collect();
scored.sort_unstable_by(|left, right| {
right
.0
.cmp(&left.0)
.then(left.1.cmp(&right.1))
.then(left.2.cmp(&right.2))
});
// Collapse a road browse (road-type word, no house number) to one row per street.
let collapse_streets = parsed.numeric_terms.is_empty() && query_has_road_type(query);
let mut seen = FxHashSet::default();
let mut results = Vec::with_capacity(limit);
for (score, _, row) in scored {
let address = self.address(row).trim();
if address.is_empty() {
continue;
}
let key = if collapse_streets {
format!(
"{}\n{}",
street_key(address),
outcode_of(self.postcode(row))
)
} else {
format!("{}\n{}", address.to_ascii_lowercase(), self.postcode(row))
};
if !seen.insert(key) {
continue;
}
results.push((row, score));
if results.len() == limit {
break;
}
}
results
}
/// True when the row's postcode begins with the compact partial-postcode `area`
/// (e.g. "NW1" or "NW16" matches "NW1 6XE").
fn row_postcode_in_area(&self, row: usize, area: &str) -> bool {
let mut compact = String::new();
for ch in self.postcode(row).chars() {
if !ch.is_whitespace() {
compact.push(ch.to_ascii_uppercase());
}
}
compact.starts_with(area)
}
/// Candidate rows for the distinctive query words. Words typed in full intersect by their
/// exact posting lists (precise); a still-being-typed final word with no exact match seeds
/// from the smallest prefix-expanded posting list (so partial typing keeps working).
fn address_candidate_rows(&self, terms: &[String]) -> Vec<u32> {
let mut exact: Vec<&[u32]> = terms
.iter()
.filter_map(|term| self.address_token_index.get(term).map(Vec::as_slice))
.collect();
if !exact.is_empty() {
exact.sort_by_key(|rows| rows.len());
let mut acc = exact[0].to_vec();
for rows in &exact[1..] {
if acc.is_empty() {
break;
}
acc = intersect_sorted(&acc, rows);
}
return acc;
}
self.prefix_seed_rows(terms)
}
/// Seed rows from the smallest prefix-expanded term — used only when no word matched an
/// indexed token exactly (i.e. the user is still typing the final word).
fn prefix_seed_rows(&self, terms: &[String]) -> Vec<u32> {
let mut best: Option<Vec<u32>> = None;
for term in terms {
if term.len() < ADDRESS_SEARCH_PREFIX_MIN_LEN {
continue;
}
let Some(tokens) = self.address_prefix_index.get(address_prefix_key(term)) else {
continue;
};
let mut union: Vec<u32> = Vec::new();
for token in tokens {
if !token.starts_with(term) {
continue;
}
if let Some(rows) = self.address_token_index.get(token) {
union = if union.is_empty() {
rows.clone()
} else {
union_sorted(&union, rows)
};
}
}
if !union.is_empty()
&& best
.as_ref()
.is_none_or(|current| union.len() < current.len())
{
best = Some(union);
}
}
best.unwrap_or_default()
}
fn address_match_score(&self, row: usize, parsed: &AddressQuery) -> Option<i32> {
if self.address(row).trim().is_empty() {
return None;
}
let tokens = self.row_address_search_tokens(row);
if parsed
.text_groups
.iter()
.any(|group| !self.address_tokens_match_group(tokens, group))
{
return None;
}
let numeric_matches = parsed
.numeric_terms
.iter()
.filter(|term| {
tokens.iter().any(|token| {
token_matches_numeric_term(self.address_search_interner.resolve(token), term)
})
})
.count();
if !parsed.numeric_terms.is_empty() && numeric_matches == 0 {
return None;
}
let mut score = 0;
if parsed.full_postcode.is_some() {
score += 1_000;
}
score += (parsed.text_groups.len() as i32) * 200;
score += (numeric_matches as i32) * 90;
if numeric_matches == parsed.numeric_terms.len() && numeric_matches > 0 {
score += 50;
}
// Additive bias (never a filter) when the row sits in the appended partial postcode.
if let Some(area) = parsed.postcode_area.as_deref() {
if self.row_postcode_in_area(row, area) {
score += 400;
}
}
Some(score)
}
fn address_tokens_match_group(&self, tokens: &[lasso::Spur], group: &AddressTermGroup) -> bool {
group.alternatives.iter().any(|alternative| {
tokens.iter().any(|token| {
token_matches_query_term(self.address_search_interner.resolve(token), alternative)
})
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn full_postcode_detection_accepts_common_formats() {
assert!(is_full_postcode_compact("SW1A1AA"));
assert!(is_full_postcode_compact("E142DG"));
assert!(is_full_postcode_compact("M11AE"));
assert!(!is_full_postcode_compact("E14"));
assert!(!is_full_postcode_compact("DOWNING"));
assert!(!is_full_postcode_compact("10A"));
}
#[test]
fn address_query_parsing_skips_postcodes_and_street_suffixes() {
let parsed = parse_address_query("Flat 2, 10 Downing St, SW1A 2AA");
assert_eq!(parsed.full_postcode.as_deref(), Some("SW1A 2AA"));
assert_eq!(
parsed.numeric_terms,
vec!["10".to_string(), "2".to_string()]
);
assert_eq!(parsed.candidate_terms, vec!["downing".to_string()]);
assert_eq!(parsed.text_groups.len(), 1);
assert_eq!(
parsed.text_groups[0].alternatives,
vec!["downing".to_string()]
);
}
#[test]
fn address_query_parsing_handles_compact_postcodes() {
let parsed = parse_address_query("10 downing street sw1a1aa");
assert_eq!(parsed.full_postcode.as_deref(), Some("SW1A 1AA"));
assert_eq!(parsed.numeric_terms, vec!["10".to_string()]);
assert_eq!(parsed.candidate_terms, vec!["downing".to_string()]);
}
#[test]
fn address_query_recovers_appended_partial_postcode_as_bias() {
let parsed = parse_address_query("Baker Street NW1");
assert_eq!(parsed.full_postcode, None);
assert_eq!(parsed.postcode_area.as_deref(), Some("NW1"));
// The road words are still searchable; the postcode fragment did not consume them.
assert_eq!(parsed.candidate_terms, vec!["baker".to_string()]);
assert!(parsed.numeric_terms.is_empty());
}
#[test]
fn address_query_recovers_outcode_plus_sector_without_a_phantom_house_number() {
let parsed = parse_address_query("High Street CR0 2");
assert_eq!(parsed.postcode_area.as_deref(), Some("CR02"));
// The lone sector digit must not be treated as a house number.
assert!(parsed.numeric_terms.is_empty());
assert_eq!(parsed.candidate_terms, vec!["high".to_string()]);
}
#[test]
fn full_postcode_takes_precedence_over_partial_bias() {
let parsed = parse_address_query("Baker Street NW1 6XE");
assert_eq!(parsed.full_postcode.as_deref(), Some("NW1 6XE"));
assert_eq!(parsed.postcode_area, None);
}
#[test]
fn intersect_and_union_sorted_row_ids() {
assert_eq!(
intersect_sorted(&[1, 2, 3, 5], &[2, 3, 4, 5]),
vec![2, 3, 5]
);
assert_eq!(intersect_sorted(&[1, 2], &[3, 4]), Vec::<u32>::new());
assert_eq!(union_sorted(&[1, 3, 5], &[2, 3, 4]), vec![1, 2, 3, 4, 5]);
assert_eq!(union_sorted(&[], &[2, 4]), vec![2, 4]);
}
#[test]
fn street_key_collapses_house_numbers_and_flats() {
assert_eq!(street_key("12 Baker Street"), "baker street");
assert_eq!(street_key("5 Baker Street"), "baker street");
assert_eq!(street_key("Flat 2, 10 Downing Street"), "downing street");
assert_eq!(street_key("221B Baker Street"), "baker street");
}
#[test]
fn street_key_keeps_ordinal_street_names() {
// Ordinals are part of the street name, not a house-number prefix.
assert_eq!(street_key("2nd Avenue"), "2nd avenue");
assert_eq!(street_key("12 3rd Avenue"), "3rd avenue");
assert!(is_ordinal_token("21st"));
assert!(!is_ordinal_token("21"));
assert!(!is_ordinal_token("221b"));
}
#[test]
fn postcode_area_recovered_only_from_the_trailing_position() {
// A leading road designation must NOT be taken as an area refinement.
let parsed = parse_address_query("A4 Great West Road");
assert_eq!(parsed.postcode_area, None);
// A genuine trailing outcode still is.
let trailing = parse_address_query("Great West Road W4");
assert_eq!(trailing.postcode_area.as_deref(), Some("W4"));
}
#[test]
fn road_type_detection() {
assert!(query_has_road_type("high street"));
assert!(query_has_road_type("acacia avenue"));
assert!(!query_has_road_type("acacia"));
assert!(!query_has_road_type("london"));
}
#[test]
fn address_query_parsing_keeps_partial_terms_for_row_matching() {
let parsed = parse_address_query("settlers cour");
assert_eq!(parsed.full_postcode, None);
assert_eq!(parsed.numeric_terms, Vec::<String>::new());
assert_eq!(
parsed.candidate_terms,
vec!["settlers".to_string(), "cour".to_string()]
);
assert_eq!(parsed.text_groups.len(), 2);
assert_eq!(
parsed.text_groups[0].alternatives,
vec!["settlers".to_string()]
);
assert_eq!(parsed.text_groups[1].alternatives, vec!["cour".to_string()]);
}
#[test]
fn address_search_tokens_keep_actual_address_terms_for_scoring() {
let tokens = address_search_tokens("Flat 2, 10 Downing Cour");
assert_eq!(
tokens,
vec![
"10".to_string(),
"2".to_string(),
"cour".to_string(),
"downing".to_string(),
"flat".to_string()
]
);
}
#[test]
fn address_prefix_index_finds_partial_address_terms() {
let mut token_index: FxHashMap<String, Vec<u32>> = FxHashMap::default();
token_index.insert("downing".to_string(), vec![1]);
token_index.insert("downton".to_string(), vec![2]);
token_index.insert("market".to_string(), vec![3]);
let prefix_index = build_address_prefix_index(&token_index);
assert_eq!(
prefix_index.get("down").cloned().unwrap_or_default(),
vec!["downing".to_string(), "downton".to_string()]
);
assert_eq!(
prefix_index.get("downi").cloned().unwrap_or_default(),
vec!["downing".to_string()]
);
assert_eq!(
prefix_index.get("downt").cloned().unwrap_or_default(),
vec!["downton".to_string()]
);
assert!(!prefix_index.contains_key("do"));
}
#[test]
fn address_term_matching_allows_prefixes_and_aliases() {
let tokens = tokenize_address_text("10 Downing Street");
let prefix_group = address_term_group("down").expect("prefix term should be searchable");
let alias_group = AddressTermGroup {
alternatives: vec!["st".to_string(), "street".to_string()],
};
assert!(address_tokens_match_group(&tokens, &prefix_group));
assert!(address_tokens_match_group(&tokens, &alias_group));
}
#[test]
fn address_term_matching_uses_actual_token_prefixes() {
let tokens = tokenize_address_text("12 Settlers Court");
let prefix_group = address_term_group("cou").expect("partial term should be searchable");
assert!(address_tokens_match_group(&tokens, &prefix_group));
}
}

View file

@ -0,0 +1,34 @@
//! H3 spatial cell precomputation for property rows.
use anyhow::Context;
use rayon::prelude::*;
use crate::consts::H3_PRECOMPUTE_MAX;
/// Precompute H3 cell IDs for all rows at the maximum resolution only.
/// Parent cells for lower resolutions are derived on the fly via `CellIndex::parent()`.
pub fn precompute_h3(lat: &[f32], lon: &[f32]) -> anyhow::Result<Vec<u64>> {
let res = H3_PRECOMPUTE_MAX;
tracing::info!("Precomputing H3 cells at resolution {}", res);
let h3_res =
h3o::Resolution::try_from(res).with_context(|| format!("Invalid H3 resolution: {res}"))?;
let cells: Vec<u64> = lat
.par_iter()
.zip(lon.par_iter())
.enumerate()
.map(|(i, (&latitude, &longitude))| {
let coord = h3o::LatLng::new(latitude as f64, longitude as f64).unwrap_or_else(|err| {
panic!(
"Invalid coordinates at row {}: lat={}, lon={}: {}",
i, latitude, longitude, err
)
});
u64::from(coord.to_cell(h3_res))
})
.collect();
tracing::info!("H3 precomputation complete ({} cells)", cells.len());
Ok(cells)
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,238 @@
//! Property data: the row-major quantized feature matrix plus the side tables
//! (addresses, postcodes, renovation/price history, POI metrics) built from the
//! properties + postcode parquet files.
//!
//! Split by concern:
//! - [`loading`]: parquet ingestion, validation, spatial sort and matrix build
//! - [`stats`]: histograms, percentiles and slider-bound computation
//! - [`quant`]: u16 quantization encode/decode
//! - [`poi_metrics`]: postcode-level POI metric side table
//! - [`address_search`]: address tokenization, indexing and ranked search
//! - [`h3`]: H3 cell precomputation
mod address_search;
mod h3;
mod loading;
mod poi_metrics;
mod quant;
mod stats;
pub use h3::precompute_h3;
pub use poi_metrics::PostcodePoiMetrics;
pub use quant::QuantRef;
pub use stats::{FeatureStats, Histogram};
use rustc_hash::FxHashMap;
use serde::Serialize;
use crate::consts::NAN_U16;
#[derive(Serialize, Clone)]
pub struct RenovationEvent {
pub year: i32,
pub event: String,
}
#[derive(Serialize, Clone)]
pub struct HistoricalPrice {
pub year: i32,
pub month: u8,
pub price: i64,
}
pub struct PropertyData {
pub lat: Vec<f32>,
pub lon: Vec<f32>,
pub feature_names: Vec<String>,
pub num_features: usize,
/// Number of numeric features (enum features start at this index).
pub num_numeric: usize,
/// Row-major flat array: feature_data[row * num_features + feat_idx].
/// Quantized to u16. NaN sentinel = u16::MAX (65535).
/// Numeric features: encoded via (val - min) / range * 65534.
/// Enum features: stored directly as u16 cast of the f32 index.
pub feature_data: Vec<u16>,
/// Per-feature: range / QUANT_SCALE for fast decode.
dequant_a: Vec<f32>,
/// Per-feature: minimum value (offset for dequantization).
quant_min: Vec<f32>,
/// Per-feature: max - min (for encoding filter bounds).
quant_range: Vec<f32>,
pub feature_stats: Vec<FeatureStats>,
pub poi_metrics: PostcodePoiMetrics,
/// Unquantized last sale price used by the price-history chart.
last_known_price_raw: Vec<f32>,
/// Contiguous buffer holding all address strings end-to-end.
address_buffer: String,
/// Byte offset into `address_buffer` where each row's address starts.
address_offsets: Vec<u32>,
/// Length in bytes of each row's address.
address_lengths: Vec<u16>,
/// Interned postcodes: reader is thread-safe, keys index into it.
postcode_interner: lasso::RodeoReader,
postcode_keys: Vec<lasso::Spur>,
/// Rows for each postcode, keyed by the interned postcode key.
postcode_row_index: FxHashMap<lasso::Spur, Vec<u32>>,
/// Inverted index from address tokens to property rows.
address_token_index: FxHashMap<String, Vec<u32>>,
/// Prefix lookup from typed address-token prefix to indexed full address tokens.
address_prefix_index: FxHashMap<String, Vec<String>>,
/// Interned normalized address-search tokens used for per-row scoring.
address_search_interner: lasso::RodeoReader,
/// Flat per-row normalized address-search token keys.
address_search_token_keys: Vec<lasso::Spur>,
/// Offset into `address_search_token_keys` for each row.
address_search_token_offsets: Vec<u32>,
/// Number of normalized address-search token keys for each row.
address_search_token_lengths: Vec<u16>,
/// For enum features: maps feature index to list of possible string values.
/// Index in values list corresponds to the u16 value stored in feature_data.
pub enum_values: rustc_hash::FxHashMap<usize, Vec<String>>,
/// For enum features: maps feature index to per-value global counts (same order as enum_values).
pub enum_counts: rustc_hash::FxHashMap<usize, Vec<u64>>,
/// Per-row flag: true = construction date is approximate (from EPC band),
/// false = exact (from new-build transaction date).
/// Bit-packed: byte `row / 8`, bit `row % 8`. 8x smaller than Vec<bool>.
approx_build_date_bits: Vec<u8>,
/// Per-row renovation events. Keyed by (permuted) row index.
/// Only rows with events are present in the map.
renovation_history: FxHashMap<u32, Vec<RenovationEvent>>,
/// Per-row historical sale transactions (Land Registry price-paid).
/// Keyed by (permuted) row index. Only rows with prices are present.
historical_prices: FxHashMap<u32, Vec<HistoricalPrice>>,
property_sub_type: FxHashMap<u32, String>,
price_qualifier: FxHashMap<u32, String>,
}
impl PropertyData {
/// Get the address string for a given row.
pub fn address(&self, row: usize) -> &str {
let offset = self.address_offsets[row] as usize;
let length = self.address_lengths[row] as usize;
&self.address_buffer[offset..offset + length]
}
/// Get the postcode string for a given row.
pub fn postcode(&self, row: usize) -> &str {
self.postcode_interner.resolve(&self.postcode_keys[row])
}
/// Get postcode components for field-level borrowing (avoids conflicting borrows with feature_data).
pub fn postcode_parts(&self) -> (&lasso::RodeoReader, &[lasso::Spur]) {
(&self.postcode_interner, &self.postcode_keys)
}
/// Property rows for a given postcode string, or empty if unknown.
pub fn rows_for_postcode(&self, postcode: &str) -> &[u32] {
self.postcode_interner
.get(postcode)
.and_then(|key| self.postcode_row_index.get(&key))
.map(Vec::as_slice)
.unwrap_or(&[])
}
/// Get the is_approx_build_date flag for a given row (bit-packed).
pub fn is_approx_build_date(&self, row: usize) -> bool {
let byte = self.approx_build_date_bits[row / 8];
byte & (1 << (row % 8)) != 0
}
/// Get renovation events for a given row (empty slice if none).
pub fn renovation_history(&self, row: usize) -> &[RenovationEvent] {
self.renovation_history
.get(&(row as u32))
.map(|v| v.as_slice())
.unwrap_or(&[])
}
/// Get historical sale transactions for a given row (empty slice if none).
pub fn historical_prices(&self, row: usize) -> &[HistoricalPrice] {
self.historical_prices
.get(&(row as u32))
.map(|v| v.as_slice())
.unwrap_or(&[])
}
/// Get property sub-type for a given row.
pub fn property_sub_type(&self, row: usize) -> Option<&str> {
self.property_sub_type
.get(&(row as u32))
.map(String::as_str)
}
/// Get price qualifier for a given row.
pub fn price_qualifier(&self, row: usize) -> Option<&str> {
self.price_qualifier.get(&(row as u32)).map(String::as_str)
}
/// Get the unquantized last sale price for charting.
#[inline]
pub fn last_known_price_raw(&self, row: usize) -> f32 {
self.last_known_price_raw[row]
}
/// Decode a single feature value from quantized u16 storage.
#[inline]
pub fn get_feature(&self, row: usize, feat_idx: usize) -> f32 {
let raw = self.feature_data[row * self.num_features + feat_idx];
if raw == NAN_U16 {
return f32::NAN;
}
if feat_idx >= self.num_numeric {
raw as f32
} else {
raw as f32 * self.dequant_a[feat_idx] + self.quant_min[feat_idx]
}
}
/// Get a QuantRef for passing to aggregation/filter functions.
pub fn quant_ref(&self) -> QuantRef<'_> {
QuantRef {
dequant_a: &self.dequant_a,
quant_min: &self.quant_min,
quant_range: &self.quant_range,
num_numeric: self.num_numeric,
}
}
}
#[cfg(test)]
impl PropertyData {
/// Minimal empty instance for integration tests that need an `AppState`
/// but never touch property data (e.g. checkout/webhook/invite flows).
pub(crate) fn empty_for_tests() -> Self {
PropertyData {
lat: Vec::new(),
lon: Vec::new(),
feature_names: Vec::new(),
num_features: 0,
num_numeric: 0,
feature_data: Vec::new(),
dequant_a: Vec::new(),
quant_min: Vec::new(),
quant_range: Vec::new(),
feature_stats: Vec::new(),
poi_metrics: PostcodePoiMetrics::empty(0),
last_known_price_raw: Vec::new(),
address_buffer: String::new(),
address_offsets: Vec::new(),
address_lengths: Vec::new(),
postcode_interner: lasso::Rodeo::default().into_reader(),
postcode_keys: Vec::new(),
postcode_row_index: FxHashMap::default(),
address_token_index: FxHashMap::default(),
address_prefix_index: FxHashMap::default(),
address_search_interner: lasso::Rodeo::default().into_reader(),
address_search_token_keys: Vec::new(),
address_search_token_offsets: Vec::new(),
address_search_token_lengths: Vec::new(),
enum_values: rustc_hash::FxHashMap::default(),
enum_counts: rustc_hash::FxHashMap::default(),
approx_build_date_bits: Vec::new(),
renovation_history: FxHashMap::default(),
historical_prices: FxHashMap::default(),
property_sub_type: FxHashMap::default(),
price_qualifier: FxHashMap::default(),
}
}
}

View file

@ -0,0 +1,200 @@
//! Postcode-level POI metric side table: dynamic POI features are stored once
//! per postcode (not per property row) to keep the hot row-major feature matrix
//! narrow, with a per-property row mapping for lookups.
use anyhow::Context;
use polars::prelude::*;
use rayon::prelude::*;
use rustc_hash::FxHashMap;
use crate::consts::{NAN_U16, QUANT_SCALE};
use crate::features::{self, Bounds};
use super::quant::QuantRef;
use super::stats::{column_to_f32_vec, compute_feature_stats, FeatureStats};
pub(super) const NO_POI_METRIC_ROW: u32 = u32::MAX;
pub struct PostcodePoiMetrics {
pub feature_names: Vec<String>,
pub name_to_index: FxHashMap<String, usize>,
/// Metric-major storage: columns[metric_idx][postcode_metric_idx].
pub columns: Vec<Vec<u16>>,
pub feature_stats: Vec<FeatureStats>,
/// Per-property row lookup into the postcode metric table.
row_to_metric_idx: Vec<u32>,
dequant_a: Vec<f32>,
quant_min: Vec<f32>,
quant_range: Vec<f32>,
}
impl PostcodePoiMetrics {
pub(super) fn empty(row_count: usize) -> Self {
Self {
feature_names: Vec::new(),
name_to_index: FxHashMap::default(),
columns: Vec::new(),
feature_stats: Vec::new(),
row_to_metric_idx: vec![NO_POI_METRIC_ROW; row_count],
dequant_a: Vec::new(),
quant_min: Vec::new(),
quant_range: Vec::new(),
}
}
pub(super) fn from_postcode_df(
df: &DataFrame,
feature_names: Vec<String>,
) -> anyhow::Result<Self> {
if feature_names.is_empty() {
return Ok(Self::empty(0));
}
tracing::info!(
metrics = feature_names.len(),
postcodes = df.height(),
"Building postcode POI metric side table"
);
let col_major: Vec<Vec<f32>> = feature_names
.par_iter()
.map(|name| {
let column = df
.column(name.as_str())
.with_context(|| format!("Missing POI metric column '{name}'"))?;
column_to_f32_vec(column)
})
.collect::<anyhow::Result<Vec<_>>>()?;
let feature_stats: Vec<FeatureStats> = col_major
.par_iter()
.enumerate()
.map(|(metric_idx, vals)| {
let name = feature_names[metric_idx].as_str();
let bounds = features::bounds_for(name)
.with_context(|| format!("No bounds config for POI metric '{name}'"))?;
Ok(compute_feature_stats(
vals,
&bounds,
features::has_integer_bins(name),
))
})
.collect::<anyhow::Result<Vec<_>>>()?;
let mut quant_min = Vec::with_capacity(feature_names.len());
let mut quant_range = Vec::with_capacity(feature_names.len());
for (metric_idx, stats) in feature_stats.iter().enumerate() {
let (min, max) = match features::bounds_for(feature_names[metric_idx].as_str()) {
Some(Bounds::Fixed { min, max }) => (min, max),
_ => (stats.histogram.min, stats.histogram.max),
};
quant_min.push(min);
quant_range.push(if max > min { max - min } else { 0.0 });
}
let dequant_a: Vec<f32> = quant_range
.iter()
.map(|&range| {
if range > 0.0 {
range / QUANT_SCALE
} else {
0.0
}
})
.collect();
let columns: Vec<Vec<u16>> = col_major
.par_iter()
.enumerate()
.map(|(metric_idx, vals)| {
let range = quant_range[metric_idx];
let min = quant_min[metric_idx];
vals.iter()
.map(|&value| {
if !value.is_finite() {
NAN_U16
} else if range > 0.0 {
let normalized = (value - min) / range;
(normalized * QUANT_SCALE).round().clamp(0.0, QUANT_SCALE) as u16
} else {
0
}
})
.collect()
})
.collect();
let name_to_index = feature_names
.iter()
.enumerate()
.map(|(idx, name)| (name.clone(), idx))
.collect();
Ok(Self {
feature_names,
name_to_index,
columns,
feature_stats,
row_to_metric_idx: Vec::new(),
dequant_a,
quant_min,
quant_range,
})
}
pub(super) fn set_row_mapping(&mut self, row_to_metric_idx: Vec<u32>) {
self.row_to_metric_idx = row_to_metric_idx;
}
pub fn is_empty(&self) -> bool {
self.feature_names.is_empty()
}
pub fn num_features(&self) -> usize {
self.feature_names.len()
}
pub fn quant_ref(&self) -> QuantRef<'_> {
QuantRef {
dequant_a: &self.dequant_a,
quant_min: &self.quant_min,
quant_range: &self.quant_range,
num_numeric: self.feature_names.len(),
}
}
#[inline]
pub fn metric_row_for_property(&self, row: usize) -> Option<usize> {
self.row_to_metric_idx
.get(row)
.copied()
.filter(|&idx| idx != NO_POI_METRIC_ROW)
.map(|idx| idx as usize)
}
#[inline]
pub fn raw_for_metric_row(&self, metric_row: usize, metric_idx: usize) -> u16 {
self.columns[metric_idx][metric_row]
}
#[inline]
pub fn raw_for_property_row(&self, row: usize, metric_idx: usize) -> u16 {
let Some(metric_row) = self.metric_row_for_property(row) else {
return NAN_U16;
};
self.raw_for_metric_row(metric_row, metric_idx)
}
#[inline]
pub fn decode_raw(&self, metric_idx: usize, raw: u16) -> f32 {
if raw == NAN_U16 {
f32::NAN
} else {
raw as f32 * self.dequant_a[metric_idx] + self.quant_min[metric_idx]
}
}
#[inline]
pub fn get_for_property_row(&self, row: usize, metric_idx: usize) -> f32 {
self.decode_raw(metric_idx, self.raw_for_property_row(row, metric_idx))
}
}

View file

@ -0,0 +1,46 @@
//! u16 quantization: decoding stored feature values and encoding filter bounds.
use crate::consts::{NAN_U16, QUANT_SCALE};
/// Lightweight reference to quantization parameters for decoding u16 feature data.
pub struct QuantRef<'a> {
pub dequant_a: &'a [f32],
pub quant_min: &'a [f32],
pub quant_range: &'a [f32],
pub num_numeric: usize,
}
impl QuantRef<'_> {
/// Decode a raw u16 value back to f32.
#[inline]
pub fn decode(&self, feat_idx: usize, raw: u16) -> f32 {
if raw == NAN_U16 {
return f32::NAN;
}
if feat_idx >= self.num_numeric {
raw as f32
} else {
raw as f32 * self.dequant_a[feat_idx] + self.quant_min[feat_idx]
}
}
/// Encode a filter minimum bound to u16 (floors to include boundary values).
#[inline]
pub fn encode_min(&self, feat_idx: usize, value: f32) -> u16 {
if !value.is_finite() || self.quant_range[feat_idx] == 0.0 {
return 0;
}
let norm = (value - self.quant_min[feat_idx]) / self.quant_range[feat_idx];
(norm * QUANT_SCALE).floor().clamp(0.0, QUANT_SCALE) as u16
}
/// Encode a filter maximum bound to u16 (ceils to include boundary values).
#[inline]
pub fn encode_max(&self, feat_idx: usize, value: f32) -> u16 {
if !value.is_finite() || self.quant_range[feat_idx] == 0.0 {
return QUANT_SCALE as u16;
}
let norm = (value - self.quant_min[feat_idx]) / self.quant_range[feat_idx];
(norm * QUANT_SCALE).ceil().clamp(0.0, QUANT_SCALE) as u16
}
}

View file

@ -0,0 +1,544 @@
//! Feature statistics: outlier-bracketed histograms, percentile estimation and
//! slider-bound computation.
use anyhow::Context;
use polars::prelude::*;
use serde::Serialize;
use crate::consts::HISTOGRAM_BINS;
use crate::features::Bounds;
/// Histogram with outlier buckets at the edges.
/// - Bin 0: [min, p1) — low outliers
/// - Bins 1 to n-2: [p1, p99) — main distribution, evenly divided
/// - Bin n-1: [p99, max] — high outliers
#[derive(Serialize, Clone)]
pub struct Histogram {
pub min: f32,
pub max: f32,
/// 1st percentile (left edge of main distribution)
pub p1: f32,
/// 99th percentile (right edge of main distribution)
pub p99: f32,
pub counts: Vec<u64>,
}
impl Histogram {
/// Return the bin index for a given value using the outlier-bracket layout.
#[cfg(test)]
pub fn bin_for_value(&self, value: f32) -> usize {
let num_bins = self.counts.len();
if value < self.p1 {
0
} else if value >= self.p99 {
num_bins - 1
} else {
let middle_bins = num_bins.saturating_sub(2);
if middle_bins > 0 && self.p99 > self.p1 {
let width = (self.p99 - self.p1) / middle_bins as f32;
let middle_bin = ((value - self.p1) / width) as usize;
(1 + middle_bin).min(num_bins - 2)
} else {
num_bins / 2
}
}
}
/// Width of a single middle bin (bins 1..n-2).
#[cfg(test)]
pub fn middle_bin_width(&self) -> f32 {
let middle_bins = self.counts.len().saturating_sub(2);
if middle_bins > 0 && self.p99 > self.p1 {
(self.p99 - self.p1) / middle_bins as f32
} else {
0.0
}
}
}
pub struct FeatureStats {
pub slider_min: f32,
pub slider_max: f32,
pub histogram: Histogram,
}
/// Compute a percentile from a uniformly-binned histogram.
/// `prelim_counts` are uniform bins over [min, max].
fn percentile_from_uniform_histogram(
count: usize,
min: f32,
max: f32,
prelim_counts: &[u64],
percentile: f32,
) -> f32 {
if count == 0 || prelim_counts.is_empty() {
return min;
}
let target = (count as f64 * percentile as f64 / 100.0).floor() as u64;
let bin_width = (max - min) / prelim_counts.len() as f32;
let mut cumulative = 0u64;
for (i, &bin_count) in prelim_counts.iter().enumerate() {
let prev_cumulative = cumulative;
cumulative += bin_count;
if cumulative > target {
// Interpolate within this bin
let bin_start = min + i as f32 * bin_width;
let fraction = if bin_count > 0 {
(target - prev_cumulative) as f32 / bin_count as f32
} else {
0.0
};
return bin_start + fraction * bin_width;
}
}
max
}
/// Build a histogram and compute slider bounds based on the feature's Bounds config.
pub fn compute_feature_stats(vals: &[f32], bounds: &Bounds, integer_bins: bool) -> FeatureStats {
// Single pass: min, max, count (skipping NaN and infinity)
let mut min = f32::INFINITY;
let mut max = f32::NEG_INFINITY;
let mut count = 0usize;
for &value in vals {
if value.is_finite() {
if value < min {
min = value;
}
if value > max {
max = value;
}
count += 1;
}
}
if count == 0 {
let (slider_min, slider_max) = match bounds {
Bounds::Fixed {
min: fmin,
max: fmax,
} => (*fmin, *fmax),
Bounds::Percentile { .. } => (0.0, 0.0),
};
return FeatureStats {
slider_min,
slider_max,
histogram: Histogram {
min: 0.0,
max: 0.0,
p1: 0.0,
p99: 0.0,
counts: vec![0; HISTOGRAM_BINS],
},
};
}
// Build preliminary histogram with uniform bins to compute percentiles
// Use full HISTOGRAM_BINS for percentile precision
let range = if max == min { 1.0 } else { max - min };
let prelim_max = min + range * (1.0 + 1e-6);
let prelim_bin_width = (prelim_max - min) / HISTOGRAM_BINS as f32;
let mut prelim_counts = vec![0u64; HISTOGRAM_BINS];
for &value in vals {
if value.is_finite() {
let bin = ((value - min) / prelim_bin_width) as usize;
prelim_counts[bin.min(HISTOGRAM_BINS - 1)] += 1;
}
}
// Compute p1 and p99 from preliminary histogram
let mut p1 = percentile_from_uniform_histogram(count, min, max, &prelim_counts, 1.0);
let mut p99 = percentile_from_uniform_histogram(count, min, max, &prelim_counts, 99.0);
// Iterative refinement for outlier-dominated distributions.
// When extreme outliers (e.g. 317M sqm from web scraping) dominate the range,
// the uniform histogram puts all real data in one bin, making percentile
// estimation useless. Zoom into the estimated data region and recompute.
let mut refined_counts = prelim_counts;
let mut refined_count = count;
let mut refined_min = min;
let mut refined_max = max;
for _ in 0..3 {
let iqr = p99 - p1;
if iqr <= 0.0 || (refined_max - refined_min) <= 5.0 * iqr {
break;
}
let new_min = (p1 - iqr).max(min);
let new_max = p99 + iqr;
if new_max <= new_min {
break;
}
let bin_width = (new_max - new_min) / HISTOGRAM_BINS as f32;
let mut counts = vec![0u64; HISTOGRAM_BINS];
let mut cnt = 0usize;
for &value in vals {
if value.is_finite() && value >= new_min && value <= new_max {
let bin = ((value - new_min) / bin_width) as usize;
counts[bin.min(HISTOGRAM_BINS - 1)] += 1;
cnt += 1;
}
}
if cnt == 0 {
break;
}
p1 = percentile_from_uniform_histogram(cnt, new_min, new_max, &counts, 1.0);
p99 = percentile_from_uniform_histogram(cnt, new_min, new_max, &counts, 99.0);
refined_counts = counts;
refined_count = cnt;
refined_min = new_min;
refined_max = new_max;
}
// For integer-binned features, snap p1/p99 to integer boundaries
// so each middle bin is exactly 1 unit wide.
if integer_bins {
p1 = p1.floor();
p99 = p99.ceil();
}
// Determine number of histogram bins
let num_bins = if integer_bins && p99 > p1 {
// One middle bin per integer + 2 outlier bins
(p99 - p1) as usize + 2
} else {
// Count unique values within the p1p99 range to cap histogram bins.
// Using the full-range cardinality would over-allocate bins when outliers
// inflate it (e.g. bedrooms: 1137 unique values but only ~10 within p1p99).
let cardinality = {
let mut unique_set = rustc_hash::FxHashSet::default();
for &val in vals {
if val.is_finite() && val >= p1 && val <= p99 {
unique_set.insert(val.to_bits());
}
}
unique_set.len()
};
HISTOGRAM_BINS.min(cardinality).max(3)
};
// Build final histogram with outlier bins at edges:
// - Bin 0: [min, p1) — low outliers
// - Bins 1 to n-2: [p1, p99) — main distribution, evenly divided
// - Bin n-1: [p99, max] — high outliers
let mut counts = vec![0u64; num_bins];
let middle_bins = num_bins.saturating_sub(2);
let middle_width = if middle_bins > 0 && p99 > p1 {
(p99 - p1) / middle_bins as f32
} else {
0.0
};
for &value in vals {
if value.is_finite() {
let bin = if value < p1 {
0 // Low outlier bin
} else if value >= p99 {
num_bins - 1 // High outlier bin
} else if middle_width > 0.0 {
// Middle bins (1 to n-2)
let middle_bin = ((value - p1) / middle_width) as usize;
(1 + middle_bin).min(num_bins - 2)
} else {
num_bins / 2 // Fallback if p1 == p99
};
counts[bin] += 1;
}
}
let histogram = Histogram {
min: refined_min,
max: refined_max,
p1,
p99,
counts,
};
// Compute slider bounds (use refined histogram for accurate percentiles)
let (slider_min, slider_max) = match bounds {
Bounds::Fixed {
min: fmin,
max: fmax,
} => (*fmin, *fmax),
Bounds::Percentile { low, high } => {
let p_low = percentile_from_uniform_histogram(
refined_count,
refined_min,
refined_max,
&refined_counts,
*low as f32,
);
let p_high = percentile_from_uniform_histogram(
refined_count,
refined_min,
refined_max,
&refined_counts,
*high as f32,
);
(p_low, p_high)
}
};
FeatureStats {
slider_min,
slider_max,
histogram,
}
}
pub(super) fn column_to_f32_vec(column: &Column) -> anyhow::Result<Vec<f32>> {
let float_series = column
.cast(&DataType::Float32)
.context("Failed to cast column to Float32")?;
let chunked = float_series
.f32()
.context("Failed to get f32 chunked array")?;
Ok(chunked
.into_iter()
.map(|value| value.unwrap_or(f32::NAN))
.collect())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::consts::QUANT_SCALE;
use crate::features::Bounds;
fn make_fixed_bounds(min: f32, max: f32) -> Bounds {
Bounds::Fixed { min, max }
}
fn make_percentile_bounds(low: f64, high: f64) -> Bounds {
Bounds::Percentile { low, high }
}
#[test]
fn histogram_empty_data() {
let data: Vec<f32> = vec![];
let bounds = make_fixed_bounds(0.0, 100.0);
let stats = compute_feature_stats(&data, &bounds, false);
assert_eq!(stats.slider_min, 0.0);
assert_eq!(stats.slider_max, 100.0);
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 0);
}
#[test]
fn histogram_single_value() {
let data = vec![50.0_f32];
let bounds = make_fixed_bounds(0.0, 100.0);
let stats = compute_feature_stats(&data, &bounds, false);
assert_eq!(stats.histogram.min, 50.0);
assert_eq!(stats.histogram.max, 50.0);
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 1);
}
#[test]
fn histogram_uniform_distribution() {
let data: Vec<f32> = (0..100).map(|i| i as f32).collect();
let bounds = make_fixed_bounds(0.0, 100.0);
let stats = compute_feature_stats(&data, &bounds, false);
assert_eq!(stats.histogram.min, 0.0);
assert_eq!(stats.histogram.max, 99.0);
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 100);
}
#[test]
fn histogram_with_nan_values() {
let data = vec![10.0_f32, f32::NAN, 20.0, f32::NAN, 30.0];
let bounds = make_fixed_bounds(0.0, 100.0);
let stats = compute_feature_stats(&data, &bounds, false);
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 3);
assert_eq!(stats.histogram.min, 10.0);
assert_eq!(stats.histogram.max, 30.0);
}
#[test]
fn histogram_all_nan() {
let data = vec![f32::NAN, f32::NAN, f32::NAN];
let bounds = make_fixed_bounds(0.0, 100.0);
let stats = compute_feature_stats(&data, &bounds, false);
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 0);
}
#[test]
fn histogram_all_same_value() {
let data = vec![42.0_f32; 1000];
let bounds = make_fixed_bounds(0.0, 100.0);
let stats = compute_feature_stats(&data, &bounds, false);
assert_eq!(stats.histogram.min, 42.0);
assert_eq!(stats.histogram.max, 42.0);
assert_eq!(stats.histogram.p1, 42.0);
assert_eq!(stats.histogram.p99, 42.0);
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 1000);
}
#[test]
fn histogram_percentile_bounds() {
let mut data: Vec<f32> = vec![0.0]; // Low outlier
data.extend((1..99).map(|i| 50.0 + i as f32 * 0.01));
data.push(1000.0); // High outlier
let bounds = make_percentile_bounds(2.0, 98.0);
let stats = compute_feature_stats(&data, &bounds, false);
assert!(stats.slider_min > 0.0);
assert!(stats.slider_max < 1000.0);
}
#[test]
fn fixed_price_bounds_keep_slider_cap() {
let data = vec![400_000.0_f32, 2_500_000.0, 3_750_000.0];
let bounds = make_fixed_bounds(0.0, 2_500_000.0);
let stats = compute_feature_stats(&data, &bounds, false);
assert_eq!(stats.slider_min, 0.0);
assert_eq!(stats.slider_max, 2_500_000.0);
}
#[test]
fn histogram_bin_for_value() {
let hist = Histogram {
min: 0.0,
max: 100.0,
p1: 10.0,
p99: 90.0,
counts: vec![0; 10],
};
assert_eq!(hist.bin_for_value(5.0), 0); // Low outlier bin
assert_eq!(hist.bin_for_value(95.0), 9); // High outlier bin
let mid_value = 50.0;
let bin = hist.bin_for_value(mid_value);
assert!((1..=8).contains(&bin));
}
#[test]
fn histogram_middle_bin_width() {
let hist = Histogram {
min: 0.0,
max: 100.0,
p1: 10.0,
p99: 90.0,
counts: vec![0; 10],
};
let expected_width = (90.0 - 10.0) / 8.0;
assert!((hist.middle_bin_width() - expected_width).abs() < 0.001);
}
#[test]
fn histogram_cardinality_caps_bins() {
let data = vec![1.0_f32, 1.0, 2.0, 2.0, 3.0, 3.0];
let bounds = make_fixed_bounds(0.0, 100.0);
let stats = compute_feature_stats(&data, &bounds, false);
assert_eq!(stats.histogram.counts.len(), 3);
}
#[test]
fn min_max_skips_nan() {
let values = vec![10.0_f32, f32::NAN, 20.0, f32::NAN, 5.0];
let mut min = f32::INFINITY;
let mut max = f32::NEG_INFINITY;
for &v in &values {
if v.is_finite() {
if v < min {
min = v;
}
if v > max {
max = v;
}
}
}
assert_eq!(min, 5.0);
assert_eq!(max, 20.0);
}
#[test]
fn count_skips_nan() {
let values = [1.0_f32, f32::NAN, 2.0, f32::NAN, 3.0];
let count = values.iter().filter(|v| v.is_finite()).count();
assert_eq!(count, 3);
}
#[test]
fn infinity_values_excluded() {
let data = vec![f32::INFINITY, f32::NEG_INFINITY, 50.0];
let bounds = Bounds::Fixed {
min: 0.0,
max: 100.0,
};
let stats = compute_feature_stats(&data, &bounds, false);
assert_eq!(stats.histogram.min, 50.0);
assert_eq!(stats.histogram.max, 50.0);
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 1);
}
#[test]
fn only_finite_values() {
let data = vec![10.0_f32, 20.0, 30.0];
let bounds = Bounds::Fixed {
min: 0.0,
max: 100.0,
};
let stats = compute_feature_stats(&data, &bounds, false);
assert_eq!(stats.histogram.min, 10.0);
assert_eq!(stats.histogram.max, 30.0);
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 3);
}
#[test]
fn extreme_outlier_does_not_destroy_quantization() {
// Simulate floor area: 10k normal values (50-200 sqm) + one 317M outlier
let mut data: Vec<f32> = (0..10_000).map(|i| 50.0 + (i % 150) as f32).collect();
data.push(317_000_000.0); // Extreme outlier from web scraping
let bounds = make_percentile_bounds(0.0, 98.0);
let stats = compute_feature_stats(&data, &bounds, false);
// After refinement, histogram range should be much tighter than 317M
assert!(
stats.histogram.max < 1_000_000.0,
"histogram.max should be refined, got {}",
stats.histogram.max,
);
// p1 should be near 50, not millions
assert!(
stats.histogram.p1 < 100.0,
"p1 should be near real data, got {}",
stats.histogram.p1,
);
// Slider min should reflect actual data range
assert!(
stats.slider_min < 100.0,
"slider_min should be near real data, got {}",
stats.slider_min,
);
// Quantization using histogram.min/max should give usable range
let qmin = stats.histogram.min;
let qrange = stats.histogram.max - stats.histogram.min;
assert!(qrange > 0.0 && qrange < 1_000_000.0);
// A typical floor area (100 sqm) should be distinguishable from min
let normalized = (100.0 - qmin) / qrange;
let encoded = (normalized * QUANT_SCALE).round() as u16;
assert!(
encoded > 100,
"100 sqm should encode to a meaningful u16 value, got {}",
encoded,
);
}
}

View file

@ -272,6 +272,21 @@ pub fn slugify(name: &str) -> String {
result
}
#[cfg(test)]
impl TravelTimeStore {
/// Minimal empty instance for integration tests that need an `AppState`
/// but never touch travel time data.
pub(crate) fn empty_for_tests() -> Self {
Self {
base_dir: PathBuf::new(),
available_modes: Vec::new(),
destinations: FxHashMap::default(),
slug_to_file: FxHashMap::default(),
cache: Mutex::new(LruCache::new(1)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;

View file

@ -1042,7 +1042,44 @@ async fn main() -> anyhow::Result<()> {
listener,
app.into_make_service_with_connect_info::<std::net::SocketAddr>(),
)
.with_graceful_shutdown(shutdown_signal())
.await
.context("Server error")?;
info!("Server shut down cleanly");
Ok(())
}
/// Resolves on SIGTERM or SIGINT so in-flight requests (exports, checkouts)
/// can drain before the process exits. The realtime SSE proxy connections
/// never complete, so a watchdog force-exits before Docker's default 10s
/// stop grace period elapses and it sends SIGKILL.
async fn shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("Failed to install SIGINT handler");
};
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("Failed to install SIGTERM handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
() = ctrl_c => {},
() = terminate => {},
}
info!("Shutdown signal received; draining in-flight requests");
tokio::spawn(async {
tokio::time::sleep(Duration::from_secs(8)).await;
tracing::warn!("Graceful shutdown drain timed out after 8s; forcing exit");
std::process::exit(0);
});
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,448 @@
//! The `POST /api/ai-filters` route handler: rate limiting, the Gemini
//! function-calling conversation loop, and zero-match refinement.
use std::sync::Arc;
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::Json;
use axum::Extension;
use metrics::counter;
use serde_json::{json, Value};
use tracing::{info, warn};
use crate::auth::OptionalUser;
use crate::consts::{AI_FILTERS_MAX_TOKENS, AI_FILTERS_TEMPERATURE, AI_FILTERS_WEEKLY_TOKEN_LIMIT};
use crate::pocketbase::log_ai_query;
use crate::state::SharedState;
use crate::utils::gemini_chat;
use super::matching::count_matching_rows;
use super::parsing::{
normalize_context_filters, strip_markdown_fences, validate_and_convert,
validate_travel_time_filters,
};
use super::tools::{build_tool_declarations, execute_destination_search};
use super::usage::{current_week_number, fetch_ai_usage, record_ai_request_usage};
use super::{AiFiltersRequest, AiFiltersResponse};
/// Budget limits for the Gemini conversation loop. Separate counters prevent
/// tool calls (destination searches) from starving JSON retries or zero-match
/// refinements.
const MAX_TOOL_CALLS: usize = 4;
const MAX_RETRIES: usize = 3;
const MAX_REFINEMENTS: u32 = 3;
const MAX_TOTAL_ROUNDS: usize = 10;
const MAX_AI_QUERY_CHARS: usize = 5000;
pub async fn post_ai_filters(
State(shared): State<Arc<SharedState>>,
Extension(user): Extension<OptionalUser>,
Json(req): Json<AiFiltersRequest>,
) -> Result<Json<AiFiltersResponse>, (StatusCode, String)> {
let state = shared.load_state();
// Auth check
let user = user
.0
.ok_or((StatusCode::UNAUTHORIZED, "Login required".into()))?;
if req.query.chars().count() > MAX_AI_QUERY_CHARS {
counter!("ai_requests_total", "status" => "query_too_long").increment(1);
return Err((
StatusCode::PAYLOAD_TOO_LARGE,
format!("Query too long (max {MAX_AI_QUERY_CHARS} chars)"),
));
}
// Check weekly token usage
let current_week = current_week_number();
let (stored_tokens, stored_week) = fetch_ai_usage(&state, &user.id).await?;
let tokens_used = if stored_week == current_week {
stored_tokens
} else {
0
};
if tokens_used >= AI_FILTERS_WEEKLY_TOKEN_LIMIT {
counter!("ai_requests_total", "status" => "rate_limited").increment(1);
return Err((
StatusCode::TOO_MANY_REQUESTS,
"Weekly AI usage limit reached. Resets next week.".into(),
));
}
info!(query = %req.query, user_id = %user.id, "POST /api/ai-filters");
let tools = build_tool_declarations(&state);
// Build user message with optional context for conversational refinement
let user_text = if let Some(ref ctx) = req.context {
let mut msg = String::new();
msg.push_str("Currently active filters:\n");
let normalized_filters = normalize_context_filters(&ctx.filters);
msg.push_str(&serde_json::to_string(&normalized_filters).unwrap_or_default());
if !ctx.travel_time.is_empty() {
msg.push_str("\nCurrently active travel time filters:\n");
for tt in &ctx.travel_time {
let bounds = match (tt.min, tt.max) {
(Some(min), Some(max)) => format!("{}-{} min", min, max),
(Some(min), None) => format!("min {} min", min),
(None, Some(max)) => format!("max {} min", max),
(None, None) => "no range".to_string(),
};
msg.push_str(&format!("- {} to {} ({})\n", tt.mode, tt.label, bounds));
}
}
msg.push_str(&format!("\nUser request: {}", req.query));
msg
} else {
req.query.clone()
};
let mut contents = vec![json!({
"role": "user",
"parts": [{ "text": user_text }]
})];
let mut total_tokens_accumulated: u64 = 0;
let mut tool_call_count = 0usize;
let mut retry_count = 0usize;
let mut refinement_attempts = 0u32;
// Function calling loop: model may call search_destinations, we execute and feed back
for round in 0..MAX_TOTAL_ROUNDS {
let body = json!({
"systemInstruction": {
"parts": [{ "text": state.ai_filters_system_prompt }]
},
"contents": contents,
"tools": tools,
"generationConfig": {
"temperature": AI_FILTERS_TEMPERATURE,
"maxOutputTokens": AI_FILTERS_MAX_TOKENS,
"thinkingConfig": { "thinkingLevel": "LOW" },
}
});
let json_resp = match gemini_chat(
&state.http_client,
&state.gemini_api_key,
&state.gemini_model,
&body,
)
.await
{
Ok(resp) => resp,
Err(err) => {
record_ai_request_usage(
&state,
&user.id,
tokens_used,
current_week,
total_tokens_accumulated,
"llm_error",
)
.await;
return Err(err);
}
};
// Accumulate token usage
total_tokens_accumulated += json_resp
.get("usageMetadata")
.and_then(|md| md.get("totalTokenCount"))
.and_then(|tc| tc.as_u64())
.unwrap_or(0);
let candidate = match json_resp
.get("candidates")
.and_then(|cs| cs.get(0))
.and_then(|c| c.get("content"))
{
Some(candidate) => candidate,
None => {
warn!("Malformed Gemini response: missing candidates[0].content");
record_ai_request_usage(
&state,
&user.id,
tokens_used,
current_week,
total_tokens_accumulated,
"malformed_response",
)
.await;
return Err((StatusCode::BAD_GATEWAY, "Malformed Gemini response".into()));
}
};
let parts = match candidate.get("parts").and_then(|p| p.as_array()) {
Some(parts) => parts,
None => {
warn!("Malformed Gemini response: missing parts array");
record_ai_request_usage(
&state,
&user.id,
tokens_used,
current_week,
total_tokens_accumulated,
"malformed_response",
)
.await;
return Err((StatusCode::BAD_GATEWAY, "Malformed Gemini response".into()));
}
};
// Check if the model made a function call.
// Find the full part (includes thoughtSignature required by Gemini 3 models).
if let Some(fc_part) = parts.iter().find(|part| part.get("functionCall").is_some()) {
let fc = fc_part.get("functionCall").unwrap();
let fn_name = fc.get("name").and_then(|n| n.as_str()).unwrap_or("");
let fn_args = fc.get("args").cloned().unwrap_or(json!({}));
tool_call_count += 1;
info!(
function = fn_name,
round = round,
tool_call = tool_call_count,
"AI called tool"
);
if tool_call_count > MAX_TOOL_CALLS {
warn!("Tool call budget exhausted, forcing text output");
contents.push(candidate.clone());
contents.push(json!({
"role": "user",
"parts": [{ "text": "Tool call limit reached. Output your best JSON now using the destinations you already found. Do not call any more tools." }]
}));
continue;
}
let fn_result = if fn_name == "search_destinations" {
let query = fn_args.get("query").and_then(|q| q.as_str()).unwrap_or("");
let mode = fn_args
.get("mode")
.and_then(|m| m.as_str())
.unwrap_or("transit");
execute_destination_search(&state, query, mode)
} else {
json!({"error": "unknown function"})
};
// Append the model's full response (preserves thoughtSignature) + our function result
contents.push(candidate.clone());
contents.push(json!({
"role": "user",
"parts": [{
"functionResponse": {
"name": fn_name,
"response": fn_result
}
}]
}));
// Continue the loop — model will process the results
continue;
}
// Model returned text — extract and parse as JSON
let text = parts
.iter()
.find_map(|part| part.get("text").and_then(|t| t.as_str()))
.unwrap_or("");
let text = strip_markdown_fences(text);
let text = text.trim();
if text.is_empty() {
retry_count += 1;
warn!(
"Gemini returned empty text content (round {}, retry {})",
round, retry_count
);
if retry_count > MAX_RETRIES {
record_ai_request_usage(
&state,
&user.id,
tokens_used,
current_week,
total_tokens_accumulated,
"empty_response",
)
.await;
return Err((
StatusCode::BAD_GATEWAY,
"AI returned empty responses".into(),
));
}
contents.push(candidate.clone());
contents.push(json!({
"role": "user",
"parts": [{ "text": "Your response was empty. Please output the JSON object." }]
}));
continue;
}
let raw: Value = match serde_json::from_str(text) {
Ok(val) => val,
Err(err) => {
retry_count += 1;
warn!(error = %err, round = round, retry = retry_count, "Failed to parse Gemini JSON output");
if retry_count > MAX_RETRIES {
record_ai_request_usage(
&state,
&user.id,
tokens_used,
current_week,
total_tokens_accumulated,
"invalid_json",
)
.await;
return Err((StatusCode::BAD_GATEWAY, "AI returned invalid JSON".into()));
}
contents.push(candidate.clone());
contents.push(json!({
"role": "user",
"parts": [{ "text": "That was not valid JSON. Please output ONLY the JSON object with numeric_filters, enum_filters, travel_time_filters, and notes." }]
}));
continue;
}
};
let filters = validate_and_convert(&raw, &state.features_response);
let travel_time_filters = validate_travel_time_filters(&raw, &state);
let notes = raw
.get("notes")
.and_then(|val| val.as_str())
.unwrap_or("")
.to_string();
// Count matching properties and refine if too restrictive
let (match_count, match_bounds) =
count_matching_rows(&state, &filters, &travel_time_filters);
info!(
match_count = match_count,
round = round,
"AI filter match count"
);
if match_count == 0 {
refinement_attempts += 1;
let total_rows = state.data.lat.len();
info!(
attempt = refinement_attempts,
"0 matches out of {total_rows} — asking AI to relax filters"
);
if refinement_attempts > MAX_REFINEMENTS {
warn!("Refinement budget exhausted, returning filters with 0 matches");
record_ai_request_usage(
&state,
&user.id,
tokens_used,
current_week,
total_tokens_accumulated,
"zero_matches",
)
.await;
let notes = if notes.is_empty() {
"No properties match these filters. Try relaxing some constraints.".to_string()
} else {
format!(
"{}. No properties match. Try relaxing some constraints.",
notes
)
};
return Ok(Json(AiFiltersResponse {
filters,
travel_time_filters,
notes,
match_count: 0,
match_bounds: None,
}));
}
let feedback = match refinement_attempts {
1 => format!(
"Your proposed filters matched 0 properties out of {total_rows} total. \
The combination is too restrictive. Please widen some numeric ranges \
or add more enum values while keeping the user's intent. \
Output the adjusted JSON."
),
2 => format!(
"Still 0 matches out of {total_rows}. Please widen ranges further. \
Output the adjusted JSON."
),
_ => format!(
"Still 0 matches out of {total_rows}. Please remove additional filters \
until some properties match, keeping the user's core priority. \
Output the adjusted JSON."
),
};
contents.push(candidate.clone());
contents.push(json!({
"role": "user",
"parts": [{ "text": feedback }]
}));
continue;
}
record_ai_request_usage(
&state,
&user.id,
tokens_used,
current_week,
total_tokens_accumulated,
"success",
)
.await;
// Log the query to PocketBase (fire-and-forget)
let filters_json = serde_json::to_string(&filters).unwrap_or_default();
let log_state = state.clone();
let log_user_id = user.id.clone();
let log_query = req.query.clone();
let log_notes = notes.clone();
let log_rounds = (round + 1) as u64;
tokio::spawn(async move {
log_ai_query(
&log_state,
&log_user_id,
&log_query,
&filters_json,
&log_notes,
total_tokens_accumulated,
log_rounds,
)
.await;
});
return Ok(Json(AiFiltersResponse {
filters,
travel_time_filters,
notes,
match_count,
match_bounds,
}));
}
// Exhausted total round budget without getting a valid response
warn!(
"AI exhausted {} total rounds without final response (tools={}, retries={}, refinements={})",
MAX_TOTAL_ROUNDS, tool_call_count, retry_count, refinement_attempts
);
record_ai_request_usage(
&state,
&user.id,
tokens_used,
current_week,
total_tokens_accumulated,
"incomplete",
)
.await;
Err((
StatusCode::BAD_GATEWAY,
"AI could not complete the request".into(),
))
}

View file

@ -0,0 +1,158 @@
//! Counting properties that match the AI-proposed property and travel time
//! filters, and computing a camera-friendly bounding box of the matches.
use serde_json::Value;
use tracing::warn;
use crate::data::travel_time::TravelData;
use crate::parsing::{parse_filters_with_poi, row_passes_filters, row_passes_poi_filters};
use crate::state::AppState;
use super::{MatchBounds, TravelTimeFilter};
/// Bounding box over matched coordinates, trimmed to the 5th95th percentile
/// per axis (when there are enough points) so a handful of remote outliers
/// doesn't zoom the camera out to all of England.
fn percentile_trimmed_bounds(mut lats: Vec<f32>, mut lons: Vec<f32>) -> Option<MatchBounds> {
if lats.is_empty() || lats.len() != lons.len() {
return None;
}
lats.sort_unstable_by(f32::total_cmp);
lons.sort_unstable_by(f32::total_cmp);
let last = lats.len() - 1;
let (lo, hi) = if lats.len() >= 20 {
let trim = lats.len() / 20;
(trim, last - trim)
} else {
(0, last)
};
Some(MatchBounds {
south: lats[lo],
north: lats[hi],
west: lons[lo],
east: lons[hi],
})
}
/// Convert validated filter JSON back to the `;;`-separated filter string format
/// that `parse_filters` expects.
///
/// Numeric: `{"name": [min, max]}` → `name:min:max`
/// Enum: `{"name": ["val1", "val2"]}` → `name:val1|val2`
fn filters_to_filter_string(filters: &Value) -> String {
let obj = match filters.as_object() {
Some(obj) => obj,
None => return String::new(),
};
let mut parts = Vec::new();
for (name, value) in obj {
if let Some(arr) = value.as_array() {
if arr.len() == 2 && arr[0].is_number() && arr[1].is_number() {
let min = arr[0].as_f64().unwrap_or(0.0);
let max = arr[1].as_f64().unwrap_or(0.0);
parts.push(format!("{name}:{min}:{max}"));
} else if !arr.is_empty() && arr[0].is_string() {
let values: Vec<&str> = arr.iter().filter_map(|v| v.as_str()).collect();
if !values.is_empty() {
parts.push(format!("{name}:{}", values.join("|")));
}
}
}
}
parts.join(";;")
}
/// Count how many rows in the property dataset pass the given property filters
/// AND travel time filters. Travel time data is loaded from the TravelTimeStore
/// and checked per-postcode (same logic as hexagons.rs).
pub(super) fn count_matching_rows(
state: &AppState,
filters: &Value,
travel_time_filters: &[TravelTimeFilter],
) -> (usize, Option<MatchBounds>) {
let filter_str = filters_to_filter_string(filters);
let quant = state.data.quant_ref();
let poi_quant = state.data.poi_metrics.quant_ref();
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = if filter_str.is_empty() {
(Vec::new(), Vec::new(), Vec::new())
} else {
match parse_filters_with_poi(
Some(&filter_str),
&state.feature_name_to_index,
&state.data.enum_values,
&quant,
&state.data.poi_metrics.name_to_index,
&poi_quant,
) {
Ok(f) => f,
Err(err) => {
warn!("Failed to parse filters for match count: {err}");
return (0, None);
}
}
};
// Load travel time data for each filter entry
let travel_data: Vec<(TravelData, Option<f32>, Option<f32>)> = travel_time_filters
.iter()
.filter_map(|ttf| {
let data = state.travel_time_store.get(&ttf.mode, &ttf.slug).ok()?;
Some((data, ttf.min, ttf.max))
})
.collect();
let has_travel = !travel_data.is_empty();
let feature_data = &state.data.feature_data;
let num_features = state.data.num_features;
let num_rows = state.data.lat.len();
let (pc_interner, pc_keys) = state.data.postcode_parts();
let has_poi_filters = !parsed_poi_filters.is_empty();
let mut count = 0usize;
let mut matched_lats: Vec<f32> = Vec::new();
let mut matched_lons: Vec<f32> = Vec::new();
for (row, pc_key) in pc_keys.iter().enumerate().take(num_rows) {
if !row_passes_filters(
row,
&parsed_filters,
&parsed_enum_filters,
feature_data,
num_features,
) {
continue;
}
if has_poi_filters
&& !row_passes_poi_filters(row, &parsed_poi_filters, &state.data.poi_metrics)
{
continue;
}
if has_travel {
let postcode = pc_interner.resolve(pc_key);
let mut passes_travel = true;
for (data, fmin, fmax) in &travel_data {
let pass = if let Some(mins) = data.get(postcode).map(|r| r.minutes as f32) {
fmin.is_none_or(|min| mins >= min) && fmax.is_none_or(|max| mins <= max)
} else {
false // no travel data → postcode not reachable
};
if !pass {
passes_travel = false;
break;
}
}
if !passes_travel {
continue;
}
}
count += 1;
matched_lats.push(state.data.lat[row]);
matched_lons.push(state.data.lon[row]);
}
(count, percentile_trimmed_bounds(matched_lats, matched_lons))
}

View file

@ -0,0 +1,81 @@
//! AI filters: translate a natural-language property query into validated
//! filter settings via Gemini.
//!
//! Split by concern:
//! - [`handler`]: the `POST /api/ai-filters` route handler and Gemini
//! conversation loop
//! - [`prompt`]: system prompt building (precomputed at startup)
//! - [`tools`]: the `search_destinations` tool declaration and execution
//! - [`parsing`]: LLM response parsing and validation against feature metadata
//! - [`matching`]: counting properties that match the proposed filters
//! - [`usage`]: weekly token usage tracking / rate limiting
mod handler;
mod matching;
mod parsing;
mod prompt;
mod tools;
mod usage;
pub use handler::post_ai_filters;
pub use prompt::build_system_prompt;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Deserialize)]
pub struct AiFiltersContext {
filters: Value,
#[serde(default)]
travel_time: Vec<AiTravelTimeContext>,
}
#[derive(Deserialize)]
pub struct AiTravelTimeContext {
mode: String,
label: String,
min: Option<f32>,
max: Option<f32>,
}
#[derive(Deserialize)]
pub struct AiFiltersRequest {
query: String,
/// Current filters for conversational refinement (e.g. "make it cheaper")
context: Option<AiFiltersContext>,
}
#[derive(Serialize)]
pub struct TravelTimeFilter {
mode: String,
slug: String,
label: String,
#[serde(skip_serializing_if = "Option::is_none")]
min: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
max: Option<f32>,
}
#[derive(Serialize)]
pub struct AiFiltersResponse {
filters: Value,
#[serde(skip_serializing_if = "Vec::is_empty")]
travel_time_filters: Vec<TravelTimeFilter>,
/// What the LLM couldn't map to existing filters (empty if everything matched)
#[serde(skip_serializing_if = "String::is_empty")]
notes: String,
/// Number of properties matching the proposed property and travel time filters.
match_count: usize,
/// Bounding box of the matching properties so the client can move the
/// camera to where matches actually are. Absent when nothing matches.
#[serde(skip_serializing_if = "Option::is_none")]
match_bounds: Option<MatchBounds>,
}
#[derive(Serialize)]
pub struct MatchBounds {
south: f32,
west: f32,
north: f32,
east: f32,
}

View file

@ -0,0 +1,385 @@
//! LLM response parsing: stripping markdown fences, normalizing frontend
//! synthetic filter keys, and validating proposed filters against feature
//! metadata and available travel destinations.
use serde_json::{json, Map, Value};
use tracing::warn;
use crate::routes::{FeatureInfo, FeaturesResponse};
use crate::state::AppState;
use super::TravelTimeFilter;
/// Strip markdown code fences (```json ... ``` or ``` ... ```) from LLM output.
/// Models occasionally wrap JSON in markdown fencing even when told not to.
pub(super) fn strip_markdown_fences(text: &str) -> &str {
let trimmed = text.trim();
// Try ```json\n...\n``` or ```\n...\n```
if let Some(rest) = trimmed.strip_prefix("```") {
// Skip optional language tag (e.g. "json")
let rest = if let Some(newline_pos) = rest.find('\n') {
&rest[newline_pos + 1..]
} else {
return trimmed;
};
if let Some(content) = rest.strip_suffix("```") {
return content.trim();
}
}
trimmed
}
fn school_feature_name_from_key(name: &str) -> Option<&'static str> {
let rest = name.strip_prefix("Schools:")?;
let mut parts = rest.split(':');
let phase = parts.next()?;
let rating = parts.next()?;
match (phase, rating) {
("primary", "good") => Some("Good+ primary school catchments"),
("secondary", "good") => Some("Good+ secondary school catchments"),
("primary", "outstanding") => Some("Outstanding primary school catchments"),
("secondary", "outstanding") => Some("Outstanding secondary school catchments"),
_ => None,
}
}
fn decode_synthetic_feature_key(name: &str, prefix: &str) -> Option<String> {
let rest = name.strip_prefix(prefix)?;
let (encoded, _id) = rest.rsplit_once(':')?;
urlencoding::decode(encoded)
.ok()
.map(|decoded| decoded.into_owned())
}
/// Convert frontend synthetic filter keys back to backend feature names.
///
/// The React filter UI stores configurable cards under keys such as
/// `Political vote share:%25%20Labour:0`. The LLM and backend validators need
/// the real feature name (`% Labour`) instead.
fn backend_filter_name(name: &str) -> Option<String> {
if let Some(feature_name) = school_feature_name_from_key(name) {
return Some(feature_name.to_string());
}
for prefix in [
"Specific crimes:",
"Political vote share:",
"Ethnicities:",
"Amenity distance:",
"Transport distance:",
"Amenities within 2km:",
"Amenities within 5km:",
] {
if let Some(feature_name) = decode_synthetic_feature_key(name, prefix) {
return Some(feature_name);
}
}
None
}
fn canonical_filter_name(name: &str) -> String {
backend_filter_name(name).unwrap_or_else(|| name.to_string())
}
pub(super) fn normalize_context_filters(filters: &Value) -> Value {
let Some(obj) = filters.as_object() else {
return filters.clone();
};
let mut normalized = Map::with_capacity(obj.len());
for (name, value) in obj {
normalized.insert(canonical_filter_name(name), value.clone());
}
Value::Object(normalized)
}
/// Maximum travel-time minutes the data can contain. Matches the Java pipeline's
/// MAX_TRIP_DURATION_MINUTES and the frontend's MAX_TRAVEL_MINUTES.
const TRAVEL_TIME_MAX_MINUTES: f64 = 90.0;
fn travel_time_minute_field(item: &Value, key: &str) -> Option<f32> {
item.get(key)
.and_then(|val| val.as_f64())
.filter(|val| val.is_finite())
.map(|val| val.clamp(0.0, TRAVEL_TIME_MAX_MINUTES) as f32)
}
fn parse_travel_time_bounds(item: &Value) -> (Option<f32>, Option<f32>) {
let explicit_min = travel_time_minute_field(item, "min");
let explicit_max = travel_time_minute_field(item, "max");
let (mut min, mut max) = if explicit_min.is_some() || explicit_max.is_some() {
(explicit_min, explicit_max)
} else {
let value = travel_time_minute_field(item, "value");
match (item.get("bound").and_then(|val| val.as_str()), value) {
(Some("min"), Some(val)) => (Some(val), None),
(Some("max"), Some(val)) => (None, Some(val)),
_ => (None, None),
}
};
if let (Some(min_val), Some(max_val)) = (min, max) {
if min_val > max_val {
min = Some(max_val);
max = Some(min_val);
}
}
(min, max)
}
/// Validate travel time filters from LLM output against available destinations.
pub(super) fn validate_travel_time_filters(raw: &Value, state: &AppState) -> Vec<TravelTimeFilter> {
let arr = match raw
.get("travel_time_filters")
.and_then(|val| val.as_array())
{
Some(arr) => arr,
None => return Vec::new(),
};
let tt_store = &state.travel_time_store;
let mut results = Vec::new();
for item in arr {
let mode = match item.get("mode").and_then(|val| val.as_str()) {
Some(mode) => mode,
None => continue,
};
let slug = match item.get("slug").and_then(|val| val.as_str()) {
Some(slug) => slug,
None => continue,
};
let label = item
.get("label")
.and_then(|val| val.as_str())
.unwrap_or(slug);
// Verify this destination actually exists
if !tt_store.has_destination(mode, slug) {
warn!(
mode = mode,
slug = slug,
"AI suggested non-existent destination"
);
continue;
}
let (min, max) = parse_travel_time_bounds(item);
// Only include if at least one bound is set
if min.is_some() || max.is_some() {
results.push(TravelTimeFilter {
mode: mode.to_string(),
slug: slug.to_string(),
label: label.to_string(),
min,
max,
});
}
}
results
}
/// Validate LLM output against feature metadata and convert to FeatureFilters format.
///
/// Input format (array-based, each numeric filter sets one bound):
/// ```json
/// {
/// "numeric_filters": [{"name": "Last known price", "bound": "max", "value": 300000}],
/// "enum_filters": [{"name": "Leasehold/Freehold", "values": ["Freehold"]}]
/// }
/// ```
///
/// Output format (FeatureFilters):
/// ```json
/// { "Last known price": [0, 300000], "Leasehold/Freehold": ["Freehold"] }
/// ```
pub(super) fn validate_and_convert(raw: &Value, features: &FeaturesResponse) -> Value {
let mut result = serde_json::Map::new();
// Build lookup maps from feature metadata.
// Store both slider bounds (min/max from percentiles) and true data bounds
// (histogram.min/max) so one-sided AI filters use the full data range.
let mut numeric_features: rustc_hash::FxHashMap<&str, (f32, f32, f32, f32)> =
rustc_hash::FxHashMap::default();
let mut enum_features: rustc_hash::FxHashMap<&str, &[String]> =
rustc_hash::FxHashMap::default();
for group in &features.groups {
for feature in &group.features {
match feature {
FeatureInfo::Numeric {
name,
min,
max,
histogram,
..
} => {
numeric_features.insert(name, (*min, *max, histogram.min, histogram.max));
}
FeatureInfo::Enum { name, values, .. } => {
enum_features.insert(name, values);
}
}
}
}
// Process numeric filters — each sets one bound (min or max).
// The unset side uses the true data min/max (from histogram), not
// the slider bounds (percentile-based), so a "max" filter for crime
// produces [0, value] rather than [2nd-percentile, value].
if let Some(arr) = raw.get("numeric_filters").and_then(|val| val.as_array()) {
for item in arr {
let raw_name = match item.get("name").and_then(|val| val.as_str()) {
Some(name) => name,
None => continue,
};
let name = canonical_filter_name(raw_name);
let (slider_min, slider_max, data_min, data_max) =
match numeric_features.get(name.as_str()) {
Some(range) => *range,
None => continue,
};
let bound = match item.get("bound").and_then(|val| val.as_str()) {
Some(b) => b,
None => continue,
};
// Clamp value to true data range (not slider range)
let value = match item.get("value").and_then(|val| val.as_f64()) {
Some(v) => v.max(data_min as f64).min(data_max as f64) as f32,
None => continue,
};
let (filter_min, filter_max) = match bound {
"min" => (value, data_max),
"max" => (data_min, value),
_ => continue,
};
// Only include if range is narrower than full slider range
if filter_min > slider_min || filter_max < slider_max {
result.insert(name, json!([filter_min, filter_max]));
}
}
}
// Process enum filters
if let Some(arr) = raw.get("enum_filters").and_then(|val| val.as_array()) {
for item in arr {
let raw_name = match item.get("name").and_then(|val| val.as_str()) {
Some(name) => name,
None => continue,
};
let name = canonical_filter_name(raw_name);
let valid_values = match enum_features.get(name.as_str()) {
Some(values) => *values,
None => continue,
};
if let Some(selected) = item.get("values").and_then(|val| val.as_array()) {
let valid: Vec<&str> = selected
.iter()
.filter_map(|item| item.as_str())
.filter(|str_val| valid_values.iter().any(|known| known == str_val))
.collect();
if !valid.is_empty() && valid.len() < valid_values.len() {
result.insert(name, json!(valid));
}
}
}
}
Value::Object(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn strip_fences_json_tag() {
let input = "```json\n{\"a\": 1}\n```";
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
}
#[test]
fn strip_fences_no_tag() {
let input = "```\n{\"a\": 1}\n```";
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
}
#[test]
fn strip_fences_passthrough() {
let input = "{\"a\": 1}";
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
}
#[test]
fn strip_fences_whitespace() {
let input = " ```json\n {\"a\": 1} \n``` ";
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
}
#[test]
fn synthetic_filter_keys_are_normalized_to_backend_names() {
assert_eq!(
canonical_filter_name("Schools:primary:good:0"),
"Good+ primary school catchments"
);
// Legacy keys still carry a distance segment; it is ignored.
assert_eq!(
canonical_filter_name("Schools:primary:good:2:0"),
"Good+ primary school catchments"
);
assert_eq!(
canonical_filter_name("Specific crimes:Burglary%20%28avg%2Fyr%29:1"),
"Burglary (avg/yr)"
);
assert_eq!(
canonical_filter_name("Political vote share:%25%20Labour:0"),
"% Labour"
);
assert_eq!(
canonical_filter_name(
"Transport distance:Distance%20to%20nearest%20amenity%20%28Bus%20stop%29%20%28km%29:0"
),
"Distance to nearest amenity (Bus stop) (km)"
);
}
#[test]
fn context_filters_are_normalized_before_prompting() {
let filters = json!({
"Political vote share:%25%20Green:0": [40, 100],
"Estimated current price": [0, 500000],
});
let normalized = normalize_context_filters(&filters);
assert_eq!(normalized["% Green"], json!([40, 100]));
assert_eq!(normalized["Estimated current price"], json!([0, 500000]));
}
#[test]
fn travel_time_bounds_accept_min_max_schema() {
let item = json!({ "min": 30, "max": 45 });
assert_eq!(parse_travel_time_bounds(&item), (Some(30.0), Some(45.0)));
}
#[test]
fn travel_time_bounds_accept_legacy_bound_value_schema() {
let item = json!({ "bound": "max", "value": 30 });
assert_eq!(parse_travel_time_bounds(&item), (None, Some(30.0)));
}
#[test]
fn travel_time_bounds_clamp_and_order_range() {
// Data ceiling is 90 (matches Java MAX_TRIP_DURATION_MINUTES).
// Inputs outside [0, 90] clamp; min/max ordering is preserved as-given here.
let item = json!({ "min": 150, "max": -10 });
assert_eq!(parse_travel_time_bounds(&item), (Some(0.0), Some(90.0)));
}
}

View file

@ -0,0 +1,282 @@
//! System prompt building for the AI filters assistant.
use crate::routes::{FeatureInfo, FeaturesResponse};
/// Build the complete system prompt for AI filters.
///
/// Contains: role instructions, feature catalogue, travel time info,
/// few-shot examples, output rules.
/// Precomputed at startup and cached in AppState.
pub fn build_system_prompt(
features: &FeaturesResponse,
mode_destinations: &[(String, usize)],
) -> String {
let mut parts = Vec::new();
parts.push(
"You are a UK property search assistant. \
The user describes their ideal property or area in natural language. \
Translate their description into filter settings using ONLY the features listed below.\n\
\n\
Rules:\n\
- ONLY set filters the user explicitly mentioned or clearly implied.\n\
- Leave out any filter the user did not mention. Empty arrays are fine.\n\
- Each numeric filter sets ONE bound only: \"min\" (at least this value) \
or \"max\" (at most this value). Never set two filters on the same feature.\n\
- Use EXACT feature names from the list spelling, capitalisation, and punctuation must match.\n\
- \"cheap\" / \"affordable\" = lower price range. \"expensive\" = higher price range.\n\
- \"low crime\" / \"safe\" = low values on the Serious crime (avg/yr) and Minor crime (avg/yr) \
features (area-normalised incident density near the postcode). Prefer these aggregates for broad \
area safety; use specific crime features only when the user names a crime type.\n\
- \"quiet\" = low Noise (dB). \"green\" / \"near parks\" = high Number of amenities (Park) within 2km \
or low Distance to nearest park (km), depending on wording.\n\
- \"good schools\" = Good+ school features. \"outstanding schools\" = Outstanding school features.\n\
- Amenities and transport stops are normal filters in the feature catalogue. \
For \"near a bus stop\", \"near a station\", \"near shops\", etc., use the exact \
Distance to nearest amenity (...) or Number of amenities (...) feature when available.\n\
- Politics/elections are normal filters in the Neighbours group. Use exact vote share \
features such as % Labour, % Conservative, % Liberal Democrat, % Reform UK, % Green, \
% Other parties, or Voter turnout (%) when the user asks for political character.\n\
- When the user says a number like \"under 400k\", interpret it as 400000.\n\
- When the user says \"3 bed\" or \"3 bedroom\", use Number of bedrooms & living rooms \
(note: this counts bedrooms + living rooms combined, so 3 bed ~ min 4).\n\
- If the user mentions something that has no matching filter, put it in \"notes\" \
as a short phrase (e.g. \"No filter for: garden, sea view\"). \
If everything was matched, set \"notes\" to an empty string.\n\
\n\
CONVERSATIONAL REFINEMENT:\n\
The user's message may include their currently active filters as context. \
When context is provided:\n\
- \"make it cheaper\" / \"lower the price\" = adjust the existing price filter down\n\
- \"also add ...\" / \"and good schools\" = keep existing filters and add new ones\n\
- \"remove the ...\" / \"drop the ...\" = return filters WITHOUT the mentioned one\n\
- If the request is a completely new search (not a refinement), ignore the context \
and build filters from scratch.\n\
- Always output the COMPLETE set of filters (existing + modified), not just the changes."
.to_string(),
);
// Travel time section with available modes
let modes_list = mode_destinations
.iter()
.map(|(mode, count)| format!("- {} ({} destinations available)", mode, count))
.collect::<Vec<_>>()
.join("\n");
parts.push(format!(
"\n--- TRAVEL TIME FILTERS ---\n\
You can add travel time filters when the user mentions commute times, \
proximity to places, or wanting to be near/within X minutes of somewhere.\n\
\n\
Available travel-time modes (only use modes that have destinations):\n\
{}\n\
- \"car\" / \"drive\" / \"driving\" = car mode\n\
- \"cycle\" / \"bike\" / \"cycling\" = bicycle mode\n\
- \"walk\" / \"walking\" / \"on foot\" = walking mode\n\
- \"train\" / \"tube\" / \"bus\" / \"public transport\" / \"commute\" = transit mode\n\
- \"without buses\" / \"no bus\" / \"rail only\" = transit-no-bus mode\n\
- \"no change\" / \"no transfer\" / \"direct\" / \"single bus/train\" = transit-no-change mode\n\
- \"no change and no bus\" / \"direct rail/tube\" = transit-no-change-no-bus mode\n\
- If a mode appears in the available mode list but is not named above, you may still \
use the exact mode string from the list.\n\
\n\
When the user mentions a specific place, you MUST call the search_destinations \
tool to find the exact slug. Use the name and slug from the search results.\n\
If search_destinations returns an empty array, the destination is not available \
mention it in \"notes\" (e.g. \"No travel data for: Gatwick Airport\") and do NOT \
include a travel_time_filter for it.\n\
\n\
Travel time values are in MINUTES (0-90 range; data is capped at 90 min).\n\
- \"within 30 minutes\" = set \"max\": 30\n\
- \"at least 10 minutes\" = set \"min\": 10\n\
- \"30-45 minute commute\" = set \"min\": 30 and \"max\": 45 on the same travel_time_filter\n\
- If only a max is given, omit min (and vice versa). Do not use bound/value for travel time.\n\
\n\
INFERRING TRANSPORT MODE (when the user does not specify one explicitly):\n\
- \"commute\" to a major city centre or station = transit\n\
- \"near\" / \"close to\" a city centre or station = transit\n\
- \"near\" / \"close to\" a smaller town, village, or rural area = car\n\
- \"drive\" / \"driving distance\" / \"driving time\" = always car\n\
- If multiple modes are plausible, prefer transit for urban destinations \
(London, Manchester, Birmingham, Leeds, etc.) and car for everything else.",
modes_list,
));
// Feature guidance
parts.push(
"\n--- DATA SOURCE ---\n\
The data is historical property sales from the Land Registry.\n\
\n\
Use these features for price queries:\n\
- For purchase price: use \"Estimated current price\" or \"Last known price\"\n\
- For price per sqm: use \"Est. price per sqm\"\n\
- For rent estimates: use \"Estimated monthly rent\""
.to_string(),
);
// Feature catalogue
parts.push("\n--- AVAILABLE FEATURES ---\n".to_string());
for group in &features.groups {
parts.push(format!("## {}", group.name));
for feature in &group.features {
match feature {
FeatureInfo::Numeric {
name,
min,
max,
description,
prefix,
suffix,
..
} => {
parts.push(format!(
"- \"{}\" (numeric, {}{:.0}{} to {}{:.0}{}): {}",
name, prefix, min, suffix, prefix, max, suffix, description
));
}
FeatureInfo::Enum {
name,
values,
description,
..
} => {
parts.push(format!(
"- \"{}\" (enum, values: [{}]): {}",
name,
values
.iter()
.map(|val| format!("\"{}\"", val))
.collect::<Vec<_>>()
.join(", "),
description
));
}
}
}
}
// Few-shot examples
parts.push("\n--- EXAMPLES ---\n".to_string());
parts.push(
"User: \"cheap freehold house under 400k\"\n\
Output: {\"numeric_filters\": [{\"name\": \"Last known price\", \"bound\": \"max\", \"value\": 400000}], \
\"enum_filters\": [{\"name\": \"Leasehold/Freehold\", \"values\": [\"Freehold\"]}, \
{\"name\": \"Property type\", \"values\": [\"Detached\", \"Semi-Detached\", \"Terraced\"]}], \
\"travel_time_filters\": [], \
\"notes\": \"\"}"
.to_string(),
);
parts.push(
"\nUser: \"safe quiet area with good schools and parks\"\n\
Output: {\"numeric_filters\": [\
{\"name\": \"Serious crime (avg/yr)\", \"bound\": \"max\", \"value\": 5}, \
{\"name\": \"Minor crime (avg/yr)\", \"bound\": \"max\", \"value\": 20}, \
{\"name\": \"Noise (dB)\", \"bound\": \"max\", \"value\": 55}, \
{\"name\": \"Good+ primary school catchments\", \"bound\": \"min\", \"value\": 2}, \
{\"name\": \"Good+ secondary school catchments\", \"bound\": \"min\", \"value\": 1}, \
{\"name\": \"Number of amenities (Park) within 2km\", \"bound\": \"min\", \"value\": 3}], \
\"enum_filters\": [], \"travel_time_filters\": [], \"notes\": \"\"}"
.to_string(),
);
parts.push(
"\nUser: \"quiet area with outstanding schools\"\n\
Output: {\"numeric_filters\": [\
{\"name\": \"Noise (dB)\", \"bound\": \"max\", \"value\": 55}, \
{\"name\": \"Outstanding primary school catchments\", \"bound\": \"min\", \"value\": 1}, \
{\"name\": \"Outstanding secondary school catchments\", \"bound\": \"min\", \"value\": 1}], \
\"enum_filters\": [], \"travel_time_filters\": [], \"notes\": \"\"}"
.to_string(),
);
parts.push(
"\nUser: \"3 bed flat under 300k with fast broadband near the beach\"\n\
Output: {\"numeric_filters\": [\
{\"name\": \"Last known price\", \"bound\": \"max\", \"value\": 300000}, \
{\"name\": \"Number of bedrooms & living rooms\", \"bound\": \"min\", \"value\": 4}], \
\"enum_filters\": [{\"name\": \"Property type\", \"values\": [\"Flats/Maisonettes\"]}, \
{\"name\": \"Max available download speed (Mbps)\", \"values\": [\"100\", \"300\", \"1000\"]}], \
\"travel_time_filters\": [], \
\"notes\": \"No filter for: beach proximity\"}"
.to_string(),
);
parts.push(
"\nUser: \"within 30 minutes commute of Kings Cross, under 500k\"\n\
(After calling search_destinations for \"Kings Cross\" with mode \"transit\" \
and getting [{\"name\": \"Kings Cross\", \"slug\": \"kings-cross\", \"place_type\": \"station\"}])\n\
Output: {\"numeric_filters\": [\
{\"name\": \"Last known price\", \"bound\": \"max\", \"value\": 500000}], \
\"enum_filters\": [], \
\"travel_time_filters\": [{\"mode\": \"transit\", \"slug\": \"kings-cross\", \
\"label\": \"Kings Cross\", \"max\": 30}], \
\"notes\": \"\"}"
.to_string(),
);
parts.push(
"\nUser: \"family home with garden, 45 min drive from Manchester, good schools\"\n\
(After calling search_destinations for \"Manchester\" with mode \"car\" \
and getting [{\"name\": \"Manchester\", \"slug\": \"manchester\", \"place_type\": \"city\"}])\n\
Output: {\"numeric_filters\": [\
{\"name\": \"Total floor area (sqm)\", \"bound\": \"min\", \"value\": 100}, \
{\"name\": \"Number of bedrooms & living rooms\", \"bound\": \"min\", \"value\": 5}, \
{\"name\": \"Good+ primary school catchments\", \"bound\": \"min\", \"value\": 2}, \
{\"name\": \"Good+ secondary school catchments\", \"bound\": \"min\", \"value\": 1}], \
\"enum_filters\": [{\"name\": \"Property type\", \
\"values\": [\"Detached\", \"Semi-Detached\"]}], \
\"travel_time_filters\": [{\"mode\": \"car\", \"slug\": \"manchester\", \
\"label\": \"Manchester\", \"max\": 45}], \
\"notes\": \"No filter for: garden\"}"
.to_string(),
);
parts.push(
"\nUser: \"Labour-voting area with low burglary and a station nearby\"\n\
Output: {\"numeric_filters\": [\
{\"name\": \"% Labour\", \"bound\": \"min\", \"value\": 40}, \
{\"name\": \"Burglary (avg/yr)\", \"bound\": \"max\", \"value\": 10}, \
{\"name\": \"Distance to nearest amenity (Rail station) (km)\", \"bound\": \"max\", \"value\": 1}], \
\"enum_filters\": [], \"travel_time_filters\": [], \"notes\": \"\"}"
.to_string(),
);
// Examples showing rent and price features
parts.push(
"\nUser: \"2 bed flat with rent under £1500/month\"\n\
Output: {\
\"numeric_filters\": [{\"name\": \"Estimated monthly rent\", \"bound\": \"max\", \"value\": 1500}], \
\"enum_filters\": [{\"name\": \"Property type\", \"values\": [\"Flats/Maisonettes\"]}], \
\"travel_time_filters\": [], \
\"notes\": \"\"}"
.to_string(),
);
parts.push(
"\nUser: \"3 bed house under 500k with good schools\"\n\
Output: {\
\"numeric_filters\": [{\"name\": \"Estimated current price\", \"bound\": \"max\", \"value\": 500000}, \
{\"name\": \"Good+ primary school catchments\", \"bound\": \"min\", \"value\": 2}], \
\"enum_filters\": [{\"name\": \"Property type\", \
\"values\": [\"Detached\", \"Semi-Detached\", \"Terraced\"]}], \
\"travel_time_filters\": [], \
\"notes\": \"\"}"
.to_string(),
);
// Output format reminder
parts.push(
"\n--- OUTPUT FORMAT ---\n\
{\"numeric_filters\": [...], \"enum_filters\": [...], \
\"travel_time_filters\": [{\"mode\": \"...\", \"slug\": \"...\", \"label\": \"...\", \
\"min\": N, \"max\": N}, ...], \"notes\": \"...\"}\n\
- travel_time_filters: min and max are both optional, but include at least one. \
Use ONLY slugs returned by search_destinations. If a place isn't found, mention it in notes.\n\
Respond with ONLY the JSON object. No explanation."
.to_string(),
);
parts.join("\n")
}

View file

@ -0,0 +1,188 @@
//! The `search_destinations` Gemini tool: declaration and execution against
//! PlaceData + TravelTimeStore.
use serde_json::{json, Value};
use tracing::info;
use crate::data::slugify;
use crate::state::AppState;
/// Build the Gemini tool declaration for destination search.
pub(super) fn build_tool_declarations(state: &AppState) -> Value {
let modes: Vec<&str> = state
.travel_time_store
.available_modes
.iter()
.map(|mode| mode.as_str())
.collect();
json!([{
"functionDeclarations": [{
"name": "search_destinations",
"description": "Search for available travel time destinations (cities, stations, towns) that have precomputed travel time data. Call this when the user mentions wanting to be near, close to, or within a certain travel time of a specific place.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Place name to search for (e.g. 'Manchester', 'Kings Cross', 'Heathrow')"
},
"mode": {
"type": "string",
"enum": modes,
"description": "Transport mode to search destinations for"
}
},
"required": ["query", "mode"]
}
}]
}])
}
/// Execute a destination search against PlaceData + TravelTimeStore.
/// Returns matching destinations as a JSON value with `results` and optional `message`.
///
/// Uses word-based matching: all words in the query must appear somewhere in the
/// place name (order-independent). Also matches against slugs for short queries.
pub(super) fn execute_destination_search(state: &AppState, query: &str, mode: &str) -> Value {
let query_lower = query.to_lowercase();
let query_words: Vec<&str> = query_lower.split_whitespace().collect();
let query_slug = slugify(query);
let tt_store = &state.travel_time_store;
let pd = &state.place_data;
let slug_set = match tt_store.destinations.get(mode) {
Some(slugs) => slugs,
None => {
return json!({ "results": [], "message": format!("No travel data available for mode '{}'", mode) })
}
};
// Find places matching the query that have travel time data.
// A place matches if ALL query words appear in its name, OR its slug matches the query slug.
let mut matches: Vec<(usize, String, u8, u32)> = pd
.name_lower
.iter()
.enumerate()
.filter_map(|(idx, name_lower)| {
if !pd.travel_destination[idx] {
return None;
}
let words_match = query_words.iter().all(|word| name_lower.contains(word));
let slug = slugify(&pd.name[idx]);
let slug_match = slug.contains(&query_slug) || query_slug.contains(&slug);
if !words_match && !slug_match {
return None;
}
if slug_set.contains(&slug) {
Some((idx, slug, pd.type_rank[idx], pd.population[idx]))
} else {
None
}
})
.collect();
// Sort: type rank asc, population desc
matches.sort_unstable_by(|a, b| a.2.cmp(&b.2).then(b.3.cmp(&a.3)));
matches.truncate(10);
if matches.is_empty() {
// Check if the query matched a city that lacks its own travel data.
// If so, return nearby stations within that city as suggestions.
let matched_city_name: Option<&str> =
pd.name_lower
.iter()
.enumerate()
.find_map(|(idx, name_lower)| {
if !pd.travel_destination[idx] {
return None;
}
let words_match = query_words.iter().all(|word| name_lower.contains(word));
let slug = slugify(&pd.name[idx]);
let slug_match = slug.contains(&query_slug) || query_slug.contains(&slug);
if (words_match || slug_match) && pd.type_rank[idx] == 0 {
Some(pd.name[idx].as_str())
} else {
None
}
});
if let Some(city_name) = matched_city_name {
let city_lower = city_name.to_lowercase();
let mut city_matches: Vec<(usize, String, u8, u32)> = pd
.city
.iter()
.enumerate()
.filter_map(|(idx, city_opt)| {
if !pd.travel_destination[idx] {
return None;
}
let city = city_opt.as_deref()?;
if city.to_lowercase() != city_lower {
return None;
}
let slug = slugify(&pd.name[idx]);
if slug_set.contains(&slug) {
Some((idx, slug, pd.type_rank[idx], pd.population[idx]))
} else {
None
}
})
.collect();
city_matches.sort_unstable_by(|a, b| a.2.cmp(&b.2).then(b.3.cmp(&a.3)));
city_matches.truncate(10);
if !city_matches.is_empty() {
let results: Vec<Value> = city_matches
.into_iter()
.map(|(idx, slug, ..)| {
json!({
"name": pd.name[idx],
"slug": slug,
"place_type": pd.place_type.get(idx).to_string(),
})
})
.collect();
info!(
query = query,
city = city_name,
results = results.len(),
"Destination search fell back to city stations"
);
return json!({
"results": results,
"message": format!(
"No travel data for '{}' directly. Pick one of these nearby stations:",
city_name
)
});
}
}
info!(
query = query,
mode = mode,
"Destination search returned no results"
);
return json!({
"results": [],
"message": format!("No travel time data available for '{}' by {}. This destination cannot be used as a travel time filter.", query, mode)
});
}
let results: Vec<Value> = matches
.into_iter()
.map(|(idx, slug, ..)| {
json!({
"name": pd.name[idx],
"slug": slug,
"place_type": pd.place_type.get(idx).to_string(),
})
})
.collect();
json!({ "results": results })
}

View file

@ -0,0 +1,119 @@
//! Weekly AI token usage tracking and rate limiting, persisted on the user's
//! PocketBase record.
use axum::http::StatusCode;
use metrics::counter;
use serde_json::{json, Value};
use tracing::warn;
use crate::pocketbase::get_superuser_token;
use crate::state::AppState;
/// Monotonically increasing week number derived from Unix epoch.
/// Resets every 7 days (604800 seconds). Used for weekly rate limiting.
pub(super) fn current_week_number() -> u64 {
let secs = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
// Only possible if the system clock is before 1970; fall back to
// week 0 rather than panicking inside a request handler.
.unwrap_or_default()
.as_secs();
secs / 604_800
}
/// Fetch the user's current AI token usage from PocketBase.
/// Returns `(tokens_used, week_number)`.
pub(super) async fn fetch_ai_usage(
state: &AppState,
user_id: &str,
) -> Result<(u64, u64), (StatusCode, String)> {
let token = get_superuser_token(state).await.map_err(|err| {
warn!("Failed to auth superuser for AI usage check: {err}");
(StatusCode::BAD_GATEWAY, "Internal error".into())
})?;
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await
.map_err(|err| {
warn!("Failed to fetch user record for AI usage: {err}");
(StatusCode::BAD_GATEWAY, "Internal error".into())
})?;
if !resp.status().is_success() {
let status = resp.status();
warn!("PocketBase user fetch failed ({status})");
return Err((StatusCode::BAD_GATEWAY, "Internal error".into()));
}
let body: Value = resp.json().await.map_err(|err| {
warn!("Failed to parse user record: {err}");
(StatusCode::BAD_GATEWAY, "Internal error".into())
})?;
let tokens_used = body
.get("ai_tokens_used")
.and_then(|val| val.as_u64())
.unwrap_or(0);
let week = body
.get("ai_tokens_week")
.and_then(|val| val.as_u64())
.unwrap_or(0);
Ok((tokens_used, week))
}
/// Update the user's AI token usage in PocketBase.
/// Best-effort — logs warnings on failure but does not propagate errors.
async fn update_ai_usage(state: &AppState, user_id: &str, tokens_used: u64, week: u64) {
let token = match get_superuser_token(state).await {
Ok(tk) => tk,
Err(err) => {
warn!("Failed to auth superuser for AI usage update: {err}");
return;
}
};
let pb_url = state.pocketbase_url.trim_end_matches('/');
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
let res = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&json!({
"ai_tokens_used": tokens_used,
"ai_tokens_week": week,
}))
.send()
.await;
match res {
Ok(resp) if resp.status().is_success() => {}
Ok(resp) => {
let status = resp.status();
warn!("Failed to update AI usage ({status})");
}
Err(err) => warn!("Failed to update AI usage: {err}"),
}
}
pub(super) async fn record_ai_request_usage(
state: &AppState,
user_id: &str,
existing_tokens_used: u64,
week: u64,
request_tokens_used: u64,
status: &'static str,
) {
if request_tokens_used > 0 {
let new_total = existing_tokens_used.saturating_add(request_tokens_used);
update_ai_usage(state, user_id, new_total, week).await;
counter!("ai_tokens_total").increment(request_tokens_used);
}
counter!("ai_requests_total", "status" => status).increment(1);
}

View file

@ -780,31 +780,28 @@ pub async fn get_export(
// groups themselves; postcodes within a group are sorted alphabetically.
// Each group carries a rolled-up summary aggregate for its header row.
let outcode_groups: Vec<OutcodeGroup> = {
let mut order: Vec<String> = Vec::new();
let mut by_outcode: FxHashMap<String, OutcodeGroup> = FxHashMap::default();
let mut groups: Vec<OutcodeGroup> = Vec::new();
let mut idx_by_outcode: FxHashMap<String, usize> = FxHashMap::default();
for (i, (pc_idx, agg)) in postcode_aggs.iter().enumerate() {
let outcode = outcode_of(&postcode_data.postcodes[*pc_idx]).to_string();
let group = by_outcode.entry(outcode.clone()).or_insert_with(|| {
order.push(outcode.clone());
OutcodeGroup {
outcode: outcode.clone(),
let idx = *idx_by_outcode.entry(outcode.clone()).or_insert_with(|| {
groups.push(OutcodeGroup {
outcode,
members: Vec::new(),
summary: PostcodeExportAgg::new(total_export_features),
}
});
groups.len() - 1
});
group.members.push(i);
group.summary.merge_from(agg);
groups[idx].members.push(i);
groups[idx].summary.merge_from(agg);
}
for group in by_outcode.values_mut() {
for group in &mut groups {
group.members.sort_by(|&a, &b| {
postcode_data.postcodes[postcode_aggs[a].0]
.cmp(&postcode_data.postcodes[postcode_aggs[b].0])
});
}
order
.into_iter()
.map(|outcode| by_outcode.remove(&outcode).unwrap())
.collect()
groups
};
// Build Excel workbook with two sheets

View file

@ -130,6 +130,11 @@ pub struct HexagonStatsResponse {
pub price_history: Vec<PricePoint>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub crime_by_year: Vec<CrimeYearStats>,
/// Latest year in the crime dataset as a whole. When a selection's series
/// end earlier (force-level publication gap, e.g. Greater Manchester),
/// the client captions the data as stale.
#[serde(skip_serializing_if = "Option::is_none")]
pub crime_latest_year: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub central_postcode: Option<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
@ -645,12 +650,19 @@ pub async fn get_hexagon_stats(
"GET /api/hexagon-stats"
);
let crime_latest_year = if crime_by_year.is_empty() {
None
} else {
stats::crime_latest_available_year(&state.crime_by_year)
};
Ok(HexagonStatsResponse {
count: total_count,
numeric_features,
enum_features: enum_features_out,
price_history,
crime_by_year,
crime_latest_year,
central_postcode,
filter_exclusions,
})

View file

@ -36,8 +36,9 @@ fn is_allowed_pb_path(path: &str) -> bool {
/// Dedicated HTTP client for proxying — does not follow redirects so 3xx
/// responses are passed through to the browser (needed for OAuth flows).
/// No overall timeout because SSE (Server-Sent Events) connections used by
/// PocketBase realtime/OAuth2 are long-lived streams.
/// No client-wide timeout because SSE (Server-Sent Events) connections used
/// by PocketBase realtime/OAuth2 are long-lived streams; non-realtime
/// requests get a per-request timeout instead (see PROXY_REQUEST_TIMEOUT).
static PROXY_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
@ -47,6 +48,11 @@ static PROXY_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
.expect("Failed to build proxy HTTP client")
});
/// Timeout for proxied requests other than the realtime SSE stream, so a hung
/// PocketBase cannot pile up handlers indefinitely. Generous enough for file
/// uploads/downloads over slow links.
const PROXY_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
pub async fn proxy_to_pocketbase(
State(shared): State<Arc<SharedState>>,
req: Request,
@ -58,10 +64,7 @@ pub async fn proxy_to_pocketbase(
let target_path = path.strip_prefix("/pb").unwrap_or(path);
if !is_allowed_pb_path(target_path) {
warn!(path = %target_path, "Rejected PocketBase proxy request to disallowed path");
return Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::empty())
.unwrap();
return StatusCode::NOT_FOUND.into_response();
}
let query = req
.uri()
@ -73,6 +76,12 @@ pub async fn proxy_to_pocketbase(
let method = req.method().clone();
let mut builder = PROXY_CLIENT.request(method, &url);
// The realtime SSE stream is intentionally unbounded; everything else
// must complete within the timeout.
if target_path != "/api/realtime" {
builder = builder.timeout(PROXY_REQUEST_TIMEOUT);
}
// Forward only safe headers (allowlist)
const ALLOWED_HEADERS: &[&str] = &[
"content-type",
@ -96,10 +105,7 @@ pub async fn proxy_to_pocketbase(
Ok(bytes) => bytes,
Err(err) => {
warn!("Failed to read request body: {err}");
return Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Failed to read request body"))
.unwrap();
return (StatusCode::BAD_REQUEST, "Failed to read request body").into_response();
}
};
builder = builder.body(body_bytes);
@ -129,14 +135,14 @@ pub async fn proxy_to_pocketbase(
// realtime system and OAuth2 flow — buffering would hang forever
// since SSE responses never complete.
let body = Body::from_stream(upstream.bytes_stream());
response.body(body).unwrap()
response.body(body).unwrap_or_else(|err| {
warn!("Failed to build proxied response: {err}");
(StatusCode::BAD_GATEWAY, "Invalid upstream response").into_response()
})
}
Err(err) => {
warn!("PocketBase proxy error: {err}");
Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Body::from("PocketBase unavailable"))
.unwrap()
(StatusCode::BAD_GATEWAY, "PocketBase unavailable").into_response()
}
}
}

View file

@ -184,12 +184,19 @@ pub async fn get_postcode_stats(
"GET /api/postcode-stats"
);
let crime_latest_year = if crime_by_year.is_empty() {
None
} else {
stats::crime_latest_available_year(&state.crime_by_year)
};
Ok(HexagonStatsResponse {
count: total_count,
numeric_features,
enum_features: enum_features_out,
price_history,
crime_by_year,
crime_latest_year,
central_postcode: None,
filter_exclusions,
})

View file

@ -340,16 +340,14 @@ pub fn compute_crime_by_year(
let points: Vec<CrimeYearPoint> = years
.iter()
.filter_map(|&year| {
let denom = fully_covered_rows
+ covered_counts.get(&year).copied().unwrap_or(0);
let denom = fully_covered_rows + covered_counts.get(&year).copied().unwrap_or(0);
if denom == 0 {
// No selected postcode has published data for this year.
return None;
}
Some(CrimeYearPoint {
year,
count: (sums.get(&year).copied().unwrap_or(0.0) / denom as f64)
as f32,
count: (sums.get(&year).copied().unwrap_or(0.0) / denom as f64) as f32,
})
})
.collect();
@ -365,6 +363,19 @@ pub fn compute_crime_by_year(
out
}
/// Latest year present anywhere in the by-year crime dataset. The client
/// compares each selection's last charted year against this to caption
/// force-level publication gaps (e.g. Greater Manchester ends mid-2019) as
/// stale data instead of letting old numbers read as current.
pub fn crime_latest_available_year(crime_by_year: &CrimeByYearData) -> Option<i32> {
crime_by_year
.years_by_type
.iter()
.flatten()
.copied()
.max()
}
pub fn compute_poi_feature_stats(
matching_rows: &[usize],
poi_metrics: &PostcodePoiMetrics,

View file

@ -87,6 +87,105 @@ pub struct AppState {
pub bugsink_frontend_config: Option<BugsinkFrontendConfig>,
}
#[cfg(test)]
impl AppState {
/// Minimal AppState for integration tests of the PocketBase/Stripe money
/// paths (checkout, webhook, licensing, invites). All map/property data is
/// empty; only the HTTP-facing config (PocketBase URL, Stripe secrets,
/// caches) carries meaningful values.
pub(crate) fn for_tests(pocketbase_url: String) -> Self {
use std::time::Duration;
use crate::data::{
ActualListingData, CrimeByYearData, OutcodeData, POIData, PlaceData, PostcodeData,
PropertyData, TravelTimeStore,
};
use crate::utils::InternedColumn;
let http_client = reqwest::Client::builder()
.timeout(Duration::from_secs(5))
.connect_timeout(Duration::from_secs(2))
.build()
.expect("test HTTP client should build");
AppState {
data: PropertyData::empty_for_tests(),
grid: GridIndex::build(&[], &[], 0.01),
h3_cells: Vec::new(),
feature_name_to_index: FxHashMap::default(),
min_keys: Vec::new(),
max_keys: Vec::new(),
avg_keys: Vec::new(),
features_response: FeaturesResponse { groups: Vec::new() },
ai_filters_system_prompt: String::new(),
poi_data: Arc::new(POIData::empty_for_tests()),
poi_grid: Arc::new(GridIndex::build(&[], &[], 0.01)),
place_data: Arc::new(PlaceData::empty_for_tests()),
postcode_data: Arc::new(PostcodeData {
postcodes: Vec::new(),
centroids: Vec::new(),
aabbs: Vec::new(),
polygons: Vec::new(),
postcode_to_idx: FxHashMap::default(),
}),
outcode_data: Arc::new(OutcodeData {
names: Vec::new(),
name_lower: Vec::new(),
centroids: Vec::new(),
cities: Vec::new(),
}),
poi_category_groups: Arc::new(Vec::new()),
travel_time_store: Arc::new(TravelTimeStore::empty_for_tests()),
actual_listings: Arc::new(ActualListingData {
lat: Vec::new(),
lon: Vec::new(),
postcode: Vec::new(),
address: Vec::new(),
property_type: InternedColumn::build(&[]),
property_sub_type: InternedColumn::build(&[]),
leasehold_freehold: InternedColumn::build(&[]),
price_qualifier: InternedColumn::build(&[]),
bedrooms: Vec::new(),
bathrooms: Vec::new(),
rooms_total: Vec::new(),
floor_area_sqm: Vec::new(),
asking_price: Vec::new(),
asking_price_per_sqm: Vec::new(),
listing_url: Vec::new(),
listing_status: InternedColumn::build(&[]),
listing_date_iso: Vec::new(),
features: Vec::new(),
filter_feature_data: Vec::new(),
poi_filter_feature_data: Vec::new(),
grid: GridIndex::build(&[], &[], 0.01),
}),
crime_by_year: Arc::new(CrimeByYearData {
crime_types: Vec::new(),
years_by_type: Vec::new(),
series_by_postcode: FxHashMap::default(),
covered_years_by_postcode: FxHashMap::default(),
}),
token_cache: Arc::new(TokenCache::new()),
superuser_token_cache: Arc::new(SuperuserTokenCache::new()),
share_cache: Arc::new(ShareBoundsCache::new()),
screenshot_url: "http://127.0.0.1:1/screenshot".to_string(),
public_url: "https://test.example".to_string(),
is_dev: false,
http_client,
pocketbase_url,
pocketbase_admin_email: "admin@test.example".to_string(),
pocketbase_admin_password: "test-admin-password".to_string(),
gemini_api_key: "test-gemini-key".to_string(),
gemini_model: "test-model".to_string(),
google_maps_api_key: "test-maps-key".to_string(),
stripe_secret_key: "sk_test_dummy".to_string(),
stripe_webhook_secret: "whsec_test_secret".to_string(),
stripe_referral_coupon_id: "couponTest30".to_string(),
bugsink_frontend_config: None,
}
}
}
/// Wraps AppState for shared access across route handlers.
/// Route handlers call `load_state()` to get the current snapshot.
pub struct SharedState {