use std::sync::Arc; use axum::body::Bytes; use axum::extract::State; use axum::http::{HeaderMap, StatusCode}; use axum::response::{IntoResponse, Response}; use hmac::{Hmac, Mac}; use sha2::Sha256; use tracing::{info, warn}; use crate::pocketbase::get_superuser_token; use crate::state::SharedState; type HmacSha256 = Hmac; /// Verify Stripe webhook signature (v1 scheme). fn verify_signature(payload: &[u8], sig_header: &str, secret: &str) -> bool { // Parse timestamp and signature from header: "t=TIMESTAMP,v1=SIGNATURE" let mut timestamp = None; let mut signature = None; for part in sig_header.split(',') { if let Some(ts) = part.strip_prefix("t=") { timestamp = Some(ts); } else if let Some(sig) = part.strip_prefix("v1=") { signature = Some(sig); } } let (ts, sig_hex) = match (timestamp, signature) { (Some(t), Some(s)) => (t, s), _ => return false, }; // Reject webhooks older than 5 minutes to prevent replay attacks if let Ok(ts_secs) = ts.parse::() { let now = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs() as i64; if (now - ts_secs).abs() > 300 { return false; } } else { return false; } // Compute expected signature: HMAC-SHA256(secret, "TIMESTAMP.PAYLOAD") let signed_payload = format!("{ts}.{}", String::from_utf8_lossy(payload)); let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) { Ok(m) => m, Err(_) => return false, }; mac.update(signed_payload.as_bytes()); // Decode the provided hex signature and verify with constant-time comparison let sig_bytes = match hex::decode(sig_hex) { Ok(bytes) => bytes, Err(_) => return false, }; mac.verify_slice(&sig_bytes).is_ok() } /// Handle Stripe webhook events. /// On `checkout.session.completed`, updates the user's subscription to "licensed". pub async fn post_stripe_webhook( State(shared): State>, headers: HeaderMap, body: Bytes, ) -> Response { let state = shared.load_state(); let webhook_secret = &state.stripe_webhook_secret; let sig_header = match headers .get("stripe-signature") .and_then(|h| h.to_str().ok()) { Some(s) => s, None => { warn!("Missing Stripe-Signature header"); return StatusCode::BAD_REQUEST.into_response(); } }; if !verify_signature(&body, sig_header, webhook_secret) { warn!("Invalid Stripe webhook signature"); return StatusCode::BAD_REQUEST.into_response(); } let event: serde_json::Value = match serde_json::from_slice(&body) { Ok(v) => v, Err(err) => { warn!("Failed to parse webhook body: {err}"); return StatusCode::BAD_REQUEST.into_response(); } }; let event_type = event["type"].as_str().unwrap_or(""); info!(event_type, "Received Stripe webhook"); if event_type == "checkout.session.completed" { let user_id = event["data"]["object"]["client_reference_id"] .as_str() .unwrap_or(""); if user_id.is_empty() { warn!("checkout.session.completed missing client_reference_id"); return StatusCode::OK.into_response(); } if !user_id.bytes().all(|b| b.is_ascii_alphanumeric()) || user_id.len() > 20 { warn!(user_id, "Invalid client_reference_id format in webhook"); return StatusCode::BAD_REQUEST.into_response(); } // Update user subscription to "licensed" via PocketBase superuser auth let token = match get_superuser_token(&state).await { Ok(t) => t, Err(err) => { warn!("Failed to auth as PocketBase superuser in webhook: {err}"); return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } }; let pb_url = state.pocketbase_url.trim_end_matches('/'); let url = format!("{pb_url}/api/collections/users/records/{user_id}"); let res = state .http_client .patch(&url) .header("Authorization", format!("Bearer {token}")) .json(&serde_json::json!({ "subscription": "licensed" })) .send() .await; match res { Ok(resp) if resp.status().is_success() => { state.token_cache.invalidate_by_user_id(user_id); info!( user_id, "User subscription updated to licensed via Stripe webhook" ); } Ok(resp) => { let status = resp.status(); let text = resp.text().await.unwrap_or_default(); warn!( user_id, "Failed to update user subscription ({status}): {text}" ); } Err(err) => { warn!(user_id, "PocketBase request error in webhook: {err}"); } } } StatusCode::OK.into_response() }