125 lines
3.8 KiB
Rust
125 lines
3.8 KiB
Rust
use std::sync::Arc;
|
|
|
|
use axum::extract::State;
|
|
use axum::http::StatusCode;
|
|
use axum::response::{IntoResponse, Response};
|
|
use axum::{Extension, Json};
|
|
use serde::{Deserialize, Serialize};
|
|
use tracing::{info, warn};
|
|
|
|
use crate::auth::OptionalUser;
|
|
use crate::checkout_sessions::{start_license_checkout, CheckoutStart};
|
|
use crate::state::SharedState;
|
|
|
|
#[derive(Deserialize)]
|
|
pub struct CheckoutRequest {
|
|
referral_code: Option<String>,
|
|
return_path: Option<String>,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct CheckoutResponse {
|
|
url: String,
|
|
}
|
|
|
|
fn sanitize_return_path(path: Option<&str>) -> &str {
|
|
let Some(path) = path else {
|
|
return "/pricing";
|
|
};
|
|
let path = path.split('#').next().unwrap_or(path);
|
|
if path.is_empty()
|
|
|| path.len() > 2048
|
|
|| !path.starts_with('/')
|
|
|| path.starts_with("//")
|
|
|| path.chars().any(char::is_control)
|
|
{
|
|
return "/pricing";
|
|
}
|
|
path
|
|
}
|
|
|
|
fn append_query_param(path: &str, key: &str, value: &str) -> String {
|
|
let separator = if path.contains('?') { '&' } else { '?' };
|
|
format!("{path}{separator}{key}={value}")
|
|
}
|
|
|
|
/// Create a reserved Stripe Checkout session for the lifetime license.
|
|
/// Requires authentication. Referral discounts are issued via invite redemption.
|
|
pub async fn post_checkout(
|
|
State(shared): State<Arc<SharedState>>,
|
|
Extension(user): Extension<OptionalUser>,
|
|
Json(req): Json<CheckoutRequest>,
|
|
) -> Response {
|
|
let state = shared.load_state();
|
|
let user = match user.0 {
|
|
Some(u) => u,
|
|
None => return StatusCode::UNAUTHORIZED.into_response(),
|
|
};
|
|
|
|
let public_url = state.public_url.trim_end_matches('/');
|
|
let return_path = sanitize_return_path(req.return_path.as_deref());
|
|
let success_url = format!(
|
|
"{public_url}{}",
|
|
append_query_param(return_path, "license_success", "1")
|
|
);
|
|
let cancel_url = format!("{public_url}{return_path}");
|
|
|
|
if req.referral_code.is_some() {
|
|
return (
|
|
StatusCode::BAD_REQUEST,
|
|
"Referral codes must be redeemed from the invite link",
|
|
)
|
|
.into_response();
|
|
}
|
|
|
|
if user.is_admin || user.subscription == "licensed" {
|
|
return (StatusCode::CONFLICT, "This account already has full access").into_response();
|
|
}
|
|
|
|
match start_license_checkout(&state, &user, &success_url, &cancel_url, None, None).await {
|
|
Ok(CheckoutStart::Free) => {
|
|
info!(user_id = %user.id, "Granted free early-bird license");
|
|
Json(CheckoutResponse { url: success_url }).into_response()
|
|
}
|
|
Ok(CheckoutStart::Stripe { url }) => Json(CheckoutResponse { url }).into_response(),
|
|
Err(err) => {
|
|
warn!(user_id = %user.id, "Failed to start checkout: {err:?}");
|
|
StatusCode::BAD_GATEWAY.into_response()
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn sanitize_return_path_accepts_local_paths_and_strips_fragments() {
|
|
assert_eq!(
|
|
sanitize_return_path(Some("/map?postcode=SW1A#details")),
|
|
"/map?postcode=SW1A"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn sanitize_return_path_rejects_external_or_control_paths() {
|
|
assert_eq!(sanitize_return_path(Some("//evil.test/path")), "/pricing");
|
|
assert_eq!(
|
|
sanitize_return_path(Some("https://evil.test/path")),
|
|
"/pricing"
|
|
);
|
|
assert_eq!(sanitize_return_path(Some("/map\nbad")), "/pricing");
|
|
}
|
|
|
|
#[test]
|
|
fn append_query_param_preserves_existing_query_separator() {
|
|
assert_eq!(
|
|
append_query_param("/map?postcode=SW1A", "license_success", "1"),
|
|
"/map?postcode=SW1A&license_success=1"
|
|
);
|
|
assert_eq!(
|
|
append_query_param("/pricing", "license_success", "1"),
|
|
"/pricing?license_success=1"
|
|
);
|
|
}
|
|
}
|