150 lines
4.1 KiB
Rust
150 lines
4.1 KiB
Rust
use std::sync::Arc;
|
|
use std::time::Instant;
|
|
|
|
use axum::extract::Request;
|
|
use axum::middleware::Next;
|
|
use axum::response::Response;
|
|
use parking_lot::RwLock;
|
|
use reqwest::Client;
|
|
use rustc_hash::FxHashMap;
|
|
use serde::{Deserialize, Serialize};
|
|
use tracing::warn;
|
|
|
|
const TOKEN_TTL_SECS: u64 = 60;
|
|
const MAX_CACHE_ENTRIES: usize = 1000;
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct PocketBaseUser {
|
|
pub id: String,
|
|
pub email: String,
|
|
#[serde(default)]
|
|
pub verified: bool,
|
|
#[serde(default)]
|
|
pub is_admin: bool,
|
|
#[serde(default)]
|
|
pub subscription: String,
|
|
#[serde(default)]
|
|
pub newsletter: bool,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct OptionalUser(pub Option<PocketBaseUser>);
|
|
|
|
pub struct TokenCache {
|
|
entries: RwLock<FxHashMap<String, (PocketBaseUser, Instant)>>,
|
|
}
|
|
|
|
impl TokenCache {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
entries: RwLock::new(FxHashMap::default()),
|
|
}
|
|
}
|
|
|
|
fn get(&self, token: &str) -> Option<PocketBaseUser> {
|
|
let map = self.entries.read();
|
|
if let Some((user, created)) = map.get(token) {
|
|
if created.elapsed().as_secs() < TOKEN_TTL_SECS {
|
|
return Some(user.clone());
|
|
}
|
|
}
|
|
None
|
|
}
|
|
|
|
fn insert(&self, token: String, user: PocketBaseUser) {
|
|
let mut map = self.entries.write();
|
|
if map.len() >= MAX_CACHE_ENTRIES {
|
|
// Evict expired entries first
|
|
let now = Instant::now();
|
|
map.retain(|_, (_, created)| now.duration_since(*created).as_secs() < TOKEN_TTL_SECS);
|
|
// If still too many, clear all
|
|
if map.len() >= MAX_CACHE_ENTRIES {
|
|
map.clear();
|
|
}
|
|
}
|
|
map.insert(token, (user, Instant::now()));
|
|
}
|
|
|
|
/// Remove all cached tokens for a given user ID so the next request re-validates.
|
|
pub fn invalidate_by_user_id(&self, user_id: &str) {
|
|
let mut map = self.entries.write();
|
|
map.retain(|_, (user, _)| user.id != user_id);
|
|
}
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct AuthRefreshResponse {
|
|
record: PocketBaseUser,
|
|
}
|
|
|
|
async fn validate_token(
|
|
client: &Client,
|
|
pocketbase_url: &str,
|
|
token: &str,
|
|
) -> Option<PocketBaseUser> {
|
|
let url = format!(
|
|
"{}/api/collections/users/auth-refresh",
|
|
pocketbase_url.trim_end_matches('/')
|
|
);
|
|
let res = client
|
|
.post(&url)
|
|
.header("Authorization", format!("Bearer {token}"))
|
|
.send()
|
|
.await
|
|
.map_err(|err| warn!("Token validation request failed: {err}"))
|
|
.ok()?;
|
|
|
|
if !res.status().is_success() {
|
|
let status = res.status();
|
|
let body = res.text().await.unwrap_or_default();
|
|
warn!("PocketBase auth-refresh returned {status}: {body}");
|
|
return None;
|
|
}
|
|
|
|
let body: AuthRefreshResponse = res
|
|
.json()
|
|
.await
|
|
.map_err(|err| warn!("Failed to parse auth refresh response: {err}"))
|
|
.ok()?;
|
|
Some(body.record)
|
|
}
|
|
|
|
pub async fn auth_middleware(req: Request, next: Next) -> Response {
|
|
let state = req
|
|
.extensions()
|
|
.get::<Arc<crate::state::AppState>>()
|
|
.cloned();
|
|
|
|
let token = req
|
|
.headers()
|
|
.get("authorization")
|
|
.and_then(|hv| hv.to_str().ok())
|
|
.and_then(|hv| hv.strip_prefix("Bearer "))
|
|
.map(String::from);
|
|
|
|
let user = match (&state, &token) {
|
|
(Some(st), Some(tk)) => {
|
|
if let Some(cached) = st.token_cache.get(tk) {
|
|
Some(cached)
|
|
} else {
|
|
match validate_token(&st.http_client, &st.pocketbase_url, tk).await {
|
|
Some(user) => {
|
|
st.token_cache.insert(tk.clone(), user.clone());
|
|
Some(user)
|
|
}
|
|
None => {
|
|
warn!("Invalid auth token");
|
|
None
|
|
}
|
|
}
|
|
}
|
|
}
|
|
_ => None,
|
|
};
|
|
|
|
let (mut parts, body) = req.into_parts();
|
|
parts.extensions.insert(OptionalUser(user));
|
|
let req = Request::from_parts(parts, body);
|
|
|
|
next.run(req).await
|
|
}
|