Various server improvements
This commit is contained in:
parent
3fe5f49050
commit
233ce1254b
10 changed files with 177 additions and 55 deletions
|
|
@ -9,7 +9,7 @@ use axum_extra::{
|
|||
TypedHeader,
|
||||
headers::{Authorization, authorization::Bearer},
|
||||
};
|
||||
use log::info;
|
||||
use log::{debug, info};
|
||||
|
||||
use crate::{
|
||||
app_state::{AppState, database::models::VaultId},
|
||||
|
|
@ -21,10 +21,12 @@ use crate::{
|
|||
pub async fn auth_middleware(
|
||||
State(state): State<AppState>,
|
||||
Path(path_params): Path<HashMap<String, String>>,
|
||||
TypedHeader(auth_header): TypedHeader<Authorization<Bearer>>,
|
||||
auth_header: Option<TypedHeader<Authorization<Bearer>>>,
|
||||
mut req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, SyncServerError> {
|
||||
let auth_header = auth_header
|
||||
.ok_or_else(|| unauthenticated_error(anyhow::anyhow!("Missing Authorization header")))?;
|
||||
let token = auth_header.token().trim();
|
||||
let vault_id = normalize_string(
|
||||
path_params
|
||||
|
|
@ -51,8 +53,8 @@ pub fn auth(state: &AppState, token: &str, vault_id: &VaultId) -> Result<User, S
|
|||
VaultAccess::AllowAccessToAll => true,
|
||||
VaultAccess::AllowList(AllowListedVaults { ref allowed }) => allowed.contains(vault_id),
|
||||
} {
|
||||
info!(
|
||||
"User `{}` is authenticated and is authorised to access to vault `{vault_id}`",
|
||||
debug!(
|
||||
"User `{}` is authenticated and is authorised to access vault `{vault_id}`",
|
||||
user.name
|
||||
);
|
||||
|
||||
|
|
|
|||
72
sync-server/src/server/rate_limit.rs
Normal file
72
sync-server/src/server/rate_limit.rs
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
use std::sync::{
|
||||
Arc,
|
||||
atomic::{AtomicU64, Ordering},
|
||||
};
|
||||
|
||||
use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
|
||||
|
||||
/// Simple token-bucket rate limiter that refills every second.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RateLimiter {
|
||||
inner: Arc<TokenBucket>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TokenBucket {
|
||||
tokens: AtomicU64,
|
||||
max_tokens: u64,
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
/// Create a new rate limiter. Spawns a background task that refills tokens
|
||||
/// every second.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `max_per_second` is 0.
|
||||
pub fn new(max_per_second: u64) -> Self {
|
||||
assert!(
|
||||
max_per_second > 0,
|
||||
"max_per_second must be > 0 (use 0 in config to disable rate limiting entirely)"
|
||||
);
|
||||
|
||||
let bucket = Arc::new(TokenBucket {
|
||||
tokens: AtomicU64::new(max_per_second),
|
||||
max_tokens: max_per_second,
|
||||
});
|
||||
|
||||
let bucket_clone = bucket.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(1));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
bucket_clone
|
||||
.tokens
|
||||
.store(bucket_clone.max_tokens, Ordering::Release);
|
||||
}
|
||||
});
|
||||
|
||||
Self { inner: bucket }
|
||||
}
|
||||
|
||||
fn try_acquire(&self) -> bool {
|
||||
self.inner
|
||||
.tokens
|
||||
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
|
||||
if current > 0 { Some(current - 1) } else { None }
|
||||
})
|
||||
.is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn rate_limit_middleware(
|
||||
axum::extract::State(limiter): axum::extract::State<RateLimiter>,
|
||||
req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
if limiter.try_acquire() {
|
||||
Ok(next.run(req).await)
|
||||
} else {
|
||||
Err(StatusCode::TOO_MANY_REQUESTS)
|
||||
}
|
||||
}
|
||||
|
|
@ -4,21 +4,18 @@ use reconcile_text::NumberOrText;
|
|||
use serde::{self, Deserialize};
|
||||
use ts_rs::TS;
|
||||
|
||||
use crate::app_state::database::models::{DocumentId, VaultUpdateId};
|
||||
use crate::app_state::database::models::VaultUpdateId;
|
||||
|
||||
#[derive(TS, Debug, TryFromMultipart)]
|
||||
#[ts(export)]
|
||||
pub struct CreateDocumentVersion {
|
||||
/// The client can decide the document id (if it wishes to) in order
|
||||
/// to help with syncing. If the client does not provide a document id,
|
||||
/// the server will generate one. If the client provides a document id
|
||||
/// it must not already exist in the database.
|
||||
pub document_id: Option<DocumentId>,
|
||||
pub relative_path: String,
|
||||
|
||||
#[ts(as = "Vec<u8>")]
|
||||
#[form_data(limit = "unlimited")]
|
||||
pub content: FieldData<Bytes>,
|
||||
|
||||
pub idempotency_key: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, TryFromMultipart)]
|
||||
|
|
@ -34,7 +31,7 @@ pub struct UpdateBinaryDocumentVersion {
|
|||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export)]
|
||||
pub struct UpdateTextDocumentVersion {
|
||||
#[ts(as = "i32")]
|
||||
#[ts(type = "number")]
|
||||
pub parent_version_id: VaultUpdateId,
|
||||
|
||||
pub relative_path: String,
|
||||
|
|
@ -43,9 +40,5 @@ pub struct UpdateTextDocumentVersion {
|
|||
pub content: Vec<NumberOrText>,
|
||||
}
|
||||
|
||||
#[derive(TS, Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export)]
|
||||
pub struct DeleteDocumentVersion {
|
||||
pub relative_path: String,
|
||||
}
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct DeleteDocumentVersion {}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue