diff --git a/sync-server/src/app_state/database/models.rs b/sync-server/src/app_state/database/models.rs index a216125a..f6b35424 100644 --- a/sync-server/src/app_state/database/models.rs +++ b/sync-server/src/app_state/database/models.rs @@ -22,6 +22,7 @@ pub struct StoredDocumentVersion { pub device_id: DeviceId, #[allow(dead_code)] // This is for manual analysis pub has_been_merged: bool, + pub idempotency_key: Option, } impl PartialEq for StoredDocumentVersion { @@ -33,7 +34,7 @@ impl PartialEq for StoredDocumentVersion { #[derive(TS, Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] pub struct DocumentVersionWithoutContent { - #[ts(as = "i32")] + #[ts(type = "number")] pub vault_update_id: VaultUpdateId, pub document_id: DocumentId, @@ -43,7 +44,7 @@ pub struct DocumentVersionWithoutContent { pub user_id: UserId, pub device_id: DeviceId, - #[ts(as = "i32")] + #[ts(type = "number")] pub content_size: u64, } @@ -65,7 +66,7 @@ impl From for DocumentVersionWithoutContent { #[derive(TS, Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] pub struct DocumentVersion { - #[ts(as = "i32")] + #[ts(type = "number")] pub vault_update_id: VaultUpdateId, pub document_id: DocumentId, diff --git a/sync-server/src/app_state/websocket/models.rs b/sync-server/src/app_state/websocket/models.rs index e037fb7e..fb1d24b9 100644 --- a/sync-server/src/app_state/websocket/models.rs +++ b/sync-server/src/app_state/websocket/models.rs @@ -11,7 +11,7 @@ pub struct WebSocketHandshake { pub token: String, pub device_id: DeviceId, - #[ts(as = "Option")] + #[ts(type = "number | null")] pub last_seen_vault_update_id: Option, } @@ -28,7 +28,7 @@ pub struct DocumentWithCursors { // that it exists and can be client-side // interpolated. However, the actual // position is meaningless. - #[ts(as = "Option")] + #[ts(type = "number | null")] pub vault_update_id: Option, pub document_id: DocumentId, @@ -70,6 +70,7 @@ pub struct WebSocketVaultUpdate { pub enum WebSocketClientMessage { Handshake(WebSocketHandshake), CursorPositions(CursorPositionFromClient), + Ping {}, } #[derive(TS, Serialize, Clone, Debug)] diff --git a/sync-server/src/app_state/websocket/utils.rs b/sync-server/src/app_state/websocket/utils.rs index 1e0dd243..ce8205fa 100644 --- a/sync-server/src/app_state/websocket/utils.rs +++ b/sync-server/src/app_state/websocket/utils.rs @@ -9,7 +9,7 @@ use crate::{ database::models::{DocumentVersionWithoutContent, VaultId, VaultUpdateId}, }, config::user_config::User, - errors::{SyncServerError, server_error, unauthenticated_error}, + errors::{SyncServerError, client_error, server_error, unauthenticated_error}, server::auth::auth, }; @@ -26,16 +26,16 @@ pub fn get_authenticated_handshake( if let Some(Message::Text(message)) = message { let message: WebSocketClientMessage = serde_json::from_str(&message) .context("Failed to parse message") - .map_err(server_error)?; + .map_err(client_error)?; match message { WebSocketClientMessage::Handshake(handshake) => { let user = auth(state, handshake.token.trim(), vault_id)?; Ok(AuthenticatedWebSocketHandshake { handshake, user }) } - WebSocketClientMessage::CursorPositions(_) => Err(unauthenticated_error( - anyhow::anyhow!("Expected a handshake message"), - )), + WebSocketClientMessage::CursorPositions(_) | WebSocketClientMessage::Ping {} => Err( + unauthenticated_error(anyhow::anyhow!("Expected a handshake message")), + ), } } else { Err(unauthenticated_error(anyhow::anyhow!( diff --git a/sync-server/src/config.rs b/sync-server/src/config.rs index 6a003d2e..75d4dba7 100644 --- a/sync-server/src/config.rs +++ b/sync-server/src/config.rs @@ -28,23 +28,20 @@ pub struct Config { impl Config { pub async fn read_or_create(path: &Path) -> Result { - let config = if path.exists() { - info!( - "Loading configuration from `{}`", - path.canonicalize().unwrap().display() - ); - Self::load_from_file(path).await? + let display_path = path.canonicalize().unwrap_or_else(|_| path.to_path_buf()); + + if path.exists() { + info!("Loading configuration from `{}`", display_path.display()); + Self::load_from_file(path).await } else { - Self::default() - }; - - config.write(path).await?; - info!( - "Updated configuration at `{}`", - path.canonicalize().unwrap().display() - ); - - Ok(config) + let config = Self::default(); + config.write(path).await?; + info!( + "Created default configuration at `{}`", + display_path.display() + ); + Ok(config) + } } pub async fn load_from_file(path: &Path) -> Result { diff --git a/sync-server/src/config/user_config.rs b/sync-server/src/config/user_config.rs index 8b2537f0..fd824f39 100644 --- a/sync-server/src/config/user_config.rs +++ b/sync-server/src/config/user_config.rs @@ -1,6 +1,7 @@ use bimap::BiHashMap; use rand::{Rng, distr::Alphanumeric, rng}; use serde::{Deserialize, Deserializer, Serialize, de::Error}; +use subtle::ConstantTimeEq; use crate::app_state::database::models::VaultId; @@ -19,10 +20,19 @@ where let mut user_token_map = BiHashMap::new(); for user in &users { if let Some(existing_name) = user_token_map.get_by_right(&user.token) { + let redacted = if user.token.len() > 6 { + format!( + "{}...{}", + &user.token[..3], + &user.token[user.token.len() - 3..] + ) + } else { + "***".to_owned() + }; return Err(D::Error::custom(format!( - "Duplicate user token found: `{}` for users `{}` and `{}`. User tokens must be \ - unique.", - user.token, existing_name, user.name + "Duplicate user token found: `{redacted}` for users `{}` and `{}`. User tokens \ + must be unique.", + existing_name, user.name ))); } @@ -41,7 +51,9 @@ where impl UserConfig { pub fn get_user(&self, token: &str) -> Option<&User> { - self.user_configs.iter().find(|u| u.token == token) + self.user_configs + .iter() + .find(|u| u.token.as_bytes().ct_eq(token.as_bytes()).into()) } } diff --git a/sync-server/src/errors.rs b/sync-server/src/errors.rs index 831b0e86..0dad0463 100644 --- a/sync-server/src/errors.rs +++ b/sync-server/src/errors.rs @@ -5,7 +5,7 @@ use axum::{ http::StatusCode, response::{IntoResponse, Response}, }; -use log::{debug, error}; +use log::{debug, error, warn}; use serde::Serialize; use thiserror::Error; use ts_rs::TS; @@ -29,6 +29,9 @@ pub enum SyncServerError { #[error("Permission denied error: {0}")] PermissionDeniedError(#[source] anyhow::Error), + + #[error("Too many requests: {0}")] + TooManyRequests(#[source] anyhow::Error), } impl SyncServerError { @@ -39,7 +42,8 @@ impl SyncServerError { | Self::ServerError(error) | Self::NotFound(error) | Self::Unauthenticated(error) - | Self::PermissionDeniedError(error) => error.into(), + | Self::PermissionDeniedError(error) + | Self::TooManyRequests(error) => error.into(), } } } @@ -69,7 +73,22 @@ impl Display for SerializedError { impl IntoResponse for SyncServerError { fn into_response(self) -> Response { - let body = Json(self.serialize()); + let serialized = self.serialize(); + + match &self { + Self::InitError(_) | Self::ServerError(_) => { + error!("{serialized}"); + } + Self::ClientError(_) | Self::NotFound(_) => { + warn!("{serialized}"); + } + Self::TooManyRequests(_) => { + warn!("{serialized}"); + } + Self::Unauthenticated(_) | Self::PermissionDeniedError(_) => {} + } + + let body = Json(serialized); match self { Self::InitError(_) | Self::ServerError(_) => { @@ -79,6 +98,9 @@ impl IntoResponse for SyncServerError { Self::NotFound(_) => (StatusCode::NOT_FOUND, body).into_response(), Self::Unauthenticated(_) => (StatusCode::UNAUTHORIZED, body).into_response(), Self::PermissionDeniedError(_) => (StatusCode::FORBIDDEN, body).into_response(), + Self::TooManyRequests(_) => { + (StatusCode::TOO_MANY_REQUESTS, body).into_response() + } } } } @@ -102,6 +124,7 @@ impl From<&anyhow::Error> for SerializedError { SyncServerError::NotFound(_) => "NotFound", SyncServerError::Unauthenticated(_) => "Unauthenticated", SyncServerError::PermissionDeniedError(_) => "PermissionDeniedError", + SyncServerError::TooManyRequests(_) => "TooManyRequests", }, ), message: error.to_string(), @@ -139,3 +162,18 @@ pub fn permission_denied_error(error: anyhow::Error) -> SyncServerError { debug!("Permission denied: {error:?}"); SyncServerError::PermissionDeniedError(error) } + +pub fn too_many_requests_error(error: anyhow::Error) -> SyncServerError { + debug!("Too many requests: {error:?}"); + SyncServerError::TooManyRequests(error) +} + +/// Maps a `create_write_transaction` error to 429 if the database is busy, +/// or 500 for all other failures. +pub fn write_transaction_error(error: anyhow::Error) -> SyncServerError { + if error.downcast_ref::().is_some() { + too_many_requests_error(error) + } else { + server_error(error) + } +} diff --git a/sync-server/src/server/auth.rs b/sync-server/src/server/auth.rs index e56f4acc..3b5474d4 100644 --- a/sync-server/src/server/auth.rs +++ b/sync-server/src/server/auth.rs @@ -9,7 +9,7 @@ use axum_extra::{ TypedHeader, headers::{Authorization, authorization::Bearer}, }; -use log::info; +use log::{debug, info}; use crate::{ app_state::{AppState, database::models::VaultId}, @@ -21,10 +21,12 @@ use crate::{ pub async fn auth_middleware( State(state): State, Path(path_params): Path>, - TypedHeader(auth_header): TypedHeader>, + auth_header: Option>>, mut req: Request, next: Next, ) -> Result { + let auth_header = auth_header + .ok_or_else(|| unauthenticated_error(anyhow::anyhow!("Missing Authorization header")))?; let token = auth_header.token().trim(); let vault_id = normalize_string( path_params @@ -51,8 +53,8 @@ pub fn auth(state: &AppState, token: &str, vault_id: &VaultId) -> Result true, VaultAccess::AllowList(AllowListedVaults { ref allowed }) => allowed.contains(vault_id), } { - info!( - "User `{}` is authenticated and is authorised to access to vault `{vault_id}`", + debug!( + "User `{}` is authenticated and is authorised to access vault `{vault_id}`", user.name ); diff --git a/sync-server/src/server/rate_limit.rs b/sync-server/src/server/rate_limit.rs new file mode 100644 index 00000000..8047adc2 --- /dev/null +++ b/sync-server/src/server/rate_limit.rs @@ -0,0 +1,72 @@ +use std::sync::{ + Arc, + atomic::{AtomicU64, Ordering}, +}; + +use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response}; + +/// Simple token-bucket rate limiter that refills every second. +#[derive(Clone, Debug)] +pub struct RateLimiter { + inner: Arc, +} + +#[derive(Debug)] +struct TokenBucket { + tokens: AtomicU64, + max_tokens: u64, +} + +impl RateLimiter { + /// Create a new rate limiter. Spawns a background task that refills tokens + /// every second. + /// + /// # Panics + /// + /// Panics if `max_per_second` is 0. + pub fn new(max_per_second: u64) -> Self { + assert!( + max_per_second > 0, + "max_per_second must be > 0 (use 0 in config to disable rate limiting entirely)" + ); + + let bucket = Arc::new(TokenBucket { + tokens: AtomicU64::new(max_per_second), + max_tokens: max_per_second, + }); + + let bucket_clone = bucket.clone(); + tokio::spawn(async move { + let mut interval = tokio::time::interval(std::time::Duration::from_secs(1)); + loop { + interval.tick().await; + bucket_clone + .tokens + .store(bucket_clone.max_tokens, Ordering::Release); + } + }); + + Self { inner: bucket } + } + + fn try_acquire(&self) -> bool { + self.inner + .tokens + .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| { + if current > 0 { Some(current - 1) } else { None } + }) + .is_ok() + } +} + +pub async fn rate_limit_middleware( + axum::extract::State(limiter): axum::extract::State, + req: Request, + next: Next, +) -> Result { + if limiter.try_acquire() { + Ok(next.run(req).await) + } else { + Err(StatusCode::TOO_MANY_REQUESTS) + } +} diff --git a/sync-server/src/server/requests.rs b/sync-server/src/server/requests.rs index 119ad467..2e612234 100644 --- a/sync-server/src/server/requests.rs +++ b/sync-server/src/server/requests.rs @@ -4,21 +4,18 @@ use reconcile_text::NumberOrText; use serde::{self, Deserialize}; use ts_rs::TS; -use crate::app_state::database::models::{DocumentId, VaultUpdateId}; +use crate::app_state::database::models::VaultUpdateId; #[derive(TS, Debug, TryFromMultipart)] #[ts(export)] pub struct CreateDocumentVersion { - /// The client can decide the document id (if it wishes to) in order - /// to help with syncing. If the client does not provide a document id, - /// the server will generate one. If the client provides a document id - /// it must not already exist in the database. - pub document_id: Option, pub relative_path: String, #[ts(as = "Vec")] #[form_data(limit = "unlimited")] pub content: FieldData, + + pub idempotency_key: Option, } #[derive(Debug, TryFromMultipart)] @@ -34,7 +31,7 @@ pub struct UpdateBinaryDocumentVersion { #[serde(rename_all = "camelCase")] #[ts(export)] pub struct UpdateTextDocumentVersion { - #[ts(as = "i32")] + #[ts(type = "number")] pub parent_version_id: VaultUpdateId, pub relative_path: String, @@ -43,9 +40,5 @@ pub struct UpdateTextDocumentVersion { pub content: Vec, } -#[derive(TS, Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -#[ts(export)] -pub struct DeleteDocumentVersion { - pub relative_path: String, -} +#[derive(Debug, Deserialize)] +pub struct DeleteDocumentVersion {} diff --git a/sync-server/src/utils/find_first_available_path.rs b/sync-server/src/utils/find_first_available_path.rs index 7629d8f1..d80564b0 100644 --- a/sync-server/src/utils/find_first_available_path.rs +++ b/sync-server/src/utils/find_first_available_path.rs @@ -1,25 +1,31 @@ use crate::app_state::database::models::VaultId; -use crate::{app_state::database::Transaction, utils::dedup_paths::dedup_paths}; -use anyhow::Result; -use log::{debug, info}; +use crate::utils::dedup_paths::dedup_paths; +use anyhow::{Result, bail}; +use log::info; +use sqlx::sqlite::SqliteConnection; + pub async fn find_first_available_path( vault_id: &VaultId, sanitized_relative_path: &str, database: &crate::app_state::database::Database, - transaction: &mut Transaction<'_>, + connection: &mut SqliteConnection, ) -> Result { info!("Finding first available path for `{sanitized_relative_path}` in vault `{vault_id}`"); for candidate in dedup_paths(sanitized_relative_path) { debug!("Checking candidate path for deconflicting names: `{candidate}`"); if database - .get_latest_document_by_path(vault_id, &candidate, Some(transaction)) + .get_latest_non_deleted_document_by_path(vault_id, &candidate, Some(connection)) .await? .is_none() { info!("Selected available path: `{candidate}`"); return Ok(candidate); } + + info!( + "Finding first available path for `{sanitized_relative_path}` in vault `{vault_id}` as `{candidate}` is already taken" + ); } unreachable!("dedup_paths produces infinite paths");