perfect-postcode/server-rs/src/auth.rs

150 lines
4.1 KiB
Rust

use std::sync::Arc;
use std::time::Instant;
use axum::extract::Request;
use axum::middleware::Next;
use axum::response::Response;
use parking_lot::RwLock;
use reqwest::Client;
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use tracing::warn;
const TOKEN_TTL_SECS: u64 = 60;
const MAX_CACHE_ENTRIES: usize = 1000;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PocketBaseUser {
pub id: String,
pub email: String,
#[serde(default)]
pub verified: bool,
#[serde(default)]
pub is_admin: bool,
#[serde(default)]
pub subscription: String,
#[serde(default)]
pub newsletter: bool,
}
#[derive(Clone)]
pub struct OptionalUser(pub Option<PocketBaseUser>);
pub struct TokenCache {
entries: RwLock<FxHashMap<String, (PocketBaseUser, Instant)>>,
}
impl TokenCache {
pub fn new() -> Self {
Self {
entries: RwLock::new(FxHashMap::default()),
}
}
fn get(&self, token: &str) -> Option<PocketBaseUser> {
let map = self.entries.read();
if let Some((user, created)) = map.get(token) {
if created.elapsed().as_secs() < TOKEN_TTL_SECS {
return Some(user.clone());
}
}
None
}
fn insert(&self, token: String, user: PocketBaseUser) {
let mut map = self.entries.write();
if map.len() >= MAX_CACHE_ENTRIES {
// Evict expired entries first
let now = Instant::now();
map.retain(|_, (_, created)| now.duration_since(*created).as_secs() < TOKEN_TTL_SECS);
// If still too many, clear all
if map.len() >= MAX_CACHE_ENTRIES {
map.clear();
}
}
map.insert(token, (user, Instant::now()));
}
/// Remove all cached tokens for a given user ID so the next request re-validates.
pub fn invalidate_by_user_id(&self, user_id: &str) {
let mut map = self.entries.write();
map.retain(|_, (user, _)| user.id != user_id);
}
}
#[derive(Deserialize)]
struct AuthRefreshResponse {
record: PocketBaseUser,
}
async fn validate_token(
client: &Client,
pocketbase_url: &str,
token: &str,
) -> Option<PocketBaseUser> {
let url = format!(
"{}/api/collections/users/auth-refresh",
pocketbase_url.trim_end_matches('/')
);
let res = client
.post(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await
.map_err(|err| warn!("Token validation request failed: {err}"))
.ok()?;
if !res.status().is_success() {
let status = res.status();
let body = res.text().await.unwrap_or_default();
warn!("PocketBase auth-refresh returned {status}: {body}");
return None;
}
let body: AuthRefreshResponse = res
.json()
.await
.map_err(|err| warn!("Failed to parse auth refresh response: {err}"))
.ok()?;
Some(body.record)
}
pub async fn auth_middleware(req: Request, next: Next) -> Response {
let state = req
.extensions()
.get::<Arc<crate::state::AppState>>()
.cloned();
let token = req
.headers()
.get("authorization")
.and_then(|hv| hv.to_str().ok())
.and_then(|hv| hv.strip_prefix("Bearer "))
.map(String::from);
let user = match (&state, &token) {
(Some(st), Some(tk)) => {
if let Some(cached) = st.token_cache.get(tk) {
Some(cached)
} else {
match validate_token(&st.http_client, &st.pocketbase_url, tk).await {
Some(user) => {
st.token_cache.insert(tk.clone(), user.clone());
Some(user)
}
None => {
warn!("Invalid auth token");
None
}
}
}
}
_ => None,
};
let (mut parts, body) = req.into_parts();
parts.extensions.insert(OptionalUser(user));
let req = Request::from_parts(parts, body);
next.run(req).await
}