perfect-postcode/server-rs/src/routes/invites.rs
2026-05-04 16:19:09 +01:00

704 lines
22 KiB
Rust

use std::collections::HashSet;
use std::sync::{Arc, LazyLock, Mutex};
use axum::extract::{Path, State};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::{Extension, Json};
use serde::{Deserialize, Serialize};
use tracing::{info, warn};
use crate::auth::{OptionalUser, PocketBaseUser};
use crate::pocketbase::get_superuser_token;
use crate::state::{AppState, SharedState};
static INVITE_REDEMPTIONS_IN_PROGRESS: LazyLock<Mutex<HashSet<String>>> =
LazyLock::new(|| Mutex::new(HashSet::new()));
struct InviteRedemptionGuard {
code: String,
}
impl InviteRedemptionGuard {
fn acquire(code: &str) -> Option<Self> {
let mut in_progress = INVITE_REDEMPTIONS_IN_PROGRESS
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
if !in_progress.insert(code.to_string()) {
return None;
}
Some(Self {
code: code.to_string(),
})
}
}
impl Drop for InviteRedemptionGuard {
fn drop(&mut self) {
let mut in_progress = INVITE_REDEMPTIONS_IN_PROGRESS
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
in_progress.remove(&self.code);
}
}
#[derive(Serialize)]
struct InviteResponse {
code: String,
url: String,
invite_type: String,
}
#[derive(Serialize)]
struct InviteListItem {
code: String,
url: String,
invite_type: String,
used: bool,
created: String,
}
#[derive(Serialize)]
struct InviteListResponse {
invites: Vec<InviteListItem>,
}
#[derive(Serialize)]
struct InviteValidation {
valid: bool,
invite_type: String,
used: bool,
invited_by: Option<String>,
}
#[derive(Deserialize)]
pub struct CreateInviteRequest {
/// Admins can explicitly choose "admin" or "referral". Ignored for non-admins.
invite_type: Option<String>,
}
#[derive(Deserialize)]
pub struct RedeemRequest {
code: String,
}
#[derive(Serialize)]
struct RedeemResponse {
/// "licensed" if admin invite was redeemed directly, or a checkout URL for referral
result: String,
/// For referral invites: the Stripe checkout URL with coupon
checkout_url: Option<String>,
}
/// Validate that an invite code contains only safe characters (alphanumeric, lowercase).
/// Rejects any code that could be used for PocketBase filter injection.
fn validate_invite_code(code: &str) -> Result<(), &'static str> {
if code.is_empty() || code.len() > 20 {
return Err("Invalid invite code length");
}
if !code.bytes().all(|b| b.is_ascii_alphanumeric()) {
return Err("Invalid invite code characters");
}
Ok(())
}
fn generate_invite_code() -> String {
use rand::Rng;
let mut rng = rand::rng();
let chars: Vec<char> = (0..12)
.map(|_| {
let idx: u8 = rng.random_range(0..36);
if idx < 10 {
(b'0' + idx) as char
} else {
(b'a' + idx - 10) as char
}
})
.collect();
chars.into_iter().collect()
}
fn current_unix_secs_string() -> String {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
.to_string()
}
async fn lookup_unused_invite(
state: &AppState,
pb_url: &str,
token: &str,
code: &str,
) -> Result<Option<serde_json::Value>, Response> {
let filter = format!("code=\"{}\" && used_by_id=\"\"", code);
let lookup_url = format!(
"{pb_url}/api/collections/invites/records?filter={}&perPage=1",
urlencoding::encode(&filter)
);
let res = match state
.http_client
.get(&lookup_url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await
{
Ok(resp) => resp,
Err(err) => {
warn!("Failed to look up invite: {err}");
return Err(StatusCode::BAD_GATEWAY.into_response());
}
};
if !res.status().is_success() {
let status = res.status();
let text = res.text().await.unwrap_or_default();
warn!("PocketBase invite lookup failed ({status}): {text}");
return Err(StatusCode::BAD_GATEWAY.into_response());
}
let body: serde_json::Value = match res.json().await {
Ok(value) => value,
Err(err) => {
warn!("Failed to parse invite lookup response: {err}");
return Err(StatusCode::BAD_GATEWAY.into_response());
}
};
Ok(body["items"]
.as_array()
.and_then(|arr| arr.first())
.cloned())
}
async fn mark_invite_used(
state: &AppState,
pb_url: &str,
token: &str,
invite_id: &str,
user_id: &str,
) -> Result<(), Response> {
let resp = match state
.http_client
.patch(format!(
"{pb_url}/api/collections/invites/records/{invite_id}"
))
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"used_by_id": user_id,
"used_at": current_unix_secs_string(),
}))
.send()
.await
{
Ok(resp) => resp,
Err(err) => {
warn!("Failed to mark invite as used: {err}");
return Err(StatusCode::BAD_GATEWAY.into_response());
}
};
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
warn!("PocketBase invite usage update failed ({status}): {text}");
return Err(StatusCode::BAD_GATEWAY.into_response());
}
Ok(())
}
async fn grant_license_for_invite(
state: &AppState,
pb_url: &str,
token: &str,
user_id: &str,
) -> Result<(), Response> {
let update_url = format!("{pb_url}/api/collections/users/records/{user_id}");
let resp = match state
.http_client
.patch(&update_url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({ "subscription": "licensed" }))
.send()
.await
{
Ok(resp) => resp,
Err(err) => {
warn!("Failed to update user subscription for admin invite: {err}");
return Err(StatusCode::BAD_GATEWAY.into_response());
}
};
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
warn!("PocketBase user subscription update failed ({status}): {text}");
return Err(StatusCode::BAD_GATEWAY.into_response());
}
state.token_cache.invalidate_by_user_id(user_id);
Ok(())
}
async fn create_referral_checkout(
state: &AppState,
user: &PocketBaseUser,
) -> Result<String, Response> {
let count = match super::pricing::count_licensed_users(state).await {
Ok(count) => count,
Err(err) => {
warn!("Failed to count licensed users for invite checkout: {err}");
return Err(StatusCode::SERVICE_UNAVAILABLE.into_response());
}
};
let price_pence = super::pricing::price_for_count(count);
let public_url = &state.public_url;
let success_url = format!("{public_url}/pricing?license_success=1");
let cancel_url = format!("{public_url}/pricing");
let form_params = vec![
("mode", "payment".to_string()),
(
"line_items[0][price_data][unit_amount]",
price_pence.to_string(),
),
("line_items[0][price_data][currency]", "gbp".to_string()),
(
"line_items[0][price_data][product_data][name]",
"Perfect Postcodes Lifetime License".to_string(),
),
("line_items[0][quantity]", "1".to_string()),
("success_url", success_url),
("cancel_url", cancel_url),
("client_reference_id", user.id.clone()),
("customer_email", user.email.clone()),
(
"discounts[0][coupon]",
state.stripe_referral_coupon_id.clone(),
),
];
let stripe_res = state
.http_client
.post("https://api.stripe.com/v1/checkout/sessions")
.basic_auth(&state.stripe_secret_key, None::<&str>)
.form(&form_params)
.send()
.await;
match stripe_res {
Ok(resp) if resp.status().is_success() => {
let stripe_body: serde_json::Value = match resp.json().await {
Ok(value) => value,
Err(err) => {
warn!("Failed to parse Stripe checkout response: {err}");
return Err(StatusCode::BAD_GATEWAY.into_response());
}
};
let checkout_url = stripe_body["url"].as_str().unwrap_or_default().to_string();
if checkout_url.is_empty() {
warn!("Stripe checkout response did not include a URL");
return Err(StatusCode::BAD_GATEWAY.into_response());
}
Ok(checkout_url)
}
Ok(resp) => {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
warn!("Failed to create Stripe checkout for referral invite ({status}): {text}");
Err(StatusCode::BAD_GATEWAY.into_response())
}
Err(err) => {
warn!("Stripe request error for referral invite: {err}");
Err(StatusCode::BAD_GATEWAY.into_response())
}
}
}
/// Create an invite. Admins create "admin" invites (free license) by default,
/// but can explicitly request "referral" type. Licensed non-admin users always create "referral" invites (30% off).
pub async fn post_invites(
State(shared): State<Arc<SharedState>>,
Extension(user): Extension<OptionalUser>,
Json(body): Json<CreateInviteRequest>,
) -> Response {
let state = shared.load_state();
let user = match user.0 {
Some(u) => u,
None => return StatusCode::UNAUTHORIZED.into_response(),
};
let invite_type = if user.is_admin {
match body.invite_type.as_deref() {
Some("referral") => "referral",
_ => "admin",
}
} else if user.subscription == "licensed" {
"referral"
} else {
return (
StatusCode::FORBIDDEN,
"Only licensed users can create invites",
)
.into_response();
};
let code = generate_invite_code();
let pb_url = state.pocketbase_url.trim_end_matches('/');
let token = match get_superuser_token(&state).await {
Ok(t) => t,
Err(err) => {
warn!("Failed to auth as PocketBase superuser: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
};
let create_url = format!("{pb_url}/api/collections/invites/records");
let res = state
.http_client
.post(&create_url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"code": code,
"created_by": user.id,
"invite_type": invite_type,
"used_by_id": "",
"used_at": "",
}))
.send()
.await;
match res {
Ok(resp) if resp.status().is_success() => {
let public_url = &state.public_url;
let url = format!("{public_url}/invite/{code}");
info!(code = %code, invite_type, user_id = %user.id, "Created invite");
Json(InviteResponse {
code,
url,
invite_type: invite_type.to_string(),
})
.into_response()
}
Ok(resp) => {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
warn!("Failed to create invite ({status}): {text}");
StatusCode::BAD_GATEWAY.into_response()
}
Err(err) => {
warn!("PocketBase request error: {err}");
StatusCode::BAD_GATEWAY.into_response()
}
}
}
/// Dev-only fake invite code (12 alphanumeric chars, passes validation).
/// Only recognized when `--dist` is not set (i.e., dev mode).
const DEV_INVITE_CODE: &str = "devdevdevdev";
/// Validate an invite code. Public endpoint — codes are 12-char random alphanumeric
/// so enumeration is impractical, and the response only reveals valid/invalid + type.
pub async fn get_invite(
State(shared): State<Arc<SharedState>>,
Extension(_user): Extension<OptionalUser>,
Path(code): Path<String>,
) -> Response {
let state = shared.load_state();
if let Err(msg) = validate_invite_code(&code) {
return (StatusCode::BAD_REQUEST, msg).into_response();
}
// Dev-only: return a fake valid admin invite without hitting PocketBase
if state.is_dev && code == DEV_INVITE_CODE {
return Json(InviteValidation {
valid: true,
invite_type: "admin".to_string(),
used: false,
invited_by: Some("Developer".to_string()),
})
.into_response();
}
let pb_url = state.pocketbase_url.trim_end_matches('/');
let token = match get_superuser_token(&state).await {
Ok(t) => t,
Err(err) => {
warn!("Failed to auth as PocketBase superuser: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
};
let filter = format!("code=\"{}\"", code);
let url = format!(
"{pb_url}/api/collections/invites/records?filter={}&perPage=1",
urlencoding::encode(&filter)
);
let res = match state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await
{
Ok(r) => r,
Err(err) => {
warn!("Failed to look up invite: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
};
if !res.status().is_success() {
return StatusCode::BAD_GATEWAY.into_response();
}
let body: serde_json::Value = match res.json().await {
Ok(v) => v,
Err(_) => return StatusCode::BAD_GATEWAY.into_response(),
};
let items = body["items"].as_array();
match items.and_then(|arr| arr.first()) {
Some(invite) => {
let invite_type = invite["invite_type"].as_str().unwrap_or("").to_string();
let used_by = invite["used_by_id"].as_str().unwrap_or("");
let used = !used_by.is_empty();
let created_by = invite["created_by"].as_str().unwrap_or("");
// Look up inviter's name (email local part)
let invited_by = if !created_by.is_empty() {
let user_url = format!("{pb_url}/api/collections/users/records/{created_by}");
match state
.http_client
.get(&user_url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await
{
Ok(resp) if resp.status().is_success() => {
let user_body: serde_json::Value = resp.json().await.unwrap_or_default();
user_body["email"]
.as_str()
.and_then(|e| e.split('@').next())
.map(String::from)
}
_ => None,
}
} else {
None
};
Json(InviteValidation {
valid: true,
invite_type,
used,
invited_by,
})
.into_response()
}
None => Json(InviteValidation {
valid: false,
invite_type: String::new(),
used: false,
invited_by: None,
})
.into_response(),
}
}
/// Redeem an invite code. Requires authentication.
/// Admin invite: sets subscription to "licensed" directly.
/// Referral invite: returns a discounted Stripe checkout URL.
pub async fn post_redeem_invite(
State(shared): State<Arc<SharedState>>,
Extension(user): Extension<OptionalUser>,
Json(req): Json<RedeemRequest>,
) -> Response {
let state = shared.load_state();
let user = match user.0 {
Some(u) => u,
None => return StatusCode::UNAUTHORIZED.into_response(),
};
if let Err(msg) = validate_invite_code(&req.code) {
return (StatusCode::BAD_REQUEST, msg).into_response();
}
// Dev-only: fake redeem — just return "licensed" without touching PocketBase
if state.is_dev && req.code == DEV_INVITE_CODE {
info!(user_id = %user.id, "Dev invite redeemed (no-op)");
return Json(RedeemResponse {
result: "licensed".to_string(),
checkout_url: None,
})
.into_response();
}
let pb_url = state.pocketbase_url.trim_end_matches('/');
let token = match get_superuser_token(&state).await {
Ok(t) => t,
Err(err) => {
warn!("Failed to auth as PocketBase superuser: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
};
let _redemption_guard = match InviteRedemptionGuard::acquire(&req.code) {
Some(guard) => guard,
None => {
return (
StatusCode::CONFLICT,
"Invite redemption is already in progress",
)
.into_response()
}
};
let invite = match lookup_unused_invite(&state, pb_url, &token, &req.code).await {
Ok(Some(invite)) => invite,
Ok(None) => {
return (StatusCode::NOT_FOUND, "Invalid or already used invite code").into_response()
}
Err(response) => return response,
};
let invite_id = match invite["id"].as_str().filter(|id| !id.is_empty()) {
Some(id) => id,
None => {
warn!(code = %req.code, "Invite lookup returned record without id");
return StatusCode::BAD_GATEWAY.into_response();
}
};
let invite_type = match invite["invite_type"].as_str() {
Some("admin") => "admin",
Some("referral") => "referral",
Some(other) => {
warn!(code = %req.code, invite_type = other, "Invite has unsupported type");
return StatusCode::BAD_GATEWAY.into_response();
}
None => {
warn!(code = %req.code, "Invite lookup returned record without invite_type");
return StatusCode::BAD_GATEWAY.into_response();
}
};
if invite_type == "admin" {
if let Err(response) = grant_license_for_invite(&state, pb_url, &token, &user.id).await {
return response;
}
if let Err(response) = mark_invite_used(&state, pb_url, &token, invite_id, &user.id).await {
return response;
}
info!(user_id = %user.id, code = %req.code, "Admin invite redeemed; user licensed");
return Json(RedeemResponse {
result: "licensed".to_string(),
checkout_url: None,
})
.into_response();
}
let checkout_url = match create_referral_checkout(&state, &user).await {
Ok(url) => url,
Err(response) => return response,
};
if let Err(response) = mark_invite_used(&state, pb_url, &token, invite_id, &user.id).await {
return response;
}
info!(user_id = %user.id, code = %req.code, "Referral invite redeemed; checkout created");
Json(RedeemResponse {
result: "checkout".to_string(),
checkout_url: Some(checkout_url),
})
.into_response()
}
/// List invites. Users only see invites they created, including admins.
pub async fn get_invites(
State(shared): State<Arc<SharedState>>,
Extension(user): Extension<OptionalUser>,
) -> Response {
let state = shared.load_state();
let user = match user.0 {
Some(u) => u,
None => return StatusCode::UNAUTHORIZED.into_response(),
};
let pb_url = state.pocketbase_url.trim_end_matches('/');
let token = match get_superuser_token(&state).await {
Ok(t) => t,
Err(err) => {
warn!("Failed to auth as PocketBase superuser: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
};
let filter = format!("created_by=\"{}\"", user.id);
let mut url = format!("{pb_url}/api/collections/invites/records?sort=-created&perPage=200");
url.push_str(&format!("&filter={}", urlencoding::encode(&filter)));
let res = match state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await
{
Ok(r) => r,
Err(err) => {
warn!("Failed to list invites: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
};
if !res.status().is_success() {
let status = res.status();
let text = res.text().await.unwrap_or_default();
warn!("PocketBase list invites failed ({status}): {text}");
return StatusCode::BAD_GATEWAY.into_response();
}
let body: serde_json::Value = match res.json().await {
Ok(v) => v,
Err(err) => {
warn!("Failed to parse invites response: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
};
let public_url = &state.public_url;
let invites: Vec<InviteListItem> = body["items"]
.as_array()
.map(|arr| {
arr.iter()
.map(|inv| {
let code = inv["code"].as_str().unwrap_or("").to_string();
let invite_type = inv["invite_type"].as_str().unwrap_or("").to_string();
let used_by = inv["used_by_id"].as_str().unwrap_or("");
let created = inv["created"].as_str().unwrap_or("").to_string();
InviteListItem {
url: format!("{public_url}/invite/{code}"),
code,
invite_type,
used: !used_by.is_empty(),
created,
}
})
.collect()
})
.unwrap_or_default();
Json(InviteListResponse { invites }).into_response()
}