use std::sync::Arc; use axum::extract::State; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use axum::{Extension, Json}; use serde::{Deserialize, Serialize}; use tracing::{info, warn}; use crate::auth::OptionalUser; use crate::checkout_sessions::{start_license_checkout, CheckoutStart}; use crate::state::SharedState; #[derive(Deserialize)] pub struct CheckoutRequest { referral_code: Option, return_path: Option, } #[derive(Serialize)] struct CheckoutResponse { url: String, } fn sanitize_return_path(path: Option<&str>) -> &str { let Some(path) = path else { return "/pricing"; }; let path = path.split('#').next().unwrap_or(path); if path.is_empty() || path.len() > 2048 || !path.starts_with('/') || path.starts_with("//") || path.chars().any(char::is_control) { return "/pricing"; } path } fn append_query_param(path: &str, key: &str, value: &str) -> String { let separator = if path.contains('?') { '&' } else { '?' }; format!("{path}{separator}{key}={value}") } /// Create a reserved Stripe Checkout session for the lifetime license. /// Requires authentication. Referral discounts are issued via invite redemption. pub async fn post_checkout( State(shared): State>, Extension(user): Extension, Json(req): Json, ) -> Response { let state = shared.load_state(); let user = match user.0 { Some(u) => u, None => return StatusCode::UNAUTHORIZED.into_response(), }; let public_url = state.public_url.trim_end_matches('/'); let return_path = sanitize_return_path(req.return_path.as_deref()); let success_url = format!( "{public_url}{}", append_query_param(return_path, "license_success", "1") ); let cancel_url = format!("{public_url}{return_path}"); if req.referral_code.is_some() { return ( StatusCode::BAD_REQUEST, "Referral codes must be redeemed from the invite link", ) .into_response(); } if user.is_admin || user.subscription == "licensed" { return (StatusCode::CONFLICT, "This account already has full access").into_response(); } match start_license_checkout(&state, &user, &success_url, &cancel_url, None, None).await { Ok(CheckoutStart::Free) => { info!(user_id = %user.id, "Granted free early-bird license"); Json(CheckoutResponse { url: success_url }).into_response() } Ok(CheckoutStart::Stripe { url }) => Json(CheckoutResponse { url }).into_response(), Err(err) => { warn!(user_id = %user.id, "Failed to start checkout: {err:?}"); StatusCode::BAD_GATEWAY.into_response() } } } #[cfg(test)] mod tests { use super::*; #[test] fn sanitize_return_path_accepts_local_paths_and_strips_fragments() { assert_eq!( sanitize_return_path(Some("/map?postcode=SW1A#details")), "/map?postcode=SW1A" ); } #[test] fn sanitize_return_path_rejects_external_or_control_paths() { assert_eq!(sanitize_return_path(Some("//evil.test/path")), "/pricing"); assert_eq!( sanitize_return_path(Some("https://evil.test/path")), "/pricing" ); assert_eq!(sanitize_return_path(Some("/map\nbad")), "/pricing"); } #[test] fn append_query_param_preserves_existing_query_separator() { assert_eq!( append_query_param("/map?postcode=SW1A", "license_success", "1"), "/map?postcode=SW1A&license_success=1" ); assert_eq!( append_query_param("/pricing", "license_success", "1"), "/pricing?license_success=1" ); } }