SPlit up
This commit is contained in:
parent
cf39ad754e
commit
f59d01227b
91 changed files with 10370 additions and 7562 deletions
File diff suppressed because it is too large
Load diff
589
server-rs/src/checkout_sessions/lifecycle.rs
Normal file
589
server-rs/src/checkout_sessions/lifecycle.rs
Normal file
|
|
@ -0,0 +1,589 @@
|
|||
//! The checkout session state machine: starting a checkout (with pricing and
|
||||
//! reservation under a cross-instance lock), verifying Stripe's completion
|
||||
//! payload, completing/granting, and reversing/reinstating after refunds or
|
||||
//! disputes.
|
||||
|
||||
use std::sync::LazyLock;
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::auth::PocketBaseUser;
|
||||
use crate::pocketbase::get_superuser_token;
|
||||
use crate::pocketbase_locks::acquire_pocketbase_lock;
|
||||
use crate::routes::pricing::{count_licensed_users, price_for_count};
|
||||
use crate::state::AppState;
|
||||
|
||||
use super::records::{
|
||||
attach_stripe_session, count_active_pending_checkouts, create_pending_checkout,
|
||||
expire_stale_pending_checkouts, find_active_checkout_for_user,
|
||||
find_checkout_by_payment_intent_or_checkout_session, find_checkout_by_stripe_session,
|
||||
has_other_completed_checkout_for_user, mark_checkout_completed, mark_checkout_reinstated,
|
||||
mark_checkout_reversed, mark_checkout_status, PendingCheckoutInput,
|
||||
};
|
||||
use super::referral::{
|
||||
mark_referral_invite_used, release_referral_invite_reservation, reserve_referral_invite,
|
||||
};
|
||||
use super::stripe::create_stripe_session;
|
||||
use super::{
|
||||
ensure_success, is_safe_reversal_reason, is_safe_stripe_session_id, now_unix_secs,
|
||||
number_field, CheckoutCompletion, CheckoutStart, PaymentReinstatementOutcome,
|
||||
PaymentReversalOutcome, VerifiedCheckout, CHECKOUT_CURRENCY, REFERRAL_DISCOUNT_PERCENT,
|
||||
};
|
||||
|
||||
const CHECKOUT_SESSION_TTL_SECS: u64 = 31 * 60;
|
||||
const CHECKOUT_PRICING_LOCK_NAME: &str = "checkout:pricing";
|
||||
const CHECKOUT_PRICING_LOCK_TTL_SECS: u64 = 5 * 60;
|
||||
|
||||
static CHECKOUT_RESERVATION_LOCK: LazyLock<Mutex<()>> = LazyLock::new(|| Mutex::new(()));
|
||||
|
||||
pub async fn start_license_checkout(
|
||||
state: &AppState,
|
||||
user: &PocketBaseUser,
|
||||
success_url: &str,
|
||||
cancel_url: &str,
|
||||
discount_coupon_id: Option<&str>,
|
||||
referral_invite_id: Option<&str>,
|
||||
) -> anyhow::Result<CheckoutStart> {
|
||||
let _guard = CHECKOUT_RESERVATION_LOCK.lock().await;
|
||||
let pricing_lock = acquire_pocketbase_lock(
|
||||
state,
|
||||
CHECKOUT_PRICING_LOCK_NAME,
|
||||
CHECKOUT_PRICING_LOCK_TTL_SECS,
|
||||
)
|
||||
.await?;
|
||||
let result = start_license_checkout_locked(
|
||||
state,
|
||||
user,
|
||||
success_url,
|
||||
cancel_url,
|
||||
discount_coupon_id,
|
||||
referral_invite_id,
|
||||
)
|
||||
.await;
|
||||
if let Err(err) = pricing_lock.release().await {
|
||||
warn!("Failed to release checkout pricing lock: {err}");
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
async fn start_license_checkout_locked(
|
||||
state: &AppState,
|
||||
user: &PocketBaseUser,
|
||||
success_url: &str,
|
||||
cancel_url: &str,
|
||||
discount_coupon_id: Option<&str>,
|
||||
referral_invite_id: Option<&str>,
|
||||
) -> anyhow::Result<CheckoutStart> {
|
||||
let now = now_unix_secs();
|
||||
expire_stale_pending_checkouts(state, now).await?;
|
||||
|
||||
if let Some(existing) = find_active_checkout_for_user(
|
||||
state,
|
||||
&user.id,
|
||||
discount_coupon_id.unwrap_or_default(),
|
||||
referral_invite_id.unwrap_or_default(),
|
||||
now,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
if !existing.checkout_url.is_empty() {
|
||||
return Ok(CheckoutStart::Stripe {
|
||||
url: existing.checkout_url,
|
||||
});
|
||||
}
|
||||
if let Err(err) = mark_checkout_status(state, &existing.id, "failed").await {
|
||||
warn!(
|
||||
reservation_id = %existing.id,
|
||||
"Failed to fail incomplete checkout reservation: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let licensed_count = count_licensed_users(state).await?;
|
||||
let pending_count = count_active_pending_checkouts(state, now).await?;
|
||||
let price_pence = price_for_count(licensed_count + pending_count);
|
||||
|
||||
if price_pence == 0 {
|
||||
grant_license(state, &user.id).await?;
|
||||
return Ok(CheckoutStart::Free);
|
||||
}
|
||||
|
||||
let expires_at_unix = now + CHECKOUT_SESSION_TTL_SECS;
|
||||
let expected_total_pence = expected_total_for_checkout(price_pence, discount_coupon_id);
|
||||
let reservation_id = create_pending_checkout(
|
||||
state,
|
||||
PendingCheckoutInput {
|
||||
user_id: &user.id,
|
||||
amount_pence: price_pence,
|
||||
expected_total_pence,
|
||||
currency: CHECKOUT_CURRENCY,
|
||||
discount_coupon_id: discount_coupon_id.unwrap_or_default(),
|
||||
referral_invite_id: referral_invite_id.unwrap_or_default(),
|
||||
expires_at_unix,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
if let Some(invite_id) = referral_invite_id.filter(|id| !id.is_empty()) {
|
||||
if let Err(err) =
|
||||
reserve_referral_invite(state, invite_id, &user.id, &reservation_id, expires_at_unix)
|
||||
.await
|
||||
{
|
||||
if let Err(mark_err) = mark_checkout_status(state, &reservation_id, "failed").await {
|
||||
warn!(
|
||||
reservation_id,
|
||||
"Failed to mark checkout reservation failed: {mark_err}"
|
||||
);
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
|
||||
let stripe_result = create_stripe_session(
|
||||
state,
|
||||
user,
|
||||
&reservation_id,
|
||||
price_pence,
|
||||
success_url,
|
||||
cancel_url,
|
||||
expires_at_unix,
|
||||
discount_coupon_id,
|
||||
)
|
||||
.await;
|
||||
|
||||
let (stripe_session_id, url) = match stripe_result {
|
||||
Ok(session) => session,
|
||||
Err(err) => {
|
||||
if let Err(mark_err) = mark_checkout_status(state, &reservation_id, "failed").await {
|
||||
warn!(
|
||||
reservation_id,
|
||||
"Failed to mark checkout reservation failed: {mark_err}"
|
||||
);
|
||||
}
|
||||
if let Some(invite_id) = referral_invite_id.filter(|id| !id.is_empty()) {
|
||||
if let Err(release_err) =
|
||||
release_referral_invite_reservation(state, invite_id, &reservation_id).await
|
||||
{
|
||||
warn!(
|
||||
reservation_id,
|
||||
referral_invite_id = invite_id,
|
||||
"Failed to release referral invite reservation: {release_err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(err) = attach_stripe_session(state, &reservation_id, &stripe_session_id, &url).await
|
||||
{
|
||||
if let Err(mark_err) = mark_checkout_status(state, &reservation_id, "failed").await {
|
||||
warn!(
|
||||
reservation_id,
|
||||
"Failed to mark checkout reservation failed: {mark_err}"
|
||||
);
|
||||
}
|
||||
if let Some(invite_id) = referral_invite_id.filter(|id| !id.is_empty()) {
|
||||
if let Err(release_err) =
|
||||
release_referral_invite_reservation(state, invite_id, &reservation_id).await
|
||||
{
|
||||
warn!(
|
||||
reservation_id,
|
||||
referral_invite_id = invite_id,
|
||||
"Failed to release referral invite reservation: {release_err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
Ok(CheckoutStart::Stripe { url })
|
||||
}
|
||||
|
||||
pub async fn verify_checkout_completion(
|
||||
state: &AppState,
|
||||
session: &Value,
|
||||
) -> anyhow::Result<CheckoutCompletion> {
|
||||
let session_id = match session["id"].as_str() {
|
||||
Some(id) if is_safe_stripe_session_id(id) => id,
|
||||
_ => {
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"missing or invalid session id".into(),
|
||||
))
|
||||
}
|
||||
};
|
||||
let payment_intent_id = match session["payment_intent"].as_str() {
|
||||
Some(id) if is_safe_stripe_session_id(id) => id,
|
||||
_ => {
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"missing or invalid payment intent id".into(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let checkout = match find_checkout_by_stripe_session(state, session_id).await? {
|
||||
Some(checkout) => checkout,
|
||||
None => {
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"checkout session has no reservation".into(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let already_completed = checkout.status == "completed";
|
||||
if !already_completed && checkout.status != "pending" && checkout.status != "expired" {
|
||||
return Ok(CheckoutCompletion::Rejected(format!(
|
||||
"checkout reservation is {}",
|
||||
checkout.status
|
||||
)));
|
||||
}
|
||||
if checkout.stripe_session_id != session_id {
|
||||
mark_checkout_status(state, &checkout.id, "invalid").await?;
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"checkout reservation session id mismatch".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let client_reference_id = session["client_reference_id"].as_str().unwrap_or_default();
|
||||
if client_reference_id != checkout.user_id {
|
||||
mark_checkout_status(state, &checkout.id, "invalid").await?;
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"checkout client_reference_id mismatch".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let payment_status = session["payment_status"].as_str().unwrap_or_default();
|
||||
if payment_status != "paid" {
|
||||
return Ok(CheckoutCompletion::Rejected(format!(
|
||||
"checkout payment_status is {payment_status}"
|
||||
)));
|
||||
}
|
||||
|
||||
let currency = session["currency"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_ascii_lowercase();
|
||||
if currency != checkout.currency {
|
||||
mark_checkout_status(state, &checkout.id, "invalid").await?;
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"checkout currency mismatch".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let amount_subtotal = match number_field(session, "amount_subtotal") {
|
||||
Some(amount) => amount,
|
||||
None => {
|
||||
mark_checkout_status(state, &checkout.id, "invalid").await?;
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"checkout amount_subtotal missing".into(),
|
||||
));
|
||||
}
|
||||
};
|
||||
if amount_subtotal != checkout.amount_pence {
|
||||
mark_checkout_status(state, &checkout.id, "invalid").await?;
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"checkout amount_subtotal mismatch".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let amount_total = match number_field(session, "amount_total") {
|
||||
Some(amount) => amount,
|
||||
None => {
|
||||
mark_checkout_status(state, &checkout.id, "invalid").await?;
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"checkout amount_total missing".into(),
|
||||
));
|
||||
}
|
||||
};
|
||||
if amount_total != checkout.expected_total_pence {
|
||||
mark_checkout_status(state, &checkout.id, "invalid").await?;
|
||||
return Ok(CheckoutCompletion::Rejected(
|
||||
"checkout amount_total mismatch".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let verified = VerifiedCheckout {
|
||||
reservation_id: checkout.id,
|
||||
user_id: checkout.user_id,
|
||||
stripe_session_id: session_id.to_string(),
|
||||
payment_intent_id: payment_intent_id.to_string(),
|
||||
paid_amount_pence: amount_total,
|
||||
referral_invite_id: checkout.referral_invite_id,
|
||||
};
|
||||
|
||||
if already_completed {
|
||||
Ok(CheckoutCompletion::AlreadyHandled(verified))
|
||||
} else {
|
||||
Ok(CheckoutCompletion::Grant(verified))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn complete_verified_checkout(
|
||||
state: &AppState,
|
||||
checkout: &VerifiedCheckout,
|
||||
) -> anyhow::Result<()> {
|
||||
let _guard = CHECKOUT_RESERVATION_LOCK.lock().await;
|
||||
let pricing_lock = acquire_pocketbase_lock(
|
||||
state,
|
||||
CHECKOUT_PRICING_LOCK_NAME,
|
||||
CHECKOUT_PRICING_LOCK_TTL_SECS,
|
||||
)
|
||||
.await?;
|
||||
let result = complete_verified_checkout_locked(state, checkout).await;
|
||||
if let Err(err) = pricing_lock.release().await {
|
||||
warn!("Failed to release checkout pricing lock: {err}");
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
async fn complete_verified_checkout_locked(
|
||||
state: &AppState,
|
||||
checkout: &VerifiedCheckout,
|
||||
) -> anyhow::Result<()> {
|
||||
let live_checkout = find_checkout_by_stripe_session(state, &checkout.stripe_session_id)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("checkout reservation disappeared before completion"))?;
|
||||
|
||||
if live_checkout.status == "completed" {
|
||||
if !checkout.referral_invite_id.is_empty() {
|
||||
mark_referral_invite_used(
|
||||
state,
|
||||
&checkout.referral_invite_id,
|
||||
&checkout.user_id,
|
||||
&checkout.reservation_id,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
if live_checkout.id != checkout.reservation_id
|
||||
|| live_checkout.user_id != checkout.user_id
|
||||
|| live_checkout.referral_invite_id != checkout.referral_invite_id
|
||||
{
|
||||
mark_checkout_status(state, &checkout.reservation_id, "invalid").await?;
|
||||
return Err(anyhow!("checkout reservation changed before completion"));
|
||||
}
|
||||
if live_checkout.status != "pending" && live_checkout.status != "expired" {
|
||||
return Err(anyhow!("checkout reservation is {}", live_checkout.status));
|
||||
}
|
||||
|
||||
grant_license(state, &checkout.user_id).await?;
|
||||
mark_checkout_completed(
|
||||
state,
|
||||
&checkout.reservation_id,
|
||||
checkout.paid_amount_pence,
|
||||
&checkout.payment_intent_id,
|
||||
)
|
||||
.await?;
|
||||
if !checkout.referral_invite_id.is_empty() {
|
||||
mark_referral_invite_used(
|
||||
state,
|
||||
&checkout.referral_invite_id,
|
||||
&checkout.user_id,
|
||||
&checkout.reservation_id,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn grant_license_with_pricing_lock(
|
||||
state: &AppState,
|
||||
user_id: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let _guard = CHECKOUT_RESERVATION_LOCK.lock().await;
|
||||
let pricing_lock = acquire_pocketbase_lock(
|
||||
state,
|
||||
CHECKOUT_PRICING_LOCK_NAME,
|
||||
CHECKOUT_PRICING_LOCK_TTL_SECS,
|
||||
)
|
||||
.await?;
|
||||
let result = grant_license(state, user_id).await;
|
||||
if let Err(err) = pricing_lock.release().await {
|
||||
warn!("Failed to release checkout pricing lock: {err}");
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
pub async fn grant_license(state: &AppState, user_id: &str) -> anyhow::Result<()> {
|
||||
set_user_subscription(state, user_id, "licensed").await
|
||||
}
|
||||
|
||||
pub async fn reverse_license_for_payment_intent(
|
||||
state: &AppState,
|
||||
payment_intent_id: &str,
|
||||
reason: &str,
|
||||
refunded_amount_pence: Option<u64>,
|
||||
) -> anyhow::Result<PaymentReversalOutcome> {
|
||||
if !is_safe_stripe_session_id(payment_intent_id) {
|
||||
return Err(anyhow!("invalid Stripe payment intent id"));
|
||||
}
|
||||
if !is_safe_reversal_reason(reason) {
|
||||
return Err(anyhow!("invalid Stripe reversal reason"));
|
||||
}
|
||||
|
||||
let _guard = CHECKOUT_RESERVATION_LOCK.lock().await;
|
||||
let checkout = match find_checkout_by_payment_intent_or_checkout_session(
|
||||
state,
|
||||
payment_intent_id,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
Some(checkout) => checkout,
|
||||
None => return Ok(PaymentReversalOutcome::NoMatchingCheckout),
|
||||
};
|
||||
|
||||
let paid_amount_pence = checkout
|
||||
.paid_amount_pence
|
||||
.max(checkout.expected_total_pence);
|
||||
if let Some(refunded_amount_pence) = refunded_amount_pence {
|
||||
if refunded_amount_pence < paid_amount_pence {
|
||||
return Ok(PaymentReversalOutcome::IgnoredPartialRefund {
|
||||
user_id: checkout.user_id,
|
||||
refunded_amount_pence,
|
||||
paid_amount_pence,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if checkout.status == "reversed" {
|
||||
return Ok(PaymentReversalOutcome::AlreadyHandled {
|
||||
user_id: checkout.user_id,
|
||||
});
|
||||
}
|
||||
|
||||
if matches!(checkout.status.as_str(), "pending" | "expired" | "failed") {
|
||||
mark_checkout_reversed(state, &checkout.id, reason, payment_intent_id).await?;
|
||||
return Ok(PaymentReversalOutcome::Applied {
|
||||
user_id: checkout.user_id,
|
||||
});
|
||||
}
|
||||
|
||||
if checkout.status != "completed" {
|
||||
return Ok(PaymentReversalOutcome::NotReversible {
|
||||
user_id: checkout.user_id,
|
||||
status: checkout.status,
|
||||
});
|
||||
}
|
||||
|
||||
let has_other_license = has_other_completed_checkout_for_user(
|
||||
state,
|
||||
&checkout.user_id,
|
||||
&checkout.id,
|
||||
payment_intent_id,
|
||||
)
|
||||
.await?;
|
||||
if !has_other_license {
|
||||
revoke_license(state, &checkout.user_id).await?;
|
||||
}
|
||||
mark_checkout_reversed(state, &checkout.id, reason, payment_intent_id).await?;
|
||||
|
||||
Ok(PaymentReversalOutcome::Applied {
|
||||
user_id: checkout.user_id,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn reinstate_license_for_payment_intent(
|
||||
state: &AppState,
|
||||
payment_intent_id: &str,
|
||||
reason: &str,
|
||||
) -> anyhow::Result<PaymentReinstatementOutcome> {
|
||||
if !is_safe_stripe_session_id(payment_intent_id) {
|
||||
return Err(anyhow!("invalid Stripe payment intent id"));
|
||||
}
|
||||
if !is_safe_reversal_reason(reason) {
|
||||
return Err(anyhow!("invalid Stripe reinstatement reason"));
|
||||
}
|
||||
|
||||
let _guard = CHECKOUT_RESERVATION_LOCK.lock().await;
|
||||
let checkout = match find_checkout_by_payment_intent_or_checkout_session(
|
||||
state,
|
||||
payment_intent_id,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
Some(checkout) => checkout,
|
||||
None => return Ok(PaymentReinstatementOutcome::NoMatchingCheckout),
|
||||
};
|
||||
|
||||
if checkout.status == "completed" {
|
||||
return Ok(PaymentReinstatementOutcome::AlreadyHandled {
|
||||
user_id: checkout.user_id,
|
||||
});
|
||||
}
|
||||
|
||||
if checkout.status != "reversed" {
|
||||
return Ok(PaymentReinstatementOutcome::Ignored {
|
||||
user_id: checkout.user_id,
|
||||
reason: format!("checkout status is {}", checkout.status),
|
||||
});
|
||||
}
|
||||
|
||||
if !checkout.reversal_reason.starts_with("charge.dispute.") {
|
||||
return Ok(PaymentReinstatementOutcome::Ignored {
|
||||
user_id: checkout.user_id,
|
||||
reason: format!("checkout was reversed by {}", checkout.reversal_reason),
|
||||
});
|
||||
}
|
||||
|
||||
grant_license(state, &checkout.user_id).await?;
|
||||
mark_checkout_reinstated(state, &checkout.id, reason).await?;
|
||||
|
||||
Ok(PaymentReinstatementOutcome::Applied {
|
||||
user_id: checkout.user_id,
|
||||
})
|
||||
}
|
||||
|
||||
async fn revoke_license(state: &AppState, user_id: &str) -> anyhow::Result<()> {
|
||||
set_user_subscription(state, user_id, "free").await
|
||||
}
|
||||
|
||||
async fn set_user_subscription(
|
||||
state: &AppState,
|
||||
user_id: &str,
|
||||
subscription: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({ "subscription": subscription }))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp)
|
||||
.await
|
||||
.context("PocketBase license update failed")?;
|
||||
|
||||
state.token_cache.invalidate_by_user_id(user_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn expected_total_for_checkout(
|
||||
amount_pence: u64,
|
||||
discount_coupon_id: Option<&str>,
|
||||
) -> u64 {
|
||||
if discount_coupon_id.is_some_and(|id| !id.is_empty()) {
|
||||
return ((amount_pence * (100 - REFERRAL_DISCOUNT_PERCENT)) / 100).max(1);
|
||||
}
|
||||
amount_pence
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn expected_total_for_referral_discount_rounds_down_like_stripe_amount_math() {
|
||||
assert_eq!(expected_total_for_checkout(999, Some("coupon_30")), 699);
|
||||
assert_eq!(expected_total_for_checkout(1, Some("coupon_30")), 1);
|
||||
assert_eq!(expected_total_for_checkout(999, None), 999);
|
||||
}
|
||||
}
|
||||
133
server-rs/src/checkout_sessions/mod.rs
Normal file
133
server-rs/src/checkout_sessions/mod.rs
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
//! Checkout sessions: Stripe-backed lifetime-license purchases reserved and
|
||||
//! recorded in the PocketBase `checkout_sessions` collection.
|
||||
//!
|
||||
//! Split by concern:
|
||||
//! - [`lifecycle`]: the session state machine (start, verify, complete,
|
||||
//! reverse, reinstate) and license granting
|
||||
//! - [`records`]: PocketBase `checkout_sessions` record handling
|
||||
//! - [`referral`]: referral invite reservation/consumption bookkeeping
|
||||
//! - [`stripe`]: Stripe API interaction (sessions, coupons, lookups)
|
||||
|
||||
mod lifecycle;
|
||||
mod records;
|
||||
mod referral;
|
||||
mod stripe;
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
pub use lifecycle::{
|
||||
complete_verified_checkout, grant_license_with_pricing_lock,
|
||||
reinstate_license_for_payment_intent, reverse_license_for_payment_intent,
|
||||
start_license_checkout, verify_checkout_completion,
|
||||
};
|
||||
pub use referral::active_referral_checkout_user;
|
||||
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use anyhow::anyhow;
|
||||
use serde_json::Value;
|
||||
|
||||
pub const CHECKOUT_CURRENCY: &str = "gbp";
|
||||
|
||||
const CHECKOUT_COLLECTION: &str = "checkout_sessions";
|
||||
const REFERRAL_DISCOUNT_PERCENT: u64 = 30;
|
||||
|
||||
pub enum CheckoutStart {
|
||||
Free,
|
||||
Stripe { url: String },
|
||||
}
|
||||
|
||||
pub enum CheckoutCompletion {
|
||||
Grant(VerifiedCheckout),
|
||||
AlreadyHandled(VerifiedCheckout),
|
||||
Rejected(String),
|
||||
}
|
||||
|
||||
pub enum PaymentReversalOutcome {
|
||||
Applied {
|
||||
user_id: String,
|
||||
},
|
||||
AlreadyHandled {
|
||||
user_id: String,
|
||||
},
|
||||
IgnoredPartialRefund {
|
||||
user_id: String,
|
||||
refunded_amount_pence: u64,
|
||||
paid_amount_pence: u64,
|
||||
},
|
||||
NoMatchingCheckout,
|
||||
NotReversible {
|
||||
user_id: String,
|
||||
status: String,
|
||||
},
|
||||
}
|
||||
|
||||
pub enum PaymentReinstatementOutcome {
|
||||
Applied { user_id: String },
|
||||
AlreadyHandled { user_id: String },
|
||||
Ignored { user_id: String, reason: String },
|
||||
NoMatchingCheckout,
|
||||
}
|
||||
|
||||
pub struct VerifiedCheckout {
|
||||
pub reservation_id: String,
|
||||
pub user_id: String,
|
||||
pub stripe_session_id: String,
|
||||
pub payment_intent_id: String,
|
||||
pub paid_amount_pence: u64,
|
||||
pub referral_invite_id: String,
|
||||
}
|
||||
|
||||
pub fn now_unix_secs() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
fn number_field(value: &Value, field: &str) -> Option<u64> {
|
||||
value[field].as_u64().or_else(|| {
|
||||
value[field]
|
||||
.as_f64()
|
||||
.filter(|n| n.is_finite() && *n >= 0.0 && n.fract() == 0.0)
|
||||
.map(|n| n as u64)
|
||||
})
|
||||
}
|
||||
|
||||
fn is_safe_stripe_session_id(id: &str) -> bool {
|
||||
!id.is_empty()
|
||||
&& id.len() <= 128
|
||||
&& id
|
||||
.bytes()
|
||||
.all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'-')
|
||||
}
|
||||
|
||||
fn is_safe_pocketbase_id(id: &str) -> bool {
|
||||
!id.is_empty() && id.len() <= 32 && id.bytes().all(|b| b.is_ascii_alphanumeric())
|
||||
}
|
||||
|
||||
fn is_safe_reversal_reason(reason: &str) -> bool {
|
||||
!reason.is_empty()
|
||||
&& reason.len() <= 128
|
||||
&& reason
|
||||
.bytes()
|
||||
.all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'-' || b == b'.')
|
||||
}
|
||||
|
||||
async fn ensure_success(resp: reqwest::Response) -> anyhow::Result<()> {
|
||||
if resp.status().is_success() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
Err(anyhow!("upstream returned {status}: {text}"))
|
||||
}
|
||||
|
||||
async fn ensure_success_ref(resp: &reqwest::Response) -> anyhow::Result<()> {
|
||||
if resp.status().is_success() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
Err(anyhow!("upstream returned {}", resp.status()))
|
||||
}
|
||||
564
server-rs/src/checkout_sessions/records.rs
Normal file
564
server-rs/src/checkout_sessions/records.rs
Normal file
|
|
@ -0,0 +1,564 @@
|
|||
//! PocketBase `checkout_sessions` record handling: creating reservations,
|
||||
//! status transitions, and lookups by Stripe session / payment intent.
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use serde_json::Value;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::pocketbase::get_superuser_token;
|
||||
use crate::state::AppState;
|
||||
|
||||
use super::referral::release_referral_invite_reservation;
|
||||
use super::stripe::fetch_stripe_checkout_session_id_for_payment_intent;
|
||||
use super::{
|
||||
ensure_success, ensure_success_ref, is_safe_pocketbase_id, is_safe_stripe_session_id,
|
||||
now_unix_secs, number_field, CHECKOUT_COLLECTION,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(super) struct PendingCheckout {
|
||||
pub(super) id: String,
|
||||
pub(super) user_id: String,
|
||||
pub(super) stripe_session_id: String,
|
||||
pub(super) checkout_url: String,
|
||||
pub(super) amount_pence: u64,
|
||||
pub(super) expected_total_pence: u64,
|
||||
pub(super) currency: String,
|
||||
pub(super) referral_invite_id: String,
|
||||
pub(super) status: String,
|
||||
pub(super) payment_intent_id: String,
|
||||
pub(super) paid_amount_pence: u64,
|
||||
pub(super) reversal_reason: String,
|
||||
}
|
||||
|
||||
pub async fn mark_checkout_completed(
|
||||
state: &AppState,
|
||||
reservation_id: &str,
|
||||
paid_amount_pence: u64,
|
||||
payment_intent_id: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
if !is_safe_stripe_session_id(payment_intent_id) {
|
||||
return Err(anyhow!("invalid Stripe payment intent id"));
|
||||
}
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"status": "completed",
|
||||
"paid_amount_pence": paid_amount_pence,
|
||||
"completed_at_unix": now_unix_secs().to_string(),
|
||||
"stripe_payment_intent_id": payment_intent_id,
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp)
|
||||
.await
|
||||
.context("PocketBase checkout completion update failed")
|
||||
}
|
||||
|
||||
pub(super) async fn count_active_pending_checkouts(
|
||||
state: &AppState,
|
||||
now: u64,
|
||||
) -> anyhow::Result<u64> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let filter = format!("status=\"pending\" && expires_at_unix>={now}");
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
let body: Value = resp.json().await?;
|
||||
Ok(body["totalItems"].as_u64().unwrap_or(0))
|
||||
}
|
||||
|
||||
pub(super) async fn find_active_checkout_for_user(
|
||||
state: &AppState,
|
||||
user_id: &str,
|
||||
discount_coupon_id: &str,
|
||||
referral_invite_id: &str,
|
||||
now: u64,
|
||||
) -> anyhow::Result<Option<PendingCheckout>> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let filter = active_checkout_filter(user_id, discount_coupon_id, referral_invite_id, now)?;
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
let body: Value = resp.json().await?;
|
||||
let item = body["items"]
|
||||
.as_array()
|
||||
.and_then(|items| items.first())
|
||||
.cloned();
|
||||
|
||||
item.map(parse_pending_checkout).transpose()
|
||||
}
|
||||
|
||||
fn active_checkout_filter(
|
||||
user_id: &str,
|
||||
discount_coupon_id: &str,
|
||||
referral_invite_id: &str,
|
||||
now: u64,
|
||||
) -> anyhow::Result<String> {
|
||||
if !is_safe_pocketbase_id(user_id) {
|
||||
return Err(anyhow!("invalid PocketBase user id"));
|
||||
}
|
||||
if !discount_coupon_id.is_empty() && !is_safe_stripe_session_id(discount_coupon_id) {
|
||||
return Err(anyhow!("invalid Stripe coupon id"));
|
||||
}
|
||||
if !referral_invite_id.is_empty() && !is_safe_pocketbase_id(referral_invite_id) {
|
||||
return Err(anyhow!("invalid PocketBase referral invite id"));
|
||||
}
|
||||
|
||||
Ok(format!(
|
||||
"status=\"pending\" && expires_at_unix>={now} && user=\"{user_id}\" && discount_coupon_id=\"{discount_coupon_id}\" && referral_invite_id=\"{referral_invite_id}\""
|
||||
))
|
||||
}
|
||||
|
||||
pub(super) async fn expire_stale_pending_checkouts(
|
||||
state: &AppState,
|
||||
now: u64,
|
||||
) -> anyhow::Result<()> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let filter = format!("status=\"pending\" && expires_at_unix<{now}");
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=50",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
let body: Value = resp.json().await?;
|
||||
let Some(items) = body["items"].as_array() else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
for item in items {
|
||||
let Some(id) = item["id"].as_str() else {
|
||||
continue;
|
||||
};
|
||||
if let Err(err) = mark_checkout_status(state, id, "expired").await {
|
||||
warn!(
|
||||
reservation_id = id,
|
||||
"Failed to expire checkout reservation: {err}"
|
||||
);
|
||||
}
|
||||
if let Some(invite_id) = item["referral_invite_id"]
|
||||
.as_str()
|
||||
.filter(|invite_id| !invite_id.is_empty())
|
||||
{
|
||||
if let Err(err) = release_referral_invite_reservation(state, invite_id, id).await {
|
||||
warn!(
|
||||
reservation_id = id,
|
||||
referral_invite_id = invite_id,
|
||||
"Failed to release expired referral invite reservation: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) struct PendingCheckoutInput<'a> {
|
||||
pub(super) user_id: &'a str,
|
||||
pub(super) amount_pence: u64,
|
||||
pub(super) expected_total_pence: u64,
|
||||
pub(super) currency: &'a str,
|
||||
pub(super) discount_coupon_id: &'a str,
|
||||
pub(super) referral_invite_id: &'a str,
|
||||
pub(super) expires_at_unix: u64,
|
||||
}
|
||||
|
||||
pub(super) async fn create_pending_checkout(
|
||||
state: &AppState,
|
||||
input: PendingCheckoutInput<'_>,
|
||||
) -> anyhow::Result<String> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records");
|
||||
let resp = state
|
||||
.http_client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"user": input.user_id,
|
||||
"stripe_session_id": "",
|
||||
"stripe_payment_intent_id": "",
|
||||
"checkout_url": "",
|
||||
"amount_pence": input.amount_pence,
|
||||
"expected_total_pence": input.expected_total_pence,
|
||||
"currency": input.currency,
|
||||
"discount_coupon_id": input.discount_coupon_id,
|
||||
"referral_invite_id": input.referral_invite_id,
|
||||
"status": "pending",
|
||||
"expires_at_unix": input.expires_at_unix,
|
||||
"paid_amount_pence": 0,
|
||||
"completed_at_unix": "",
|
||||
"reversal_reason": "",
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
let body: Value = resp.json().await?;
|
||||
body["id"]
|
||||
.as_str()
|
||||
.map(str::to_string)
|
||||
.ok_or_else(|| anyhow!("PocketBase checkout reservation missing id"))
|
||||
}
|
||||
|
||||
pub(super) async fn attach_stripe_session(
|
||||
state: &AppState,
|
||||
reservation_id: &str,
|
||||
stripe_session_id: &str,
|
||||
checkout_url: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"stripe_session_id": stripe_session_id,
|
||||
"checkout_url": checkout_url,
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp)
|
||||
.await
|
||||
.context("PocketBase checkout session attach failed")
|
||||
}
|
||||
|
||||
pub(super) async fn mark_checkout_status(
|
||||
state: &AppState,
|
||||
reservation_id: &str,
|
||||
status: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({ "status": status }))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp)
|
||||
.await
|
||||
.with_context(|| format!("PocketBase checkout status update failed for {reservation_id}"))
|
||||
}
|
||||
|
||||
pub(super) async fn mark_checkout_reversed(
|
||||
state: &AppState,
|
||||
reservation_id: &str,
|
||||
reason: &str,
|
||||
payment_intent_id: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"status": "reversed",
|
||||
"reversal_reason": reason,
|
||||
"stripe_payment_intent_id": payment_intent_id,
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp)
|
||||
.await
|
||||
.with_context(|| format!("PocketBase checkout reversal update failed for {reservation_id}"))
|
||||
}
|
||||
|
||||
pub(super) async fn mark_checkout_reinstated(
|
||||
state: &AppState,
|
||||
reservation_id: &str,
|
||||
_reason: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"status": "completed",
|
||||
"reversal_reason": "",
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp).await.with_context(|| {
|
||||
format!("PocketBase checkout reinstatement update failed for {reservation_id}")
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) async fn find_checkout_by_stripe_session(
|
||||
state: &AppState,
|
||||
stripe_session_id: &str,
|
||||
) -> anyhow::Result<Option<PendingCheckout>> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let filter = format!("stripe_session_id=\"{}\"", stripe_session_id);
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
let body: Value = resp.json().await?;
|
||||
let item = body["items"]
|
||||
.as_array()
|
||||
.and_then(|items| items.first())
|
||||
.cloned();
|
||||
|
||||
item.map(parse_pending_checkout).transpose()
|
||||
}
|
||||
|
||||
async fn find_checkout_by_payment_intent(
|
||||
state: &AppState,
|
||||
payment_intent_id: &str,
|
||||
) -> anyhow::Result<Option<PendingCheckout>> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let filter = format!("stripe_payment_intent_id=\"{}\"", payment_intent_id);
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
let body: Value = resp.json().await?;
|
||||
let item = body["items"]
|
||||
.as_array()
|
||||
.and_then(|items| items.first())
|
||||
.cloned();
|
||||
|
||||
item.map(parse_pending_checkout).transpose()
|
||||
}
|
||||
|
||||
pub(super) async fn find_checkout_by_payment_intent_or_checkout_session(
|
||||
state: &AppState,
|
||||
payment_intent_id: &str,
|
||||
) -> anyhow::Result<Option<PendingCheckout>> {
|
||||
if let Some(checkout) = find_checkout_by_payment_intent(state, payment_intent_id).await? {
|
||||
return Ok(Some(checkout));
|
||||
}
|
||||
|
||||
let Some(session_id) =
|
||||
fetch_stripe_checkout_session_id_for_payment_intent(state, payment_intent_id).await?
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let Some(mut checkout) = find_checkout_by_stripe_session(state, &session_id).await? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
if checkout.payment_intent_id.is_empty() {
|
||||
attach_payment_intent_to_checkout(state, &checkout.id, payment_intent_id).await?;
|
||||
checkout.payment_intent_id = payment_intent_id.to_string();
|
||||
} else if checkout.payment_intent_id != payment_intent_id {
|
||||
mark_checkout_status(state, &checkout.id, "invalid").await?;
|
||||
return Err(anyhow!(
|
||||
"checkout reservation payment intent changed before reversal"
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Some(checkout))
|
||||
}
|
||||
|
||||
async fn attach_payment_intent_to_checkout(
|
||||
state: &AppState,
|
||||
reservation_id: &str,
|
||||
payment_intent_id: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records/{reservation_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"stripe_payment_intent_id": payment_intent_id,
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp)
|
||||
.await
|
||||
.context("PocketBase checkout payment intent attach failed")
|
||||
}
|
||||
|
||||
pub(super) async fn has_other_completed_checkout_for_user(
|
||||
state: &AppState,
|
||||
user_id: &str,
|
||||
reservation_id: &str,
|
||||
payment_intent_id: &str,
|
||||
) -> anyhow::Result<bool> {
|
||||
if !is_safe_pocketbase_id(user_id) || !is_safe_pocketbase_id(reservation_id) {
|
||||
return Err(anyhow!("invalid PocketBase id"));
|
||||
}
|
||||
if !is_safe_stripe_session_id(payment_intent_id) {
|
||||
return Err(anyhow!("invalid Stripe payment intent id"));
|
||||
}
|
||||
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let filter = format!("user=\"{user_id}\" && status=\"completed\"");
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=50",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
let body: Value = resp.json().await?;
|
||||
let Some(items) = body["items"].as_array() else {
|
||||
return Ok(false);
|
||||
};
|
||||
|
||||
Ok(items.iter().any(|item| {
|
||||
let other_id = item["id"].as_str().unwrap_or_default();
|
||||
let other_payment_intent = item["stripe_payment_intent_id"]
|
||||
.as_str()
|
||||
.unwrap_or_default();
|
||||
other_id != reservation_id && other_payment_intent != payment_intent_id
|
||||
}))
|
||||
}
|
||||
|
||||
fn parse_pending_checkout(item: Value) -> anyhow::Result<PendingCheckout> {
|
||||
Ok(PendingCheckout {
|
||||
id: item["id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow!("checkout reservation missing id"))?
|
||||
.to_string(),
|
||||
user_id: item["user"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow!("checkout reservation missing user"))?
|
||||
.to_string(),
|
||||
stripe_session_id: item["stripe_session_id"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
checkout_url: item["checkout_url"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
amount_pence: number_field(&item, "amount_pence")
|
||||
.ok_or_else(|| anyhow!("checkout reservation missing amount_pence"))?,
|
||||
expected_total_pence: number_field(&item, "expected_total_pence")
|
||||
.ok_or_else(|| anyhow!("checkout reservation missing expected_total_pence"))?,
|
||||
currency: item["currency"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_ascii_lowercase(),
|
||||
referral_invite_id: item["referral_invite_id"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
status: item["status"].as_str().unwrap_or_default().to_string(),
|
||||
payment_intent_id: item["stripe_payment_intent_id"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
paid_amount_pence: number_field(&item, "paid_amount_pence").unwrap_or(0),
|
||||
reversal_reason: item["reversal_reason"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn active_checkout_filter_includes_empty_context_for_standard_checkout() {
|
||||
let filter = active_checkout_filter("abc123", "", "", 42).unwrap();
|
||||
assert_eq!(
|
||||
filter,
|
||||
"status=\"pending\" && expires_at_unix>=42 && user=\"abc123\" && discount_coupon_id=\"\" && referral_invite_id=\"\""
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn active_checkout_filter_includes_referral_context() {
|
||||
let filter = active_checkout_filter("user123", "coupon_30", "invite123", 99).unwrap();
|
||||
assert_eq!(
|
||||
filter,
|
||||
"status=\"pending\" && expires_at_unix>=99 && user=\"user123\" && discount_coupon_id=\"coupon_30\" && referral_invite_id=\"invite123\""
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn active_checkout_filter_rejects_unsafe_context_values() {
|
||||
assert!(active_checkout_filter("user123", "bad\"coupon", "", 1).is_err());
|
||||
assert!(active_checkout_filter("user123", "", "bad-invite", 1).is_err());
|
||||
assert!(active_checkout_filter("bad-user", "", "", 1).is_err());
|
||||
}
|
||||
}
|
||||
312
server-rs/src/checkout_sessions/referral.rs
Normal file
312
server-rs/src/checkout_sessions/referral.rs
Normal file
|
|
@ -0,0 +1,312 @@
|
|||
//! Referral invite bookkeeping: reserving an invite for an in-flight checkout,
|
||||
//! releasing the reservation on failure/expiry, and recording final usage when
|
||||
//! a verified payment completes.
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use serde_json::Value;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::pocketbase::get_superuser_token;
|
||||
use crate::state::AppState;
|
||||
|
||||
use super::{
|
||||
ensure_success, ensure_success_ref, is_safe_pocketbase_id, now_unix_secs, number_field,
|
||||
CHECKOUT_COLLECTION,
|
||||
};
|
||||
|
||||
pub async fn mark_referral_invite_used(
|
||||
state: &AppState,
|
||||
invite_id: &str,
|
||||
user_id: &str,
|
||||
reservation_id: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
if invite_id.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
if !is_safe_pocketbase_id(invite_id)
|
||||
|| !is_safe_pocketbase_id(user_id)
|
||||
|| !is_safe_pocketbase_id(reservation_id)
|
||||
{
|
||||
return Err(anyhow!("invalid PocketBase id"));
|
||||
}
|
||||
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let invite = fetch_invite_record(state, pb_url, &token, invite_id).await?;
|
||||
|
||||
// A verified Stripe payment must not lose entitlement just because local
|
||||
// invite reservation bookkeeping expired or moved before webhook delivery.
|
||||
match referral_invite_completion_action(&invite, user_id, reservation_id) {
|
||||
ReferralInviteCompletionAction::AlreadyRecorded => return Ok(()),
|
||||
ReferralInviteCompletionAction::AlreadyUsedByAnother => {
|
||||
warn!(
|
||||
invite_id,
|
||||
user_id,
|
||||
existing_used_by = invite["used_by_id"].as_str().unwrap_or_default(),
|
||||
"Referral invite was already used by another account; preserving verified checkout entitlement"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
ReferralInviteCompletionAction::Record {
|
||||
reservation_reassigned,
|
||||
} => {
|
||||
if reservation_reassigned {
|
||||
warn!(
|
||||
invite_id,
|
||||
user_id,
|
||||
reservation_id,
|
||||
reserved_by_id = invite["reserved_by_id"].as_str().unwrap_or_default(),
|
||||
reserved_checkout_id = invite["reserved_checkout_id"].as_str().unwrap_or_default(),
|
||||
"Referral invite reservation moved before webhook completion; verified checkout will consume it"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let url = format!("{pb_url}/api/collections/invites/records/{invite_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"used_by_id": user_id,
|
||||
"used_at": now_unix_secs().to_string(),
|
||||
"reserved_by_id": "",
|
||||
"reserved_checkout_id": "",
|
||||
"reserved_until_unix": 0,
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp)
|
||||
.await
|
||||
.context("PocketBase invite usage update failed")
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
enum ReferralInviteCompletionAction {
|
||||
AlreadyRecorded,
|
||||
AlreadyUsedByAnother,
|
||||
Record { reservation_reassigned: bool },
|
||||
}
|
||||
|
||||
fn referral_invite_completion_action(
|
||||
invite: &Value,
|
||||
user_id: &str,
|
||||
reservation_id: &str,
|
||||
) -> ReferralInviteCompletionAction {
|
||||
let existing_used_by = invite["used_by_id"].as_str().unwrap_or_default();
|
||||
if existing_used_by == user_id {
|
||||
return ReferralInviteCompletionAction::AlreadyRecorded;
|
||||
}
|
||||
if !existing_used_by.is_empty() {
|
||||
return ReferralInviteCompletionAction::AlreadyUsedByAnother;
|
||||
}
|
||||
|
||||
let reserved_by_id = invite["reserved_by_id"].as_str().unwrap_or_default();
|
||||
let reserved_checkout_id = invite["reserved_checkout_id"].as_str().unwrap_or_default();
|
||||
let reservation_reassigned = (!reserved_by_id.is_empty() && reserved_by_id != user_id)
|
||||
|| (!reserved_checkout_id.is_empty() && reserved_checkout_id != reservation_id);
|
||||
|
||||
ReferralInviteCompletionAction::Record {
|
||||
reservation_reassigned,
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch_invite_record(
|
||||
state: &AppState,
|
||||
pb_url: &str,
|
||||
token: &str,
|
||||
invite_id: &str,
|
||||
) -> anyhow::Result<Value> {
|
||||
let url = format!("{pb_url}/api/collections/invites/records/{invite_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
resp.json().await.map_err(Into::into)
|
||||
}
|
||||
|
||||
pub(super) async fn reserve_referral_invite(
|
||||
state: &AppState,
|
||||
invite_id: &str,
|
||||
user_id: &str,
|
||||
reservation_id: &str,
|
||||
reserved_until_unix: u64,
|
||||
) -> anyhow::Result<()> {
|
||||
if !is_safe_pocketbase_id(invite_id)
|
||||
|| !is_safe_pocketbase_id(user_id)
|
||||
|| !is_safe_pocketbase_id(reservation_id)
|
||||
{
|
||||
return Err(anyhow!("invalid PocketBase id"));
|
||||
}
|
||||
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let invite = fetch_invite_record(state, pb_url, &token, invite_id).await?;
|
||||
let used_by = invite["used_by_id"].as_str().unwrap_or_default();
|
||||
if !used_by.is_empty() {
|
||||
return Err(anyhow!("referral invite already used"));
|
||||
}
|
||||
|
||||
let now = now_unix_secs();
|
||||
let reserved_by_id = invite["reserved_by_id"].as_str().unwrap_or_default();
|
||||
let reserved_checkout_id = invite["reserved_checkout_id"].as_str().unwrap_or_default();
|
||||
let existing_reserved_until = number_field(&invite, "reserved_until_unix").unwrap_or(0);
|
||||
let reservation_is_live = existing_reserved_until >= now;
|
||||
if reservation_is_live
|
||||
&& !reserved_checkout_id.is_empty()
|
||||
&& reserved_checkout_id != reservation_id
|
||||
{
|
||||
return Err(anyhow!("referral invite already has an active checkout"));
|
||||
}
|
||||
if reservation_is_live && !reserved_by_id.is_empty() && reserved_by_id != user_id {
|
||||
return Err(anyhow!("referral invite reserved by another account"));
|
||||
}
|
||||
|
||||
let url = format!("{pb_url}/api/collections/invites/records/{invite_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"reserved_by_id": user_id,
|
||||
"reserved_checkout_id": reservation_id,
|
||||
"reserved_until_unix": reserved_until_unix,
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp)
|
||||
.await
|
||||
.context("PocketBase invite reservation update failed")
|
||||
}
|
||||
|
||||
pub(super) async fn release_referral_invite_reservation(
|
||||
state: &AppState,
|
||||
invite_id: &str,
|
||||
reservation_id: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
if !is_safe_pocketbase_id(invite_id) || !is_safe_pocketbase_id(reservation_id) {
|
||||
return Err(anyhow!("invalid PocketBase id"));
|
||||
}
|
||||
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let invite = fetch_invite_record(state, pb_url, &token, invite_id).await?;
|
||||
let used_by = invite["used_by_id"].as_str().unwrap_or_default();
|
||||
let reserved_checkout_id = invite["reserved_checkout_id"].as_str().unwrap_or_default();
|
||||
if !used_by.is_empty() || reserved_checkout_id != reservation_id {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let url = format!("{pb_url}/api/collections/invites/records/{invite_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"reserved_by_id": "",
|
||||
"reserved_checkout_id": "",
|
||||
"reserved_until_unix": 0,
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success(resp)
|
||||
.await
|
||||
.context("PocketBase invite reservation release failed")
|
||||
}
|
||||
|
||||
pub async fn active_referral_checkout_user(
|
||||
state: &AppState,
|
||||
invite_id: &str,
|
||||
) -> anyhow::Result<Option<String>> {
|
||||
if !is_safe_pocketbase_id(invite_id) {
|
||||
return Err(anyhow!("invalid PocketBase invite id"));
|
||||
}
|
||||
|
||||
let now = now_unix_secs();
|
||||
let token = get_superuser_token(state).await?;
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let filter = format!(
|
||||
"status=\"pending\" && expires_at_unix>={now} && referral_invite_id=\"{}\"",
|
||||
invite_id
|
||||
);
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/{CHECKOUT_COLLECTION}/records?filter={}&perPage=1",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
ensure_success_ref(&resp).await?;
|
||||
|
||||
let body: Value = resp.json().await?;
|
||||
Ok(body["items"]
|
||||
.as_array()
|
||||
.and_then(|items| items.first())
|
||||
.and_then(|item| item["user"].as_str())
|
||||
.map(str::to_string))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn referral_invite_completion_records_available_invite() {
|
||||
let invite = serde_json::json!({
|
||||
"used_by_id": "",
|
||||
"reserved_by_id": "",
|
||||
"reserved_checkout_id": "",
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
referral_invite_completion_action(&invite, "user123", "checkout123"),
|
||||
ReferralInviteCompletionAction::Record {
|
||||
reservation_reassigned: false
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn referral_invite_completion_records_reassigned_reservation() {
|
||||
let invite = serde_json::json!({
|
||||
"used_by_id": "",
|
||||
"reserved_by_id": "otheruser",
|
||||
"reserved_checkout_id": "othercheckout",
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
referral_invite_completion_action(&invite, "user123", "checkout123"),
|
||||
ReferralInviteCompletionAction::Record {
|
||||
reservation_reassigned: true
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn referral_invite_completion_detects_existing_usage() {
|
||||
let used_by_same_user = serde_json::json!({ "used_by_id": "user123" });
|
||||
let used_by_another_user = serde_json::json!({ "used_by_id": "otheruser" });
|
||||
|
||||
assert_eq!(
|
||||
referral_invite_completion_action(&used_by_same_user, "user123", "checkout123"),
|
||||
ReferralInviteCompletionAction::AlreadyRecorded
|
||||
);
|
||||
assert_eq!(
|
||||
referral_invite_completion_action(&used_by_another_user, "user123", "checkout123"),
|
||||
ReferralInviteCompletionAction::AlreadyUsedByAnother
|
||||
);
|
||||
}
|
||||
}
|
||||
175
server-rs/src/checkout_sessions/stripe.rs
Normal file
175
server-rs/src/checkout_sessions/stripe.rs
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
//! Stripe API interaction: creating checkout sessions, verifying coupon
|
||||
//! configuration, and looking up sessions by payment intent.
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::auth::PocketBaseUser;
|
||||
use crate::state::AppState;
|
||||
|
||||
use super::lifecycle::expected_total_for_checkout;
|
||||
use super::{
|
||||
ensure_success_ref, is_safe_stripe_session_id, CHECKOUT_CURRENCY, REFERRAL_DISCOUNT_PERCENT,
|
||||
};
|
||||
|
||||
const CHECKOUT_PRODUCT_NAME: &str = "Perfect Postcodes Lifetime License";
|
||||
|
||||
/// Fetch a Stripe coupon and ensure its `percent_off` matches the expected
|
||||
/// referral discount AND that it has no `amount_off` override. This blocks a
|
||||
/// misconfigured (or maliciously swapped) coupon ID from quietly granting a
|
||||
/// larger discount than the server's pricing math assumed.
|
||||
async fn verify_stripe_coupon_discount(state: &AppState, coupon_id: &str) -> anyhow::Result<()> {
|
||||
if !is_safe_stripe_session_id(coupon_id) {
|
||||
return Err(anyhow!("unsafe stripe coupon id"));
|
||||
}
|
||||
let url = format!(
|
||||
"https://api.stripe.com/v1/coupons/{}",
|
||||
urlencoding::encode(coupon_id)
|
||||
);
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.basic_auth(&state.stripe_secret_key, None::<&str>)
|
||||
.send()
|
||||
.await
|
||||
.context("Stripe coupon fetch failed")?;
|
||||
ensure_success_ref(&resp)
|
||||
.await
|
||||
.context("Stripe coupon fetch returned error")?;
|
||||
let body: Value = resp
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Stripe coupon response")?;
|
||||
|
||||
let valid = body["valid"].as_bool().unwrap_or(false);
|
||||
if !valid {
|
||||
return Err(anyhow!("stripe coupon is not valid"));
|
||||
}
|
||||
if body["amount_off"].is_number() {
|
||||
return Err(anyhow!(
|
||||
"stripe coupon uses amount_off; only percent_off is permitted"
|
||||
));
|
||||
}
|
||||
let percent_off = body["percent_off"]
|
||||
.as_f64()
|
||||
.ok_or_else(|| anyhow!("stripe coupon missing percent_off"))?;
|
||||
if percent_off.is_nan() || (percent_off - REFERRAL_DISCOUNT_PERCENT as f64).abs() > 0.001 {
|
||||
return Err(anyhow!(
|
||||
"stripe coupon percent_off ({percent_off}) does not match expected {REFERRAL_DISCOUNT_PERCENT}"
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(super) async fn create_stripe_session(
|
||||
state: &AppState,
|
||||
user: &PocketBaseUser,
|
||||
reservation_id: &str,
|
||||
price_pence: u64,
|
||||
success_url: &str,
|
||||
cancel_url: &str,
|
||||
expires_at_unix: u64,
|
||||
discount_coupon_id: Option<&str>,
|
||||
) -> anyhow::Result<(String, String)> {
|
||||
if let Some(coupon_id) = discount_coupon_id.filter(|id| !id.is_empty()) {
|
||||
verify_stripe_coupon_discount(state, coupon_id).await?;
|
||||
}
|
||||
|
||||
let mut form_params = vec![
|
||||
("mode", "payment".to_string()),
|
||||
("payment_method_types[0]", "card".to_string()),
|
||||
(
|
||||
"line_items[0][price_data][unit_amount]",
|
||||
price_pence.to_string(),
|
||||
),
|
||||
(
|
||||
"line_items[0][price_data][currency]",
|
||||
CHECKOUT_CURRENCY.to_string(),
|
||||
),
|
||||
(
|
||||
"line_items[0][price_data][product_data][name]",
|
||||
CHECKOUT_PRODUCT_NAME.to_string(),
|
||||
),
|
||||
("line_items[0][quantity]", "1".to_string()),
|
||||
("success_url", success_url.to_string()),
|
||||
("cancel_url", cancel_url.to_string()),
|
||||
("expires_at", expires_at_unix.to_string()),
|
||||
("client_reference_id", user.id.clone()),
|
||||
("customer_email", user.email.clone()),
|
||||
("metadata[pending_checkout_id]", reservation_id.to_string()),
|
||||
("metadata[expected_amount_pence]", price_pence.to_string()),
|
||||
(
|
||||
"metadata[expected_total_pence]",
|
||||
expected_total_for_checkout(price_pence, discount_coupon_id).to_string(),
|
||||
),
|
||||
("metadata[expected_currency]", CHECKOUT_CURRENCY.to_string()),
|
||||
];
|
||||
|
||||
if let Some(coupon_id) = discount_coupon_id.filter(|id| !id.is_empty()) {
|
||||
form_params.push(("discounts[0][coupon]", coupon_id.to_string()));
|
||||
form_params.push(("metadata[discount_coupon_id]", coupon_id.to_string()));
|
||||
}
|
||||
|
||||
let resp = state
|
||||
.http_client
|
||||
.post("https://api.stripe.com/v1/checkout/sessions")
|
||||
.basic_auth(&state.stripe_secret_key, None::<&str>)
|
||||
.form(&form_params)
|
||||
.send()
|
||||
.await
|
||||
.context("Stripe checkout request failed")?;
|
||||
|
||||
ensure_success_ref(&resp)
|
||||
.await
|
||||
.context("Stripe checkout failed")?;
|
||||
|
||||
let body: Value = resp
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Stripe response")?;
|
||||
let session_id = body["id"]
|
||||
.as_str()
|
||||
.filter(|id| is_safe_stripe_session_id(id))
|
||||
.map(str::to_string)
|
||||
.ok_or_else(|| anyhow!("Stripe session missing valid id"))?;
|
||||
let url = body["url"]
|
||||
.as_str()
|
||||
.map(str::to_string)
|
||||
.filter(|url| !url.is_empty())
|
||||
.ok_or_else(|| anyhow!("Stripe session missing URL"))?;
|
||||
|
||||
Ok((session_id, url))
|
||||
}
|
||||
|
||||
pub(super) async fn fetch_stripe_checkout_session_id_for_payment_intent(
|
||||
state: &AppState,
|
||||
payment_intent_id: &str,
|
||||
) -> anyhow::Result<Option<String>> {
|
||||
let url = format!(
|
||||
"https://api.stripe.com/v1/checkout/sessions?payment_intent={}&limit=1",
|
||||
urlencoding::encode(payment_intent_id)
|
||||
);
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.basic_auth(&state.stripe_secret_key, None::<&str>)
|
||||
.send()
|
||||
.await
|
||||
.context("Stripe checkout session lookup failed")?;
|
||||
|
||||
ensure_success_ref(&resp)
|
||||
.await
|
||||
.context("Stripe checkout session lookup returned error")?;
|
||||
|
||||
let body: Value = resp
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Stripe checkout session lookup")?;
|
||||
Ok(body["data"]
|
||||
.as_array()
|
||||
.and_then(|items| items.first())
|
||||
.and_then(|item| item["id"].as_str())
|
||||
.filter(|id| is_safe_stripe_session_id(id))
|
||||
.map(str::to_string))
|
||||
}
|
||||
688
server-rs/src/checkout_sessions/tests.rs
Normal file
688
server-rs/src/checkout_sessions/tests.rs
Normal file
|
|
@ -0,0 +1,688 @@
|
|||
//! Integration-style tests for the money paths: Stripe webhook verification →
|
||||
//! license granting, checkout reservation bookkeeping, and invite redemption.
|
||||
//!
|
||||
//! PocketBase (an external HTTP service in production) is replaced by an
|
||||
//! in-process axum mock listening on an ephemeral local port. The mock keeps
|
||||
//! records in memory, evaluates the small PocketBase filter subset the server
|
||||
//! actually uses, and records every mutating request so tests can assert that
|
||||
//! e.g. a replayed webhook does not grant a license twice.
|
||||
//!
|
||||
//! Stripe itself is NOT mocked (its API URL is hardcoded to
|
||||
//! `https://api.stripe.com`), so tests that reach the Stripe call assert the
|
||||
//! failure-cleanup behaviour instead: the reservation is marked `failed` and
|
||||
//! referral invite reservations are released.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use axum::body::Bytes;
|
||||
use axum::extract::{Path, Query, State};
|
||||
use axum::http::{HeaderMap, HeaderValue, StatusCode};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::routing::{get, post};
|
||||
use axum::{Extension, Json, Router};
|
||||
use hmac::{Hmac, KeyInit, Mac};
|
||||
use serde_json::{json, Value};
|
||||
use sha2::Sha256;
|
||||
|
||||
use crate::auth::{OptionalUser, PocketBaseUser};
|
||||
use crate::routes::{post_redeem_invite, post_stripe_webhook};
|
||||
use crate::state::{AppState, SharedState};
|
||||
|
||||
use super::start_license_checkout;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mock PocketBase
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Default)]
|
||||
struct MockPocketBase {
|
||||
/// collection name → records (each a JSON object with an "id").
|
||||
records: Mutex<HashMap<String, Vec<Value>>>,
|
||||
next_id: AtomicUsize,
|
||||
/// Every mutating request: (method, path, body).
|
||||
log: Mutex<Vec<(String, String, Value)>>,
|
||||
}
|
||||
|
||||
impl MockPocketBase {
|
||||
fn seed(&self, collection: &str, mut record: Value) -> String {
|
||||
let id = match record["id"].as_str() {
|
||||
Some(id) if !id.is_empty() => id.to_string(),
|
||||
_ => format!("mockid{:06}", self.next_id.fetch_add(1, Ordering::SeqCst)),
|
||||
};
|
||||
record["id"] = json!(id);
|
||||
self.records
|
||||
.lock()
|
||||
.unwrap()
|
||||
.entry(collection.to_string())
|
||||
.or_default()
|
||||
.push(record);
|
||||
id
|
||||
}
|
||||
|
||||
fn record(&self, collection: &str, id: &str) -> Option<Value> {
|
||||
self.records
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get(collection)?
|
||||
.iter()
|
||||
.find(|record| record["id"].as_str() == Some(id))
|
||||
.cloned()
|
||||
}
|
||||
|
||||
fn records_in(&self, collection: &str) -> Vec<Value> {
|
||||
self.records
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get(collection)
|
||||
.cloned()
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Number of PATCHes that set the user's subscription to "licensed" —
|
||||
/// i.e. how many times a license was granted.
|
||||
fn license_grant_count(&self, user_id: &str) -> usize {
|
||||
let path = format!("/api/collections/users/records/{user_id}");
|
||||
self.log
|
||||
.lock()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.filter(|(method, request_path, body)| {
|
||||
method == "PATCH"
|
||||
&& request_path == &path
|
||||
&& body["subscription"].as_str() == Some("licensed")
|
||||
})
|
||||
.count()
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate the PocketBase filter subset used by the server:
|
||||
/// `a="x" && b>=N && c<N && (d="" || d="y")`.
|
||||
fn record_matches(record: &Value, filter: &str) -> bool {
|
||||
if filter.is_empty() {
|
||||
return true;
|
||||
}
|
||||
filter
|
||||
.split(" && ")
|
||||
.all(|clause| clause_matches(record, clause.trim()))
|
||||
}
|
||||
|
||||
fn clause_matches(record: &Value, clause: &str) -> bool {
|
||||
if let Some(inner) = clause.strip_prefix('(').and_then(|c| c.strip_suffix(')')) {
|
||||
return inner
|
||||
.split(" || ")
|
||||
.any(|alternative| clause_matches(record, alternative.trim()));
|
||||
}
|
||||
if let Some((field, value)) = clause.split_once(">=") {
|
||||
return record_number(record, field) >= value.trim().parse::<i64>().unwrap_or(i64::MAX);
|
||||
}
|
||||
if let Some((field, value)) = clause.split_once('<') {
|
||||
return record_number(record, field) < value.trim().parse::<i64>().unwrap_or(i64::MIN);
|
||||
}
|
||||
if let Some((field, value)) = clause.split_once('=') {
|
||||
let expected = value.trim().trim_matches('"');
|
||||
return record_string(record, field) == expected;
|
||||
}
|
||||
panic!("mock PocketBase cannot evaluate filter clause: {clause}");
|
||||
}
|
||||
|
||||
fn record_number(record: &Value, field: &str) -> i64 {
|
||||
let value = &record[field];
|
||||
value
|
||||
.as_i64()
|
||||
.or_else(|| value.as_f64().map(|float| float as i64))
|
||||
.or_else(|| value.as_str().and_then(|text| text.parse().ok()))
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
fn record_string(record: &Value, field: &str) -> String {
|
||||
let value = &record[field];
|
||||
value.as_str().map(str::to_string).unwrap_or_else(|| {
|
||||
if value.is_null() {
|
||||
String::new()
|
||||
} else {
|
||||
value.to_string()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn auth_handler() -> Json<Value> {
|
||||
Json(json!({ "token": "testsuperusertoken" }))
|
||||
}
|
||||
|
||||
async fn list_records(
|
||||
State(pb): State<Arc<MockPocketBase>>,
|
||||
Path(collection): Path<String>,
|
||||
Query(params): Query<HashMap<String, String>>,
|
||||
) -> Json<Value> {
|
||||
let filter = params.get("filter").map(String::as_str).unwrap_or("");
|
||||
let matching: Vec<Value> = pb
|
||||
.records
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get(&collection)
|
||||
.map(|records| {
|
||||
records
|
||||
.iter()
|
||||
.filter(|record| record_matches(record, filter))
|
||||
.cloned()
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
let total = matching.len();
|
||||
let per_page = params
|
||||
.get("perPage")
|
||||
.and_then(|raw| raw.parse::<usize>().ok())
|
||||
.unwrap_or(30);
|
||||
let items: Vec<Value> = matching.into_iter().take(per_page).collect();
|
||||
Json(json!({ "items": items, "totalItems": total }))
|
||||
}
|
||||
|
||||
async fn create_record(
|
||||
State(pb): State<Arc<MockPocketBase>>,
|
||||
Path(collection): Path<String>,
|
||||
Json(body): Json<Value>,
|
||||
) -> Response {
|
||||
pb.log.lock().unwrap().push((
|
||||
"POST".to_string(),
|
||||
format!("/api/collections/{collection}/records"),
|
||||
body.clone(),
|
||||
));
|
||||
|
||||
// Emulate the unique `name` constraint on the distributed-lock collection
|
||||
// so concurrent acquisitions conflict like they do against real PocketBase.
|
||||
if collection == "checkout_locks" {
|
||||
let exists = pb
|
||||
.records
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get(&collection)
|
||||
.is_some_and(|records| records.iter().any(|record| record["name"] == body["name"]));
|
||||
if exists {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({ "message": "name must be unique" })),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
|
||||
let id = pb.seed(&collection, body);
|
||||
Json(pb.record(&collection, &id).expect("record just created")).into_response()
|
||||
}
|
||||
|
||||
async fn get_record(
|
||||
State(pb): State<Arc<MockPocketBase>>,
|
||||
Path((collection, id)): Path<(String, String)>,
|
||||
) -> Response {
|
||||
match pb.record(&collection, &id) {
|
||||
Some(record) => Json(record).into_response(),
|
||||
None => StatusCode::NOT_FOUND.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn patch_record(
|
||||
State(pb): State<Arc<MockPocketBase>>,
|
||||
Path((collection, id)): Path<(String, String)>,
|
||||
Json(body): Json<Value>,
|
||||
) -> Response {
|
||||
pb.log.lock().unwrap().push((
|
||||
"PATCH".to_string(),
|
||||
format!("/api/collections/{collection}/records/{id}"),
|
||||
body.clone(),
|
||||
));
|
||||
|
||||
let mut records = pb.records.lock().unwrap();
|
||||
let Some(record) = records.get_mut(&collection).and_then(|list| {
|
||||
list.iter_mut()
|
||||
.find(|record| record["id"].as_str() == Some(&id))
|
||||
}) else {
|
||||
return StatusCode::NOT_FOUND.into_response();
|
||||
};
|
||||
if let (Some(target), Some(updates)) = (record.as_object_mut(), body.as_object()) {
|
||||
for (key, value) in updates {
|
||||
target.insert(key.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
Json(record.clone()).into_response()
|
||||
}
|
||||
|
||||
async fn delete_record(
|
||||
State(pb): State<Arc<MockPocketBase>>,
|
||||
Path((collection, id)): Path<(String, String)>,
|
||||
) -> Response {
|
||||
pb.log.lock().unwrap().push((
|
||||
"DELETE".to_string(),
|
||||
format!("/api/collections/{collection}/records/{id}"),
|
||||
Value::Null,
|
||||
));
|
||||
|
||||
let mut records = pb.records.lock().unwrap();
|
||||
let Some(list) = records.get_mut(&collection) else {
|
||||
return StatusCode::NOT_FOUND.into_response();
|
||||
};
|
||||
let before = list.len();
|
||||
list.retain(|record| record["id"].as_str() != Some(&id));
|
||||
if list.len() == before {
|
||||
StatusCode::NOT_FOUND.into_response()
|
||||
} else {
|
||||
StatusCode::NO_CONTENT.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
fn mock_pb_router(pb: Arc<MockPocketBase>) -> Router {
|
||||
Router::new()
|
||||
.route(
|
||||
"/api/collections/_superusers/auth-with-password",
|
||||
post(auth_handler),
|
||||
)
|
||||
.route(
|
||||
"/api/collections/{collection}/records",
|
||||
get(list_records).post(create_record),
|
||||
)
|
||||
.route(
|
||||
"/api/collections/{collection}/records/{id}",
|
||||
get(get_record).patch(patch_record).delete(delete_record),
|
||||
)
|
||||
.with_state(pb)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test harness
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct TestEnv {
|
||||
shared: Arc<SharedState>,
|
||||
pb: Arc<MockPocketBase>,
|
||||
}
|
||||
|
||||
async fn setup() -> TestEnv {
|
||||
let pb = Arc::new(MockPocketBase::default());
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("bind mock PocketBase listener");
|
||||
let addr = listener.local_addr().expect("mock PocketBase address");
|
||||
let router = mock_pb_router(pb.clone());
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, router)
|
||||
.await
|
||||
.expect("mock PocketBase serve");
|
||||
});
|
||||
|
||||
let state = AppState::for_tests(format!("http://{addr}"));
|
||||
TestEnv {
|
||||
shared: Arc::new(SharedState::new(state)),
|
||||
pb,
|
||||
}
|
||||
}
|
||||
|
||||
fn now_unix() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("clock after epoch")
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
fn test_user(id: &str) -> PocketBaseUser {
|
||||
PocketBaseUser {
|
||||
id: id.to_string(),
|
||||
email: format!("{id}@test.example"),
|
||||
is_admin: false,
|
||||
subscription: "free".to_string(),
|
||||
newsletter: false,
|
||||
can_see_listings: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn seed_user(pb: &MockPocketBase, id: &str) {
|
||||
pb.seed(
|
||||
"users",
|
||||
json!({
|
||||
"id": id,
|
||||
"email": format!("{id}@test.example"),
|
||||
"subscription": "free",
|
||||
"is_admin": false,
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
fn seed_pending_checkout(pb: &MockPocketBase, user_id: &str, session_id: &str) -> String {
|
||||
pb.seed(
|
||||
"checkout_sessions",
|
||||
json!({
|
||||
"user": user_id,
|
||||
"stripe_session_id": session_id,
|
||||
"stripe_payment_intent_id": "",
|
||||
"checkout_url": "https://checkout.stripe.test/session",
|
||||
"amount_pence": 999,
|
||||
"expected_total_pence": 999,
|
||||
"currency": "gbp",
|
||||
"discount_coupon_id": "",
|
||||
"referral_invite_id": "",
|
||||
"status": "pending",
|
||||
"expires_at_unix": now_unix() + 1800,
|
||||
"paid_amount_pence": 0,
|
||||
"completed_at_unix": "",
|
||||
"reversal_reason": "",
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn checkout_completed_event(session_id: &str, user_id: &str, amount_total: u64) -> Vec<u8> {
|
||||
serde_json::to_vec(&json!({
|
||||
"id": "evt_test_1",
|
||||
"type": "checkout.session.completed",
|
||||
"data": { "object": {
|
||||
"id": session_id,
|
||||
"payment_intent": "pi_test_1",
|
||||
"client_reference_id": user_id,
|
||||
"payment_status": "paid",
|
||||
"currency": "gbp",
|
||||
"amount_subtotal": 999,
|
||||
"amount_total": amount_total,
|
||||
}}
|
||||
}))
|
||||
.expect("event serializes")
|
||||
}
|
||||
|
||||
/// Sign a payload the way Stripe does: HMAC-SHA256 over `{timestamp}.{payload}`
|
||||
/// with the webhook secret, presented as `t=...,v1=...`.
|
||||
fn stripe_signature_header(payload: &[u8], secret: &str) -> String {
|
||||
let timestamp = now_unix();
|
||||
let mut mac =
|
||||
Hmac::<Sha256>::new_from_slice(secret.as_bytes()).expect("HMAC accepts any key length");
|
||||
mac.update(format!("{timestamp}.").as_bytes());
|
||||
mac.update(payload);
|
||||
let signature = hex::encode(mac.finalize().into_bytes());
|
||||
format!("t={timestamp},v1={signature}")
|
||||
}
|
||||
|
||||
async fn deliver_webhook(env: &TestEnv, payload: Vec<u8>, signature: Option<&str>) -> Response {
|
||||
let mut headers = HeaderMap::new();
|
||||
if let Some(signature) = signature {
|
||||
headers.insert(
|
||||
"stripe-signature",
|
||||
HeaderValue::from_str(signature).expect("signature header value"),
|
||||
);
|
||||
}
|
||||
post_stripe_webhook(State(env.shared.clone()), headers, Bytes::from(payload)).await
|
||||
}
|
||||
|
||||
async fn redeem_invite(env: &TestEnv, user: PocketBaseUser, code: &str) -> Response {
|
||||
post_redeem_invite(
|
||||
State(env.shared.clone()),
|
||||
Extension(OptionalUser(Some(user))),
|
||||
Json(serde_json::from_value(json!({ "code": code })).expect("redeem request deserializes")),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn response_json(response: Response) -> Value {
|
||||
let bytes = axum::body::to_bytes(response.into_body(), 1 << 20)
|
||||
.await
|
||||
.expect("response body");
|
||||
serde_json::from_slice(&bytes).expect("response body is JSON")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Stripe webhook → license granting
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tokio::test]
|
||||
async fn webhook_with_valid_signature_grants_license() {
|
||||
let env = setup().await;
|
||||
seed_user(&env.pb, "user1");
|
||||
let reservation_id = seed_pending_checkout(&env.pb, "user1", "cs_test_abc");
|
||||
|
||||
let payload = checkout_completed_event("cs_test_abc", "user1", 999);
|
||||
let signature = stripe_signature_header(&payload, "whsec_test_secret");
|
||||
let response = deliver_webhook(&env, payload, Some(&signature)).await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let user = env.pb.record("users", "user1").expect("user exists");
|
||||
assert_eq!(user["subscription"], json!("licensed"));
|
||||
|
||||
let checkout = env
|
||||
.pb
|
||||
.record("checkout_sessions", &reservation_id)
|
||||
.expect("checkout exists");
|
||||
assert_eq!(checkout["status"], json!("completed"));
|
||||
assert_eq!(checkout["paid_amount_pence"], json!(999));
|
||||
assert_eq!(checkout["stripe_payment_intent_id"], json!("pi_test_1"));
|
||||
|
||||
assert_eq!(env.pb.license_grant_count("user1"), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn webhook_with_invalid_signature_is_rejected() {
|
||||
let env = setup().await;
|
||||
seed_user(&env.pb, "user1");
|
||||
let reservation_id = seed_pending_checkout(&env.pb, "user1", "cs_test_bad");
|
||||
|
||||
let payload = checkout_completed_event("cs_test_bad", "user1", 999);
|
||||
|
||||
// Signed with the wrong secret.
|
||||
let wrong_signature = stripe_signature_header(&payload, "whsec_wrong_secret");
|
||||
let response = deliver_webhook(&env, payload.clone(), Some(&wrong_signature)).await;
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
// Missing signature header entirely.
|
||||
let response = deliver_webhook(&env, payload, None).await;
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
let user = env.pb.record("users", "user1").expect("user exists");
|
||||
assert_eq!(user["subscription"], json!("free"));
|
||||
let checkout = env
|
||||
.pb
|
||||
.record("checkout_sessions", &reservation_id)
|
||||
.expect("checkout exists");
|
||||
assert_eq!(checkout["status"], json!("pending"));
|
||||
assert_eq!(env.pb.license_grant_count("user1"), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn replayed_webhook_does_not_double_grant() {
|
||||
let env = setup().await;
|
||||
seed_user(&env.pb, "user1");
|
||||
let reservation_id = seed_pending_checkout(&env.pb, "user1", "cs_test_replay");
|
||||
|
||||
let payload = checkout_completed_event("cs_test_replay", "user1", 999);
|
||||
let signature = stripe_signature_header(&payload, "whsec_test_secret");
|
||||
|
||||
let first = deliver_webhook(&env, payload.clone(), Some(&signature)).await;
|
||||
assert_eq!(first.status(), StatusCode::OK);
|
||||
let replay = deliver_webhook(&env, payload, Some(&signature)).await;
|
||||
assert_eq!(replay.status(), StatusCode::OK);
|
||||
|
||||
let user = env.pb.record("users", "user1").expect("user exists");
|
||||
assert_eq!(user["subscription"], json!("licensed"));
|
||||
let checkout = env
|
||||
.pb
|
||||
.record("checkout_sessions", &reservation_id)
|
||||
.expect("checkout exists");
|
||||
assert_eq!(checkout["status"], json!("completed"));
|
||||
|
||||
// The replay must be acknowledged without granting a second time.
|
||||
assert_eq!(env.pb.license_grant_count("user1"), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn webhook_with_tampered_amount_is_rejected_and_marks_reservation_invalid() {
|
||||
let env = setup().await;
|
||||
seed_user(&env.pb, "user1");
|
||||
let reservation_id = seed_pending_checkout(&env.pb, "user1", "cs_test_amount");
|
||||
|
||||
// Validly signed event whose amount_total does not match the reservation.
|
||||
let payload = checkout_completed_event("cs_test_amount", "user1", 500);
|
||||
let signature = stripe_signature_header(&payload, "whsec_test_secret");
|
||||
let response = deliver_webhook(&env, payload, Some(&signature)).await;
|
||||
|
||||
// Rejections are acknowledged with 200 so Stripe stops retrying.
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let user = env.pb.record("users", "user1").expect("user exists");
|
||||
assert_eq!(user["subscription"], json!("free"));
|
||||
let checkout = env
|
||||
.pb
|
||||
.record("checkout_sessions", &reservation_id)
|
||||
.expect("checkout exists");
|
||||
assert_eq!(checkout["status"], json!("invalid"));
|
||||
assert_eq!(env.pb.license_grant_count("user1"), 0);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Checkout session creation (up to the hardcoded Stripe API call)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tokio::test]
|
||||
async fn checkout_start_reserves_then_marks_failed_when_stripe_is_unreachable() {
|
||||
let env = setup().await;
|
||||
seed_user(&env.pb, "user9");
|
||||
let state = env.shared.load_state();
|
||||
let user = test_user("user9");
|
||||
|
||||
// The Stripe API URL is hardcoded to https://api.stripe.com, so the
|
||||
// session-creation call fails in tests (no network / dummy key). The
|
||||
// reservation bookkeeping before and after that call is what we assert.
|
||||
let result = start_license_checkout(
|
||||
&state,
|
||||
&user,
|
||||
"https://x/success",
|
||||
"https://x/cancel",
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
assert!(result.is_err(), "Stripe call must fail in tests");
|
||||
|
||||
let checkouts = env.pb.records_in("checkout_sessions");
|
||||
assert_eq!(checkouts.len(), 1, "exactly one reservation created");
|
||||
let checkout = &checkouts[0];
|
||||
assert_eq!(checkout["user"], json!("user9"));
|
||||
// 0 licensed users → public count 120 → second tier price (999p).
|
||||
assert_eq!(checkout["amount_pence"], json!(999));
|
||||
assert_eq!(checkout["expected_total_pence"], json!(999));
|
||||
assert_eq!(checkout["currency"], json!("gbp"));
|
||||
// The failed Stripe call must not leave a live pending reservation.
|
||||
assert_eq!(checkout["status"], json!("failed"));
|
||||
|
||||
// The cross-instance pricing lock was released.
|
||||
assert!(env.pb.records_in("checkout_locks").is_empty());
|
||||
assert_eq!(env.pb.license_grant_count("user9"), 0);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Invite redemption
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tokio::test]
|
||||
async fn admin_invite_redemption_grants_license() {
|
||||
let env = setup().await;
|
||||
seed_user(&env.pb, "user2");
|
||||
let invite_id = env.pb.seed(
|
||||
"invites",
|
||||
json!({
|
||||
"code": "admininvite1",
|
||||
"invite_type": "admin",
|
||||
"created_by": "adminuser1",
|
||||
"used_by_id": "",
|
||||
"used_at": "",
|
||||
}),
|
||||
);
|
||||
|
||||
let response = redeem_invite(&env, test_user("user2"), "admininvite1").await;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
let body = response_json(response).await;
|
||||
assert_eq!(body["result"], json!("licensed"));
|
||||
|
||||
let invite = env.pb.record("invites", &invite_id).expect("invite exists");
|
||||
assert_eq!(invite["used_by_id"], json!("user2"));
|
||||
|
||||
let user = env.pb.record("users", "user2").expect("user exists");
|
||||
assert_eq!(user["subscription"], json!("licensed"));
|
||||
assert_eq!(env.pb.license_grant_count("user2"), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invalid_and_oversized_invite_codes_are_rejected() {
|
||||
let env = setup().await;
|
||||
|
||||
// Non-alphanumeric characters.
|
||||
let response = redeem_invite(&env, test_user("user2"), "bad-code!").await;
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
// Longer than the 20-character limit.
|
||||
let oversized = "a".repeat(21);
|
||||
let response = redeem_invite(&env, test_user("user2"), &oversized).await;
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
// Empty code.
|
||||
let response = redeem_invite(&env, test_user("user2"), "").await;
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
assert_eq!(env.pb.license_grant_count("user2"), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn already_used_invite_is_rejected() {
|
||||
let env = setup().await;
|
||||
seed_user(&env.pb, "user2");
|
||||
env.pb.seed(
|
||||
"invites",
|
||||
json!({
|
||||
"code": "usedinvite12",
|
||||
"invite_type": "admin",
|
||||
"created_by": "adminuser1",
|
||||
"used_by_id": "otheruser9",
|
||||
"used_at": "1700000000",
|
||||
}),
|
||||
);
|
||||
|
||||
let response = redeem_invite(&env, test_user("user2"), "usedinvite12").await;
|
||||
assert_eq!(response.status(), StatusCode::NOT_FOUND);
|
||||
|
||||
let user = env.pb.record("users", "user2").expect("user exists");
|
||||
assert_eq!(user["subscription"], json!("free"));
|
||||
assert_eq!(env.pb.license_grant_count("user2"), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn referral_invite_redemption_releases_reservation_when_stripe_is_unreachable() {
|
||||
let env = setup().await;
|
||||
seed_user(&env.pb, "user3");
|
||||
let invite_id = env.pb.seed(
|
||||
"invites",
|
||||
json!({
|
||||
"code": "refcode12345",
|
||||
"invite_type": "referral",
|
||||
"created_by": "licenseduser1",
|
||||
"used_by_id": "",
|
||||
"used_at": "",
|
||||
"reserved_by_id": "",
|
||||
"reserved_checkout_id": "",
|
||||
"reserved_until_unix": 0,
|
||||
}),
|
||||
);
|
||||
|
||||
// The redemption itself is valid; it fails only at the hardcoded Stripe
|
||||
// call (coupon verification / session creation), which must roll back the
|
||||
// reservation cleanly.
|
||||
let response = redeem_invite(&env, test_user("user3"), "refcode12345").await;
|
||||
assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
|
||||
|
||||
let checkouts = env.pb.records_in("checkout_sessions");
|
||||
assert_eq!(checkouts.len(), 1, "referral reservation was created");
|
||||
assert_eq!(checkouts[0]["referral_invite_id"], json!(invite_id));
|
||||
assert_eq!(checkouts[0]["status"], json!("failed"));
|
||||
|
||||
// The invite reservation was released and the invite is still unused.
|
||||
let invite = env.pb.record("invites", &invite_id).expect("invite exists");
|
||||
assert_eq!(invite["used_by_id"], json!(""));
|
||||
assert_eq!(invite["reserved_checkout_id"], json!(""));
|
||||
assert_eq!(invite["reserved_by_id"], json!(""));
|
||||
|
||||
assert_eq!(env.pb.license_grant_count("user3"), 0);
|
||||
}
|
||||
|
|
@ -182,8 +182,7 @@ impl CrimeByYearData {
|
|||
// Force-coverage calendar (optional column: legacy parquets predate it;
|
||||
// their postcodes are treated as fully covered). A row with an empty
|
||||
// list is meaningful — zero covered years — so it IS inserted.
|
||||
let mut covered_years_by_postcode: FxHashMap<String, Vec<i32>> =
|
||||
FxHashMap::default();
|
||||
let mut covered_years_by_postcode: FxHashMap<String, Vec<i32>> = FxHashMap::default();
|
||||
if let Ok(col) = df.column(COVERAGE_COLUMN) {
|
||||
let list_ca = col
|
||||
.list()
|
||||
|
|
@ -195,12 +194,12 @@ impl CrimeByYearData {
|
|||
};
|
||||
let mut years: Vec<i32> = Vec::with_capacity(inner.len());
|
||||
if !inner.is_empty() {
|
||||
let structs = inner.struct_().with_context(|| {
|
||||
format!("Inner of '{COVERAGE_COLUMN}' is not a struct")
|
||||
})?;
|
||||
let year_field = structs.field_by_name("year").with_context(|| {
|
||||
format!("Missing 'year' field in '{COVERAGE_COLUMN}'")
|
||||
})?;
|
||||
let structs = inner
|
||||
.struct_()
|
||||
.with_context(|| format!("Inner of '{COVERAGE_COLUMN}' is not a struct"))?;
|
||||
let year_field = structs
|
||||
.field_by_name("year")
|
||||
.with_context(|| format!("Missing 'year' field in '{COVERAGE_COLUMN}'"))?;
|
||||
for idx in 0..inner.len() {
|
||||
match year_field.get(idx).ok() {
|
||||
Some(AnyValue::Int32(y)) => years.push(y),
|
||||
|
|
|
|||
|
|
@ -742,6 +742,29 @@ impl PlaceData {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl PlaceData {
|
||||
/// Minimal empty instance for integration tests that need an `AppState`
|
||||
/// but never touch place data.
|
||||
pub(crate) fn empty_for_tests() -> Self {
|
||||
PlaceData {
|
||||
name: Vec::new(),
|
||||
name_lower: Vec::new(),
|
||||
name_search: Vec::new(),
|
||||
place_type: InternedColumn::build(&[]),
|
||||
type_rank: Vec::new(),
|
||||
population: Vec::new(),
|
||||
lat: Vec::new(),
|
||||
lon: Vec::new(),
|
||||
city: Vec::new(),
|
||||
travel_destination: Vec::new(),
|
||||
token_index: FxHashMap::default(),
|
||||
token_prefix_index: FxHashMap::default(),
|
||||
fuzzy_trigram_index: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
|||
|
|
@ -588,6 +588,29 @@ impl POIData {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl POIData {
|
||||
/// Minimal empty instance for integration tests that need an `AppState`
|
||||
/// but never touch POI data.
|
||||
pub(crate) fn empty_for_tests() -> Self {
|
||||
POIData {
|
||||
id_buffer: String::new(),
|
||||
id_offsets: Vec::new(),
|
||||
id_lengths: Vec::new(),
|
||||
group: InternedColumn::build(&[]),
|
||||
category: InternedColumn::build(&[]),
|
||||
icon_category: InternedColumn::build(&[]),
|
||||
name: Vec::new(),
|
||||
lat: Vec::new(),
|
||||
lng: Vec::new(),
|
||||
emoji: InternedColumn::build(&[]),
|
||||
priority: Vec::new(),
|
||||
school_meta_idx: Vec::new(),
|
||||
school_meta: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
973
server-rs/src/data/property/address_search.rs
Normal file
973
server-rs/src/data/property/address_search.rs
Normal file
|
|
@ -0,0 +1,973 @@
|
|||
//! Address search: tokenization, query parsing, inverted/prefix indexes and the
|
||||
//! ranked per-row search over property addresses.
|
||||
|
||||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
|
||||
use super::PropertyData;
|
||||
|
||||
/// Upper bound on rows scored per query. Intersection keeps most candidate sets far below
|
||||
/// this; only a single very common road word (e.g. "high") approaches it, and the in-area
|
||||
/// priority sort keeps a refined query's matches ahead of the cut.
|
||||
const ADDRESS_SEARCH_CANDIDATE_LIMIT: usize = 150_000;
|
||||
const ADDRESS_SEARCH_PREFIX_MIN_LEN: usize = 4;
|
||||
const ADDRESS_SEARCH_PREFIX_MAX_LEN: usize = 8;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(super) struct AddressTermGroup {
|
||||
alternatives: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(super) struct AddressQuery {
|
||||
full_postcode: Option<String>,
|
||||
/// Compact uppercase outward code (optionally with a sector digit) recovered when the
|
||||
/// user appended a partial postcode like "NW1" or "NW1 6". Used as an additive ranking
|
||||
/// bias, never as a hard filter — so the disambiguating hint is honoured without
|
||||
/// excluding the same road in other areas.
|
||||
postcode_area: Option<String>,
|
||||
text_groups: Vec<AddressTermGroup>,
|
||||
numeric_terms: Vec<String>,
|
||||
candidate_terms: Vec<String>,
|
||||
}
|
||||
|
||||
fn tokenize_address_text(text: &str) -> Vec<String> {
|
||||
let mut tokens = Vec::new();
|
||||
let mut current = String::new();
|
||||
|
||||
for ch in text.chars() {
|
||||
if ch.is_ascii_alphanumeric() {
|
||||
current.push(ch.to_ascii_lowercase());
|
||||
} else if matches!(ch, '\'' | '’' | '`') {
|
||||
continue;
|
||||
} else if !current.is_empty() {
|
||||
tokens.push(std::mem::take(&mut current));
|
||||
}
|
||||
}
|
||||
|
||||
if !current.is_empty() {
|
||||
tokens.push(current);
|
||||
}
|
||||
|
||||
tokens
|
||||
}
|
||||
|
||||
fn is_full_postcode_compact(compact: &str) -> bool {
|
||||
let bytes = compact.as_bytes();
|
||||
let len = bytes.len();
|
||||
if !(5..=7).contains(&len) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let inward = &bytes[len - 3..];
|
||||
if !inward[0].is_ascii_digit()
|
||||
|| !inward[1].is_ascii_alphabetic()
|
||||
|| !inward[2].is_ascii_alphabetic()
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
let outward = &bytes[..len - 3];
|
||||
if !(2..=4).contains(&outward.len()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
outward[0].is_ascii_alphabetic()
|
||||
&& outward.iter().all(u8::is_ascii_alphanumeric)
|
||||
&& outward.iter().any(u8::is_ascii_digit)
|
||||
}
|
||||
|
||||
fn canonical_postcode_from_compact(compact: &str) -> String {
|
||||
let upper = compact.to_ascii_uppercase();
|
||||
let split = upper.len() - 3;
|
||||
format!("{} {}", &upper[..split], &upper[split..])
|
||||
}
|
||||
|
||||
fn extract_full_postcode(tokens: &[String]) -> Option<(String, Vec<usize>)> {
|
||||
for (idx, token) in tokens.iter().enumerate() {
|
||||
let compact = token.to_ascii_uppercase();
|
||||
if is_full_postcode_compact(&compact) {
|
||||
return Some((canonical_postcode_from_compact(&compact), vec![idx]));
|
||||
}
|
||||
}
|
||||
|
||||
for idx in 0..tokens.len().saturating_sub(1) {
|
||||
let compact = format!(
|
||||
"{}{}",
|
||||
tokens[idx].to_ascii_uppercase(),
|
||||
tokens[idx + 1].to_ascii_uppercase()
|
||||
);
|
||||
if is_full_postcode_compact(&compact) {
|
||||
return Some((
|
||||
canonical_postcode_from_compact(&compact),
|
||||
vec![idx, idx + 1],
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn looks_like_postcode_fragment(token: &str) -> bool {
|
||||
(2..=4).contains(&token.len())
|
||||
&& token
|
||||
.chars()
|
||||
.next()
|
||||
.is_some_and(|ch| ch.is_ascii_alphabetic())
|
||||
&& token.chars().any(|ch| ch.is_ascii_digit())
|
||||
&& token.chars().all(|ch| ch.is_ascii_alphanumeric())
|
||||
}
|
||||
|
||||
fn is_numeric_address_token(token: &str) -> bool {
|
||||
token.chars().all(|ch| ch.is_ascii_digit())
|
||||
}
|
||||
|
||||
fn address_token_aliases(token: &str) -> Vec<&'static str> {
|
||||
match token {
|
||||
"apt" => vec!["apt", "apartment"],
|
||||
"apartment" => vec!["apartment", "apt"],
|
||||
"ave" => vec!["ave", "avenue"],
|
||||
"avenue" => vec!["avenue", "ave"],
|
||||
"blvd" => vec!["blvd", "boulevard"],
|
||||
"boulevard" => vec!["boulevard", "blvd"],
|
||||
"cl" => vec!["cl", "close"],
|
||||
"close" => vec!["close", "cl"],
|
||||
"ct" => vec!["ct", "court"],
|
||||
"court" => vec!["court", "ct"],
|
||||
"cres" => vec!["cres", "crescent"],
|
||||
"crescent" => vec!["crescent", "cres"],
|
||||
"dr" => vec!["dr", "drive"],
|
||||
"drive" => vec!["drive", "dr"],
|
||||
"fl" => vec!["fl", "flat"],
|
||||
"flat" => vec!["flat", "fl"],
|
||||
"gdns" => vec!["gdns", "gardens", "garden"],
|
||||
"garden" => vec!["garden", "gardens", "gdns"],
|
||||
"gardens" => vec!["gardens", "garden", "gdns"],
|
||||
"hse" => vec!["hse", "house"],
|
||||
"house" => vec!["house", "hse"],
|
||||
"ln" => vec!["ln", "lane"],
|
||||
"lane" => vec!["lane", "ln"],
|
||||
"rd" => vec!["rd", "road"],
|
||||
"road" => vec!["road", "rd"],
|
||||
"sq" => vec!["sq", "square"],
|
||||
"square" => vec!["square", "sq"],
|
||||
"st" => vec!["st", "street", "saint"],
|
||||
"street" => vec!["street", "st"],
|
||||
"saint" => vec!["saint", "st"],
|
||||
"terr" => vec!["terr", "terrace"],
|
||||
"terrace" => vec!["terrace", "terr"],
|
||||
_ => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_address_stop_token(token: &str) -> bool {
|
||||
matches!(
|
||||
token,
|
||||
"a" | "an"
|
||||
| "and"
|
||||
| "apartment"
|
||||
| "apt"
|
||||
| "avenue"
|
||||
| "ave"
|
||||
| "block"
|
||||
| "building"
|
||||
| "bungalow"
|
||||
| "close"
|
||||
| "cl"
|
||||
| "court"
|
||||
| "ct"
|
||||
| "cres"
|
||||
| "crescent"
|
||||
| "drive"
|
||||
| "dr"
|
||||
| "estate"
|
||||
| "flat"
|
||||
| "fl"
|
||||
| "floor"
|
||||
| "garden"
|
||||
| "gardens"
|
||||
| "gdns"
|
||||
| "grove"
|
||||
| "house"
|
||||
| "hse"
|
||||
| "lane"
|
||||
| "ln"
|
||||
| "lodge"
|
||||
| "mansions"
|
||||
| "mews"
|
||||
| "of"
|
||||
| "park"
|
||||
| "place"
|
||||
| "road"
|
||||
| "rd"
|
||||
| "room"
|
||||
| "row"
|
||||
| "saint"
|
||||
| "sq"
|
||||
| "square"
|
||||
| "st"
|
||||
| "street"
|
||||
| "terr"
|
||||
| "terrace"
|
||||
| "the"
|
||||
| "unit"
|
||||
| "view"
|
||||
| "villas"
|
||||
| "walk"
|
||||
| "way"
|
||||
| "yard"
|
||||
)
|
||||
}
|
||||
|
||||
fn address_term_group(token: &str) -> Option<AddressTermGroup> {
|
||||
if token.len() < 3 || is_numeric_address_token(token) || looks_like_postcode_fragment(token) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut alternatives = Vec::new();
|
||||
alternatives.push(token.to_string());
|
||||
for alias in address_token_aliases(token) {
|
||||
if !alternatives.iter().any(|existing| existing == alias) {
|
||||
alternatives.push(alias.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if alternatives
|
||||
.iter()
|
||||
.all(|alternative| is_address_stop_token(alternative))
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(AddressTermGroup { alternatives })
|
||||
}
|
||||
|
||||
pub(super) fn address_search_tokens(text: &str) -> Vec<String> {
|
||||
let mut tokens: Vec<String> = tokenize_address_text(text)
|
||||
.into_iter()
|
||||
.filter(|token| is_address_search_token(token))
|
||||
.collect();
|
||||
tokens.sort_unstable();
|
||||
tokens.dedup();
|
||||
tokens
|
||||
}
|
||||
|
||||
fn is_address_search_token(token: &str) -> bool {
|
||||
if looks_like_postcode_fragment(token) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if is_numeric_address_token(token) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if token.chars().any(|ch| ch.is_ascii_digit()) {
|
||||
return token.len() >= 2;
|
||||
}
|
||||
|
||||
token.len() >= 3
|
||||
}
|
||||
|
||||
pub(super) fn is_address_candidate_token(token: &str) -> bool {
|
||||
!is_numeric_address_token(token)
|
||||
&& !looks_like_postcode_fragment(token)
|
||||
&& (token.chars().any(|ch| ch.is_ascii_digit())
|
||||
|| (token.len() >= 3 && !is_address_stop_token(token)))
|
||||
}
|
||||
|
||||
fn address_prefix_key(term: &str) -> &str {
|
||||
if term.len() > ADDRESS_SEARCH_PREFIX_MAX_LEN {
|
||||
&term[..ADDRESS_SEARCH_PREFIX_MAX_LEN]
|
||||
} else {
|
||||
term
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn build_address_prefix_index(
|
||||
address_token_index: &FxHashMap<String, Vec<u32>>,
|
||||
) -> FxHashMap<String, Vec<String>> {
|
||||
let mut prefix_index: FxHashMap<String, Vec<String>> = FxHashMap::default();
|
||||
|
||||
for token in address_token_index.keys() {
|
||||
let max_prefix_len = token.len().min(ADDRESS_SEARCH_PREFIX_MAX_LEN);
|
||||
for prefix_len in ADDRESS_SEARCH_PREFIX_MIN_LEN..=max_prefix_len {
|
||||
prefix_index
|
||||
.entry(token[..prefix_len].to_string())
|
||||
.or_default()
|
||||
.push(token.clone());
|
||||
}
|
||||
}
|
||||
|
||||
for tokens in prefix_index.values_mut() {
|
||||
tokens.sort_unstable();
|
||||
tokens.dedup();
|
||||
}
|
||||
|
||||
prefix_index
|
||||
}
|
||||
|
||||
/// Intersect two ascending-sorted row-id slices.
|
||||
fn intersect_sorted(left: &[u32], right: &[u32]) -> Vec<u32> {
|
||||
let mut out = Vec::new();
|
||||
let (mut i, mut j) = (0, 0);
|
||||
while i < left.len() && j < right.len() {
|
||||
match left[i].cmp(&right[j]) {
|
||||
std::cmp::Ordering::Less => i += 1,
|
||||
std::cmp::Ordering::Greater => j += 1,
|
||||
std::cmp::Ordering::Equal => {
|
||||
out.push(left[i]);
|
||||
i += 1;
|
||||
j += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Union two ascending-sorted row-id slices (deduplicated, stays sorted).
|
||||
fn union_sorted(left: &[u32], right: &[u32]) -> Vec<u32> {
|
||||
let mut out = Vec::with_capacity(left.len() + right.len());
|
||||
let (mut i, mut j) = (0, 0);
|
||||
while i < left.len() && j < right.len() {
|
||||
match left[i].cmp(&right[j]) {
|
||||
std::cmp::Ordering::Less => {
|
||||
out.push(left[i]);
|
||||
i += 1;
|
||||
}
|
||||
std::cmp::Ordering::Greater => {
|
||||
out.push(right[j]);
|
||||
j += 1;
|
||||
}
|
||||
std::cmp::Ordering::Equal => {
|
||||
out.push(left[i]);
|
||||
i += 1;
|
||||
j += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
out.extend_from_slice(&left[i..]);
|
||||
out.extend_from_slice(&right[j..]);
|
||||
out
|
||||
}
|
||||
|
||||
/// An ordinal like "1st", "2nd", "3rd", "21st" — part of the street name ("2nd Avenue"), not a
|
||||
/// house-number prefix.
|
||||
fn is_ordinal_token(token: &str) -> bool {
|
||||
let split = token.len().saturating_sub(2);
|
||||
let (digits, suffix) = token.split_at(split);
|
||||
!digits.is_empty()
|
||||
&& digits.chars().all(|ch| ch.is_ascii_digit())
|
||||
&& matches!(suffix, "st" | "nd" | "rd" | "th")
|
||||
}
|
||||
|
||||
/// Leading address tokens that denote a unit/house number rather than the street itself.
|
||||
fn is_house_prefix_token(token: &str) -> bool {
|
||||
if is_ordinal_token(token) {
|
||||
return false;
|
||||
}
|
||||
matches!(
|
||||
token,
|
||||
"flat" | "fl" | "apartment" | "apt" | "unit" | "no" | "block" | "floor" | "room"
|
||||
) || token.len() == 1
|
||||
|| token.chars().all(|ch| ch.is_ascii_digit())
|
||||
|| (token.chars().next().is_some_and(|ch| ch.is_ascii_digit())
|
||||
&& token.chars().any(|ch| ch.is_ascii_alphabetic()))
|
||||
}
|
||||
|
||||
/// Street-level key for an address: drops the leading house-number / flat prefix so that
|
||||
/// "12 Baker Street" and "5 Baker Street" collapse to a single street entry.
|
||||
fn street_key(address: &str) -> String {
|
||||
let tokens = tokenize_address_text(address);
|
||||
let mut start = 0;
|
||||
while start < tokens.len() && is_house_prefix_token(&tokens[start]) {
|
||||
start += 1;
|
||||
}
|
||||
if start >= tokens.len() {
|
||||
return tokens.join(" ");
|
||||
}
|
||||
tokens[start..].join(" ")
|
||||
}
|
||||
|
||||
/// Road-type words. Their presence (with no house number) marks a road browse, which we
|
||||
/// collapse to one result per street.
|
||||
const ROAD_TYPE_TOKENS: &[&str] = &[
|
||||
"street",
|
||||
"st",
|
||||
"road",
|
||||
"rd",
|
||||
"lane",
|
||||
"ln",
|
||||
"avenue",
|
||||
"ave",
|
||||
"close",
|
||||
"cl",
|
||||
"drive",
|
||||
"dr",
|
||||
"way",
|
||||
"court",
|
||||
"ct",
|
||||
"crescent",
|
||||
"cres",
|
||||
"place",
|
||||
"terrace",
|
||||
"terr",
|
||||
"grove",
|
||||
"gardens",
|
||||
"gdns",
|
||||
"walk",
|
||||
"row",
|
||||
"square",
|
||||
"sq",
|
||||
"hill",
|
||||
"parade",
|
||||
"mews",
|
||||
"embankment",
|
||||
"broadway",
|
||||
"boulevard",
|
||||
"blvd",
|
||||
];
|
||||
|
||||
fn query_has_road_type(query: &str) -> bool {
|
||||
tokenize_address_text(query)
|
||||
.iter()
|
||||
.any(|token| ROAD_TYPE_TOKENS.contains(&token.as_str()))
|
||||
}
|
||||
|
||||
/// The outward code (everything before the space) of a canonical postcode.
|
||||
fn outcode_of(postcode: &str) -> &str {
|
||||
postcode.split(' ').next().unwrap_or(postcode)
|
||||
}
|
||||
|
||||
fn parse_address_query(query: &str) -> AddressQuery {
|
||||
let tokens = tokenize_address_text(query);
|
||||
let (full_postcode, postcode_token_indices) = extract_full_postcode(&tokens)
|
||||
.map(|(postcode, indices)| (Some(postcode), indices))
|
||||
.unwrap_or((None, Vec::new()));
|
||||
|
||||
let skip_postcode_tokens: FxHashSet<usize> = postcode_token_indices.into_iter().collect();
|
||||
|
||||
// Recover an appended partial postcode (outcode, or outcode + sector digit) as a ranking
|
||||
// bias rather than discarding it — but only from the TRAILING position, so a leading road
|
||||
// designation like "A4 Great West Road" is not mistaken for an area refinement.
|
||||
let mut postcode_area: Option<String> = None;
|
||||
let mut consumed_partial_tokens: FxHashSet<usize> = FxHashSet::default();
|
||||
if full_postcode.is_none() && !tokens.is_empty() {
|
||||
let last = tokens.len() - 1;
|
||||
if !skip_postcode_tokens.contains(&last) {
|
||||
let sector_digit =
|
||||
tokens[last].len() == 1 && tokens[last].chars().all(|ch| ch.is_ascii_digit());
|
||||
if last >= 1
|
||||
&& sector_digit
|
||||
&& !skip_postcode_tokens.contains(&(last - 1))
|
||||
&& looks_like_postcode_fragment(&tokens[last - 1])
|
||||
{
|
||||
postcode_area = Some(format!(
|
||||
"{}{}",
|
||||
tokens[last - 1].to_ascii_uppercase(),
|
||||
tokens[last]
|
||||
));
|
||||
consumed_partial_tokens.insert(last);
|
||||
consumed_partial_tokens.insert(last - 1);
|
||||
} else if looks_like_postcode_fragment(&tokens[last]) {
|
||||
postcode_area = Some(tokens[last].to_ascii_uppercase());
|
||||
consumed_partial_tokens.insert(last);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut text_groups = Vec::new();
|
||||
let mut numeric_terms = Vec::new();
|
||||
let mut candidate_terms = Vec::new();
|
||||
|
||||
for (idx, token) in tokens.iter().enumerate() {
|
||||
if skip_postcode_tokens.contains(&idx)
|
||||
|| consumed_partial_tokens.contains(&idx)
|
||||
|| looks_like_postcode_fragment(token)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if is_numeric_address_token(token) {
|
||||
numeric_terms.push(token.clone());
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(group) = address_term_group(token) {
|
||||
for alternative in &group.alternatives {
|
||||
if !is_address_stop_token(alternative)
|
||||
&& !candidate_terms.iter().any(|term| term == alternative)
|
||||
{
|
||||
candidate_terms.push(alternative.clone());
|
||||
}
|
||||
}
|
||||
text_groups.push(group);
|
||||
} else if token.chars().any(|ch| ch.is_ascii_digit()) && token.len() >= 2 {
|
||||
numeric_terms.push(token.clone());
|
||||
if !candidate_terms.iter().any(|term| term == token) {
|
||||
candidate_terms.push(token.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
text_groups.dedup_by(|left, right| left.alternatives == right.alternatives);
|
||||
numeric_terms.sort_unstable();
|
||||
numeric_terms.dedup();
|
||||
|
||||
AddressQuery {
|
||||
full_postcode,
|
||||
postcode_area,
|
||||
text_groups,
|
||||
numeric_terms,
|
||||
candidate_terms,
|
||||
}
|
||||
}
|
||||
|
||||
fn token_matches_query_term(token: &str, query_term: &str) -> bool {
|
||||
token == query_term || (query_term.len() >= 3 && token.starts_with(query_term))
|
||||
}
|
||||
|
||||
fn token_matches_numeric_term(token: &str, query_term: &str) -> bool {
|
||||
token == query_term || token.starts_with(query_term)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn address_tokens_match_group(tokens: &[String], group: &AddressTermGroup) -> bool {
|
||||
group.alternatives.iter().any(|alternative| {
|
||||
tokens
|
||||
.iter()
|
||||
.any(|token| token_matches_query_term(token, alternative))
|
||||
})
|
||||
}
|
||||
|
||||
impl PropertyData {
|
||||
fn row_address_search_tokens(&self, row: usize) -> &[lasso::Spur] {
|
||||
let offset = self.address_search_token_offsets[row] as usize;
|
||||
let length = self.address_search_token_lengths[row] as usize;
|
||||
&self.address_search_token_keys[offset..offset + length]
|
||||
}
|
||||
|
||||
/// Search individual property addresses, returning `(row, score)` ranked best-first.
|
||||
///
|
||||
/// Candidate rows come from intersecting the posting lists of the distinctive words the
|
||||
/// user typed in full (so "Cherry Hinton Road" narrows to rows containing both), unioned
|
||||
/// with the exact-postcode rows when a complete postcode is present (so a postcode is a
|
||||
/// boost, not an all-or-nothing gate). An appended partial postcode keeps in-area rows
|
||||
/// ahead of the candidate cut and adds a scoring bias. With a road-type word and no house
|
||||
/// number, results collapse to one row per street.
|
||||
pub fn search_addresses(&self, query: &str, limit: usize) -> Vec<(usize, i32)> {
|
||||
if limit == 0 {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let parsed = parse_address_query(query);
|
||||
if parsed.full_postcode.is_none()
|
||||
&& parsed.text_groups.is_empty()
|
||||
&& parsed.numeric_terms.is_empty()
|
||||
{
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut candidate_rows = self.address_candidate_rows(&parsed.candidate_terms);
|
||||
|
||||
// A complete postcode contributes its rows too, instead of replacing the road match.
|
||||
if let Some(postcode) = parsed.full_postcode.as_deref() {
|
||||
if let Some(rows) = self
|
||||
.postcode_interner
|
||||
.get(postcode)
|
||||
.and_then(|key| self.postcode_row_index.get(&key))
|
||||
{
|
||||
candidate_rows = if candidate_rows.is_empty() {
|
||||
rows.clone()
|
||||
} else {
|
||||
union_sorted(&candidate_rows, rows)
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if candidate_rows.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
// When the user appended a partial postcode, keep in-area rows ahead of the cut so the
|
||||
// refinement still surfaces even for very common roads. Single pass (stable partition) so
|
||||
// the postcode check — which allocates — runs exactly once per candidate.
|
||||
if let Some(area) = parsed.postcode_area.as_deref() {
|
||||
let mut in_area = Vec::new();
|
||||
let mut others = Vec::new();
|
||||
for &row in &candidate_rows {
|
||||
if self.row_postcode_in_area(row as usize, area) {
|
||||
in_area.push(row);
|
||||
} else {
|
||||
others.push(row);
|
||||
}
|
||||
}
|
||||
in_area.extend(others);
|
||||
candidate_rows = in_area;
|
||||
}
|
||||
candidate_rows.truncate(ADDRESS_SEARCH_CANDIDATE_LIMIT);
|
||||
|
||||
let mut scored: Vec<(i32, usize, usize)> = candidate_rows
|
||||
.into_iter()
|
||||
.filter_map(|row| {
|
||||
let row = row as usize;
|
||||
self.address_match_score(row, &parsed)
|
||||
.map(|score| (score, self.address(row).len(), row))
|
||||
})
|
||||
.collect();
|
||||
|
||||
scored.sort_unstable_by(|left, right| {
|
||||
right
|
||||
.0
|
||||
.cmp(&left.0)
|
||||
.then(left.1.cmp(&right.1))
|
||||
.then(left.2.cmp(&right.2))
|
||||
});
|
||||
|
||||
// Collapse a road browse (road-type word, no house number) to one row per street.
|
||||
let collapse_streets = parsed.numeric_terms.is_empty() && query_has_road_type(query);
|
||||
|
||||
let mut seen = FxHashSet::default();
|
||||
let mut results = Vec::with_capacity(limit);
|
||||
for (score, _, row) in scored {
|
||||
let address = self.address(row).trim();
|
||||
if address.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let key = if collapse_streets {
|
||||
format!(
|
||||
"{}\n{}",
|
||||
street_key(address),
|
||||
outcode_of(self.postcode(row))
|
||||
)
|
||||
} else {
|
||||
format!("{}\n{}", address.to_ascii_lowercase(), self.postcode(row))
|
||||
};
|
||||
if !seen.insert(key) {
|
||||
continue;
|
||||
}
|
||||
results.push((row, score));
|
||||
if results.len() == limit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// True when the row's postcode begins with the compact partial-postcode `area`
|
||||
/// (e.g. "NW1" or "NW16" matches "NW1 6XE").
|
||||
fn row_postcode_in_area(&self, row: usize, area: &str) -> bool {
|
||||
let mut compact = String::new();
|
||||
for ch in self.postcode(row).chars() {
|
||||
if !ch.is_whitespace() {
|
||||
compact.push(ch.to_ascii_uppercase());
|
||||
}
|
||||
}
|
||||
compact.starts_with(area)
|
||||
}
|
||||
|
||||
/// Candidate rows for the distinctive query words. Words typed in full intersect by their
|
||||
/// exact posting lists (precise); a still-being-typed final word with no exact match seeds
|
||||
/// from the smallest prefix-expanded posting list (so partial typing keeps working).
|
||||
fn address_candidate_rows(&self, terms: &[String]) -> Vec<u32> {
|
||||
let mut exact: Vec<&[u32]> = terms
|
||||
.iter()
|
||||
.filter_map(|term| self.address_token_index.get(term).map(Vec::as_slice))
|
||||
.collect();
|
||||
|
||||
if !exact.is_empty() {
|
||||
exact.sort_by_key(|rows| rows.len());
|
||||
let mut acc = exact[0].to_vec();
|
||||
for rows in &exact[1..] {
|
||||
if acc.is_empty() {
|
||||
break;
|
||||
}
|
||||
acc = intersect_sorted(&acc, rows);
|
||||
}
|
||||
return acc;
|
||||
}
|
||||
|
||||
self.prefix_seed_rows(terms)
|
||||
}
|
||||
|
||||
/// Seed rows from the smallest prefix-expanded term — used only when no word matched an
|
||||
/// indexed token exactly (i.e. the user is still typing the final word).
|
||||
fn prefix_seed_rows(&self, terms: &[String]) -> Vec<u32> {
|
||||
let mut best: Option<Vec<u32>> = None;
|
||||
for term in terms {
|
||||
if term.len() < ADDRESS_SEARCH_PREFIX_MIN_LEN {
|
||||
continue;
|
||||
}
|
||||
let Some(tokens) = self.address_prefix_index.get(address_prefix_key(term)) else {
|
||||
continue;
|
||||
};
|
||||
let mut union: Vec<u32> = Vec::new();
|
||||
for token in tokens {
|
||||
if !token.starts_with(term) {
|
||||
continue;
|
||||
}
|
||||
if let Some(rows) = self.address_token_index.get(token) {
|
||||
union = if union.is_empty() {
|
||||
rows.clone()
|
||||
} else {
|
||||
union_sorted(&union, rows)
|
||||
};
|
||||
}
|
||||
}
|
||||
if !union.is_empty()
|
||||
&& best
|
||||
.as_ref()
|
||||
.is_none_or(|current| union.len() < current.len())
|
||||
{
|
||||
best = Some(union);
|
||||
}
|
||||
}
|
||||
best.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn address_match_score(&self, row: usize, parsed: &AddressQuery) -> Option<i32> {
|
||||
if self.address(row).trim().is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let tokens = self.row_address_search_tokens(row);
|
||||
if parsed
|
||||
.text_groups
|
||||
.iter()
|
||||
.any(|group| !self.address_tokens_match_group(tokens, group))
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let numeric_matches = parsed
|
||||
.numeric_terms
|
||||
.iter()
|
||||
.filter(|term| {
|
||||
tokens.iter().any(|token| {
|
||||
token_matches_numeric_term(self.address_search_interner.resolve(token), term)
|
||||
})
|
||||
})
|
||||
.count();
|
||||
|
||||
if !parsed.numeric_terms.is_empty() && numeric_matches == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut score = 0;
|
||||
if parsed.full_postcode.is_some() {
|
||||
score += 1_000;
|
||||
}
|
||||
score += (parsed.text_groups.len() as i32) * 200;
|
||||
score += (numeric_matches as i32) * 90;
|
||||
if numeric_matches == parsed.numeric_terms.len() && numeric_matches > 0 {
|
||||
score += 50;
|
||||
}
|
||||
// Additive bias (never a filter) when the row sits in the appended partial postcode.
|
||||
if let Some(area) = parsed.postcode_area.as_deref() {
|
||||
if self.row_postcode_in_area(row, area) {
|
||||
score += 400;
|
||||
}
|
||||
}
|
||||
|
||||
Some(score)
|
||||
}
|
||||
|
||||
fn address_tokens_match_group(&self, tokens: &[lasso::Spur], group: &AddressTermGroup) -> bool {
|
||||
group.alternatives.iter().any(|alternative| {
|
||||
tokens.iter().any(|token| {
|
||||
token_matches_query_term(self.address_search_interner.resolve(token), alternative)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn full_postcode_detection_accepts_common_formats() {
|
||||
assert!(is_full_postcode_compact("SW1A1AA"));
|
||||
assert!(is_full_postcode_compact("E142DG"));
|
||||
assert!(is_full_postcode_compact("M11AE"));
|
||||
assert!(!is_full_postcode_compact("E14"));
|
||||
assert!(!is_full_postcode_compact("DOWNING"));
|
||||
assert!(!is_full_postcode_compact("10A"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn address_query_parsing_skips_postcodes_and_street_suffixes() {
|
||||
let parsed = parse_address_query("Flat 2, 10 Downing St, SW1A 2AA");
|
||||
|
||||
assert_eq!(parsed.full_postcode.as_deref(), Some("SW1A 2AA"));
|
||||
assert_eq!(
|
||||
parsed.numeric_terms,
|
||||
vec!["10".to_string(), "2".to_string()]
|
||||
);
|
||||
assert_eq!(parsed.candidate_terms, vec!["downing".to_string()]);
|
||||
assert_eq!(parsed.text_groups.len(), 1);
|
||||
assert_eq!(
|
||||
parsed.text_groups[0].alternatives,
|
||||
vec!["downing".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn address_query_parsing_handles_compact_postcodes() {
|
||||
let parsed = parse_address_query("10 downing street sw1a1aa");
|
||||
|
||||
assert_eq!(parsed.full_postcode.as_deref(), Some("SW1A 1AA"));
|
||||
assert_eq!(parsed.numeric_terms, vec!["10".to_string()]);
|
||||
assert_eq!(parsed.candidate_terms, vec!["downing".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn address_query_recovers_appended_partial_postcode_as_bias() {
|
||||
let parsed = parse_address_query("Baker Street NW1");
|
||||
assert_eq!(parsed.full_postcode, None);
|
||||
assert_eq!(parsed.postcode_area.as_deref(), Some("NW1"));
|
||||
// The road words are still searchable; the postcode fragment did not consume them.
|
||||
assert_eq!(parsed.candidate_terms, vec!["baker".to_string()]);
|
||||
assert!(parsed.numeric_terms.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn address_query_recovers_outcode_plus_sector_without_a_phantom_house_number() {
|
||||
let parsed = parse_address_query("High Street CR0 2");
|
||||
assert_eq!(parsed.postcode_area.as_deref(), Some("CR02"));
|
||||
// The lone sector digit must not be treated as a house number.
|
||||
assert!(parsed.numeric_terms.is_empty());
|
||||
assert_eq!(parsed.candidate_terms, vec!["high".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn full_postcode_takes_precedence_over_partial_bias() {
|
||||
let parsed = parse_address_query("Baker Street NW1 6XE");
|
||||
assert_eq!(parsed.full_postcode.as_deref(), Some("NW1 6XE"));
|
||||
assert_eq!(parsed.postcode_area, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn intersect_and_union_sorted_row_ids() {
|
||||
assert_eq!(
|
||||
intersect_sorted(&[1, 2, 3, 5], &[2, 3, 4, 5]),
|
||||
vec![2, 3, 5]
|
||||
);
|
||||
assert_eq!(intersect_sorted(&[1, 2], &[3, 4]), Vec::<u32>::new());
|
||||
assert_eq!(union_sorted(&[1, 3, 5], &[2, 3, 4]), vec![1, 2, 3, 4, 5]);
|
||||
assert_eq!(union_sorted(&[], &[2, 4]), vec![2, 4]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn street_key_collapses_house_numbers_and_flats() {
|
||||
assert_eq!(street_key("12 Baker Street"), "baker street");
|
||||
assert_eq!(street_key("5 Baker Street"), "baker street");
|
||||
assert_eq!(street_key("Flat 2, 10 Downing Street"), "downing street");
|
||||
assert_eq!(street_key("221B Baker Street"), "baker street");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn street_key_keeps_ordinal_street_names() {
|
||||
// Ordinals are part of the street name, not a house-number prefix.
|
||||
assert_eq!(street_key("2nd Avenue"), "2nd avenue");
|
||||
assert_eq!(street_key("12 3rd Avenue"), "3rd avenue");
|
||||
assert!(is_ordinal_token("21st"));
|
||||
assert!(!is_ordinal_token("21"));
|
||||
assert!(!is_ordinal_token("221b"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn postcode_area_recovered_only_from_the_trailing_position() {
|
||||
// A leading road designation must NOT be taken as an area refinement.
|
||||
let parsed = parse_address_query("A4 Great West Road");
|
||||
assert_eq!(parsed.postcode_area, None);
|
||||
// A genuine trailing outcode still is.
|
||||
let trailing = parse_address_query("Great West Road W4");
|
||||
assert_eq!(trailing.postcode_area.as_deref(), Some("W4"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn road_type_detection() {
|
||||
assert!(query_has_road_type("high street"));
|
||||
assert!(query_has_road_type("acacia avenue"));
|
||||
assert!(!query_has_road_type("acacia"));
|
||||
assert!(!query_has_road_type("london"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn address_query_parsing_keeps_partial_terms_for_row_matching() {
|
||||
let parsed = parse_address_query("settlers cour");
|
||||
|
||||
assert_eq!(parsed.full_postcode, None);
|
||||
assert_eq!(parsed.numeric_terms, Vec::<String>::new());
|
||||
assert_eq!(
|
||||
parsed.candidate_terms,
|
||||
vec!["settlers".to_string(), "cour".to_string()]
|
||||
);
|
||||
assert_eq!(parsed.text_groups.len(), 2);
|
||||
assert_eq!(
|
||||
parsed.text_groups[0].alternatives,
|
||||
vec!["settlers".to_string()]
|
||||
);
|
||||
assert_eq!(parsed.text_groups[1].alternatives, vec!["cour".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn address_search_tokens_keep_actual_address_terms_for_scoring() {
|
||||
let tokens = address_search_tokens("Flat 2, 10 Downing Cour");
|
||||
|
||||
assert_eq!(
|
||||
tokens,
|
||||
vec![
|
||||
"10".to_string(),
|
||||
"2".to_string(),
|
||||
"cour".to_string(),
|
||||
"downing".to_string(),
|
||||
"flat".to_string()
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn address_prefix_index_finds_partial_address_terms() {
|
||||
let mut token_index: FxHashMap<String, Vec<u32>> = FxHashMap::default();
|
||||
token_index.insert("downing".to_string(), vec![1]);
|
||||
token_index.insert("downton".to_string(), vec![2]);
|
||||
token_index.insert("market".to_string(), vec![3]);
|
||||
|
||||
let prefix_index = build_address_prefix_index(&token_index);
|
||||
|
||||
assert_eq!(
|
||||
prefix_index.get("down").cloned().unwrap_or_default(),
|
||||
vec!["downing".to_string(), "downton".to_string()]
|
||||
);
|
||||
assert_eq!(
|
||||
prefix_index.get("downi").cloned().unwrap_or_default(),
|
||||
vec!["downing".to_string()]
|
||||
);
|
||||
assert_eq!(
|
||||
prefix_index.get("downt").cloned().unwrap_or_default(),
|
||||
vec!["downton".to_string()]
|
||||
);
|
||||
assert!(!prefix_index.contains_key("do"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn address_term_matching_allows_prefixes_and_aliases() {
|
||||
let tokens = tokenize_address_text("10 Downing Street");
|
||||
let prefix_group = address_term_group("down").expect("prefix term should be searchable");
|
||||
let alias_group = AddressTermGroup {
|
||||
alternatives: vec!["st".to_string(), "street".to_string()],
|
||||
};
|
||||
|
||||
assert!(address_tokens_match_group(&tokens, &prefix_group));
|
||||
assert!(address_tokens_match_group(&tokens, &alias_group));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn address_term_matching_uses_actual_token_prefixes() {
|
||||
let tokens = tokenize_address_text("12 Settlers Court");
|
||||
let prefix_group = address_term_group("cou").expect("partial term should be searchable");
|
||||
|
||||
assert!(address_tokens_match_group(&tokens, &prefix_group));
|
||||
}
|
||||
}
|
||||
34
server-rs/src/data/property/h3.rs
Normal file
34
server-rs/src/data/property/h3.rs
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
//! H3 spatial cell precomputation for property rows.
|
||||
|
||||
use anyhow::Context;
|
||||
use rayon::prelude::*;
|
||||
|
||||
use crate::consts::H3_PRECOMPUTE_MAX;
|
||||
|
||||
/// Precompute H3 cell IDs for all rows at the maximum resolution only.
|
||||
/// Parent cells for lower resolutions are derived on the fly via `CellIndex::parent()`.
|
||||
pub fn precompute_h3(lat: &[f32], lon: &[f32]) -> anyhow::Result<Vec<u64>> {
|
||||
let res = H3_PRECOMPUTE_MAX;
|
||||
tracing::info!("Precomputing H3 cells at resolution {}", res);
|
||||
|
||||
let h3_res =
|
||||
h3o::Resolution::try_from(res).with_context(|| format!("Invalid H3 resolution: {res}"))?;
|
||||
|
||||
let cells: Vec<u64> = lat
|
||||
.par_iter()
|
||||
.zip(lon.par_iter())
|
||||
.enumerate()
|
||||
.map(|(i, (&latitude, &longitude))| {
|
||||
let coord = h3o::LatLng::new(latitude as f64, longitude as f64).unwrap_or_else(|err| {
|
||||
panic!(
|
||||
"Invalid coordinates at row {}: lat={}, lon={}: {}",
|
||||
i, latitude, longitude, err
|
||||
)
|
||||
});
|
||||
u64::from(coord.to_cell(h3_res))
|
||||
})
|
||||
.collect();
|
||||
|
||||
tracing::info!("H3 precomputation complete ({} cells)", cells.len());
|
||||
Ok(cells)
|
||||
}
|
||||
1105
server-rs/src/data/property/loading.rs
Normal file
1105
server-rs/src/data/property/loading.rs
Normal file
File diff suppressed because it is too large
Load diff
238
server-rs/src/data/property/mod.rs
Normal file
238
server-rs/src/data/property/mod.rs
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
//! Property data: the row-major quantized feature matrix plus the side tables
|
||||
//! (addresses, postcodes, renovation/price history, POI metrics) built from the
|
||||
//! properties + postcode parquet files.
|
||||
//!
|
||||
//! Split by concern:
|
||||
//! - [`loading`]: parquet ingestion, validation, spatial sort and matrix build
|
||||
//! - [`stats`]: histograms, percentiles and slider-bound computation
|
||||
//! - [`quant`]: u16 quantization encode/decode
|
||||
//! - [`poi_metrics`]: postcode-level POI metric side table
|
||||
//! - [`address_search`]: address tokenization, indexing and ranked search
|
||||
//! - [`h3`]: H3 cell precomputation
|
||||
|
||||
mod address_search;
|
||||
mod h3;
|
||||
mod loading;
|
||||
mod poi_metrics;
|
||||
mod quant;
|
||||
mod stats;
|
||||
|
||||
pub use h3::precompute_h3;
|
||||
pub use poi_metrics::PostcodePoiMetrics;
|
||||
pub use quant::QuantRef;
|
||||
pub use stats::{FeatureStats, Histogram};
|
||||
|
||||
use rustc_hash::FxHashMap;
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::consts::NAN_U16;
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct RenovationEvent {
|
||||
pub year: i32,
|
||||
pub event: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct HistoricalPrice {
|
||||
pub year: i32,
|
||||
pub month: u8,
|
||||
pub price: i64,
|
||||
}
|
||||
|
||||
pub struct PropertyData {
|
||||
pub lat: Vec<f32>,
|
||||
pub lon: Vec<f32>,
|
||||
pub feature_names: Vec<String>,
|
||||
pub num_features: usize,
|
||||
/// Number of numeric features (enum features start at this index).
|
||||
pub num_numeric: usize,
|
||||
/// Row-major flat array: feature_data[row * num_features + feat_idx].
|
||||
/// Quantized to u16. NaN sentinel = u16::MAX (65535).
|
||||
/// Numeric features: encoded via (val - min) / range * 65534.
|
||||
/// Enum features: stored directly as u16 cast of the f32 index.
|
||||
pub feature_data: Vec<u16>,
|
||||
/// Per-feature: range / QUANT_SCALE for fast decode.
|
||||
dequant_a: Vec<f32>,
|
||||
/// Per-feature: minimum value (offset for dequantization).
|
||||
quant_min: Vec<f32>,
|
||||
/// Per-feature: max - min (for encoding filter bounds).
|
||||
quant_range: Vec<f32>,
|
||||
pub feature_stats: Vec<FeatureStats>,
|
||||
pub poi_metrics: PostcodePoiMetrics,
|
||||
/// Unquantized last sale price used by the price-history chart.
|
||||
last_known_price_raw: Vec<f32>,
|
||||
/// Contiguous buffer holding all address strings end-to-end.
|
||||
address_buffer: String,
|
||||
/// Byte offset into `address_buffer` where each row's address starts.
|
||||
address_offsets: Vec<u32>,
|
||||
/// Length in bytes of each row's address.
|
||||
address_lengths: Vec<u16>,
|
||||
/// Interned postcodes: reader is thread-safe, keys index into it.
|
||||
postcode_interner: lasso::RodeoReader,
|
||||
postcode_keys: Vec<lasso::Spur>,
|
||||
/// Rows for each postcode, keyed by the interned postcode key.
|
||||
postcode_row_index: FxHashMap<lasso::Spur, Vec<u32>>,
|
||||
/// Inverted index from address tokens to property rows.
|
||||
address_token_index: FxHashMap<String, Vec<u32>>,
|
||||
/// Prefix lookup from typed address-token prefix to indexed full address tokens.
|
||||
address_prefix_index: FxHashMap<String, Vec<String>>,
|
||||
/// Interned normalized address-search tokens used for per-row scoring.
|
||||
address_search_interner: lasso::RodeoReader,
|
||||
/// Flat per-row normalized address-search token keys.
|
||||
address_search_token_keys: Vec<lasso::Spur>,
|
||||
/// Offset into `address_search_token_keys` for each row.
|
||||
address_search_token_offsets: Vec<u32>,
|
||||
/// Number of normalized address-search token keys for each row.
|
||||
address_search_token_lengths: Vec<u16>,
|
||||
/// For enum features: maps feature index to list of possible string values.
|
||||
/// Index in values list corresponds to the u16 value stored in feature_data.
|
||||
pub enum_values: rustc_hash::FxHashMap<usize, Vec<String>>,
|
||||
/// For enum features: maps feature index to per-value global counts (same order as enum_values).
|
||||
pub enum_counts: rustc_hash::FxHashMap<usize, Vec<u64>>,
|
||||
/// Per-row flag: true = construction date is approximate (from EPC band),
|
||||
/// false = exact (from new-build transaction date).
|
||||
/// Bit-packed: byte `row / 8`, bit `row % 8`. 8x smaller than Vec<bool>.
|
||||
approx_build_date_bits: Vec<u8>,
|
||||
/// Per-row renovation events. Keyed by (permuted) row index.
|
||||
/// Only rows with events are present in the map.
|
||||
renovation_history: FxHashMap<u32, Vec<RenovationEvent>>,
|
||||
/// Per-row historical sale transactions (Land Registry price-paid).
|
||||
/// Keyed by (permuted) row index. Only rows with prices are present.
|
||||
historical_prices: FxHashMap<u32, Vec<HistoricalPrice>>,
|
||||
property_sub_type: FxHashMap<u32, String>,
|
||||
price_qualifier: FxHashMap<u32, String>,
|
||||
}
|
||||
|
||||
impl PropertyData {
|
||||
/// Get the address string for a given row.
|
||||
pub fn address(&self, row: usize) -> &str {
|
||||
let offset = self.address_offsets[row] as usize;
|
||||
let length = self.address_lengths[row] as usize;
|
||||
&self.address_buffer[offset..offset + length]
|
||||
}
|
||||
|
||||
/// Get the postcode string for a given row.
|
||||
pub fn postcode(&self, row: usize) -> &str {
|
||||
self.postcode_interner.resolve(&self.postcode_keys[row])
|
||||
}
|
||||
|
||||
/// Get postcode components for field-level borrowing (avoids conflicting borrows with feature_data).
|
||||
pub fn postcode_parts(&self) -> (&lasso::RodeoReader, &[lasso::Spur]) {
|
||||
(&self.postcode_interner, &self.postcode_keys)
|
||||
}
|
||||
|
||||
/// Property rows for a given postcode string, or empty if unknown.
|
||||
pub fn rows_for_postcode(&self, postcode: &str) -> &[u32] {
|
||||
self.postcode_interner
|
||||
.get(postcode)
|
||||
.and_then(|key| self.postcode_row_index.get(&key))
|
||||
.map(Vec::as_slice)
|
||||
.unwrap_or(&[])
|
||||
}
|
||||
|
||||
/// Get the is_approx_build_date flag for a given row (bit-packed).
|
||||
pub fn is_approx_build_date(&self, row: usize) -> bool {
|
||||
let byte = self.approx_build_date_bits[row / 8];
|
||||
byte & (1 << (row % 8)) != 0
|
||||
}
|
||||
|
||||
/// Get renovation events for a given row (empty slice if none).
|
||||
pub fn renovation_history(&self, row: usize) -> &[RenovationEvent] {
|
||||
self.renovation_history
|
||||
.get(&(row as u32))
|
||||
.map(|v| v.as_slice())
|
||||
.unwrap_or(&[])
|
||||
}
|
||||
|
||||
/// Get historical sale transactions for a given row (empty slice if none).
|
||||
pub fn historical_prices(&self, row: usize) -> &[HistoricalPrice] {
|
||||
self.historical_prices
|
||||
.get(&(row as u32))
|
||||
.map(|v| v.as_slice())
|
||||
.unwrap_or(&[])
|
||||
}
|
||||
|
||||
/// Get property sub-type for a given row.
|
||||
pub fn property_sub_type(&self, row: usize) -> Option<&str> {
|
||||
self.property_sub_type
|
||||
.get(&(row as u32))
|
||||
.map(String::as_str)
|
||||
}
|
||||
|
||||
/// Get price qualifier for a given row.
|
||||
pub fn price_qualifier(&self, row: usize) -> Option<&str> {
|
||||
self.price_qualifier.get(&(row as u32)).map(String::as_str)
|
||||
}
|
||||
|
||||
/// Get the unquantized last sale price for charting.
|
||||
#[inline]
|
||||
pub fn last_known_price_raw(&self, row: usize) -> f32 {
|
||||
self.last_known_price_raw[row]
|
||||
}
|
||||
|
||||
/// Decode a single feature value from quantized u16 storage.
|
||||
#[inline]
|
||||
pub fn get_feature(&self, row: usize, feat_idx: usize) -> f32 {
|
||||
let raw = self.feature_data[row * self.num_features + feat_idx];
|
||||
if raw == NAN_U16 {
|
||||
return f32::NAN;
|
||||
}
|
||||
if feat_idx >= self.num_numeric {
|
||||
raw as f32
|
||||
} else {
|
||||
raw as f32 * self.dequant_a[feat_idx] + self.quant_min[feat_idx]
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a QuantRef for passing to aggregation/filter functions.
|
||||
pub fn quant_ref(&self) -> QuantRef<'_> {
|
||||
QuantRef {
|
||||
dequant_a: &self.dequant_a,
|
||||
quant_min: &self.quant_min,
|
||||
quant_range: &self.quant_range,
|
||||
num_numeric: self.num_numeric,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl PropertyData {
|
||||
/// Minimal empty instance for integration tests that need an `AppState`
|
||||
/// but never touch property data (e.g. checkout/webhook/invite flows).
|
||||
pub(crate) fn empty_for_tests() -> Self {
|
||||
PropertyData {
|
||||
lat: Vec::new(),
|
||||
lon: Vec::new(),
|
||||
feature_names: Vec::new(),
|
||||
num_features: 0,
|
||||
num_numeric: 0,
|
||||
feature_data: Vec::new(),
|
||||
dequant_a: Vec::new(),
|
||||
quant_min: Vec::new(),
|
||||
quant_range: Vec::new(),
|
||||
feature_stats: Vec::new(),
|
||||
poi_metrics: PostcodePoiMetrics::empty(0),
|
||||
last_known_price_raw: Vec::new(),
|
||||
address_buffer: String::new(),
|
||||
address_offsets: Vec::new(),
|
||||
address_lengths: Vec::new(),
|
||||
postcode_interner: lasso::Rodeo::default().into_reader(),
|
||||
postcode_keys: Vec::new(),
|
||||
postcode_row_index: FxHashMap::default(),
|
||||
address_token_index: FxHashMap::default(),
|
||||
address_prefix_index: FxHashMap::default(),
|
||||
address_search_interner: lasso::Rodeo::default().into_reader(),
|
||||
address_search_token_keys: Vec::new(),
|
||||
address_search_token_offsets: Vec::new(),
|
||||
address_search_token_lengths: Vec::new(),
|
||||
enum_values: rustc_hash::FxHashMap::default(),
|
||||
enum_counts: rustc_hash::FxHashMap::default(),
|
||||
approx_build_date_bits: Vec::new(),
|
||||
renovation_history: FxHashMap::default(),
|
||||
historical_prices: FxHashMap::default(),
|
||||
property_sub_type: FxHashMap::default(),
|
||||
price_qualifier: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
200
server-rs/src/data/property/poi_metrics.rs
Normal file
200
server-rs/src/data/property/poi_metrics.rs
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
//! Postcode-level POI metric side table: dynamic POI features are stored once
|
||||
//! per postcode (not per property row) to keep the hot row-major feature matrix
|
||||
//! narrow, with a per-property row mapping for lookups.
|
||||
|
||||
use anyhow::Context;
|
||||
use polars::prelude::*;
|
||||
use rayon::prelude::*;
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::consts::{NAN_U16, QUANT_SCALE};
|
||||
use crate::features::{self, Bounds};
|
||||
|
||||
use super::quant::QuantRef;
|
||||
use super::stats::{column_to_f32_vec, compute_feature_stats, FeatureStats};
|
||||
|
||||
pub(super) const NO_POI_METRIC_ROW: u32 = u32::MAX;
|
||||
|
||||
pub struct PostcodePoiMetrics {
|
||||
pub feature_names: Vec<String>,
|
||||
pub name_to_index: FxHashMap<String, usize>,
|
||||
/// Metric-major storage: columns[metric_idx][postcode_metric_idx].
|
||||
pub columns: Vec<Vec<u16>>,
|
||||
pub feature_stats: Vec<FeatureStats>,
|
||||
/// Per-property row lookup into the postcode metric table.
|
||||
row_to_metric_idx: Vec<u32>,
|
||||
dequant_a: Vec<f32>,
|
||||
quant_min: Vec<f32>,
|
||||
quant_range: Vec<f32>,
|
||||
}
|
||||
|
||||
impl PostcodePoiMetrics {
|
||||
pub(super) fn empty(row_count: usize) -> Self {
|
||||
Self {
|
||||
feature_names: Vec::new(),
|
||||
name_to_index: FxHashMap::default(),
|
||||
columns: Vec::new(),
|
||||
feature_stats: Vec::new(),
|
||||
row_to_metric_idx: vec![NO_POI_METRIC_ROW; row_count],
|
||||
dequant_a: Vec::new(),
|
||||
quant_min: Vec::new(),
|
||||
quant_range: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn from_postcode_df(
|
||||
df: &DataFrame,
|
||||
feature_names: Vec<String>,
|
||||
) -> anyhow::Result<Self> {
|
||||
if feature_names.is_empty() {
|
||||
return Ok(Self::empty(0));
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
metrics = feature_names.len(),
|
||||
postcodes = df.height(),
|
||||
"Building postcode POI metric side table"
|
||||
);
|
||||
|
||||
let col_major: Vec<Vec<f32>> = feature_names
|
||||
.par_iter()
|
||||
.map(|name| {
|
||||
let column = df
|
||||
.column(name.as_str())
|
||||
.with_context(|| format!("Missing POI metric column '{name}'"))?;
|
||||
column_to_f32_vec(column)
|
||||
})
|
||||
.collect::<anyhow::Result<Vec<_>>>()?;
|
||||
|
||||
let feature_stats: Vec<FeatureStats> = col_major
|
||||
.par_iter()
|
||||
.enumerate()
|
||||
.map(|(metric_idx, vals)| {
|
||||
let name = feature_names[metric_idx].as_str();
|
||||
let bounds = features::bounds_for(name)
|
||||
.with_context(|| format!("No bounds config for POI metric '{name}'"))?;
|
||||
Ok(compute_feature_stats(
|
||||
vals,
|
||||
&bounds,
|
||||
features::has_integer_bins(name),
|
||||
))
|
||||
})
|
||||
.collect::<anyhow::Result<Vec<_>>>()?;
|
||||
|
||||
let mut quant_min = Vec::with_capacity(feature_names.len());
|
||||
let mut quant_range = Vec::with_capacity(feature_names.len());
|
||||
for (metric_idx, stats) in feature_stats.iter().enumerate() {
|
||||
let (min, max) = match features::bounds_for(feature_names[metric_idx].as_str()) {
|
||||
Some(Bounds::Fixed { min, max }) => (min, max),
|
||||
_ => (stats.histogram.min, stats.histogram.max),
|
||||
};
|
||||
quant_min.push(min);
|
||||
quant_range.push(if max > min { max - min } else { 0.0 });
|
||||
}
|
||||
let dequant_a: Vec<f32> = quant_range
|
||||
.iter()
|
||||
.map(|&range| {
|
||||
if range > 0.0 {
|
||||
range / QUANT_SCALE
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let columns: Vec<Vec<u16>> = col_major
|
||||
.par_iter()
|
||||
.enumerate()
|
||||
.map(|(metric_idx, vals)| {
|
||||
let range = quant_range[metric_idx];
|
||||
let min = quant_min[metric_idx];
|
||||
vals.iter()
|
||||
.map(|&value| {
|
||||
if !value.is_finite() {
|
||||
NAN_U16
|
||||
} else if range > 0.0 {
|
||||
let normalized = (value - min) / range;
|
||||
(normalized * QUANT_SCALE).round().clamp(0.0, QUANT_SCALE) as u16
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let name_to_index = feature_names
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, name)| (name.clone(), idx))
|
||||
.collect();
|
||||
|
||||
Ok(Self {
|
||||
feature_names,
|
||||
name_to_index,
|
||||
columns,
|
||||
feature_stats,
|
||||
row_to_metric_idx: Vec::new(),
|
||||
dequant_a,
|
||||
quant_min,
|
||||
quant_range,
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn set_row_mapping(&mut self, row_to_metric_idx: Vec<u32>) {
|
||||
self.row_to_metric_idx = row_to_metric_idx;
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.feature_names.is_empty()
|
||||
}
|
||||
|
||||
pub fn num_features(&self) -> usize {
|
||||
self.feature_names.len()
|
||||
}
|
||||
|
||||
pub fn quant_ref(&self) -> QuantRef<'_> {
|
||||
QuantRef {
|
||||
dequant_a: &self.dequant_a,
|
||||
quant_min: &self.quant_min,
|
||||
quant_range: &self.quant_range,
|
||||
num_numeric: self.feature_names.len(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn metric_row_for_property(&self, row: usize) -> Option<usize> {
|
||||
self.row_to_metric_idx
|
||||
.get(row)
|
||||
.copied()
|
||||
.filter(|&idx| idx != NO_POI_METRIC_ROW)
|
||||
.map(|idx| idx as usize)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn raw_for_metric_row(&self, metric_row: usize, metric_idx: usize) -> u16 {
|
||||
self.columns[metric_idx][metric_row]
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn raw_for_property_row(&self, row: usize, metric_idx: usize) -> u16 {
|
||||
let Some(metric_row) = self.metric_row_for_property(row) else {
|
||||
return NAN_U16;
|
||||
};
|
||||
self.raw_for_metric_row(metric_row, metric_idx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn decode_raw(&self, metric_idx: usize, raw: u16) -> f32 {
|
||||
if raw == NAN_U16 {
|
||||
f32::NAN
|
||||
} else {
|
||||
raw as f32 * self.dequant_a[metric_idx] + self.quant_min[metric_idx]
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn get_for_property_row(&self, row: usize, metric_idx: usize) -> f32 {
|
||||
self.decode_raw(metric_idx, self.raw_for_property_row(row, metric_idx))
|
||||
}
|
||||
}
|
||||
46
server-rs/src/data/property/quant.rs
Normal file
46
server-rs/src/data/property/quant.rs
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
//! u16 quantization: decoding stored feature values and encoding filter bounds.
|
||||
|
||||
use crate::consts::{NAN_U16, QUANT_SCALE};
|
||||
|
||||
/// Lightweight reference to quantization parameters for decoding u16 feature data.
|
||||
pub struct QuantRef<'a> {
|
||||
pub dequant_a: &'a [f32],
|
||||
pub quant_min: &'a [f32],
|
||||
pub quant_range: &'a [f32],
|
||||
pub num_numeric: usize,
|
||||
}
|
||||
|
||||
impl QuantRef<'_> {
|
||||
/// Decode a raw u16 value back to f32.
|
||||
#[inline]
|
||||
pub fn decode(&self, feat_idx: usize, raw: u16) -> f32 {
|
||||
if raw == NAN_U16 {
|
||||
return f32::NAN;
|
||||
}
|
||||
if feat_idx >= self.num_numeric {
|
||||
raw as f32
|
||||
} else {
|
||||
raw as f32 * self.dequant_a[feat_idx] + self.quant_min[feat_idx]
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode a filter minimum bound to u16 (floors to include boundary values).
|
||||
#[inline]
|
||||
pub fn encode_min(&self, feat_idx: usize, value: f32) -> u16 {
|
||||
if !value.is_finite() || self.quant_range[feat_idx] == 0.0 {
|
||||
return 0;
|
||||
}
|
||||
let norm = (value - self.quant_min[feat_idx]) / self.quant_range[feat_idx];
|
||||
(norm * QUANT_SCALE).floor().clamp(0.0, QUANT_SCALE) as u16
|
||||
}
|
||||
|
||||
/// Encode a filter maximum bound to u16 (ceils to include boundary values).
|
||||
#[inline]
|
||||
pub fn encode_max(&self, feat_idx: usize, value: f32) -> u16 {
|
||||
if !value.is_finite() || self.quant_range[feat_idx] == 0.0 {
|
||||
return QUANT_SCALE as u16;
|
||||
}
|
||||
let norm = (value - self.quant_min[feat_idx]) / self.quant_range[feat_idx];
|
||||
(norm * QUANT_SCALE).ceil().clamp(0.0, QUANT_SCALE) as u16
|
||||
}
|
||||
}
|
||||
544
server-rs/src/data/property/stats.rs
Normal file
544
server-rs/src/data/property/stats.rs
Normal file
|
|
@ -0,0 +1,544 @@
|
|||
//! Feature statistics: outlier-bracketed histograms, percentile estimation and
|
||||
//! slider-bound computation.
|
||||
|
||||
use anyhow::Context;
|
||||
use polars::prelude::*;
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::consts::HISTOGRAM_BINS;
|
||||
use crate::features::Bounds;
|
||||
|
||||
/// Histogram with outlier buckets at the edges.
|
||||
/// - Bin 0: [min, p1) — low outliers
|
||||
/// - Bins 1 to n-2: [p1, p99) — main distribution, evenly divided
|
||||
/// - Bin n-1: [p99, max] — high outliers
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct Histogram {
|
||||
pub min: f32,
|
||||
pub max: f32,
|
||||
/// 1st percentile (left edge of main distribution)
|
||||
pub p1: f32,
|
||||
/// 99th percentile (right edge of main distribution)
|
||||
pub p99: f32,
|
||||
pub counts: Vec<u64>,
|
||||
}
|
||||
|
||||
impl Histogram {
|
||||
/// Return the bin index for a given value using the outlier-bracket layout.
|
||||
#[cfg(test)]
|
||||
pub fn bin_for_value(&self, value: f32) -> usize {
|
||||
let num_bins = self.counts.len();
|
||||
if value < self.p1 {
|
||||
0
|
||||
} else if value >= self.p99 {
|
||||
num_bins - 1
|
||||
} else {
|
||||
let middle_bins = num_bins.saturating_sub(2);
|
||||
if middle_bins > 0 && self.p99 > self.p1 {
|
||||
let width = (self.p99 - self.p1) / middle_bins as f32;
|
||||
let middle_bin = ((value - self.p1) / width) as usize;
|
||||
(1 + middle_bin).min(num_bins - 2)
|
||||
} else {
|
||||
num_bins / 2
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Width of a single middle bin (bins 1..n-2).
|
||||
#[cfg(test)]
|
||||
pub fn middle_bin_width(&self) -> f32 {
|
||||
let middle_bins = self.counts.len().saturating_sub(2);
|
||||
if middle_bins > 0 && self.p99 > self.p1 {
|
||||
(self.p99 - self.p1) / middle_bins as f32
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FeatureStats {
|
||||
pub slider_min: f32,
|
||||
pub slider_max: f32,
|
||||
pub histogram: Histogram,
|
||||
}
|
||||
|
||||
/// Compute a percentile from a uniformly-binned histogram.
|
||||
/// `prelim_counts` are uniform bins over [min, max].
|
||||
fn percentile_from_uniform_histogram(
|
||||
count: usize,
|
||||
min: f32,
|
||||
max: f32,
|
||||
prelim_counts: &[u64],
|
||||
percentile: f32,
|
||||
) -> f32 {
|
||||
if count == 0 || prelim_counts.is_empty() {
|
||||
return min;
|
||||
}
|
||||
let target = (count as f64 * percentile as f64 / 100.0).floor() as u64;
|
||||
let bin_width = (max - min) / prelim_counts.len() as f32;
|
||||
let mut cumulative = 0u64;
|
||||
for (i, &bin_count) in prelim_counts.iter().enumerate() {
|
||||
let prev_cumulative = cumulative;
|
||||
cumulative += bin_count;
|
||||
if cumulative > target {
|
||||
// Interpolate within this bin
|
||||
let bin_start = min + i as f32 * bin_width;
|
||||
let fraction = if bin_count > 0 {
|
||||
(target - prev_cumulative) as f32 / bin_count as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
return bin_start + fraction * bin_width;
|
||||
}
|
||||
}
|
||||
max
|
||||
}
|
||||
|
||||
/// Build a histogram and compute slider bounds based on the feature's Bounds config.
|
||||
pub fn compute_feature_stats(vals: &[f32], bounds: &Bounds, integer_bins: bool) -> FeatureStats {
|
||||
// Single pass: min, max, count (skipping NaN and infinity)
|
||||
let mut min = f32::INFINITY;
|
||||
let mut max = f32::NEG_INFINITY;
|
||||
let mut count = 0usize;
|
||||
for &value in vals {
|
||||
if value.is_finite() {
|
||||
if value < min {
|
||||
min = value;
|
||||
}
|
||||
if value > max {
|
||||
max = value;
|
||||
}
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
let (slider_min, slider_max) = match bounds {
|
||||
Bounds::Fixed {
|
||||
min: fmin,
|
||||
max: fmax,
|
||||
} => (*fmin, *fmax),
|
||||
Bounds::Percentile { .. } => (0.0, 0.0),
|
||||
};
|
||||
return FeatureStats {
|
||||
slider_min,
|
||||
slider_max,
|
||||
histogram: Histogram {
|
||||
min: 0.0,
|
||||
max: 0.0,
|
||||
p1: 0.0,
|
||||
p99: 0.0,
|
||||
counts: vec![0; HISTOGRAM_BINS],
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Build preliminary histogram with uniform bins to compute percentiles
|
||||
// Use full HISTOGRAM_BINS for percentile precision
|
||||
let range = if max == min { 1.0 } else { max - min };
|
||||
let prelim_max = min + range * (1.0 + 1e-6);
|
||||
let prelim_bin_width = (prelim_max - min) / HISTOGRAM_BINS as f32;
|
||||
|
||||
let mut prelim_counts = vec![0u64; HISTOGRAM_BINS];
|
||||
for &value in vals {
|
||||
if value.is_finite() {
|
||||
let bin = ((value - min) / prelim_bin_width) as usize;
|
||||
prelim_counts[bin.min(HISTOGRAM_BINS - 1)] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute p1 and p99 from preliminary histogram
|
||||
let mut p1 = percentile_from_uniform_histogram(count, min, max, &prelim_counts, 1.0);
|
||||
let mut p99 = percentile_from_uniform_histogram(count, min, max, &prelim_counts, 99.0);
|
||||
|
||||
// Iterative refinement for outlier-dominated distributions.
|
||||
// When extreme outliers (e.g. 317M sqm from web scraping) dominate the range,
|
||||
// the uniform histogram puts all real data in one bin, making percentile
|
||||
// estimation useless. Zoom into the estimated data region and recompute.
|
||||
let mut refined_counts = prelim_counts;
|
||||
let mut refined_count = count;
|
||||
let mut refined_min = min;
|
||||
let mut refined_max = max;
|
||||
for _ in 0..3 {
|
||||
let iqr = p99 - p1;
|
||||
if iqr <= 0.0 || (refined_max - refined_min) <= 5.0 * iqr {
|
||||
break;
|
||||
}
|
||||
let new_min = (p1 - iqr).max(min);
|
||||
let new_max = p99 + iqr;
|
||||
if new_max <= new_min {
|
||||
break;
|
||||
}
|
||||
let bin_width = (new_max - new_min) / HISTOGRAM_BINS as f32;
|
||||
let mut counts = vec![0u64; HISTOGRAM_BINS];
|
||||
let mut cnt = 0usize;
|
||||
for &value in vals {
|
||||
if value.is_finite() && value >= new_min && value <= new_max {
|
||||
let bin = ((value - new_min) / bin_width) as usize;
|
||||
counts[bin.min(HISTOGRAM_BINS - 1)] += 1;
|
||||
cnt += 1;
|
||||
}
|
||||
}
|
||||
if cnt == 0 {
|
||||
break;
|
||||
}
|
||||
p1 = percentile_from_uniform_histogram(cnt, new_min, new_max, &counts, 1.0);
|
||||
p99 = percentile_from_uniform_histogram(cnt, new_min, new_max, &counts, 99.0);
|
||||
refined_counts = counts;
|
||||
refined_count = cnt;
|
||||
refined_min = new_min;
|
||||
refined_max = new_max;
|
||||
}
|
||||
|
||||
// For integer-binned features, snap p1/p99 to integer boundaries
|
||||
// so each middle bin is exactly 1 unit wide.
|
||||
if integer_bins {
|
||||
p1 = p1.floor();
|
||||
p99 = p99.ceil();
|
||||
}
|
||||
|
||||
// Determine number of histogram bins
|
||||
let num_bins = if integer_bins && p99 > p1 {
|
||||
// One middle bin per integer + 2 outlier bins
|
||||
(p99 - p1) as usize + 2
|
||||
} else {
|
||||
// Count unique values within the p1–p99 range to cap histogram bins.
|
||||
// Using the full-range cardinality would over-allocate bins when outliers
|
||||
// inflate it (e.g. bedrooms: 1–137 unique values but only ~10 within p1–p99).
|
||||
let cardinality = {
|
||||
let mut unique_set = rustc_hash::FxHashSet::default();
|
||||
for &val in vals {
|
||||
if val.is_finite() && val >= p1 && val <= p99 {
|
||||
unique_set.insert(val.to_bits());
|
||||
}
|
||||
}
|
||||
unique_set.len()
|
||||
};
|
||||
HISTOGRAM_BINS.min(cardinality).max(3)
|
||||
};
|
||||
|
||||
// Build final histogram with outlier bins at edges:
|
||||
// - Bin 0: [min, p1) — low outliers
|
||||
// - Bins 1 to n-2: [p1, p99) — main distribution, evenly divided
|
||||
// - Bin n-1: [p99, max] — high outliers
|
||||
let mut counts = vec![0u64; num_bins];
|
||||
let middle_bins = num_bins.saturating_sub(2);
|
||||
let middle_width = if middle_bins > 0 && p99 > p1 {
|
||||
(p99 - p1) / middle_bins as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
for &value in vals {
|
||||
if value.is_finite() {
|
||||
let bin = if value < p1 {
|
||||
0 // Low outlier bin
|
||||
} else if value >= p99 {
|
||||
num_bins - 1 // High outlier bin
|
||||
} else if middle_width > 0.0 {
|
||||
// Middle bins (1 to n-2)
|
||||
let middle_bin = ((value - p1) / middle_width) as usize;
|
||||
(1 + middle_bin).min(num_bins - 2)
|
||||
} else {
|
||||
num_bins / 2 // Fallback if p1 == p99
|
||||
};
|
||||
counts[bin] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let histogram = Histogram {
|
||||
min: refined_min,
|
||||
max: refined_max,
|
||||
p1,
|
||||
p99,
|
||||
counts,
|
||||
};
|
||||
|
||||
// Compute slider bounds (use refined histogram for accurate percentiles)
|
||||
let (slider_min, slider_max) = match bounds {
|
||||
Bounds::Fixed {
|
||||
min: fmin,
|
||||
max: fmax,
|
||||
} => (*fmin, *fmax),
|
||||
Bounds::Percentile { low, high } => {
|
||||
let p_low = percentile_from_uniform_histogram(
|
||||
refined_count,
|
||||
refined_min,
|
||||
refined_max,
|
||||
&refined_counts,
|
||||
*low as f32,
|
||||
);
|
||||
let p_high = percentile_from_uniform_histogram(
|
||||
refined_count,
|
||||
refined_min,
|
||||
refined_max,
|
||||
&refined_counts,
|
||||
*high as f32,
|
||||
);
|
||||
(p_low, p_high)
|
||||
}
|
||||
};
|
||||
|
||||
FeatureStats {
|
||||
slider_min,
|
||||
slider_max,
|
||||
histogram,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn column_to_f32_vec(column: &Column) -> anyhow::Result<Vec<f32>> {
|
||||
let float_series = column
|
||||
.cast(&DataType::Float32)
|
||||
.context("Failed to cast column to Float32")?;
|
||||
let chunked = float_series
|
||||
.f32()
|
||||
.context("Failed to get f32 chunked array")?;
|
||||
Ok(chunked
|
||||
.into_iter()
|
||||
.map(|value| value.unwrap_or(f32::NAN))
|
||||
.collect())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::consts::QUANT_SCALE;
|
||||
use crate::features::Bounds;
|
||||
|
||||
fn make_fixed_bounds(min: f32, max: f32) -> Bounds {
|
||||
Bounds::Fixed { min, max }
|
||||
}
|
||||
|
||||
fn make_percentile_bounds(low: f64, high: f64) -> Bounds {
|
||||
Bounds::Percentile { low, high }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn histogram_empty_data() {
|
||||
let data: Vec<f32> = vec![];
|
||||
let bounds = make_fixed_bounds(0.0, 100.0);
|
||||
let stats = compute_feature_stats(&data, &bounds, false);
|
||||
|
||||
assert_eq!(stats.slider_min, 0.0);
|
||||
assert_eq!(stats.slider_max, 100.0);
|
||||
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn histogram_single_value() {
|
||||
let data = vec![50.0_f32];
|
||||
let bounds = make_fixed_bounds(0.0, 100.0);
|
||||
let stats = compute_feature_stats(&data, &bounds, false);
|
||||
|
||||
assert_eq!(stats.histogram.min, 50.0);
|
||||
assert_eq!(stats.histogram.max, 50.0);
|
||||
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn histogram_uniform_distribution() {
|
||||
let data: Vec<f32> = (0..100).map(|i| i as f32).collect();
|
||||
let bounds = make_fixed_bounds(0.0, 100.0);
|
||||
let stats = compute_feature_stats(&data, &bounds, false);
|
||||
|
||||
assert_eq!(stats.histogram.min, 0.0);
|
||||
assert_eq!(stats.histogram.max, 99.0);
|
||||
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn histogram_with_nan_values() {
|
||||
let data = vec![10.0_f32, f32::NAN, 20.0, f32::NAN, 30.0];
|
||||
let bounds = make_fixed_bounds(0.0, 100.0);
|
||||
let stats = compute_feature_stats(&data, &bounds, false);
|
||||
|
||||
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 3);
|
||||
assert_eq!(stats.histogram.min, 10.0);
|
||||
assert_eq!(stats.histogram.max, 30.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn histogram_all_nan() {
|
||||
let data = vec![f32::NAN, f32::NAN, f32::NAN];
|
||||
let bounds = make_fixed_bounds(0.0, 100.0);
|
||||
let stats = compute_feature_stats(&data, &bounds, false);
|
||||
|
||||
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn histogram_all_same_value() {
|
||||
let data = vec![42.0_f32; 1000];
|
||||
let bounds = make_fixed_bounds(0.0, 100.0);
|
||||
let stats = compute_feature_stats(&data, &bounds, false);
|
||||
|
||||
assert_eq!(stats.histogram.min, 42.0);
|
||||
assert_eq!(stats.histogram.max, 42.0);
|
||||
assert_eq!(stats.histogram.p1, 42.0);
|
||||
assert_eq!(stats.histogram.p99, 42.0);
|
||||
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn histogram_percentile_bounds() {
|
||||
let mut data: Vec<f32> = vec![0.0]; // Low outlier
|
||||
data.extend((1..99).map(|i| 50.0 + i as f32 * 0.01));
|
||||
data.push(1000.0); // High outlier
|
||||
|
||||
let bounds = make_percentile_bounds(2.0, 98.0);
|
||||
let stats = compute_feature_stats(&data, &bounds, false);
|
||||
|
||||
assert!(stats.slider_min > 0.0);
|
||||
assert!(stats.slider_max < 1000.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fixed_price_bounds_keep_slider_cap() {
|
||||
let data = vec![400_000.0_f32, 2_500_000.0, 3_750_000.0];
|
||||
let bounds = make_fixed_bounds(0.0, 2_500_000.0);
|
||||
let stats = compute_feature_stats(&data, &bounds, false);
|
||||
|
||||
assert_eq!(stats.slider_min, 0.0);
|
||||
assert_eq!(stats.slider_max, 2_500_000.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn histogram_bin_for_value() {
|
||||
let hist = Histogram {
|
||||
min: 0.0,
|
||||
max: 100.0,
|
||||
p1: 10.0,
|
||||
p99: 90.0,
|
||||
counts: vec![0; 10],
|
||||
};
|
||||
|
||||
assert_eq!(hist.bin_for_value(5.0), 0); // Low outlier bin
|
||||
assert_eq!(hist.bin_for_value(95.0), 9); // High outlier bin
|
||||
|
||||
let mid_value = 50.0;
|
||||
let bin = hist.bin_for_value(mid_value);
|
||||
assert!((1..=8).contains(&bin));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn histogram_middle_bin_width() {
|
||||
let hist = Histogram {
|
||||
min: 0.0,
|
||||
max: 100.0,
|
||||
p1: 10.0,
|
||||
p99: 90.0,
|
||||
counts: vec![0; 10],
|
||||
};
|
||||
|
||||
let expected_width = (90.0 - 10.0) / 8.0;
|
||||
assert!((hist.middle_bin_width() - expected_width).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn histogram_cardinality_caps_bins() {
|
||||
let data = vec![1.0_f32, 1.0, 2.0, 2.0, 3.0, 3.0];
|
||||
let bounds = make_fixed_bounds(0.0, 100.0);
|
||||
let stats = compute_feature_stats(&data, &bounds, false);
|
||||
|
||||
assert_eq!(stats.histogram.counts.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn min_max_skips_nan() {
|
||||
let values = vec![10.0_f32, f32::NAN, 20.0, f32::NAN, 5.0];
|
||||
|
||||
let mut min = f32::INFINITY;
|
||||
let mut max = f32::NEG_INFINITY;
|
||||
for &v in &values {
|
||||
if v.is_finite() {
|
||||
if v < min {
|
||||
min = v;
|
||||
}
|
||||
if v > max {
|
||||
max = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(min, 5.0);
|
||||
assert_eq!(max, 20.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn count_skips_nan() {
|
||||
let values = [1.0_f32, f32::NAN, 2.0, f32::NAN, 3.0];
|
||||
let count = values.iter().filter(|v| v.is_finite()).count();
|
||||
assert_eq!(count, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn infinity_values_excluded() {
|
||||
let data = vec![f32::INFINITY, f32::NEG_INFINITY, 50.0];
|
||||
let bounds = Bounds::Fixed {
|
||||
min: 0.0,
|
||||
max: 100.0,
|
||||
};
|
||||
let stats = compute_feature_stats(&data, &bounds, false);
|
||||
|
||||
assert_eq!(stats.histogram.min, 50.0);
|
||||
assert_eq!(stats.histogram.max, 50.0);
|
||||
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn only_finite_values() {
|
||||
let data = vec![10.0_f32, 20.0, 30.0];
|
||||
let bounds = Bounds::Fixed {
|
||||
min: 0.0,
|
||||
max: 100.0,
|
||||
};
|
||||
let stats = compute_feature_stats(&data, &bounds, false);
|
||||
|
||||
assert_eq!(stats.histogram.min, 10.0);
|
||||
assert_eq!(stats.histogram.max, 30.0);
|
||||
assert_eq!(stats.histogram.counts.iter().sum::<u64>(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extreme_outlier_does_not_destroy_quantization() {
|
||||
// Simulate floor area: 10k normal values (50-200 sqm) + one 317M outlier
|
||||
let mut data: Vec<f32> = (0..10_000).map(|i| 50.0 + (i % 150) as f32).collect();
|
||||
data.push(317_000_000.0); // Extreme outlier from web scraping
|
||||
|
||||
let bounds = make_percentile_bounds(0.0, 98.0);
|
||||
let stats = compute_feature_stats(&data, &bounds, false);
|
||||
|
||||
// After refinement, histogram range should be much tighter than 317M
|
||||
assert!(
|
||||
stats.histogram.max < 1_000_000.0,
|
||||
"histogram.max should be refined, got {}",
|
||||
stats.histogram.max,
|
||||
);
|
||||
// p1 should be near 50, not millions
|
||||
assert!(
|
||||
stats.histogram.p1 < 100.0,
|
||||
"p1 should be near real data, got {}",
|
||||
stats.histogram.p1,
|
||||
);
|
||||
// Slider min should reflect actual data range
|
||||
assert!(
|
||||
stats.slider_min < 100.0,
|
||||
"slider_min should be near real data, got {}",
|
||||
stats.slider_min,
|
||||
);
|
||||
|
||||
// Quantization using histogram.min/max should give usable range
|
||||
let qmin = stats.histogram.min;
|
||||
let qrange = stats.histogram.max - stats.histogram.min;
|
||||
assert!(qrange > 0.0 && qrange < 1_000_000.0);
|
||||
|
||||
// A typical floor area (100 sqm) should be distinguishable from min
|
||||
let normalized = (100.0 - qmin) / qrange;
|
||||
let encoded = (normalized * QUANT_SCALE).round() as u16;
|
||||
assert!(
|
||||
encoded > 100,
|
||||
"100 sqm should encode to a meaningful u16 value, got {}",
|
||||
encoded,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -272,6 +272,21 @@ pub fn slugify(name: &str) -> String {
|
|||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl TravelTimeStore {
|
||||
/// Minimal empty instance for integration tests that need an `AppState`
|
||||
/// but never touch travel time data.
|
||||
pub(crate) fn empty_for_tests() -> Self {
|
||||
Self {
|
||||
base_dir: PathBuf::new(),
|
||||
available_modes: Vec::new(),
|
||||
destinations: FxHashMap::default(),
|
||||
slug_to_file: FxHashMap::default(),
|
||||
cache: Mutex::new(LruCache::new(1)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
|||
|
|
@ -1042,7 +1042,44 @@ async fn main() -> anyhow::Result<()> {
|
|||
listener,
|
||||
app.into_make_service_with_connect_info::<std::net::SocketAddr>(),
|
||||
)
|
||||
.with_graceful_shutdown(shutdown_signal())
|
||||
.await
|
||||
.context("Server error")?;
|
||||
info!("Server shut down cleanly");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Resolves on SIGTERM or SIGINT so in-flight requests (exports, checkouts)
|
||||
/// can drain before the process exits. The realtime SSE proxy connections
|
||||
/// never complete, so a watchdog force-exits before Docker's default 10s
|
||||
/// stop grace period elapses and it sends SIGKILL.
|
||||
async fn shutdown_signal() {
|
||||
let ctrl_c = async {
|
||||
tokio::signal::ctrl_c()
|
||||
.await
|
||||
.expect("Failed to install SIGINT handler");
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
let terminate = async {
|
||||
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
|
||||
.expect("Failed to install SIGTERM handler")
|
||||
.recv()
|
||||
.await;
|
||||
};
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let terminate = std::future::pending::<()>();
|
||||
|
||||
tokio::select! {
|
||||
() = ctrl_c => {},
|
||||
() = terminate => {},
|
||||
}
|
||||
info!("Shutdown signal received; draining in-flight requests");
|
||||
|
||||
tokio::spawn(async {
|
||||
tokio::time::sleep(Duration::from_secs(8)).await;
|
||||
tracing::warn!("Graceful shutdown drain timed out after 8s; forcing exit");
|
||||
std::process::exit(0);
|
||||
});
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
448
server-rs/src/routes/ai_filters/handler.rs
Normal file
448
server-rs/src/routes/ai_filters/handler.rs
Normal file
|
|
@ -0,0 +1,448 @@
|
|||
//! The `POST /api/ai-filters` route handler: rate limiting, the Gemini
|
||||
//! function-calling conversation loop, and zero-match refinement.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::extract::State;
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::Json;
|
||||
use axum::Extension;
|
||||
use metrics::counter;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::auth::OptionalUser;
|
||||
use crate::consts::{AI_FILTERS_MAX_TOKENS, AI_FILTERS_TEMPERATURE, AI_FILTERS_WEEKLY_TOKEN_LIMIT};
|
||||
use crate::pocketbase::log_ai_query;
|
||||
use crate::state::SharedState;
|
||||
use crate::utils::gemini_chat;
|
||||
|
||||
use super::matching::count_matching_rows;
|
||||
use super::parsing::{
|
||||
normalize_context_filters, strip_markdown_fences, validate_and_convert,
|
||||
validate_travel_time_filters,
|
||||
};
|
||||
use super::tools::{build_tool_declarations, execute_destination_search};
|
||||
use super::usage::{current_week_number, fetch_ai_usage, record_ai_request_usage};
|
||||
use super::{AiFiltersRequest, AiFiltersResponse};
|
||||
|
||||
/// Budget limits for the Gemini conversation loop. Separate counters prevent
|
||||
/// tool calls (destination searches) from starving JSON retries or zero-match
|
||||
/// refinements.
|
||||
const MAX_TOOL_CALLS: usize = 4;
|
||||
const MAX_RETRIES: usize = 3;
|
||||
const MAX_REFINEMENTS: u32 = 3;
|
||||
const MAX_TOTAL_ROUNDS: usize = 10;
|
||||
|
||||
const MAX_AI_QUERY_CHARS: usize = 5000;
|
||||
|
||||
pub async fn post_ai_filters(
|
||||
State(shared): State<Arc<SharedState>>,
|
||||
Extension(user): Extension<OptionalUser>,
|
||||
Json(req): Json<AiFiltersRequest>,
|
||||
) -> Result<Json<AiFiltersResponse>, (StatusCode, String)> {
|
||||
let state = shared.load_state();
|
||||
// Auth check
|
||||
let user = user
|
||||
.0
|
||||
.ok_or((StatusCode::UNAUTHORIZED, "Login required".into()))?;
|
||||
|
||||
if req.query.chars().count() > MAX_AI_QUERY_CHARS {
|
||||
counter!("ai_requests_total", "status" => "query_too_long").increment(1);
|
||||
return Err((
|
||||
StatusCode::PAYLOAD_TOO_LARGE,
|
||||
format!("Query too long (max {MAX_AI_QUERY_CHARS} chars)"),
|
||||
));
|
||||
}
|
||||
|
||||
// Check weekly token usage
|
||||
let current_week = current_week_number();
|
||||
let (stored_tokens, stored_week) = fetch_ai_usage(&state, &user.id).await?;
|
||||
let tokens_used = if stored_week == current_week {
|
||||
stored_tokens
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
if tokens_used >= AI_FILTERS_WEEKLY_TOKEN_LIMIT {
|
||||
counter!("ai_requests_total", "status" => "rate_limited").increment(1);
|
||||
return Err((
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
"Weekly AI usage limit reached. Resets next week.".into(),
|
||||
));
|
||||
}
|
||||
|
||||
info!(query = %req.query, user_id = %user.id, "POST /api/ai-filters");
|
||||
|
||||
let tools = build_tool_declarations(&state);
|
||||
|
||||
// Build user message with optional context for conversational refinement
|
||||
let user_text = if let Some(ref ctx) = req.context {
|
||||
let mut msg = String::new();
|
||||
msg.push_str("Currently active filters:\n");
|
||||
let normalized_filters = normalize_context_filters(&ctx.filters);
|
||||
msg.push_str(&serde_json::to_string(&normalized_filters).unwrap_or_default());
|
||||
if !ctx.travel_time.is_empty() {
|
||||
msg.push_str("\nCurrently active travel time filters:\n");
|
||||
for tt in &ctx.travel_time {
|
||||
let bounds = match (tt.min, tt.max) {
|
||||
(Some(min), Some(max)) => format!("{}-{} min", min, max),
|
||||
(Some(min), None) => format!("min {} min", min),
|
||||
(None, Some(max)) => format!("max {} min", max),
|
||||
(None, None) => "no range".to_string(),
|
||||
};
|
||||
msg.push_str(&format!("- {} to {} ({})\n", tt.mode, tt.label, bounds));
|
||||
}
|
||||
}
|
||||
msg.push_str(&format!("\nUser request: {}", req.query));
|
||||
msg
|
||||
} else {
|
||||
req.query.clone()
|
||||
};
|
||||
|
||||
let mut contents = vec![json!({
|
||||
"role": "user",
|
||||
"parts": [{ "text": user_text }]
|
||||
})];
|
||||
|
||||
let mut total_tokens_accumulated: u64 = 0;
|
||||
let mut tool_call_count = 0usize;
|
||||
let mut retry_count = 0usize;
|
||||
let mut refinement_attempts = 0u32;
|
||||
|
||||
// Function calling loop: model may call search_destinations, we execute and feed back
|
||||
for round in 0..MAX_TOTAL_ROUNDS {
|
||||
let body = json!({
|
||||
"systemInstruction": {
|
||||
"parts": [{ "text": state.ai_filters_system_prompt }]
|
||||
},
|
||||
"contents": contents,
|
||||
"tools": tools,
|
||||
"generationConfig": {
|
||||
"temperature": AI_FILTERS_TEMPERATURE,
|
||||
"maxOutputTokens": AI_FILTERS_MAX_TOKENS,
|
||||
"thinkingConfig": { "thinkingLevel": "LOW" },
|
||||
}
|
||||
});
|
||||
|
||||
let json_resp = match gemini_chat(
|
||||
&state.http_client,
|
||||
&state.gemini_api_key,
|
||||
&state.gemini_model,
|
||||
&body,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(err) => {
|
||||
record_ai_request_usage(
|
||||
&state,
|
||||
&user.id,
|
||||
tokens_used,
|
||||
current_week,
|
||||
total_tokens_accumulated,
|
||||
"llm_error",
|
||||
)
|
||||
.await;
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
// Accumulate token usage
|
||||
total_tokens_accumulated += json_resp
|
||||
.get("usageMetadata")
|
||||
.and_then(|md| md.get("totalTokenCount"))
|
||||
.and_then(|tc| tc.as_u64())
|
||||
.unwrap_or(0);
|
||||
|
||||
let candidate = match json_resp
|
||||
.get("candidates")
|
||||
.and_then(|cs| cs.get(0))
|
||||
.and_then(|c| c.get("content"))
|
||||
{
|
||||
Some(candidate) => candidate,
|
||||
None => {
|
||||
warn!("Malformed Gemini response: missing candidates[0].content");
|
||||
record_ai_request_usage(
|
||||
&state,
|
||||
&user.id,
|
||||
tokens_used,
|
||||
current_week,
|
||||
total_tokens_accumulated,
|
||||
"malformed_response",
|
||||
)
|
||||
.await;
|
||||
return Err((StatusCode::BAD_GATEWAY, "Malformed Gemini response".into()));
|
||||
}
|
||||
};
|
||||
|
||||
let parts = match candidate.get("parts").and_then(|p| p.as_array()) {
|
||||
Some(parts) => parts,
|
||||
None => {
|
||||
warn!("Malformed Gemini response: missing parts array");
|
||||
record_ai_request_usage(
|
||||
&state,
|
||||
&user.id,
|
||||
tokens_used,
|
||||
current_week,
|
||||
total_tokens_accumulated,
|
||||
"malformed_response",
|
||||
)
|
||||
.await;
|
||||
return Err((StatusCode::BAD_GATEWAY, "Malformed Gemini response".into()));
|
||||
}
|
||||
};
|
||||
|
||||
// Check if the model made a function call.
|
||||
// Find the full part (includes thoughtSignature required by Gemini 3 models).
|
||||
if let Some(fc_part) = parts.iter().find(|part| part.get("functionCall").is_some()) {
|
||||
let fc = fc_part.get("functionCall").unwrap();
|
||||
let fn_name = fc.get("name").and_then(|n| n.as_str()).unwrap_or("");
|
||||
let fn_args = fc.get("args").cloned().unwrap_or(json!({}));
|
||||
|
||||
tool_call_count += 1;
|
||||
info!(
|
||||
function = fn_name,
|
||||
round = round,
|
||||
tool_call = tool_call_count,
|
||||
"AI called tool"
|
||||
);
|
||||
|
||||
if tool_call_count > MAX_TOOL_CALLS {
|
||||
warn!("Tool call budget exhausted, forcing text output");
|
||||
contents.push(candidate.clone());
|
||||
contents.push(json!({
|
||||
"role": "user",
|
||||
"parts": [{ "text": "Tool call limit reached. Output your best JSON now using the destinations you already found. Do not call any more tools." }]
|
||||
}));
|
||||
continue;
|
||||
}
|
||||
|
||||
let fn_result = if fn_name == "search_destinations" {
|
||||
let query = fn_args.get("query").and_then(|q| q.as_str()).unwrap_or("");
|
||||
let mode = fn_args
|
||||
.get("mode")
|
||||
.and_then(|m| m.as_str())
|
||||
.unwrap_or("transit");
|
||||
execute_destination_search(&state, query, mode)
|
||||
} else {
|
||||
json!({"error": "unknown function"})
|
||||
};
|
||||
|
||||
// Append the model's full response (preserves thoughtSignature) + our function result
|
||||
contents.push(candidate.clone());
|
||||
contents.push(json!({
|
||||
"role": "user",
|
||||
"parts": [{
|
||||
"functionResponse": {
|
||||
"name": fn_name,
|
||||
"response": fn_result
|
||||
}
|
||||
}]
|
||||
}));
|
||||
|
||||
// Continue the loop — model will process the results
|
||||
continue;
|
||||
}
|
||||
|
||||
// Model returned text — extract and parse as JSON
|
||||
let text = parts
|
||||
.iter()
|
||||
.find_map(|part| part.get("text").and_then(|t| t.as_str()))
|
||||
.unwrap_or("");
|
||||
let text = strip_markdown_fences(text);
|
||||
let text = text.trim();
|
||||
|
||||
if text.is_empty() {
|
||||
retry_count += 1;
|
||||
warn!(
|
||||
"Gemini returned empty text content (round {}, retry {})",
|
||||
round, retry_count
|
||||
);
|
||||
if retry_count > MAX_RETRIES {
|
||||
record_ai_request_usage(
|
||||
&state,
|
||||
&user.id,
|
||||
tokens_used,
|
||||
current_week,
|
||||
total_tokens_accumulated,
|
||||
"empty_response",
|
||||
)
|
||||
.await;
|
||||
return Err((
|
||||
StatusCode::BAD_GATEWAY,
|
||||
"AI returned empty responses".into(),
|
||||
));
|
||||
}
|
||||
contents.push(candidate.clone());
|
||||
contents.push(json!({
|
||||
"role": "user",
|
||||
"parts": [{ "text": "Your response was empty. Please output the JSON object." }]
|
||||
}));
|
||||
continue;
|
||||
}
|
||||
|
||||
let raw: Value = match serde_json::from_str(text) {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
retry_count += 1;
|
||||
warn!(error = %err, round = round, retry = retry_count, "Failed to parse Gemini JSON output");
|
||||
if retry_count > MAX_RETRIES {
|
||||
record_ai_request_usage(
|
||||
&state,
|
||||
&user.id,
|
||||
tokens_used,
|
||||
current_week,
|
||||
total_tokens_accumulated,
|
||||
"invalid_json",
|
||||
)
|
||||
.await;
|
||||
return Err((StatusCode::BAD_GATEWAY, "AI returned invalid JSON".into()));
|
||||
}
|
||||
contents.push(candidate.clone());
|
||||
contents.push(json!({
|
||||
"role": "user",
|
||||
"parts": [{ "text": "That was not valid JSON. Please output ONLY the JSON object with numeric_filters, enum_filters, travel_time_filters, and notes." }]
|
||||
}));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let filters = validate_and_convert(&raw, &state.features_response);
|
||||
let travel_time_filters = validate_travel_time_filters(&raw, &state);
|
||||
let notes = raw
|
||||
.get("notes")
|
||||
.and_then(|val| val.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
// Count matching properties and refine if too restrictive
|
||||
let (match_count, match_bounds) =
|
||||
count_matching_rows(&state, &filters, &travel_time_filters);
|
||||
info!(
|
||||
match_count = match_count,
|
||||
round = round,
|
||||
"AI filter match count"
|
||||
);
|
||||
|
||||
if match_count == 0 {
|
||||
refinement_attempts += 1;
|
||||
let total_rows = state.data.lat.len();
|
||||
info!(
|
||||
attempt = refinement_attempts,
|
||||
"0 matches out of {total_rows} — asking AI to relax filters"
|
||||
);
|
||||
|
||||
if refinement_attempts > MAX_REFINEMENTS {
|
||||
warn!("Refinement budget exhausted, returning filters with 0 matches");
|
||||
record_ai_request_usage(
|
||||
&state,
|
||||
&user.id,
|
||||
tokens_used,
|
||||
current_week,
|
||||
total_tokens_accumulated,
|
||||
"zero_matches",
|
||||
)
|
||||
.await;
|
||||
|
||||
let notes = if notes.is_empty() {
|
||||
"No properties match these filters. Try relaxing some constraints.".to_string()
|
||||
} else {
|
||||
format!(
|
||||
"{}. No properties match. Try relaxing some constraints.",
|
||||
notes
|
||||
)
|
||||
};
|
||||
|
||||
return Ok(Json(AiFiltersResponse {
|
||||
filters,
|
||||
travel_time_filters,
|
||||
notes,
|
||||
match_count: 0,
|
||||
match_bounds: None,
|
||||
}));
|
||||
}
|
||||
|
||||
let feedback = match refinement_attempts {
|
||||
1 => format!(
|
||||
"Your proposed filters matched 0 properties out of {total_rows} total. \
|
||||
The combination is too restrictive. Please widen some numeric ranges \
|
||||
or add more enum values while keeping the user's intent. \
|
||||
Output the adjusted JSON."
|
||||
),
|
||||
2 => format!(
|
||||
"Still 0 matches out of {total_rows}. Please widen ranges further. \
|
||||
Output the adjusted JSON."
|
||||
),
|
||||
_ => format!(
|
||||
"Still 0 matches out of {total_rows}. Please remove additional filters \
|
||||
until some properties match, keeping the user's core priority. \
|
||||
Output the adjusted JSON."
|
||||
),
|
||||
};
|
||||
contents.push(candidate.clone());
|
||||
contents.push(json!({
|
||||
"role": "user",
|
||||
"parts": [{ "text": feedback }]
|
||||
}));
|
||||
continue;
|
||||
}
|
||||
|
||||
record_ai_request_usage(
|
||||
&state,
|
||||
&user.id,
|
||||
tokens_used,
|
||||
current_week,
|
||||
total_tokens_accumulated,
|
||||
"success",
|
||||
)
|
||||
.await;
|
||||
|
||||
// Log the query to PocketBase (fire-and-forget)
|
||||
let filters_json = serde_json::to_string(&filters).unwrap_or_default();
|
||||
let log_state = state.clone();
|
||||
let log_user_id = user.id.clone();
|
||||
let log_query = req.query.clone();
|
||||
let log_notes = notes.clone();
|
||||
let log_rounds = (round + 1) as u64;
|
||||
tokio::spawn(async move {
|
||||
log_ai_query(
|
||||
&log_state,
|
||||
&log_user_id,
|
||||
&log_query,
|
||||
&filters_json,
|
||||
&log_notes,
|
||||
total_tokens_accumulated,
|
||||
log_rounds,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
return Ok(Json(AiFiltersResponse {
|
||||
filters,
|
||||
travel_time_filters,
|
||||
notes,
|
||||
match_count,
|
||||
match_bounds,
|
||||
}));
|
||||
}
|
||||
|
||||
// Exhausted total round budget without getting a valid response
|
||||
warn!(
|
||||
"AI exhausted {} total rounds without final response (tools={}, retries={}, refinements={})",
|
||||
MAX_TOTAL_ROUNDS, tool_call_count, retry_count, refinement_attempts
|
||||
);
|
||||
record_ai_request_usage(
|
||||
&state,
|
||||
&user.id,
|
||||
tokens_used,
|
||||
current_week,
|
||||
total_tokens_accumulated,
|
||||
"incomplete",
|
||||
)
|
||||
.await;
|
||||
Err((
|
||||
StatusCode::BAD_GATEWAY,
|
||||
"AI could not complete the request".into(),
|
||||
))
|
||||
}
|
||||
158
server-rs/src/routes/ai_filters/matching.rs
Normal file
158
server-rs/src/routes/ai_filters/matching.rs
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
//! Counting properties that match the AI-proposed property and travel time
|
||||
//! filters, and computing a camera-friendly bounding box of the matches.
|
||||
|
||||
use serde_json::Value;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::data::travel_time::TravelData;
|
||||
use crate::parsing::{parse_filters_with_poi, row_passes_filters, row_passes_poi_filters};
|
||||
use crate::state::AppState;
|
||||
|
||||
use super::{MatchBounds, TravelTimeFilter};
|
||||
|
||||
/// Bounding box over matched coordinates, trimmed to the 5th–95th percentile
|
||||
/// per axis (when there are enough points) so a handful of remote outliers
|
||||
/// doesn't zoom the camera out to all of England.
|
||||
fn percentile_trimmed_bounds(mut lats: Vec<f32>, mut lons: Vec<f32>) -> Option<MatchBounds> {
|
||||
if lats.is_empty() || lats.len() != lons.len() {
|
||||
return None;
|
||||
}
|
||||
lats.sort_unstable_by(f32::total_cmp);
|
||||
lons.sort_unstable_by(f32::total_cmp);
|
||||
let last = lats.len() - 1;
|
||||
let (lo, hi) = if lats.len() >= 20 {
|
||||
let trim = lats.len() / 20;
|
||||
(trim, last - trim)
|
||||
} else {
|
||||
(0, last)
|
||||
};
|
||||
Some(MatchBounds {
|
||||
south: lats[lo],
|
||||
north: lats[hi],
|
||||
west: lons[lo],
|
||||
east: lons[hi],
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert validated filter JSON back to the `;;`-separated filter string format
|
||||
/// that `parse_filters` expects.
|
||||
///
|
||||
/// Numeric: `{"name": [min, max]}` → `name:min:max`
|
||||
/// Enum: `{"name": ["val1", "val2"]}` → `name:val1|val2`
|
||||
fn filters_to_filter_string(filters: &Value) -> String {
|
||||
let obj = match filters.as_object() {
|
||||
Some(obj) => obj,
|
||||
None => return String::new(),
|
||||
};
|
||||
|
||||
let mut parts = Vec::new();
|
||||
for (name, value) in obj {
|
||||
if let Some(arr) = value.as_array() {
|
||||
if arr.len() == 2 && arr[0].is_number() && arr[1].is_number() {
|
||||
let min = arr[0].as_f64().unwrap_or(0.0);
|
||||
let max = arr[1].as_f64().unwrap_or(0.0);
|
||||
parts.push(format!("{name}:{min}:{max}"));
|
||||
} else if !arr.is_empty() && arr[0].is_string() {
|
||||
let values: Vec<&str> = arr.iter().filter_map(|v| v.as_str()).collect();
|
||||
if !values.is_empty() {
|
||||
parts.push(format!("{name}:{}", values.join("|")));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
parts.join(";;")
|
||||
}
|
||||
|
||||
/// Count how many rows in the property dataset pass the given property filters
|
||||
/// AND travel time filters. Travel time data is loaded from the TravelTimeStore
|
||||
/// and checked per-postcode (same logic as hexagons.rs).
|
||||
pub(super) fn count_matching_rows(
|
||||
state: &AppState,
|
||||
filters: &Value,
|
||||
travel_time_filters: &[TravelTimeFilter],
|
||||
) -> (usize, Option<MatchBounds>) {
|
||||
let filter_str = filters_to_filter_string(filters);
|
||||
|
||||
let quant = state.data.quant_ref();
|
||||
let poi_quant = state.data.poi_metrics.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters, parsed_poi_filters) = if filter_str.is_empty() {
|
||||
(Vec::new(), Vec::new(), Vec::new())
|
||||
} else {
|
||||
match parse_filters_with_poi(
|
||||
Some(&filter_str),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
&quant,
|
||||
&state.data.poi_metrics.name_to_index,
|
||||
&poi_quant,
|
||||
) {
|
||||
Ok(f) => f,
|
||||
Err(err) => {
|
||||
warn!("Failed to parse filters for match count: {err}");
|
||||
return (0, None);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Load travel time data for each filter entry
|
||||
let travel_data: Vec<(TravelData, Option<f32>, Option<f32>)> = travel_time_filters
|
||||
.iter()
|
||||
.filter_map(|ttf| {
|
||||
let data = state.travel_time_store.get(&ttf.mode, &ttf.slug).ok()?;
|
||||
Some((data, ttf.min, ttf.max))
|
||||
})
|
||||
.collect();
|
||||
let has_travel = !travel_data.is_empty();
|
||||
|
||||
let feature_data = &state.data.feature_data;
|
||||
let num_features = state.data.num_features;
|
||||
let num_rows = state.data.lat.len();
|
||||
let (pc_interner, pc_keys) = state.data.postcode_parts();
|
||||
let has_poi_filters = !parsed_poi_filters.is_empty();
|
||||
|
||||
let mut count = 0usize;
|
||||
let mut matched_lats: Vec<f32> = Vec::new();
|
||||
let mut matched_lons: Vec<f32> = Vec::new();
|
||||
for (row, pc_key) in pc_keys.iter().enumerate().take(num_rows) {
|
||||
if !row_passes_filters(
|
||||
row,
|
||||
&parsed_filters,
|
||||
&parsed_enum_filters,
|
||||
feature_data,
|
||||
num_features,
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
if has_poi_filters
|
||||
&& !row_passes_poi_filters(row, &parsed_poi_filters, &state.data.poi_metrics)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if has_travel {
|
||||
let postcode = pc_interner.resolve(pc_key);
|
||||
let mut passes_travel = true;
|
||||
for (data, fmin, fmax) in &travel_data {
|
||||
let pass = if let Some(mins) = data.get(postcode).map(|r| r.minutes as f32) {
|
||||
fmin.is_none_or(|min| mins >= min) && fmax.is_none_or(|max| mins <= max)
|
||||
} else {
|
||||
false // no travel data → postcode not reachable
|
||||
};
|
||||
if !pass {
|
||||
passes_travel = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !passes_travel {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
count += 1;
|
||||
matched_lats.push(state.data.lat[row]);
|
||||
matched_lons.push(state.data.lon[row]);
|
||||
}
|
||||
|
||||
(count, percentile_trimmed_bounds(matched_lats, matched_lons))
|
||||
}
|
||||
81
server-rs/src/routes/ai_filters/mod.rs
Normal file
81
server-rs/src/routes/ai_filters/mod.rs
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
//! AI filters: translate a natural-language property query into validated
|
||||
//! filter settings via Gemini.
|
||||
//!
|
||||
//! Split by concern:
|
||||
//! - [`handler`]: the `POST /api/ai-filters` route handler and Gemini
|
||||
//! conversation loop
|
||||
//! - [`prompt`]: system prompt building (precomputed at startup)
|
||||
//! - [`tools`]: the `search_destinations` tool declaration and execution
|
||||
//! - [`parsing`]: LLM response parsing and validation against feature metadata
|
||||
//! - [`matching`]: counting properties that match the proposed filters
|
||||
//! - [`usage`]: weekly token usage tracking / rate limiting
|
||||
|
||||
mod handler;
|
||||
mod matching;
|
||||
mod parsing;
|
||||
mod prompt;
|
||||
mod tools;
|
||||
mod usage;
|
||||
|
||||
pub use handler::post_ai_filters;
|
||||
pub use prompt::build_system_prompt;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct AiFiltersContext {
|
||||
filters: Value,
|
||||
#[serde(default)]
|
||||
travel_time: Vec<AiTravelTimeContext>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct AiTravelTimeContext {
|
||||
mode: String,
|
||||
label: String,
|
||||
min: Option<f32>,
|
||||
max: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct AiFiltersRequest {
|
||||
query: String,
|
||||
/// Current filters for conversational refinement (e.g. "make it cheaper")
|
||||
context: Option<AiFiltersContext>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct TravelTimeFilter {
|
||||
mode: String,
|
||||
slug: String,
|
||||
label: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
min: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
max: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct AiFiltersResponse {
|
||||
filters: Value,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
travel_time_filters: Vec<TravelTimeFilter>,
|
||||
/// What the LLM couldn't map to existing filters (empty if everything matched)
|
||||
#[serde(skip_serializing_if = "String::is_empty")]
|
||||
notes: String,
|
||||
/// Number of properties matching the proposed property and travel time filters.
|
||||
match_count: usize,
|
||||
/// Bounding box of the matching properties so the client can move the
|
||||
/// camera to where matches actually are. Absent when nothing matches.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
match_bounds: Option<MatchBounds>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct MatchBounds {
|
||||
south: f32,
|
||||
west: f32,
|
||||
north: f32,
|
||||
east: f32,
|
||||
}
|
||||
385
server-rs/src/routes/ai_filters/parsing.rs
Normal file
385
server-rs/src/routes/ai_filters/parsing.rs
Normal file
|
|
@ -0,0 +1,385 @@
|
|||
//! LLM response parsing: stripping markdown fences, normalizing frontend
|
||||
//! synthetic filter keys, and validating proposed filters against feature
|
||||
//! metadata and available travel destinations.
|
||||
|
||||
use serde_json::{json, Map, Value};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::routes::{FeatureInfo, FeaturesResponse};
|
||||
use crate::state::AppState;
|
||||
|
||||
use super::TravelTimeFilter;
|
||||
|
||||
/// Strip markdown code fences (```json ... ``` or ``` ... ```) from LLM output.
|
||||
/// Models occasionally wrap JSON in markdown fencing even when told not to.
|
||||
pub(super) fn strip_markdown_fences(text: &str) -> &str {
|
||||
let trimmed = text.trim();
|
||||
|
||||
// Try ```json\n...\n``` or ```\n...\n```
|
||||
if let Some(rest) = trimmed.strip_prefix("```") {
|
||||
// Skip optional language tag (e.g. "json")
|
||||
let rest = if let Some(newline_pos) = rest.find('\n') {
|
||||
&rest[newline_pos + 1..]
|
||||
} else {
|
||||
return trimmed;
|
||||
};
|
||||
if let Some(content) = rest.strip_suffix("```") {
|
||||
return content.trim();
|
||||
}
|
||||
}
|
||||
|
||||
trimmed
|
||||
}
|
||||
|
||||
fn school_feature_name_from_key(name: &str) -> Option<&'static str> {
|
||||
let rest = name.strip_prefix("Schools:")?;
|
||||
let mut parts = rest.split(':');
|
||||
let phase = parts.next()?;
|
||||
let rating = parts.next()?;
|
||||
|
||||
match (phase, rating) {
|
||||
("primary", "good") => Some("Good+ primary school catchments"),
|
||||
("secondary", "good") => Some("Good+ secondary school catchments"),
|
||||
("primary", "outstanding") => Some("Outstanding primary school catchments"),
|
||||
("secondary", "outstanding") => Some("Outstanding secondary school catchments"),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_synthetic_feature_key(name: &str, prefix: &str) -> Option<String> {
|
||||
let rest = name.strip_prefix(prefix)?;
|
||||
let (encoded, _id) = rest.rsplit_once(':')?;
|
||||
urlencoding::decode(encoded)
|
||||
.ok()
|
||||
.map(|decoded| decoded.into_owned())
|
||||
}
|
||||
|
||||
/// Convert frontend synthetic filter keys back to backend feature names.
|
||||
///
|
||||
/// The React filter UI stores configurable cards under keys such as
|
||||
/// `Political vote share:%25%20Labour:0`. The LLM and backend validators need
|
||||
/// the real feature name (`% Labour`) instead.
|
||||
fn backend_filter_name(name: &str) -> Option<String> {
|
||||
if let Some(feature_name) = school_feature_name_from_key(name) {
|
||||
return Some(feature_name.to_string());
|
||||
}
|
||||
|
||||
for prefix in [
|
||||
"Specific crimes:",
|
||||
"Political vote share:",
|
||||
"Ethnicities:",
|
||||
"Amenity distance:",
|
||||
"Transport distance:",
|
||||
"Amenities within 2km:",
|
||||
"Amenities within 5km:",
|
||||
] {
|
||||
if let Some(feature_name) = decode_synthetic_feature_key(name, prefix) {
|
||||
return Some(feature_name);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn canonical_filter_name(name: &str) -> String {
|
||||
backend_filter_name(name).unwrap_or_else(|| name.to_string())
|
||||
}
|
||||
|
||||
pub(super) fn normalize_context_filters(filters: &Value) -> Value {
|
||||
let Some(obj) = filters.as_object() else {
|
||||
return filters.clone();
|
||||
};
|
||||
|
||||
let mut normalized = Map::with_capacity(obj.len());
|
||||
for (name, value) in obj {
|
||||
normalized.insert(canonical_filter_name(name), value.clone());
|
||||
}
|
||||
Value::Object(normalized)
|
||||
}
|
||||
|
||||
/// Maximum travel-time minutes the data can contain. Matches the Java pipeline's
|
||||
/// MAX_TRIP_DURATION_MINUTES and the frontend's MAX_TRAVEL_MINUTES.
|
||||
const TRAVEL_TIME_MAX_MINUTES: f64 = 90.0;
|
||||
|
||||
fn travel_time_minute_field(item: &Value, key: &str) -> Option<f32> {
|
||||
item.get(key)
|
||||
.and_then(|val| val.as_f64())
|
||||
.filter(|val| val.is_finite())
|
||||
.map(|val| val.clamp(0.0, TRAVEL_TIME_MAX_MINUTES) as f32)
|
||||
}
|
||||
|
||||
fn parse_travel_time_bounds(item: &Value) -> (Option<f32>, Option<f32>) {
|
||||
let explicit_min = travel_time_minute_field(item, "min");
|
||||
let explicit_max = travel_time_minute_field(item, "max");
|
||||
|
||||
let (mut min, mut max) = if explicit_min.is_some() || explicit_max.is_some() {
|
||||
(explicit_min, explicit_max)
|
||||
} else {
|
||||
let value = travel_time_minute_field(item, "value");
|
||||
match (item.get("bound").and_then(|val| val.as_str()), value) {
|
||||
(Some("min"), Some(val)) => (Some(val), None),
|
||||
(Some("max"), Some(val)) => (None, Some(val)),
|
||||
_ => (None, None),
|
||||
}
|
||||
};
|
||||
|
||||
if let (Some(min_val), Some(max_val)) = (min, max) {
|
||||
if min_val > max_val {
|
||||
min = Some(max_val);
|
||||
max = Some(min_val);
|
||||
}
|
||||
}
|
||||
|
||||
(min, max)
|
||||
}
|
||||
|
||||
/// Validate travel time filters from LLM output against available destinations.
|
||||
pub(super) fn validate_travel_time_filters(raw: &Value, state: &AppState) -> Vec<TravelTimeFilter> {
|
||||
let arr = match raw
|
||||
.get("travel_time_filters")
|
||||
.and_then(|val| val.as_array())
|
||||
{
|
||||
Some(arr) => arr,
|
||||
None => return Vec::new(),
|
||||
};
|
||||
|
||||
let tt_store = &state.travel_time_store;
|
||||
let mut results = Vec::new();
|
||||
|
||||
for item in arr {
|
||||
let mode = match item.get("mode").and_then(|val| val.as_str()) {
|
||||
Some(mode) => mode,
|
||||
None => continue,
|
||||
};
|
||||
let slug = match item.get("slug").and_then(|val| val.as_str()) {
|
||||
Some(slug) => slug,
|
||||
None => continue,
|
||||
};
|
||||
let label = item
|
||||
.get("label")
|
||||
.and_then(|val| val.as_str())
|
||||
.unwrap_or(slug);
|
||||
|
||||
// Verify this destination actually exists
|
||||
if !tt_store.has_destination(mode, slug) {
|
||||
warn!(
|
||||
mode = mode,
|
||||
slug = slug,
|
||||
"AI suggested non-existent destination"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let (min, max) = parse_travel_time_bounds(item);
|
||||
|
||||
// Only include if at least one bound is set
|
||||
if min.is_some() || max.is_some() {
|
||||
results.push(TravelTimeFilter {
|
||||
mode: mode.to_string(),
|
||||
slug: slug.to_string(),
|
||||
label: label.to_string(),
|
||||
min,
|
||||
max,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// Validate LLM output against feature metadata and convert to FeatureFilters format.
|
||||
///
|
||||
/// Input format (array-based, each numeric filter sets one bound):
|
||||
/// ```json
|
||||
/// {
|
||||
/// "numeric_filters": [{"name": "Last known price", "bound": "max", "value": 300000}],
|
||||
/// "enum_filters": [{"name": "Leasehold/Freehold", "values": ["Freehold"]}]
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// Output format (FeatureFilters):
|
||||
/// ```json
|
||||
/// { "Last known price": [0, 300000], "Leasehold/Freehold": ["Freehold"] }
|
||||
/// ```
|
||||
pub(super) fn validate_and_convert(raw: &Value, features: &FeaturesResponse) -> Value {
|
||||
let mut result = serde_json::Map::new();
|
||||
|
||||
// Build lookup maps from feature metadata.
|
||||
// Store both slider bounds (min/max from percentiles) and true data bounds
|
||||
// (histogram.min/max) so one-sided AI filters use the full data range.
|
||||
let mut numeric_features: rustc_hash::FxHashMap<&str, (f32, f32, f32, f32)> =
|
||||
rustc_hash::FxHashMap::default();
|
||||
let mut enum_features: rustc_hash::FxHashMap<&str, &[String]> =
|
||||
rustc_hash::FxHashMap::default();
|
||||
|
||||
for group in &features.groups {
|
||||
for feature in &group.features {
|
||||
match feature {
|
||||
FeatureInfo::Numeric {
|
||||
name,
|
||||
min,
|
||||
max,
|
||||
histogram,
|
||||
..
|
||||
} => {
|
||||
numeric_features.insert(name, (*min, *max, histogram.min, histogram.max));
|
||||
}
|
||||
FeatureInfo::Enum { name, values, .. } => {
|
||||
enum_features.insert(name, values);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process numeric filters — each sets one bound (min or max).
|
||||
// The unset side uses the true data min/max (from histogram), not
|
||||
// the slider bounds (percentile-based), so a "max" filter for crime
|
||||
// produces [0, value] rather than [2nd-percentile, value].
|
||||
if let Some(arr) = raw.get("numeric_filters").and_then(|val| val.as_array()) {
|
||||
for item in arr {
|
||||
let raw_name = match item.get("name").and_then(|val| val.as_str()) {
|
||||
Some(name) => name,
|
||||
None => continue,
|
||||
};
|
||||
let name = canonical_filter_name(raw_name);
|
||||
let (slider_min, slider_max, data_min, data_max) =
|
||||
match numeric_features.get(name.as_str()) {
|
||||
Some(range) => *range,
|
||||
None => continue,
|
||||
};
|
||||
let bound = match item.get("bound").and_then(|val| val.as_str()) {
|
||||
Some(b) => b,
|
||||
None => continue,
|
||||
};
|
||||
// Clamp value to true data range (not slider range)
|
||||
let value = match item.get("value").and_then(|val| val.as_f64()) {
|
||||
Some(v) => v.max(data_min as f64).min(data_max as f64) as f32,
|
||||
None => continue,
|
||||
};
|
||||
let (filter_min, filter_max) = match bound {
|
||||
"min" => (value, data_max),
|
||||
"max" => (data_min, value),
|
||||
_ => continue,
|
||||
};
|
||||
// Only include if range is narrower than full slider range
|
||||
if filter_min > slider_min || filter_max < slider_max {
|
||||
result.insert(name, json!([filter_min, filter_max]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process enum filters
|
||||
if let Some(arr) = raw.get("enum_filters").and_then(|val| val.as_array()) {
|
||||
for item in arr {
|
||||
let raw_name = match item.get("name").and_then(|val| val.as_str()) {
|
||||
Some(name) => name,
|
||||
None => continue,
|
||||
};
|
||||
let name = canonical_filter_name(raw_name);
|
||||
let valid_values = match enum_features.get(name.as_str()) {
|
||||
Some(values) => *values,
|
||||
None => continue,
|
||||
};
|
||||
if let Some(selected) = item.get("values").and_then(|val| val.as_array()) {
|
||||
let valid: Vec<&str> = selected
|
||||
.iter()
|
||||
.filter_map(|item| item.as_str())
|
||||
.filter(|str_val| valid_values.iter().any(|known| known == str_val))
|
||||
.collect();
|
||||
if !valid.is_empty() && valid.len() < valid_values.len() {
|
||||
result.insert(name, json!(valid));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Value::Object(result)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn strip_fences_json_tag() {
|
||||
let input = "```json\n{\"a\": 1}\n```";
|
||||
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strip_fences_no_tag() {
|
||||
let input = "```\n{\"a\": 1}\n```";
|
||||
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strip_fences_passthrough() {
|
||||
let input = "{\"a\": 1}";
|
||||
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strip_fences_whitespace() {
|
||||
let input = " ```json\n {\"a\": 1} \n``` ";
|
||||
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn synthetic_filter_keys_are_normalized_to_backend_names() {
|
||||
assert_eq!(
|
||||
canonical_filter_name("Schools:primary:good:0"),
|
||||
"Good+ primary school catchments"
|
||||
);
|
||||
// Legacy keys still carry a distance segment; it is ignored.
|
||||
assert_eq!(
|
||||
canonical_filter_name("Schools:primary:good:2:0"),
|
||||
"Good+ primary school catchments"
|
||||
);
|
||||
assert_eq!(
|
||||
canonical_filter_name("Specific crimes:Burglary%20%28avg%2Fyr%29:1"),
|
||||
"Burglary (avg/yr)"
|
||||
);
|
||||
assert_eq!(
|
||||
canonical_filter_name("Political vote share:%25%20Labour:0"),
|
||||
"% Labour"
|
||||
);
|
||||
assert_eq!(
|
||||
canonical_filter_name(
|
||||
"Transport distance:Distance%20to%20nearest%20amenity%20%28Bus%20stop%29%20%28km%29:0"
|
||||
),
|
||||
"Distance to nearest amenity (Bus stop) (km)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_filters_are_normalized_before_prompting() {
|
||||
let filters = json!({
|
||||
"Political vote share:%25%20Green:0": [40, 100],
|
||||
"Estimated current price": [0, 500000],
|
||||
});
|
||||
|
||||
let normalized = normalize_context_filters(&filters);
|
||||
assert_eq!(normalized["% Green"], json!([40, 100]));
|
||||
assert_eq!(normalized["Estimated current price"], json!([0, 500000]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn travel_time_bounds_accept_min_max_schema() {
|
||||
let item = json!({ "min": 30, "max": 45 });
|
||||
assert_eq!(parse_travel_time_bounds(&item), (Some(30.0), Some(45.0)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn travel_time_bounds_accept_legacy_bound_value_schema() {
|
||||
let item = json!({ "bound": "max", "value": 30 });
|
||||
assert_eq!(parse_travel_time_bounds(&item), (None, Some(30.0)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn travel_time_bounds_clamp_and_order_range() {
|
||||
// Data ceiling is 90 (matches Java MAX_TRIP_DURATION_MINUTES).
|
||||
// Inputs outside [0, 90] clamp; min/max ordering is preserved as-given here.
|
||||
let item = json!({ "min": 150, "max": -10 });
|
||||
assert_eq!(parse_travel_time_bounds(&item), (Some(0.0), Some(90.0)));
|
||||
}
|
||||
}
|
||||
282
server-rs/src/routes/ai_filters/prompt.rs
Normal file
282
server-rs/src/routes/ai_filters/prompt.rs
Normal file
|
|
@ -0,0 +1,282 @@
|
|||
//! System prompt building for the AI filters assistant.
|
||||
|
||||
use crate::routes::{FeatureInfo, FeaturesResponse};
|
||||
|
||||
/// Build the complete system prompt for AI filters.
|
||||
///
|
||||
/// Contains: role instructions, feature catalogue, travel time info,
|
||||
/// few-shot examples, output rules.
|
||||
/// Precomputed at startup and cached in AppState.
|
||||
pub fn build_system_prompt(
|
||||
features: &FeaturesResponse,
|
||||
mode_destinations: &[(String, usize)],
|
||||
) -> String {
|
||||
let mut parts = Vec::new();
|
||||
|
||||
parts.push(
|
||||
"You are a UK property search assistant. \
|
||||
The user describes their ideal property or area in natural language. \
|
||||
Translate their description into filter settings using ONLY the features listed below.\n\
|
||||
\n\
|
||||
Rules:\n\
|
||||
- ONLY set filters the user explicitly mentioned or clearly implied.\n\
|
||||
- Leave out any filter the user did not mention. Empty arrays are fine.\n\
|
||||
- Each numeric filter sets ONE bound only: \"min\" (at least this value) \
|
||||
or \"max\" (at most this value). Never set two filters on the same feature.\n\
|
||||
- Use EXACT feature names from the list — spelling, capitalisation, and punctuation must match.\n\
|
||||
- \"cheap\" / \"affordable\" = lower price range. \"expensive\" = higher price range.\n\
|
||||
- \"low crime\" / \"safe\" = low values on the Serious crime (avg/yr) and Minor crime (avg/yr) \
|
||||
features (area-normalised incident density near the postcode). Prefer these aggregates for broad \
|
||||
area safety; use specific crime features only when the user names a crime type.\n\
|
||||
- \"quiet\" = low Noise (dB). \"green\" / \"near parks\" = high Number of amenities (Park) within 2km \
|
||||
or low Distance to nearest park (km), depending on wording.\n\
|
||||
- \"good schools\" = Good+ school features. \"outstanding schools\" = Outstanding school features.\n\
|
||||
- Amenities and transport stops are normal filters in the feature catalogue. \
|
||||
For \"near a bus stop\", \"near a station\", \"near shops\", etc., use the exact \
|
||||
Distance to nearest amenity (...) or Number of amenities (...) feature when available.\n\
|
||||
- Politics/elections are normal filters in the Neighbours group. Use exact vote share \
|
||||
features such as % Labour, % Conservative, % Liberal Democrat, % Reform UK, % Green, \
|
||||
% Other parties, or Voter turnout (%) when the user asks for political character.\n\
|
||||
- When the user says a number like \"under 400k\", interpret it as 400000.\n\
|
||||
- When the user says \"3 bed\" or \"3 bedroom\", use Number of bedrooms & living rooms \
|
||||
(note: this counts bedrooms + living rooms combined, so 3 bed ~ min 4).\n\
|
||||
- If the user mentions something that has no matching filter, put it in \"notes\" \
|
||||
as a short phrase (e.g. \"No filter for: garden, sea view\"). \
|
||||
If everything was matched, set \"notes\" to an empty string.\n\
|
||||
\n\
|
||||
CONVERSATIONAL REFINEMENT:\n\
|
||||
The user's message may include their currently active filters as context. \
|
||||
When context is provided:\n\
|
||||
- \"make it cheaper\" / \"lower the price\" = adjust the existing price filter down\n\
|
||||
- \"also add ...\" / \"and good schools\" = keep existing filters and add new ones\n\
|
||||
- \"remove the ...\" / \"drop the ...\" = return filters WITHOUT the mentioned one\n\
|
||||
- If the request is a completely new search (not a refinement), ignore the context \
|
||||
and build filters from scratch.\n\
|
||||
- Always output the COMPLETE set of filters (existing + modified), not just the changes."
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
// Travel time section with available modes
|
||||
let modes_list = mode_destinations
|
||||
.iter()
|
||||
.map(|(mode, count)| format!("- {} ({} destinations available)", mode, count))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
parts.push(format!(
|
||||
"\n--- TRAVEL TIME FILTERS ---\n\
|
||||
You can add travel time filters when the user mentions commute times, \
|
||||
proximity to places, or wanting to be near/within X minutes of somewhere.\n\
|
||||
\n\
|
||||
Available travel-time modes (only use modes that have destinations):\n\
|
||||
{}\n\
|
||||
- \"car\" / \"drive\" / \"driving\" = car mode\n\
|
||||
- \"cycle\" / \"bike\" / \"cycling\" = bicycle mode\n\
|
||||
- \"walk\" / \"walking\" / \"on foot\" = walking mode\n\
|
||||
- \"train\" / \"tube\" / \"bus\" / \"public transport\" / \"commute\" = transit mode\n\
|
||||
- \"without buses\" / \"no bus\" / \"rail only\" = transit-no-bus mode\n\
|
||||
- \"no change\" / \"no transfer\" / \"direct\" / \"single bus/train\" = transit-no-change mode\n\
|
||||
- \"no change and no bus\" / \"direct rail/tube\" = transit-no-change-no-bus mode\n\
|
||||
- If a mode appears in the available mode list but is not named above, you may still \
|
||||
use the exact mode string from the list.\n\
|
||||
\n\
|
||||
When the user mentions a specific place, you MUST call the search_destinations \
|
||||
tool to find the exact slug. Use the name and slug from the search results.\n\
|
||||
If search_destinations returns an empty array, the destination is not available — \
|
||||
mention it in \"notes\" (e.g. \"No travel data for: Gatwick Airport\") and do NOT \
|
||||
include a travel_time_filter for it.\n\
|
||||
\n\
|
||||
Travel time values are in MINUTES (0-90 range; data is capped at 90 min).\n\
|
||||
- \"within 30 minutes\" = set \"max\": 30\n\
|
||||
- \"at least 10 minutes\" = set \"min\": 10\n\
|
||||
- \"30-45 minute commute\" = set \"min\": 30 and \"max\": 45 on the same travel_time_filter\n\
|
||||
- If only a max is given, omit min (and vice versa). Do not use bound/value for travel time.\n\
|
||||
\n\
|
||||
INFERRING TRANSPORT MODE (when the user does not specify one explicitly):\n\
|
||||
- \"commute\" to a major city centre or station = transit\n\
|
||||
- \"near\" / \"close to\" a city centre or station = transit\n\
|
||||
- \"near\" / \"close to\" a smaller town, village, or rural area = car\n\
|
||||
- \"drive\" / \"driving distance\" / \"driving time\" = always car\n\
|
||||
- If multiple modes are plausible, prefer transit for urban destinations \
|
||||
(London, Manchester, Birmingham, Leeds, etc.) and car for everything else.",
|
||||
modes_list,
|
||||
));
|
||||
|
||||
// Feature guidance
|
||||
parts.push(
|
||||
"\n--- DATA SOURCE ---\n\
|
||||
The data is historical property sales from the Land Registry.\n\
|
||||
\n\
|
||||
Use these features for price queries:\n\
|
||||
- For purchase price: use \"Estimated current price\" or \"Last known price\"\n\
|
||||
- For price per sqm: use \"Est. price per sqm\"\n\
|
||||
- For rent estimates: use \"Estimated monthly rent\""
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
// Feature catalogue
|
||||
parts.push("\n--- AVAILABLE FEATURES ---\n".to_string());
|
||||
for group in &features.groups {
|
||||
parts.push(format!("## {}", group.name));
|
||||
for feature in &group.features {
|
||||
match feature {
|
||||
FeatureInfo::Numeric {
|
||||
name,
|
||||
min,
|
||||
max,
|
||||
description,
|
||||
prefix,
|
||||
suffix,
|
||||
..
|
||||
} => {
|
||||
parts.push(format!(
|
||||
"- \"{}\" (numeric, {}{:.0}{} to {}{:.0}{}): {}",
|
||||
name, prefix, min, suffix, prefix, max, suffix, description
|
||||
));
|
||||
}
|
||||
FeatureInfo::Enum {
|
||||
name,
|
||||
values,
|
||||
description,
|
||||
..
|
||||
} => {
|
||||
parts.push(format!(
|
||||
"- \"{}\" (enum, values: [{}]): {}",
|
||||
name,
|
||||
values
|
||||
.iter()
|
||||
.map(|val| format!("\"{}\"", val))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", "),
|
||||
description
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Few-shot examples
|
||||
parts.push("\n--- EXAMPLES ---\n".to_string());
|
||||
|
||||
parts.push(
|
||||
"User: \"cheap freehold house under 400k\"\n\
|
||||
Output: {\"numeric_filters\": [{\"name\": \"Last known price\", \"bound\": \"max\", \"value\": 400000}], \
|
||||
\"enum_filters\": [{\"name\": \"Leasehold/Freehold\", \"values\": [\"Freehold\"]}, \
|
||||
{\"name\": \"Property type\", \"values\": [\"Detached\", \"Semi-Detached\", \"Terraced\"]}], \
|
||||
\"travel_time_filters\": [], \
|
||||
\"notes\": \"\"}"
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
parts.push(
|
||||
"\nUser: \"safe quiet area with good schools and parks\"\n\
|
||||
Output: {\"numeric_filters\": [\
|
||||
{\"name\": \"Serious crime (avg/yr)\", \"bound\": \"max\", \"value\": 5}, \
|
||||
{\"name\": \"Minor crime (avg/yr)\", \"bound\": \"max\", \"value\": 20}, \
|
||||
{\"name\": \"Noise (dB)\", \"bound\": \"max\", \"value\": 55}, \
|
||||
{\"name\": \"Good+ primary school catchments\", \"bound\": \"min\", \"value\": 2}, \
|
||||
{\"name\": \"Good+ secondary school catchments\", \"bound\": \"min\", \"value\": 1}, \
|
||||
{\"name\": \"Number of amenities (Park) within 2km\", \"bound\": \"min\", \"value\": 3}], \
|
||||
\"enum_filters\": [], \"travel_time_filters\": [], \"notes\": \"\"}"
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
parts.push(
|
||||
"\nUser: \"quiet area with outstanding schools\"\n\
|
||||
Output: {\"numeric_filters\": [\
|
||||
{\"name\": \"Noise (dB)\", \"bound\": \"max\", \"value\": 55}, \
|
||||
{\"name\": \"Outstanding primary school catchments\", \"bound\": \"min\", \"value\": 1}, \
|
||||
{\"name\": \"Outstanding secondary school catchments\", \"bound\": \"min\", \"value\": 1}], \
|
||||
\"enum_filters\": [], \"travel_time_filters\": [], \"notes\": \"\"}"
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
parts.push(
|
||||
"\nUser: \"3 bed flat under 300k with fast broadband near the beach\"\n\
|
||||
Output: {\"numeric_filters\": [\
|
||||
{\"name\": \"Last known price\", \"bound\": \"max\", \"value\": 300000}, \
|
||||
{\"name\": \"Number of bedrooms & living rooms\", \"bound\": \"min\", \"value\": 4}], \
|
||||
\"enum_filters\": [{\"name\": \"Property type\", \"values\": [\"Flats/Maisonettes\"]}, \
|
||||
{\"name\": \"Max available download speed (Mbps)\", \"values\": [\"100\", \"300\", \"1000\"]}], \
|
||||
\"travel_time_filters\": [], \
|
||||
\"notes\": \"No filter for: beach proximity\"}"
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
parts.push(
|
||||
"\nUser: \"within 30 minutes commute of Kings Cross, under 500k\"\n\
|
||||
(After calling search_destinations for \"Kings Cross\" with mode \"transit\" \
|
||||
and getting [{\"name\": \"Kings Cross\", \"slug\": \"kings-cross\", \"place_type\": \"station\"}])\n\
|
||||
Output: {\"numeric_filters\": [\
|
||||
{\"name\": \"Last known price\", \"bound\": \"max\", \"value\": 500000}], \
|
||||
\"enum_filters\": [], \
|
||||
\"travel_time_filters\": [{\"mode\": \"transit\", \"slug\": \"kings-cross\", \
|
||||
\"label\": \"Kings Cross\", \"max\": 30}], \
|
||||
\"notes\": \"\"}"
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
parts.push(
|
||||
"\nUser: \"family home with garden, 45 min drive from Manchester, good schools\"\n\
|
||||
(After calling search_destinations for \"Manchester\" with mode \"car\" \
|
||||
and getting [{\"name\": \"Manchester\", \"slug\": \"manchester\", \"place_type\": \"city\"}])\n\
|
||||
Output: {\"numeric_filters\": [\
|
||||
{\"name\": \"Total floor area (sqm)\", \"bound\": \"min\", \"value\": 100}, \
|
||||
{\"name\": \"Number of bedrooms & living rooms\", \"bound\": \"min\", \"value\": 5}, \
|
||||
{\"name\": \"Good+ primary school catchments\", \"bound\": \"min\", \"value\": 2}, \
|
||||
{\"name\": \"Good+ secondary school catchments\", \"bound\": \"min\", \"value\": 1}], \
|
||||
\"enum_filters\": [{\"name\": \"Property type\", \
|
||||
\"values\": [\"Detached\", \"Semi-Detached\"]}], \
|
||||
\"travel_time_filters\": [{\"mode\": \"car\", \"slug\": \"manchester\", \
|
||||
\"label\": \"Manchester\", \"max\": 45}], \
|
||||
\"notes\": \"No filter for: garden\"}"
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
parts.push(
|
||||
"\nUser: \"Labour-voting area with low burglary and a station nearby\"\n\
|
||||
Output: {\"numeric_filters\": [\
|
||||
{\"name\": \"% Labour\", \"bound\": \"min\", \"value\": 40}, \
|
||||
{\"name\": \"Burglary (avg/yr)\", \"bound\": \"max\", \"value\": 10}, \
|
||||
{\"name\": \"Distance to nearest amenity (Rail station) (km)\", \"bound\": \"max\", \"value\": 1}], \
|
||||
\"enum_filters\": [], \"travel_time_filters\": [], \"notes\": \"\"}"
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
// Examples showing rent and price features
|
||||
parts.push(
|
||||
"\nUser: \"2 bed flat with rent under £1500/month\"\n\
|
||||
Output: {\
|
||||
\"numeric_filters\": [{\"name\": \"Estimated monthly rent\", \"bound\": \"max\", \"value\": 1500}], \
|
||||
\"enum_filters\": [{\"name\": \"Property type\", \"values\": [\"Flats/Maisonettes\"]}], \
|
||||
\"travel_time_filters\": [], \
|
||||
\"notes\": \"\"}"
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
parts.push(
|
||||
"\nUser: \"3 bed house under 500k with good schools\"\n\
|
||||
Output: {\
|
||||
\"numeric_filters\": [{\"name\": \"Estimated current price\", \"bound\": \"max\", \"value\": 500000}, \
|
||||
{\"name\": \"Good+ primary school catchments\", \"bound\": \"min\", \"value\": 2}], \
|
||||
\"enum_filters\": [{\"name\": \"Property type\", \
|
||||
\"values\": [\"Detached\", \"Semi-Detached\", \"Terraced\"]}], \
|
||||
\"travel_time_filters\": [], \
|
||||
\"notes\": \"\"}"
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
// Output format reminder
|
||||
parts.push(
|
||||
"\n--- OUTPUT FORMAT ---\n\
|
||||
{\"numeric_filters\": [...], \"enum_filters\": [...], \
|
||||
\"travel_time_filters\": [{\"mode\": \"...\", \"slug\": \"...\", \"label\": \"...\", \
|
||||
\"min\": N, \"max\": N}, ...], \"notes\": \"...\"}\n\
|
||||
- travel_time_filters: min and max are both optional, but include at least one. \
|
||||
Use ONLY slugs returned by search_destinations. If a place isn't found, mention it in notes.\n\
|
||||
Respond with ONLY the JSON object. No explanation."
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
parts.join("\n")
|
||||
}
|
||||
188
server-rs/src/routes/ai_filters/tools.rs
Normal file
188
server-rs/src/routes/ai_filters/tools.rs
Normal file
|
|
@ -0,0 +1,188 @@
|
|||
//! The `search_destinations` Gemini tool: declaration and execution against
|
||||
//! PlaceData + TravelTimeStore.
|
||||
|
||||
use serde_json::{json, Value};
|
||||
use tracing::info;
|
||||
|
||||
use crate::data::slugify;
|
||||
use crate::state::AppState;
|
||||
|
||||
/// Build the Gemini tool declaration for destination search.
|
||||
pub(super) fn build_tool_declarations(state: &AppState) -> Value {
|
||||
let modes: Vec<&str> = state
|
||||
.travel_time_store
|
||||
.available_modes
|
||||
.iter()
|
||||
.map(|mode| mode.as_str())
|
||||
.collect();
|
||||
|
||||
json!([{
|
||||
"functionDeclarations": [{
|
||||
"name": "search_destinations",
|
||||
"description": "Search for available travel time destinations (cities, stations, towns) that have precomputed travel time data. Call this when the user mentions wanting to be near, close to, or within a certain travel time of a specific place.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Place name to search for (e.g. 'Manchester', 'Kings Cross', 'Heathrow')"
|
||||
},
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": modes,
|
||||
"description": "Transport mode to search destinations for"
|
||||
}
|
||||
},
|
||||
"required": ["query", "mode"]
|
||||
}
|
||||
}]
|
||||
}])
|
||||
}
|
||||
|
||||
/// Execute a destination search against PlaceData + TravelTimeStore.
|
||||
/// Returns matching destinations as a JSON value with `results` and optional `message`.
|
||||
///
|
||||
/// Uses word-based matching: all words in the query must appear somewhere in the
|
||||
/// place name (order-independent). Also matches against slugs for short queries.
|
||||
pub(super) fn execute_destination_search(state: &AppState, query: &str, mode: &str) -> Value {
|
||||
let query_lower = query.to_lowercase();
|
||||
let query_words: Vec<&str> = query_lower.split_whitespace().collect();
|
||||
let query_slug = slugify(query);
|
||||
let tt_store = &state.travel_time_store;
|
||||
let pd = &state.place_data;
|
||||
|
||||
let slug_set = match tt_store.destinations.get(mode) {
|
||||
Some(slugs) => slugs,
|
||||
None => {
|
||||
return json!({ "results": [], "message": format!("No travel data available for mode '{}'", mode) })
|
||||
}
|
||||
};
|
||||
|
||||
// Find places matching the query that have travel time data.
|
||||
// A place matches if ALL query words appear in its name, OR its slug matches the query slug.
|
||||
let mut matches: Vec<(usize, String, u8, u32)> = pd
|
||||
.name_lower
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, name_lower)| {
|
||||
if !pd.travel_destination[idx] {
|
||||
return None;
|
||||
}
|
||||
let words_match = query_words.iter().all(|word| name_lower.contains(word));
|
||||
let slug = slugify(&pd.name[idx]);
|
||||
let slug_match = slug.contains(&query_slug) || query_slug.contains(&slug);
|
||||
if !words_match && !slug_match {
|
||||
return None;
|
||||
}
|
||||
if slug_set.contains(&slug) {
|
||||
Some((idx, slug, pd.type_rank[idx], pd.population[idx]))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Sort: type rank asc, population desc
|
||||
matches.sort_unstable_by(|a, b| a.2.cmp(&b.2).then(b.3.cmp(&a.3)));
|
||||
matches.truncate(10);
|
||||
|
||||
if matches.is_empty() {
|
||||
// Check if the query matched a city that lacks its own travel data.
|
||||
// If so, return nearby stations within that city as suggestions.
|
||||
let matched_city_name: Option<&str> =
|
||||
pd.name_lower
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find_map(|(idx, name_lower)| {
|
||||
if !pd.travel_destination[idx] {
|
||||
return None;
|
||||
}
|
||||
let words_match = query_words.iter().all(|word| name_lower.contains(word));
|
||||
let slug = slugify(&pd.name[idx]);
|
||||
let slug_match = slug.contains(&query_slug) || query_slug.contains(&slug);
|
||||
if (words_match || slug_match) && pd.type_rank[idx] == 0 {
|
||||
Some(pd.name[idx].as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(city_name) = matched_city_name {
|
||||
let city_lower = city_name.to_lowercase();
|
||||
let mut city_matches: Vec<(usize, String, u8, u32)> = pd
|
||||
.city
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, city_opt)| {
|
||||
if !pd.travel_destination[idx] {
|
||||
return None;
|
||||
}
|
||||
let city = city_opt.as_deref()?;
|
||||
if city.to_lowercase() != city_lower {
|
||||
return None;
|
||||
}
|
||||
let slug = slugify(&pd.name[idx]);
|
||||
if slug_set.contains(&slug) {
|
||||
Some((idx, slug, pd.type_rank[idx], pd.population[idx]))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
city_matches.sort_unstable_by(|a, b| a.2.cmp(&b.2).then(b.3.cmp(&a.3)));
|
||||
city_matches.truncate(10);
|
||||
|
||||
if !city_matches.is_empty() {
|
||||
let results: Vec<Value> = city_matches
|
||||
.into_iter()
|
||||
.map(|(idx, slug, ..)| {
|
||||
json!({
|
||||
"name": pd.name[idx],
|
||||
"slug": slug,
|
||||
"place_type": pd.place_type.get(idx).to_string(),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
info!(
|
||||
query = query,
|
||||
city = city_name,
|
||||
results = results.len(),
|
||||
"Destination search fell back to city stations"
|
||||
);
|
||||
|
||||
return json!({
|
||||
"results": results,
|
||||
"message": format!(
|
||||
"No travel data for '{}' directly. Pick one of these nearby stations:",
|
||||
city_name
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
query = query,
|
||||
mode = mode,
|
||||
"Destination search returned no results"
|
||||
);
|
||||
return json!({
|
||||
"results": [],
|
||||
"message": format!("No travel time data available for '{}' by {}. This destination cannot be used as a travel time filter.", query, mode)
|
||||
});
|
||||
}
|
||||
|
||||
let results: Vec<Value> = matches
|
||||
.into_iter()
|
||||
.map(|(idx, slug, ..)| {
|
||||
json!({
|
||||
"name": pd.name[idx],
|
||||
"slug": slug,
|
||||
"place_type": pd.place_type.get(idx).to_string(),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
json!({ "results": results })
|
||||
}
|
||||
119
server-rs/src/routes/ai_filters/usage.rs
Normal file
119
server-rs/src/routes/ai_filters/usage.rs
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
//! Weekly AI token usage tracking and rate limiting, persisted on the user's
|
||||
//! PocketBase record.
|
||||
|
||||
use axum::http::StatusCode;
|
||||
use metrics::counter;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::pocketbase::get_superuser_token;
|
||||
use crate::state::AppState;
|
||||
|
||||
/// Monotonically increasing week number derived from Unix epoch.
|
||||
/// Resets every 7 days (604800 seconds). Used for weekly rate limiting.
|
||||
pub(super) fn current_week_number() -> u64 {
|
||||
let secs = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
// Only possible if the system clock is before 1970; fall back to
|
||||
// week 0 rather than panicking inside a request handler.
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
secs / 604_800
|
||||
}
|
||||
|
||||
/// Fetch the user's current AI token usage from PocketBase.
|
||||
/// Returns `(tokens_used, week_number)`.
|
||||
pub(super) async fn fetch_ai_usage(
|
||||
state: &AppState,
|
||||
user_id: &str,
|
||||
) -> Result<(u64, u64), (StatusCode, String)> {
|
||||
let token = get_superuser_token(state).await.map_err(|err| {
|
||||
warn!("Failed to auth superuser for AI usage check: {err}");
|
||||
(StatusCode::BAD_GATEWAY, "Internal error".into())
|
||||
})?;
|
||||
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| {
|
||||
warn!("Failed to fetch user record for AI usage: {err}");
|
||||
(StatusCode::BAD_GATEWAY, "Internal error".into())
|
||||
})?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
warn!("PocketBase user fetch failed ({status})");
|
||||
return Err((StatusCode::BAD_GATEWAY, "Internal error".into()));
|
||||
}
|
||||
|
||||
let body: Value = resp.json().await.map_err(|err| {
|
||||
warn!("Failed to parse user record: {err}");
|
||||
(StatusCode::BAD_GATEWAY, "Internal error".into())
|
||||
})?;
|
||||
|
||||
let tokens_used = body
|
||||
.get("ai_tokens_used")
|
||||
.and_then(|val| val.as_u64())
|
||||
.unwrap_or(0);
|
||||
let week = body
|
||||
.get("ai_tokens_week")
|
||||
.and_then(|val| val.as_u64())
|
||||
.unwrap_or(0);
|
||||
|
||||
Ok((tokens_used, week))
|
||||
}
|
||||
|
||||
/// Update the user's AI token usage in PocketBase.
|
||||
/// Best-effort — logs warnings on failure but does not propagate errors.
|
||||
async fn update_ai_usage(state: &AppState, user_id: &str, tokens_used: u64, week: u64) {
|
||||
let token = match get_superuser_token(state).await {
|
||||
Ok(tk) => tk,
|
||||
Err(err) => {
|
||||
warn!("Failed to auth superuser for AI usage update: {err}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
|
||||
let res = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&json!({
|
||||
"ai_tokens_used": tokens_used,
|
||||
"ai_tokens_week": week,
|
||||
}))
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {}
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
warn!("Failed to update AI usage ({status})");
|
||||
}
|
||||
Err(err) => warn!("Failed to update AI usage: {err}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn record_ai_request_usage(
|
||||
state: &AppState,
|
||||
user_id: &str,
|
||||
existing_tokens_used: u64,
|
||||
week: u64,
|
||||
request_tokens_used: u64,
|
||||
status: &'static str,
|
||||
) {
|
||||
if request_tokens_used > 0 {
|
||||
let new_total = existing_tokens_used.saturating_add(request_tokens_used);
|
||||
update_ai_usage(state, user_id, new_total, week).await;
|
||||
counter!("ai_tokens_total").increment(request_tokens_used);
|
||||
}
|
||||
counter!("ai_requests_total", "status" => status).increment(1);
|
||||
}
|
||||
|
|
@ -780,31 +780,28 @@ pub async fn get_export(
|
|||
// groups themselves; postcodes within a group are sorted alphabetically.
|
||||
// Each group carries a rolled-up summary aggregate for its header row.
|
||||
let outcode_groups: Vec<OutcodeGroup> = {
|
||||
let mut order: Vec<String> = Vec::new();
|
||||
let mut by_outcode: FxHashMap<String, OutcodeGroup> = FxHashMap::default();
|
||||
let mut groups: Vec<OutcodeGroup> = Vec::new();
|
||||
let mut idx_by_outcode: FxHashMap<String, usize> = FxHashMap::default();
|
||||
for (i, (pc_idx, agg)) in postcode_aggs.iter().enumerate() {
|
||||
let outcode = outcode_of(&postcode_data.postcodes[*pc_idx]).to_string();
|
||||
let group = by_outcode.entry(outcode.clone()).or_insert_with(|| {
|
||||
order.push(outcode.clone());
|
||||
OutcodeGroup {
|
||||
outcode: outcode.clone(),
|
||||
let idx = *idx_by_outcode.entry(outcode.clone()).or_insert_with(|| {
|
||||
groups.push(OutcodeGroup {
|
||||
outcode,
|
||||
members: Vec::new(),
|
||||
summary: PostcodeExportAgg::new(total_export_features),
|
||||
}
|
||||
});
|
||||
groups.len() - 1
|
||||
});
|
||||
group.members.push(i);
|
||||
group.summary.merge_from(agg);
|
||||
groups[idx].members.push(i);
|
||||
groups[idx].summary.merge_from(agg);
|
||||
}
|
||||
for group in by_outcode.values_mut() {
|
||||
for group in &mut groups {
|
||||
group.members.sort_by(|&a, &b| {
|
||||
postcode_data.postcodes[postcode_aggs[a].0]
|
||||
.cmp(&postcode_data.postcodes[postcode_aggs[b].0])
|
||||
});
|
||||
}
|
||||
order
|
||||
.into_iter()
|
||||
.map(|outcode| by_outcode.remove(&outcode).unwrap())
|
||||
.collect()
|
||||
groups
|
||||
};
|
||||
|
||||
// Build Excel workbook with two sheets
|
||||
|
|
|
|||
|
|
@ -130,6 +130,11 @@ pub struct HexagonStatsResponse {
|
|||
pub price_history: Vec<PricePoint>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub crime_by_year: Vec<CrimeYearStats>,
|
||||
/// Latest year in the crime dataset as a whole. When a selection's series
|
||||
/// end earlier (force-level publication gap, e.g. Greater Manchester),
|
||||
/// the client captions the data as stale.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub crime_latest_year: Option<i32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub central_postcode: Option<String>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
|
|
@ -645,12 +650,19 @@ pub async fn get_hexagon_stats(
|
|||
"GET /api/hexagon-stats"
|
||||
);
|
||||
|
||||
let crime_latest_year = if crime_by_year.is_empty() {
|
||||
None
|
||||
} else {
|
||||
stats::crime_latest_available_year(&state.crime_by_year)
|
||||
};
|
||||
|
||||
Ok(HexagonStatsResponse {
|
||||
count: total_count,
|
||||
numeric_features,
|
||||
enum_features: enum_features_out,
|
||||
price_history,
|
||||
crime_by_year,
|
||||
crime_latest_year,
|
||||
central_postcode,
|
||||
filter_exclusions,
|
||||
})
|
||||
|
|
|
|||
|
|
@ -36,8 +36,9 @@ fn is_allowed_pb_path(path: &str) -> bool {
|
|||
|
||||
/// Dedicated HTTP client for proxying — does not follow redirects so 3xx
|
||||
/// responses are passed through to the browser (needed for OAuth flows).
|
||||
/// No overall timeout because SSE (Server-Sent Events) connections used by
|
||||
/// PocketBase realtime/OAuth2 are long-lived streams.
|
||||
/// No client-wide timeout because SSE (Server-Sent Events) connections used
|
||||
/// by PocketBase realtime/OAuth2 are long-lived streams; non-realtime
|
||||
/// requests get a per-request timeout instead (see PROXY_REQUEST_TIMEOUT).
|
||||
static PROXY_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
|
||||
reqwest::Client::builder()
|
||||
.redirect(reqwest::redirect::Policy::none())
|
||||
|
|
@ -47,6 +48,11 @@ static PROXY_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
|
|||
.expect("Failed to build proxy HTTP client")
|
||||
});
|
||||
|
||||
/// Timeout for proxied requests other than the realtime SSE stream, so a hung
|
||||
/// PocketBase cannot pile up handlers indefinitely. Generous enough for file
|
||||
/// uploads/downloads over slow links.
|
||||
const PROXY_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
|
||||
pub async fn proxy_to_pocketbase(
|
||||
State(shared): State<Arc<SharedState>>,
|
||||
req: Request,
|
||||
|
|
@ -58,10 +64,7 @@ pub async fn proxy_to_pocketbase(
|
|||
let target_path = path.strip_prefix("/pb").unwrap_or(path);
|
||||
if !is_allowed_pb_path(target_path) {
|
||||
warn!(path = %target_path, "Rejected PocketBase proxy request to disallowed path");
|
||||
return Response::builder()
|
||||
.status(StatusCode::NOT_FOUND)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
return StatusCode::NOT_FOUND.into_response();
|
||||
}
|
||||
let query = req
|
||||
.uri()
|
||||
|
|
@ -73,6 +76,12 @@ pub async fn proxy_to_pocketbase(
|
|||
let method = req.method().clone();
|
||||
let mut builder = PROXY_CLIENT.request(method, &url);
|
||||
|
||||
// The realtime SSE stream is intentionally unbounded; everything else
|
||||
// must complete within the timeout.
|
||||
if target_path != "/api/realtime" {
|
||||
builder = builder.timeout(PROXY_REQUEST_TIMEOUT);
|
||||
}
|
||||
|
||||
// Forward only safe headers (allowlist)
|
||||
const ALLOWED_HEADERS: &[&str] = &[
|
||||
"content-type",
|
||||
|
|
@ -96,10 +105,7 @@ pub async fn proxy_to_pocketbase(
|
|||
Ok(bytes) => bytes,
|
||||
Err(err) => {
|
||||
warn!("Failed to read request body: {err}");
|
||||
return Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.body(Body::from("Failed to read request body"))
|
||||
.unwrap();
|
||||
return (StatusCode::BAD_REQUEST, "Failed to read request body").into_response();
|
||||
}
|
||||
};
|
||||
builder = builder.body(body_bytes);
|
||||
|
|
@ -129,14 +135,14 @@ pub async fn proxy_to_pocketbase(
|
|||
// realtime system and OAuth2 flow — buffering would hang forever
|
||||
// since SSE responses never complete.
|
||||
let body = Body::from_stream(upstream.bytes_stream());
|
||||
response.body(body).unwrap()
|
||||
response.body(body).unwrap_or_else(|err| {
|
||||
warn!("Failed to build proxied response: {err}");
|
||||
(StatusCode::BAD_GATEWAY, "Invalid upstream response").into_response()
|
||||
})
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("PocketBase proxy error: {err}");
|
||||
Response::builder()
|
||||
.status(StatusCode::BAD_GATEWAY)
|
||||
.body(Body::from("PocketBase unavailable"))
|
||||
.unwrap()
|
||||
(StatusCode::BAD_GATEWAY, "PocketBase unavailable").into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -184,12 +184,19 @@ pub async fn get_postcode_stats(
|
|||
"GET /api/postcode-stats"
|
||||
);
|
||||
|
||||
let crime_latest_year = if crime_by_year.is_empty() {
|
||||
None
|
||||
} else {
|
||||
stats::crime_latest_available_year(&state.crime_by_year)
|
||||
};
|
||||
|
||||
Ok(HexagonStatsResponse {
|
||||
count: total_count,
|
||||
numeric_features,
|
||||
enum_features: enum_features_out,
|
||||
price_history,
|
||||
crime_by_year,
|
||||
crime_latest_year,
|
||||
central_postcode: None,
|
||||
filter_exclusions,
|
||||
})
|
||||
|
|
|
|||
|
|
@ -340,16 +340,14 @@ pub fn compute_crime_by_year(
|
|||
let points: Vec<CrimeYearPoint> = years
|
||||
.iter()
|
||||
.filter_map(|&year| {
|
||||
let denom = fully_covered_rows
|
||||
+ covered_counts.get(&year).copied().unwrap_or(0);
|
||||
let denom = fully_covered_rows + covered_counts.get(&year).copied().unwrap_or(0);
|
||||
if denom == 0 {
|
||||
// No selected postcode has published data for this year.
|
||||
return None;
|
||||
}
|
||||
Some(CrimeYearPoint {
|
||||
year,
|
||||
count: (sums.get(&year).copied().unwrap_or(0.0) / denom as f64)
|
||||
as f32,
|
||||
count: (sums.get(&year).copied().unwrap_or(0.0) / denom as f64) as f32,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
|
@ -365,6 +363,19 @@ pub fn compute_crime_by_year(
|
|||
out
|
||||
}
|
||||
|
||||
/// Latest year present anywhere in the by-year crime dataset. The client
|
||||
/// compares each selection's last charted year against this to caption
|
||||
/// force-level publication gaps (e.g. Greater Manchester ends mid-2019) as
|
||||
/// stale data instead of letting old numbers read as current.
|
||||
pub fn crime_latest_available_year(crime_by_year: &CrimeByYearData) -> Option<i32> {
|
||||
crime_by_year
|
||||
.years_by_type
|
||||
.iter()
|
||||
.flatten()
|
||||
.copied()
|
||||
.max()
|
||||
}
|
||||
|
||||
pub fn compute_poi_feature_stats(
|
||||
matching_rows: &[usize],
|
||||
poi_metrics: &PostcodePoiMetrics,
|
||||
|
|
|
|||
|
|
@ -87,6 +87,105 @@ pub struct AppState {
|
|||
pub bugsink_frontend_config: Option<BugsinkFrontendConfig>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl AppState {
|
||||
/// Minimal AppState for integration tests of the PocketBase/Stripe money
|
||||
/// paths (checkout, webhook, licensing, invites). All map/property data is
|
||||
/// empty; only the HTTP-facing config (PocketBase URL, Stripe secrets,
|
||||
/// caches) carries meaningful values.
|
||||
pub(crate) fn for_tests(pocketbase_url: String) -> Self {
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::data::{
|
||||
ActualListingData, CrimeByYearData, OutcodeData, POIData, PlaceData, PostcodeData,
|
||||
PropertyData, TravelTimeStore,
|
||||
};
|
||||
use crate::utils::InternedColumn;
|
||||
|
||||
let http_client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(5))
|
||||
.connect_timeout(Duration::from_secs(2))
|
||||
.build()
|
||||
.expect("test HTTP client should build");
|
||||
|
||||
AppState {
|
||||
data: PropertyData::empty_for_tests(),
|
||||
grid: GridIndex::build(&[], &[], 0.01),
|
||||
h3_cells: Vec::new(),
|
||||
feature_name_to_index: FxHashMap::default(),
|
||||
min_keys: Vec::new(),
|
||||
max_keys: Vec::new(),
|
||||
avg_keys: Vec::new(),
|
||||
features_response: FeaturesResponse { groups: Vec::new() },
|
||||
ai_filters_system_prompt: String::new(),
|
||||
poi_data: Arc::new(POIData::empty_for_tests()),
|
||||
poi_grid: Arc::new(GridIndex::build(&[], &[], 0.01)),
|
||||
place_data: Arc::new(PlaceData::empty_for_tests()),
|
||||
postcode_data: Arc::new(PostcodeData {
|
||||
postcodes: Vec::new(),
|
||||
centroids: Vec::new(),
|
||||
aabbs: Vec::new(),
|
||||
polygons: Vec::new(),
|
||||
postcode_to_idx: FxHashMap::default(),
|
||||
}),
|
||||
outcode_data: Arc::new(OutcodeData {
|
||||
names: Vec::new(),
|
||||
name_lower: Vec::new(),
|
||||
centroids: Vec::new(),
|
||||
cities: Vec::new(),
|
||||
}),
|
||||
poi_category_groups: Arc::new(Vec::new()),
|
||||
travel_time_store: Arc::new(TravelTimeStore::empty_for_tests()),
|
||||
actual_listings: Arc::new(ActualListingData {
|
||||
lat: Vec::new(),
|
||||
lon: Vec::new(),
|
||||
postcode: Vec::new(),
|
||||
address: Vec::new(),
|
||||
property_type: InternedColumn::build(&[]),
|
||||
property_sub_type: InternedColumn::build(&[]),
|
||||
leasehold_freehold: InternedColumn::build(&[]),
|
||||
price_qualifier: InternedColumn::build(&[]),
|
||||
bedrooms: Vec::new(),
|
||||
bathrooms: Vec::new(),
|
||||
rooms_total: Vec::new(),
|
||||
floor_area_sqm: Vec::new(),
|
||||
asking_price: Vec::new(),
|
||||
asking_price_per_sqm: Vec::new(),
|
||||
listing_url: Vec::new(),
|
||||
listing_status: InternedColumn::build(&[]),
|
||||
listing_date_iso: Vec::new(),
|
||||
features: Vec::new(),
|
||||
filter_feature_data: Vec::new(),
|
||||
poi_filter_feature_data: Vec::new(),
|
||||
grid: GridIndex::build(&[], &[], 0.01),
|
||||
}),
|
||||
crime_by_year: Arc::new(CrimeByYearData {
|
||||
crime_types: Vec::new(),
|
||||
years_by_type: Vec::new(),
|
||||
series_by_postcode: FxHashMap::default(),
|
||||
covered_years_by_postcode: FxHashMap::default(),
|
||||
}),
|
||||
token_cache: Arc::new(TokenCache::new()),
|
||||
superuser_token_cache: Arc::new(SuperuserTokenCache::new()),
|
||||
share_cache: Arc::new(ShareBoundsCache::new()),
|
||||
screenshot_url: "http://127.0.0.1:1/screenshot".to_string(),
|
||||
public_url: "https://test.example".to_string(),
|
||||
is_dev: false,
|
||||
http_client,
|
||||
pocketbase_url,
|
||||
pocketbase_admin_email: "admin@test.example".to_string(),
|
||||
pocketbase_admin_password: "test-admin-password".to_string(),
|
||||
gemini_api_key: "test-gemini-key".to_string(),
|
||||
gemini_model: "test-model".to_string(),
|
||||
google_maps_api_key: "test-maps-key".to_string(),
|
||||
stripe_secret_key: "sk_test_dummy".to_string(),
|
||||
stripe_webhook_secret: "whsec_test_secret".to_string(),
|
||||
stripe_referral_coupon_id: "couponTest30".to_string(),
|
||||
bugsink_frontend_config: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wraps AppState for shared access across route handlers.
|
||||
/// Route handlers call `load_state()` to get the current snapshot.
|
||||
pub struct SharedState {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue