use std::collections::HashSet; use std::sync::{Arc, LazyLock, Mutex}; use axum::extract::{Path, State}; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use axum::{Extension, Json}; use serde::{Deserialize, Serialize}; use tracing::{info, warn}; use crate::auth::{OptionalUser, PocketBaseUser}; use crate::checkout_sessions::{ active_referral_checkout_user, grant_license_with_pricing_lock, start_license_checkout, CheckoutStart, }; use crate::pocketbase::get_superuser_token; use crate::pocketbase_locks::acquire_pocketbase_lock; use crate::state::{AppState, SharedState}; static INVITE_REDEMPTIONS_IN_PROGRESS: LazyLock>> = LazyLock::new(|| Mutex::new(HashSet::new())); const INVITE_REDEMPTION_LOCK_TTL_SECS: u64 = 5 * 60; struct InviteRedemptionGuard { code: String, } impl InviteRedemptionGuard { fn acquire(code: &str) -> Option { let mut in_progress = INVITE_REDEMPTIONS_IN_PROGRESS .lock() .unwrap_or_else(|poisoned| poisoned.into_inner()); if !in_progress.insert(code.to_string()) { return None; } Some(Self { code: code.to_string(), }) } } impl Drop for InviteRedemptionGuard { fn drop(&mut self) { let mut in_progress = INVITE_REDEMPTIONS_IN_PROGRESS .lock() .unwrap_or_else(|poisoned| poisoned.into_inner()); in_progress.remove(&self.code); } } #[derive(Serialize)] struct InviteResponse { code: String, url: String, invite_type: String, } #[derive(Serialize)] struct InviteListItem { code: String, url: String, invite_type: String, used: bool, created: String, } #[derive(Serialize)] struct InviteListResponse { invites: Vec, } #[derive(Serialize)] struct InviteValidation { valid: bool, invite_type: String, used: bool, invited_by: Option, } #[derive(Deserialize)] pub struct CreateInviteRequest { /// Admins can explicitly choose "admin" or "referral". Ignored for non-admins. invite_type: Option, } #[derive(Deserialize)] pub struct RedeemRequest { code: String, } #[derive(Serialize)] struct RedeemResponse { /// "licensed" if admin invite was redeemed directly, or a checkout URL for referral result: String, /// For referral invites: the Stripe checkout URL with coupon checkout_url: Option, } /// Validate that an invite code contains only safe characters (alphanumeric, lowercase). /// Rejects any code that could be used for PocketBase filter injection. fn validate_invite_code(code: &str) -> Result<(), &'static str> { if code.is_empty() || code.len() > 20 { return Err("Invalid invite code length"); } if !code.bytes().all(|b| b.is_ascii_alphanumeric()) { return Err("Invalid invite code characters"); } Ok(()) } /// Sanitize the inviter's display name returned to anonymous clients. /// The value comes from the inviter's email local-part stored in PocketBase; /// we don't trust it, so strip control chars and HTML-meaningful characters /// and cap the length. Returns None if nothing usable remains. fn sanitize_invited_by(raw: &str) -> Option { const MAX_LEN: usize = 40; let cleaned: String = raw .chars() .filter(|c| !c.is_control() && !matches!(*c, '<' | '>' | '"' | '\'' | '&' | '\\')) .take(MAX_LEN) .collect(); let trimmed = cleaned.trim(); if trimmed.is_empty() { None } else { Some(trimmed.to_string()) } } fn generate_invite_code() -> String { use rand::RngExt; let mut rng = rand::rng(); let chars: Vec = (0..12) .map(|_| { let idx: u8 = rng.random_range(0..36); if idx < 10 { (b'0' + idx) as char } else { (b'a' + idx - 10) as char } }) .collect(); chars.into_iter().collect() } fn current_unix_secs_string() -> String { std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs() .to_string() } /// Fetch the live `is_admin` flag for a user, bypassing any cached token /// claims. Returns Err with an HTTP response if PocketBase is unreachable /// or returns an unexpected payload — the caller should propagate that. async fn verify_is_admin( state: &AppState, pb_url: &str, token: &str, user_id: &str, ) -> Result { if user_id.is_empty() || user_id.len() > 32 || !user_id.bytes().all(|b| b.is_ascii_alphanumeric()) { return Err(StatusCode::FORBIDDEN.into_response()); } let url = format!("{pb_url}/api/collections/users/records/{user_id}"); let resp = match state .http_client .get(&url) .header("Authorization", format!("Bearer {token}")) .send() .await { Ok(r) => r, Err(err) => { warn!("Failed to verify is_admin: {err}"); return Err(StatusCode::BAD_GATEWAY.into_response()); } }; if !resp.status().is_success() { return Err(StatusCode::BAD_GATEWAY.into_response()); } let body: serde_json::Value = match resp.json().await { Ok(v) => v, Err(err) => { warn!("Failed to parse user record for is_admin verify: {err}"); return Err(StatusCode::BAD_GATEWAY.into_response()); } }; Ok(body["is_admin"].as_bool().unwrap_or(false)) } fn redeemable_invite_filter(code: &str, user_id: &str) -> Result { validate_invite_code(code)?; if user_id.is_empty() || user_id.len() > 32 || !user_id.bytes().all(|b| b.is_ascii_alphanumeric()) { return Err("Invalid user id"); } Ok(format!( "code=\"{}\" && (used_by_id=\"\" || used_by_id=\"{}\")", code, user_id )) } async fn lookup_redeemable_invite( state: &AppState, pb_url: &str, token: &str, code: &str, user_id: &str, ) -> Result, Response> { let filter = match redeemable_invite_filter(code, user_id) { Ok(filter) => filter, Err(msg) => return Err((StatusCode::BAD_REQUEST, msg).into_response()), }; let lookup_url = format!( "{pb_url}/api/collections/invites/records?filter={}&perPage=1", urlencoding::encode(&filter) ); let res = match state .http_client .get(&lookup_url) .header("Authorization", format!("Bearer {token}")) .send() .await { Ok(resp) => resp, Err(err) => { warn!("Failed to look up invite: {err}"); return Err(StatusCode::BAD_GATEWAY.into_response()); } }; if !res.status().is_success() { let status = res.status(); let text = res.text().await.unwrap_or_default(); warn!("PocketBase invite lookup failed ({status}): {text}"); return Err(StatusCode::BAD_GATEWAY.into_response()); } let body: serde_json::Value = match res.json().await { Ok(value) => value, Err(err) => { warn!("Failed to parse invite lookup response: {err}"); return Err(StatusCode::BAD_GATEWAY.into_response()); } }; Ok(body["items"] .as_array() .and_then(|arr| arr.first()) .cloned()) } async fn mark_invite_used( state: &AppState, pb_url: &str, token: &str, invite_id: &str, user_id: &str, ) -> Result<(), Response> { let resp = match state .http_client .patch(format!( "{pb_url}/api/collections/invites/records/{invite_id}" )) .header("Authorization", format!("Bearer {token}")) .json(&serde_json::json!({ "used_by_id": user_id, "used_at": current_unix_secs_string(), })) .send() .await { Ok(resp) => resp, Err(err) => { warn!("Failed to mark invite as used: {err}"); return Err(StatusCode::BAD_GATEWAY.into_response()); } }; if !resp.status().is_success() { let status = resp.status(); let text = resp.text().await.unwrap_or_default(); warn!("PocketBase invite usage update failed ({status}): {text}"); return Err(StatusCode::BAD_GATEWAY.into_response()); } Ok(()) } async fn grant_license_for_invite( state: &AppState, _pb_url: &str, _token: &str, user_id: &str, ) -> Result<(), Response> { grant_license_with_pricing_lock(state, user_id) .await .map_err(|err| { warn!("Failed to update user subscription for admin invite: {err}"); StatusCode::BAD_GATEWAY.into_response() }) } async fn create_referral_checkout( state: &AppState, user: &PocketBaseUser, invite_id: &str, ) -> Result { let public_url = &state.public_url; let success_url = format!("{public_url}/pricing?license_success=1"); let cancel_url = format!("{public_url}/pricing"); match start_license_checkout( state, user, &success_url, &cancel_url, Some(&state.stripe_referral_coupon_id), Some(invite_id), ) .await { Ok(CheckoutStart::Free) => Ok(success_url), Ok(CheckoutStart::Stripe { url }) => Ok(url), Err(err) => { warn!("Failed to create reserved Stripe checkout for referral invite: {err:?}"); Err(StatusCode::BAD_GATEWAY.into_response()) } } } /// Create an invite. Admins create "admin" invites (free license) by default, /// but can explicitly request "referral" type. Licensed non-admin users always create "referral" invites (30% off). pub async fn post_invites( State(shared): State>, Extension(user): Extension, Json(body): Json, ) -> Response { let state = shared.load_state(); let user = match user.0 { Some(u) => u, None => return StatusCode::UNAUTHORIZED.into_response(), }; // Cached token claims could be stale or, in the worst case, tampered with // upstream of us. For admin-only actions, re-fetch the live record from // PocketBase and trust only that. let wants_admin_invite = user.is_admin && !matches!(body.invite_type.as_deref(), Some("referral")); let pb_url = state.pocketbase_url.trim_end_matches('/'); let token = match get_superuser_token(&state).await { Ok(t) => t, Err(err) => { warn!("Failed to auth as PocketBase superuser: {err}"); return StatusCode::BAD_GATEWAY.into_response(); } }; let invite_type = if wants_admin_invite { match verify_is_admin(&state, pb_url, &token, &user.id).await { Ok(true) => "admin", Ok(false) => { warn!(user_id = %user.id, "is_admin claim rejected by live PB lookup"); return (StatusCode::FORBIDDEN, "Not authorised").into_response(); } Err(response) => return response, } } else if user.is_admin || user.subscription == "licensed" { "referral" } else { return ( StatusCode::FORBIDDEN, "Only licensed users can create invites", ) .into_response(); }; let code = generate_invite_code(); let create_url = format!("{pb_url}/api/collections/invites/records"); let res = state .http_client .post(&create_url) .header("Authorization", format!("Bearer {token}")) .json(&serde_json::json!({ "code": code, "created_by": user.id, "invite_type": invite_type, "used_by_id": "", "used_at": "", })) .send() .await; match res { Ok(resp) if resp.status().is_success() => { let public_url = &state.public_url; let url = format!("{public_url}/invite/{code}"); info!(code = %code, invite_type, user_id = %user.id, "Created invite"); Json(InviteResponse { code, url, invite_type: invite_type.to_string(), }) .into_response() } Ok(resp) => { let status = resp.status(); let text = resp.text().await.unwrap_or_default(); warn!("Failed to create invite ({status}): {text}"); StatusCode::BAD_GATEWAY.into_response() } Err(err) => { warn!("PocketBase request error: {err}"); StatusCode::BAD_GATEWAY.into_response() } } } /// Dev-only fake invite code (12 alphanumeric chars, passes validation). /// Only recognized when `--dist` is not set (i.e., dev mode). const DEV_INVITE_CODE: &str = "devdevdevdev"; /// Validate an invite code. Public endpoint — codes are 12-char random alphanumeric /// so enumeration is impractical, and the response only reveals valid/invalid + type. pub async fn get_invite( State(shared): State>, Extension(_user): Extension, Path(code): Path, ) -> Response { let state = shared.load_state(); if let Err(msg) = validate_invite_code(&code) { return (StatusCode::BAD_REQUEST, msg).into_response(); } // Dev-only: return a fake valid admin invite without hitting PocketBase if state.is_dev && code == DEV_INVITE_CODE { return Json(InviteValidation { valid: true, invite_type: "admin".to_string(), used: false, invited_by: Some("Developer".to_string()), }) .into_response(); } let pb_url = state.pocketbase_url.trim_end_matches('/'); let token = match get_superuser_token(&state).await { Ok(t) => t, Err(err) => { warn!("Failed to auth as PocketBase superuser: {err}"); return StatusCode::BAD_GATEWAY.into_response(); } }; let filter = format!("code=\"{}\"", code); let url = format!( "{pb_url}/api/collections/invites/records?filter={}&perPage=1", urlencoding::encode(&filter) ); let res = match state .http_client .get(&url) .header("Authorization", format!("Bearer {token}")) .send() .await { Ok(r) => r, Err(err) => { warn!("Failed to look up invite: {err}"); return StatusCode::BAD_GATEWAY.into_response(); } }; if !res.status().is_success() { return StatusCode::BAD_GATEWAY.into_response(); } let body: serde_json::Value = match res.json().await { Ok(v) => v, Err(_) => return StatusCode::BAD_GATEWAY.into_response(), }; let items = body["items"].as_array(); match items.and_then(|arr| arr.first()) { Some(invite) => { let invite_type = invite["invite_type"].as_str().unwrap_or("").to_string(); let used_by = invite["used_by_id"].as_str().unwrap_or(""); let used = !used_by.is_empty(); let created_by = invite["created_by"].as_str().unwrap_or(""); // Look up inviter's name (email local part) — sanitized before returning. let invited_by = if !created_by.is_empty() { let user_url = format!("{pb_url}/api/collections/users/records/{created_by}"); match state .http_client .get(&user_url) .header("Authorization", format!("Bearer {token}")) .send() .await { Ok(resp) if resp.status().is_success() => { let user_body: serde_json::Value = resp.json().await.unwrap_or_default(); user_body["email"] .as_str() .and_then(|e| e.split('@').next()) .and_then(sanitize_invited_by) } _ => None, } } else { None }; Json(InviteValidation { valid: true, invite_type, used, invited_by, }) .into_response() } None => Json(InviteValidation { valid: false, invite_type: String::new(), used: false, invited_by: None, }) .into_response(), } } /// Redeem an invite code. Requires authentication. /// Admin invite: sets subscription to "licensed" directly. /// Referral invite: returns a discounted Stripe checkout URL. pub async fn post_redeem_invite( State(shared): State>, Extension(user): Extension, Json(req): Json, ) -> Response { let state = shared.load_state(); let user = match user.0 { Some(u) => u, None => return StatusCode::UNAUTHORIZED.into_response(), }; if let Err(msg) = validate_invite_code(&req.code) { return (StatusCode::BAD_REQUEST, msg).into_response(); } // Dev-only: fake redeem — just return "licensed" without touching PocketBase if state.is_dev && req.code == DEV_INVITE_CODE { info!(user_id = %user.id, "Dev invite redeemed (no-op)"); return Json(RedeemResponse { result: "licensed".to_string(), checkout_url: None, }) .into_response(); } if user.is_admin || user.subscription == "licensed" { return (StatusCode::CONFLICT, "Account already has full access").into_response(); } let pb_url = state.pocketbase_url.trim_end_matches('/'); let token = match get_superuser_token(&state).await { Ok(t) => t, Err(err) => { warn!("Failed to auth as PocketBase superuser: {err}"); return StatusCode::BAD_GATEWAY.into_response(); } }; let _redemption_guard = match InviteRedemptionGuard::acquire(&req.code) { Some(guard) => guard, None => { return ( StatusCode::CONFLICT, "Invite redemption is already in progress", ) .into_response() } }; let lock_name = format!("invite:{}", req.code); let _distributed_redemption_guard = match acquire_pocketbase_lock(&state, &lock_name, INVITE_REDEMPTION_LOCK_TTL_SECS).await { Ok(guard) => guard, Err(err) => { warn!(code = %req.code, "Failed to acquire invite redemption lock: {err}"); return ( StatusCode::CONFLICT, "Invite redemption is already in progress", ) .into_response(); } }; let invite = match lookup_redeemable_invite(&state, pb_url, &token, &req.code, &user.id).await { Ok(Some(invite)) => invite, Ok(None) => { return (StatusCode::NOT_FOUND, "Invalid or already used invite code").into_response() } Err(response) => return response, }; let invite_id = match invite["id"].as_str().filter(|id| !id.is_empty()) { Some(id) => id, None => { warn!(code = %req.code, "Invite lookup returned record without id"); return StatusCode::BAD_GATEWAY.into_response(); } }; let invite_type = match invite["invite_type"].as_str() { Some("admin") => "admin", Some("referral") => "referral", Some(other) => { warn!(code = %req.code, invite_type = other, "Invite has unsupported type"); return StatusCode::BAD_GATEWAY.into_response(); } None => { warn!(code = %req.code, "Invite lookup returned record without invite_type"); return StatusCode::BAD_GATEWAY.into_response(); } }; let used_by_id = invite["used_by_id"].as_str().unwrap_or_default(); if !used_by_id.is_empty() && used_by_id != user.id { return (StatusCode::NOT_FOUND, "Invalid or already used invite code").into_response(); } if invite_type == "admin" { if let Err(response) = mark_invite_used(&state, pb_url, &token, invite_id, &user.id).await { return response; } if let Err(response) = grant_license_for_invite(&state, pb_url, &token, &user.id).await { return response; } info!(user_id = %user.id, code = %req.code, "Admin invite redeemed; user licensed"); return Json(RedeemResponse { result: "licensed".to_string(), checkout_url: None, }) .into_response(); } if !used_by_id.is_empty() { return (StatusCode::NOT_FOUND, "Invalid or already used invite code").into_response(); } match active_referral_checkout_user(&state, invite_id).await { Ok(Some(active_user_id)) if active_user_id != user.id => { return ( StatusCode::CONFLICT, "Invite checkout is already in progress", ) .into_response() } Ok(_) => {} Err(err) => { warn!(code = %req.code, "Failed to check active referral checkout: {err}"); return StatusCode::BAD_GATEWAY.into_response(); } } let checkout_url = match create_referral_checkout(&state, &user, invite_id).await { Ok(url) => url, Err(response) => return response, }; info!(user_id = %user.id, code = %req.code, "Referral invite redeemed; checkout created"); Json(RedeemResponse { result: "checkout".to_string(), checkout_url: Some(checkout_url), }) .into_response() } #[cfg(test)] mod tests { use super::*; #[test] fn redeemable_invite_filter_allows_unused_or_same_user_invite() { let filter = redeemable_invite_filter("abc123", "user123").unwrap(); assert_eq!( filter, "code=\"abc123\" && (used_by_id=\"\" || used_by_id=\"user123\")" ); } #[test] fn redeemable_invite_filter_rejects_unsafe_values() { assert!(redeemable_invite_filter("bad-code", "user123").is_err()); assert!(redeemable_invite_filter("abc123", "bad-user").is_err()); } } /// List invites. Users only see invites they created, including admins. pub async fn get_invites( State(shared): State>, Extension(user): Extension, ) -> Response { let state = shared.load_state(); let user = match user.0 { Some(u) => u, None => return StatusCode::UNAUTHORIZED.into_response(), }; let pb_url = state.pocketbase_url.trim_end_matches('/'); let token = match get_superuser_token(&state).await { Ok(t) => t, Err(err) => { warn!("Failed to auth as PocketBase superuser: {err}"); return StatusCode::BAD_GATEWAY.into_response(); } }; let filter = format!("created_by=\"{}\"", user.id); let mut url = format!("{pb_url}/api/collections/invites/records?sort=-created&perPage=200"); url.push_str(&format!("&filter={}", urlencoding::encode(&filter))); let res = match state .http_client .get(&url) .header("Authorization", format!("Bearer {token}")) .send() .await { Ok(r) => r, Err(err) => { warn!("Failed to list invites: {err}"); return StatusCode::BAD_GATEWAY.into_response(); } }; if !res.status().is_success() { let status = res.status(); let text = res.text().await.unwrap_or_default(); warn!("PocketBase list invites failed ({status}): {text}"); return StatusCode::BAD_GATEWAY.into_response(); } let body: serde_json::Value = match res.json().await { Ok(v) => v, Err(err) => { warn!("Failed to parse invites response: {err}"); return StatusCode::BAD_GATEWAY.into_response(); } }; let public_url = &state.public_url; let invites: Vec = body["items"] .as_array() .map(|arr| { arr.iter() .map(|inv| { let code = inv["code"].as_str().unwrap_or("").to_string(); let invite_type = inv["invite_type"].as_str().unwrap_or("").to_string(); let used_by = inv["used_by_id"].as_str().unwrap_or(""); let created = inv["created"].as_str().unwrap_or("").to_string(); InviteListItem { url: format!("{public_url}/invite/{code}"), code, invite_type, used: !used_by.is_empty(), created, } }) .collect() }) .unwrap_or_default(); Json(InviteListResponse { invites }).into_response() }