Various server improvements

This commit is contained in:
Andras Schmelczer 2026-03-26 21:19:06 +00:00
parent 3fe5f49050
commit 233ce1254b
10 changed files with 177 additions and 55 deletions

View file

@ -9,7 +9,7 @@ use axum_extra::{
TypedHeader,
headers::{Authorization, authorization::Bearer},
};
use log::info;
use log::{debug, info};
use crate::{
app_state::{AppState, database::models::VaultId},
@ -21,10 +21,12 @@ use crate::{
pub async fn auth_middleware(
State(state): State<AppState>,
Path(path_params): Path<HashMap<String, String>>,
TypedHeader(auth_header): TypedHeader<Authorization<Bearer>>,
auth_header: Option<TypedHeader<Authorization<Bearer>>>,
mut req: Request,
next: Next,
) -> Result<Response, SyncServerError> {
let auth_header = auth_header
.ok_or_else(|| unauthenticated_error(anyhow::anyhow!("Missing Authorization header")))?;
let token = auth_header.token().trim();
let vault_id = normalize_string(
path_params
@ -51,8 +53,8 @@ pub fn auth(state: &AppState, token: &str, vault_id: &VaultId) -> Result<User, S
VaultAccess::AllowAccessToAll => true,
VaultAccess::AllowList(AllowListedVaults { ref allowed }) => allowed.contains(vault_id),
} {
info!(
"User `{}` is authenticated and is authorised to access to vault `{vault_id}`",
debug!(
"User `{}` is authenticated and is authorised to access vault `{vault_id}`",
user.name
);

View file

@ -0,0 +1,72 @@
use std::sync::{
Arc,
atomic::{AtomicU64, Ordering},
};
use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
/// Simple token-bucket rate limiter that refills every second.
#[derive(Clone, Debug)]
pub struct RateLimiter {
inner: Arc<TokenBucket>,
}
#[derive(Debug)]
struct TokenBucket {
tokens: AtomicU64,
max_tokens: u64,
}
impl RateLimiter {
/// Create a new rate limiter. Spawns a background task that refills tokens
/// every second.
///
/// # Panics
///
/// Panics if `max_per_second` is 0.
pub fn new(max_per_second: u64) -> Self {
assert!(
max_per_second > 0,
"max_per_second must be > 0 (use 0 in config to disable rate limiting entirely)"
);
let bucket = Arc::new(TokenBucket {
tokens: AtomicU64::new(max_per_second),
max_tokens: max_per_second,
});
let bucket_clone = bucket.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(1));
loop {
interval.tick().await;
bucket_clone
.tokens
.store(bucket_clone.max_tokens, Ordering::Release);
}
});
Self { inner: bucket }
}
fn try_acquire(&self) -> bool {
self.inner
.tokens
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
if current > 0 { Some(current - 1) } else { None }
})
.is_ok()
}
}
pub async fn rate_limit_middleware(
axum::extract::State(limiter): axum::extract::State<RateLimiter>,
req: Request,
next: Next,
) -> Result<Response, StatusCode> {
if limiter.try_acquire() {
Ok(next.run(req).await)
} else {
Err(StatusCode::TOO_MANY_REQUESTS)
}
}

View file

@ -4,21 +4,18 @@ use reconcile_text::NumberOrText;
use serde::{self, Deserialize};
use ts_rs::TS;
use crate::app_state::database::models::{DocumentId, VaultUpdateId};
use crate::app_state::database::models::VaultUpdateId;
#[derive(TS, Debug, TryFromMultipart)]
#[ts(export)]
pub struct CreateDocumentVersion {
/// The client can decide the document id (if it wishes to) in order
/// to help with syncing. If the client does not provide a document id,
/// the server will generate one. If the client provides a document id
/// it must not already exist in the database.
pub document_id: Option<DocumentId>,
pub relative_path: String,
#[ts(as = "Vec<u8>")]
#[form_data(limit = "unlimited")]
pub content: FieldData<Bytes>,
pub idempotency_key: Option<String>,
}
#[derive(Debug, TryFromMultipart)]
@ -34,7 +31,7 @@ pub struct UpdateBinaryDocumentVersion {
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct UpdateTextDocumentVersion {
#[ts(as = "i32")]
#[ts(type = "number")]
pub parent_version_id: VaultUpdateId,
pub relative_path: String,
@ -43,9 +40,5 @@ pub struct UpdateTextDocumentVersion {
pub content: Vec<NumberOrText>,
}
#[derive(TS, Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct DeleteDocumentVersion {
pub relative_path: String,
}
#[derive(Debug, Deserialize)]
pub struct DeleteDocumentVersion {}