use std::collections::HashSet; use std::sync::{Arc, LazyLock, Mutex}; use axum::extract::{Path, State}; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use axum::{Extension, Json}; use serde::{Deserialize, Serialize}; use tracing::{info, warn}; use crate::auth::{OptionalUser, PocketBaseUser}; use crate::pocketbase::get_superuser_token; use crate::state::{AppState, SharedState}; static INVITE_REDEMPTIONS_IN_PROGRESS: LazyLock>> = LazyLock::new(|| Mutex::new(HashSet::new())); struct InviteRedemptionGuard { code: String, } impl InviteRedemptionGuard { fn acquire(code: &str) -> Option { let mut in_progress = INVITE_REDEMPTIONS_IN_PROGRESS .lock() .unwrap_or_else(|poisoned| poisoned.into_inner()); if !in_progress.insert(code.to_string()) { return None; } Some(Self { code: code.to_string(), }) } } impl Drop for InviteRedemptionGuard { fn drop(&mut self) { let mut in_progress = INVITE_REDEMPTIONS_IN_PROGRESS .lock() .unwrap_or_else(|poisoned| poisoned.into_inner()); in_progress.remove(&self.code); } } #[derive(Serialize)] struct InviteResponse { code: String, url: String, invite_type: String, } #[derive(Serialize)] struct InviteListItem { code: String, url: String, invite_type: String, used: bool, created: String, } #[derive(Serialize)] struct InviteListResponse { invites: Vec, } #[derive(Serialize)] struct InviteValidation { valid: bool, invite_type: String, used: bool, invited_by: Option, } #[derive(Deserialize)] pub struct CreateInviteRequest { /// Admins can explicitly choose "admin" or "referral". Ignored for non-admins. invite_type: Option, } #[derive(Deserialize)] pub struct RedeemRequest { code: String, } #[derive(Serialize)] struct RedeemResponse { /// "licensed" if admin invite was redeemed directly, or a checkout URL for referral result: String, /// For referral invites: the Stripe checkout URL with coupon checkout_url: Option, } /// Validate that an invite code contains only safe characters (alphanumeric, lowercase). /// Rejects any code that could be used for PocketBase filter injection. fn validate_invite_code(code: &str) -> Result<(), &'static str> { if code.is_empty() || code.len() > 20 { return Err("Invalid invite code length"); } if !code.bytes().all(|b| b.is_ascii_alphanumeric()) { return Err("Invalid invite code characters"); } Ok(()) } fn generate_invite_code() -> String { use rand::Rng; let mut rng = rand::rng(); let chars: Vec = (0..12) .map(|_| { let idx: u8 = rng.random_range(0..36); if idx < 10 { (b'0' + idx) as char } else { (b'a' + idx - 10) as char } }) .collect(); chars.into_iter().collect() } fn current_unix_secs_string() -> String { std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs() .to_string() } async fn lookup_unused_invite( state: &AppState, pb_url: &str, token: &str, code: &str, ) -> Result, Response> { let filter = format!("code=\"{}\" && used_by_id=\"\"", code); let lookup_url = format!( "{pb_url}/api/collections/invites/records?filter={}&perPage=1", urlencoding::encode(&filter) ); let res = match state .http_client .get(&lookup_url) .header("Authorization", format!("Bearer {token}")) .send() .await { Ok(resp) => resp, Err(err) => { warn!("Failed to look up invite: {err}"); return Err(StatusCode::BAD_GATEWAY.into_response()); } }; if !res.status().is_success() { let status = res.status(); let text = res.text().await.unwrap_or_default(); warn!("PocketBase invite lookup failed ({status}): {text}"); return Err(StatusCode::BAD_GATEWAY.into_response()); } let body: serde_json::Value = match res.json().await { Ok(value) => value, Err(err) => { warn!("Failed to parse invite lookup response: {err}"); return Err(StatusCode::BAD_GATEWAY.into_response()); } }; Ok(body["items"] .as_array() .and_then(|arr| arr.first()) .cloned()) } async fn mark_invite_used( state: &AppState, pb_url: &str, token: &str, invite_id: &str, user_id: &str, ) -> Result<(), Response> { let resp = match state .http_client .patch(format!( "{pb_url}/api/collections/invites/records/{invite_id}" )) .header("Authorization", format!("Bearer {token}")) .json(&serde_json::json!({ "used_by_id": user_id, "used_at": current_unix_secs_string(), })) .send() .await { Ok(resp) => resp, Err(err) => { warn!("Failed to mark invite as used: {err}"); return Err(StatusCode::BAD_GATEWAY.into_response()); } }; if !resp.status().is_success() { let status = resp.status(); let text = resp.text().await.unwrap_or_default(); warn!("PocketBase invite usage update failed ({status}): {text}"); return Err(StatusCode::BAD_GATEWAY.into_response()); } Ok(()) } async fn grant_license_for_invite( state: &AppState, pb_url: &str, token: &str, user_id: &str, ) -> Result<(), Response> { let update_url = format!("{pb_url}/api/collections/users/records/{user_id}"); let resp = match state .http_client .patch(&update_url) .header("Authorization", format!("Bearer {token}")) .json(&serde_json::json!({ "subscription": "licensed" })) .send() .await { Ok(resp) => resp, Err(err) => { warn!("Failed to update user subscription for admin invite: {err}"); return Err(StatusCode::BAD_GATEWAY.into_response()); } }; if !resp.status().is_success() { let status = resp.status(); let text = resp.text().await.unwrap_or_default(); warn!("PocketBase user subscription update failed ({status}): {text}"); return Err(StatusCode::BAD_GATEWAY.into_response()); } state.token_cache.invalidate_by_user_id(user_id); Ok(()) } async fn create_referral_checkout( state: &AppState, user: &PocketBaseUser, ) -> Result { let count = match super::pricing::count_licensed_users(state).await { Ok(count) => count, Err(err) => { warn!("Failed to count licensed users for invite checkout: {err}"); return Err(StatusCode::SERVICE_UNAVAILABLE.into_response()); } }; let price_pence = super::pricing::price_for_count(count); let public_url = &state.public_url; let success_url = format!("{public_url}/pricing?license_success=1"); let cancel_url = format!("{public_url}/pricing"); let form_params = vec![ ("mode", "payment".to_string()), ( "line_items[0][price_data][unit_amount]", price_pence.to_string(), ), ("line_items[0][price_data][currency]", "gbp".to_string()), ( "line_items[0][price_data][product_data][name]", "Perfect Postcodes Lifetime License".to_string(), ), ("line_items[0][quantity]", "1".to_string()), ("success_url", success_url), ("cancel_url", cancel_url), ("client_reference_id", user.id.clone()), ("customer_email", user.email.clone()), ( "discounts[0][coupon]", state.stripe_referral_coupon_id.clone(), ), ]; let stripe_res = state .http_client .post("https://api.stripe.com/v1/checkout/sessions") .basic_auth(&state.stripe_secret_key, None::<&str>) .form(&form_params) .send() .await; match stripe_res { Ok(resp) if resp.status().is_success() => { let stripe_body: serde_json::Value = match resp.json().await { Ok(value) => value, Err(err) => { warn!("Failed to parse Stripe checkout response: {err}"); return Err(StatusCode::BAD_GATEWAY.into_response()); } }; let checkout_url = stripe_body["url"].as_str().unwrap_or_default().to_string(); if checkout_url.is_empty() { warn!("Stripe checkout response did not include a URL"); return Err(StatusCode::BAD_GATEWAY.into_response()); } Ok(checkout_url) } Ok(resp) => { let status = resp.status(); let text = resp.text().await.unwrap_or_default(); warn!("Failed to create Stripe checkout for referral invite ({status}): {text}"); Err(StatusCode::BAD_GATEWAY.into_response()) } Err(err) => { warn!("Stripe request error for referral invite: {err}"); Err(StatusCode::BAD_GATEWAY.into_response()) } } } /// Create an invite. Admins create "admin" invites (free license) by default, /// but can explicitly request "referral" type. Licensed non-admin users always create "referral" invites (30% off). pub async fn post_invites( State(shared): State>, Extension(user): Extension, Json(body): Json, ) -> Response { let state = shared.load_state(); let user = match user.0 { Some(u) => u, None => return StatusCode::UNAUTHORIZED.into_response(), }; let invite_type = if user.is_admin { match body.invite_type.as_deref() { Some("referral") => "referral", _ => "admin", } } else if user.subscription == "licensed" { "referral" } else { return ( StatusCode::FORBIDDEN, "Only licensed users can create invites", ) .into_response(); }; let code = generate_invite_code(); let pb_url = state.pocketbase_url.trim_end_matches('/'); let token = match get_superuser_token(&state).await { Ok(t) => t, Err(err) => { warn!("Failed to auth as PocketBase superuser: {err}"); return StatusCode::BAD_GATEWAY.into_response(); } }; let create_url = format!("{pb_url}/api/collections/invites/records"); let res = state .http_client .post(&create_url) .header("Authorization", format!("Bearer {token}")) .json(&serde_json::json!({ "code": code, "created_by": user.id, "invite_type": invite_type, "used_by_id": "", "used_at": "", })) .send() .await; match res { Ok(resp) if resp.status().is_success() => { let public_url = &state.public_url; let url = format!("{public_url}/invite/{code}"); info!(code = %code, invite_type, user_id = %user.id, "Created invite"); Json(InviteResponse { code, url, invite_type: invite_type.to_string(), }) .into_response() } Ok(resp) => { let status = resp.status(); let text = resp.text().await.unwrap_or_default(); warn!("Failed to create invite ({status}): {text}"); StatusCode::BAD_GATEWAY.into_response() } Err(err) => { warn!("PocketBase request error: {err}"); StatusCode::BAD_GATEWAY.into_response() } } } /// Dev-only fake invite code (12 alphanumeric chars, passes validation). /// Only recognized when `--dist` is not set (i.e., dev mode). const DEV_INVITE_CODE: &str = "devdevdevdev"; /// Validate an invite code. Public endpoint — codes are 12-char random alphanumeric /// so enumeration is impractical, and the response only reveals valid/invalid + type. pub async fn get_invite( State(shared): State>, Extension(_user): Extension, Path(code): Path, ) -> Response { let state = shared.load_state(); if let Err(msg) = validate_invite_code(&code) { return (StatusCode::BAD_REQUEST, msg).into_response(); } // Dev-only: return a fake valid admin invite without hitting PocketBase if state.is_dev && code == DEV_INVITE_CODE { return Json(InviteValidation { valid: true, invite_type: "admin".to_string(), used: false, invited_by: Some("Developer".to_string()), }) .into_response(); } let pb_url = state.pocketbase_url.trim_end_matches('/'); let token = match get_superuser_token(&state).await { Ok(t) => t, Err(err) => { warn!("Failed to auth as PocketBase superuser: {err}"); return StatusCode::BAD_GATEWAY.into_response(); } }; let filter = format!("code=\"{}\"", code); let url = format!( "{pb_url}/api/collections/invites/records?filter={}&perPage=1", urlencoding::encode(&filter) ); let res = match state .http_client .get(&url) .header("Authorization", format!("Bearer {token}")) .send() .await { Ok(r) => r, Err(err) => { warn!("Failed to look up invite: {err}"); return StatusCode::BAD_GATEWAY.into_response(); } }; if !res.status().is_success() { return StatusCode::BAD_GATEWAY.into_response(); } let body: serde_json::Value = match res.json().await { Ok(v) => v, Err(_) => return StatusCode::BAD_GATEWAY.into_response(), }; let items = body["items"].as_array(); match items.and_then(|arr| arr.first()) { Some(invite) => { let invite_type = invite["invite_type"].as_str().unwrap_or("").to_string(); let used_by = invite["used_by_id"].as_str().unwrap_or(""); let used = !used_by.is_empty(); let created_by = invite["created_by"].as_str().unwrap_or(""); // Look up inviter's name (email local part) let invited_by = if !created_by.is_empty() { let user_url = format!("{pb_url}/api/collections/users/records/{created_by}"); match state .http_client .get(&user_url) .header("Authorization", format!("Bearer {token}")) .send() .await { Ok(resp) if resp.status().is_success() => { let user_body: serde_json::Value = resp.json().await.unwrap_or_default(); user_body["email"] .as_str() .and_then(|e| e.split('@').next()) .map(String::from) } _ => None, } } else { None }; Json(InviteValidation { valid: true, invite_type, used, invited_by, }) .into_response() } None => Json(InviteValidation { valid: false, invite_type: String::new(), used: false, invited_by: None, }) .into_response(), } } /// Redeem an invite code. Requires authentication. /// Admin invite: sets subscription to "licensed" directly. /// Referral invite: returns a discounted Stripe checkout URL. pub async fn post_redeem_invite( State(shared): State>, Extension(user): Extension, Json(req): Json, ) -> Response { let state = shared.load_state(); let user = match user.0 { Some(u) => u, None => return StatusCode::UNAUTHORIZED.into_response(), }; if let Err(msg) = validate_invite_code(&req.code) { return (StatusCode::BAD_REQUEST, msg).into_response(); } // Dev-only: fake redeem — just return "licensed" without touching PocketBase if state.is_dev && req.code == DEV_INVITE_CODE { info!(user_id = %user.id, "Dev invite redeemed (no-op)"); return Json(RedeemResponse { result: "licensed".to_string(), checkout_url: None, }) .into_response(); } let pb_url = state.pocketbase_url.trim_end_matches('/'); let token = match get_superuser_token(&state).await { Ok(t) => t, Err(err) => { warn!("Failed to auth as PocketBase superuser: {err}"); return StatusCode::BAD_GATEWAY.into_response(); } }; let _redemption_guard = match InviteRedemptionGuard::acquire(&req.code) { Some(guard) => guard, None => { return ( StatusCode::CONFLICT, "Invite redemption is already in progress", ) .into_response() } }; let invite = match lookup_unused_invite(&state, pb_url, &token, &req.code).await { Ok(Some(invite)) => invite, Ok(None) => { return (StatusCode::NOT_FOUND, "Invalid or already used invite code").into_response() } Err(response) => return response, }; let invite_id = match invite["id"].as_str().filter(|id| !id.is_empty()) { Some(id) => id, None => { warn!(code = %req.code, "Invite lookup returned record without id"); return StatusCode::BAD_GATEWAY.into_response(); } }; let invite_type = match invite["invite_type"].as_str() { Some("admin") => "admin", Some("referral") => "referral", Some(other) => { warn!(code = %req.code, invite_type = other, "Invite has unsupported type"); return StatusCode::BAD_GATEWAY.into_response(); } None => { warn!(code = %req.code, "Invite lookup returned record without invite_type"); return StatusCode::BAD_GATEWAY.into_response(); } }; if invite_type == "admin" { if let Err(response) = grant_license_for_invite(&state, pb_url, &token, &user.id).await { return response; } if let Err(response) = mark_invite_used(&state, pb_url, &token, invite_id, &user.id).await { return response; } info!(user_id = %user.id, code = %req.code, "Admin invite redeemed; user licensed"); return Json(RedeemResponse { result: "licensed".to_string(), checkout_url: None, }) .into_response(); } let checkout_url = match create_referral_checkout(&state, &user).await { Ok(url) => url, Err(response) => return response, }; if let Err(response) = mark_invite_used(&state, pb_url, &token, invite_id, &user.id).await { return response; } info!(user_id = %user.id, code = %req.code, "Referral invite redeemed; checkout created"); Json(RedeemResponse { result: "checkout".to_string(), checkout_url: Some(checkout_url), }) .into_response() } /// List invites. Users only see invites they created, including admins. pub async fn get_invites( State(shared): State>, Extension(user): Extension, ) -> Response { let state = shared.load_state(); let user = match user.0 { Some(u) => u, None => return StatusCode::UNAUTHORIZED.into_response(), }; let pb_url = state.pocketbase_url.trim_end_matches('/'); let token = match get_superuser_token(&state).await { Ok(t) => t, Err(err) => { warn!("Failed to auth as PocketBase superuser: {err}"); return StatusCode::BAD_GATEWAY.into_response(); } }; let filter = format!("created_by=\"{}\"", user.id); let mut url = format!("{pb_url}/api/collections/invites/records?sort=-created&perPage=200"); url.push_str(&format!("&filter={}", urlencoding::encode(&filter))); let res = match state .http_client .get(&url) .header("Authorization", format!("Bearer {token}")) .send() .await { Ok(r) => r, Err(err) => { warn!("Failed to list invites: {err}"); return StatusCode::BAD_GATEWAY.into_response(); } }; if !res.status().is_success() { let status = res.status(); let text = res.text().await.unwrap_or_default(); warn!("PocketBase list invites failed ({status}): {text}"); return StatusCode::BAD_GATEWAY.into_response(); } let body: serde_json::Value = match res.json().await { Ok(v) => v, Err(err) => { warn!("Failed to parse invites response: {err}"); return StatusCode::BAD_GATEWAY.into_response(); } }; let public_url = &state.public_url; let invites: Vec = body["items"] .as_array() .map(|arr| { arr.iter() .map(|inv| { let code = inv["code"].as_str().unwrap_or("").to_string(); let invite_type = inv["invite_type"].as_str().unwrap_or("").to_string(); let used_by = inv["used_by_id"].as_str().unwrap_or(""); let created = inv["created"].as_str().unwrap_or("").to_string(); InviteListItem { url: format!("{public_url}/invite/{code}"), code, invite_type, used: !used_by.is_empty(), created, } }) .collect() }) .unwrap_or_default(); Json(InviteListResponse { invites }).into_response() }