use std::sync::Arc; use std::time::Instant; use axum::extract::Request; use axum::middleware::Next; use axum::response::Response; use parking_lot::RwLock; use reqwest::Client; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use tracing::warn; const TOKEN_TTL_SECS: u64 = 60; const MAX_CACHE_ENTRIES: usize = 1000; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PocketBaseUser { pub id: String, pub email: String, #[serde(default)] pub verified: bool, #[serde(default)] pub is_admin: bool, #[serde(default)] pub subscription: String, #[serde(default)] pub newsletter: bool, } #[derive(Clone)] pub struct OptionalUser(pub Option); pub struct TokenCache { entries: RwLock>, } impl TokenCache { pub fn new() -> Self { Self { entries: RwLock::new(FxHashMap::default()), } } fn get(&self, token: &str) -> Option { let map = self.entries.read(); if let Some((user, created)) = map.get(token) { if created.elapsed().as_secs() < TOKEN_TTL_SECS { return Some(user.clone()); } } None } fn insert(&self, token: String, user: PocketBaseUser) { let mut map = self.entries.write(); if map.len() >= MAX_CACHE_ENTRIES { // Evict expired entries first let now = Instant::now(); map.retain(|_, (_, created)| now.duration_since(*created).as_secs() < TOKEN_TTL_SECS); // If still too many, clear all if map.len() >= MAX_CACHE_ENTRIES { map.clear(); } } map.insert(token, (user, Instant::now())); } /// Remove all cached tokens for a given user ID so the next request re-validates. pub fn invalidate_by_user_id(&self, user_id: &str) { let mut map = self.entries.write(); map.retain(|_, (user, _)| user.id != user_id); } } #[derive(Deserialize)] struct AuthRefreshResponse { record: PocketBaseUser, } async fn validate_token( client: &Client, pocketbase_url: &str, token: &str, ) -> Option { let url = format!( "{}/api/collections/users/auth-refresh", pocketbase_url.trim_end_matches('/') ); let res = client .post(&url) .header("Authorization", format!("Bearer {token}")) .send() .await .map_err(|err| warn!("Token validation request failed: {err}")) .ok()?; if !res.status().is_success() { let status = res.status(); let body = res.text().await.unwrap_or_default(); warn!("PocketBase auth-refresh returned {status}: {body}"); return None; } let body: AuthRefreshResponse = res .json() .await .map_err(|err| warn!("Failed to parse auth refresh response: {err}")) .ok()?; Some(body.record) } pub async fn auth_middleware(req: Request, next: Next) -> Response { let state = req .extensions() .get::>() .cloned(); let token = req .headers() .get("authorization") .and_then(|hv| hv.to_str().ok()) .and_then(|hv| hv.strip_prefix("Bearer ")) .map(String::from); let user = match (&state, &token) { (Some(st), Some(tk)) => { if let Some(cached) = st.token_cache.get(tk) { Some(cached) } else { match validate_token(&st.http_client, &st.pocketbase_url, tk).await { Some(user) => { st.token_cache.insert(tk.clone(), user.clone()); Some(user) } None => { warn!("Invalid auth token"); None } } } } _ => None, }; let (mut parts, body) = req.into_parts(); parts.extensions.insert(OptionalUser(user)); let req = Request::from_parts(parts, body); next.run(req).await }