Various server improvements

This commit is contained in:
Andras Schmelczer 2026-03-26 21:19:06 +00:00
parent 3fe5f49050
commit 233ce1254b
10 changed files with 177 additions and 55 deletions

View file

@ -22,6 +22,7 @@ pub struct StoredDocumentVersion {
pub device_id: DeviceId,
#[allow(dead_code)] // This is for manual analysis
pub has_been_merged: bool,
pub idempotency_key: Option<String>,
}
impl PartialEq<Self> for StoredDocumentVersion {
@ -33,7 +34,7 @@ impl PartialEq<Self> for StoredDocumentVersion {
#[derive(TS, Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct DocumentVersionWithoutContent {
#[ts(as = "i32")]
#[ts(type = "number")]
pub vault_update_id: VaultUpdateId,
pub document_id: DocumentId,
@ -43,7 +44,7 @@ pub struct DocumentVersionWithoutContent {
pub user_id: UserId,
pub device_id: DeviceId,
#[ts(as = "i32")]
#[ts(type = "number")]
pub content_size: u64,
}
@ -65,7 +66,7 @@ impl From<StoredDocumentVersion> for DocumentVersionWithoutContent {
#[derive(TS, Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct DocumentVersion {
#[ts(as = "i32")]
#[ts(type = "number")]
pub vault_update_id: VaultUpdateId,
pub document_id: DocumentId,

View file

@ -11,7 +11,7 @@ pub struct WebSocketHandshake {
pub token: String,
pub device_id: DeviceId,
#[ts(as = "Option<i32>")]
#[ts(type = "number | null")]
pub last_seen_vault_update_id: Option<VaultUpdateId>,
}
@ -28,7 +28,7 @@ pub struct DocumentWithCursors {
// that it exists and can be client-side
// interpolated. However, the actual
// position is meaningless.
#[ts(as = "Option<u32>")]
#[ts(type = "number | null")]
pub vault_update_id: Option<VaultUpdateId>,
pub document_id: DocumentId,
@ -70,6 +70,7 @@ pub struct WebSocketVaultUpdate {
pub enum WebSocketClientMessage {
Handshake(WebSocketHandshake),
CursorPositions(CursorPositionFromClient),
Ping {},
}
#[derive(TS, Serialize, Clone, Debug)]

View file

@ -9,7 +9,7 @@ use crate::{
database::models::{DocumentVersionWithoutContent, VaultId, VaultUpdateId},
},
config::user_config::User,
errors::{SyncServerError, server_error, unauthenticated_error},
errors::{SyncServerError, client_error, server_error, unauthenticated_error},
server::auth::auth,
};
@ -26,16 +26,16 @@ pub fn get_authenticated_handshake(
if let Some(Message::Text(message)) = message {
let message: WebSocketClientMessage = serde_json::from_str(&message)
.context("Failed to parse message")
.map_err(server_error)?;
.map_err(client_error)?;
match message {
WebSocketClientMessage::Handshake(handshake) => {
let user = auth(state, handshake.token.trim(), vault_id)?;
Ok(AuthenticatedWebSocketHandshake { handshake, user })
}
WebSocketClientMessage::CursorPositions(_) => Err(unauthenticated_error(
anyhow::anyhow!("Expected a handshake message"),
)),
WebSocketClientMessage::CursorPositions(_) | WebSocketClientMessage::Ping {} => Err(
unauthenticated_error(anyhow::anyhow!("Expected a handshake message")),
),
}
} else {
Err(unauthenticated_error(anyhow::anyhow!(

View file

@ -28,23 +28,20 @@ pub struct Config {
impl Config {
pub async fn read_or_create(path: &Path) -> Result<Self> {
let config = if path.exists() {
info!(
"Loading configuration from `{}`",
path.canonicalize().unwrap().display()
);
Self::load_from_file(path).await?
let display_path = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
if path.exists() {
info!("Loading configuration from `{}`", display_path.display());
Self::load_from_file(path).await
} else {
Self::default()
};
config.write(path).await?;
info!(
"Updated configuration at `{}`",
path.canonicalize().unwrap().display()
);
Ok(config)
let config = Self::default();
config.write(path).await?;
info!(
"Created default configuration at `{}`",
display_path.display()
);
Ok(config)
}
}
pub async fn load_from_file(path: &Path) -> Result<Self> {

View file

@ -1,6 +1,7 @@
use bimap::BiHashMap;
use rand::{Rng, distr::Alphanumeric, rng};
use serde::{Deserialize, Deserializer, Serialize, de::Error};
use subtle::ConstantTimeEq;
use crate::app_state::database::models::VaultId;
@ -19,10 +20,19 @@ where
let mut user_token_map = BiHashMap::new();
for user in &users {
if let Some(existing_name) = user_token_map.get_by_right(&user.token) {
let redacted = if user.token.len() > 6 {
format!(
"{}...{}",
&user.token[..3],
&user.token[user.token.len() - 3..]
)
} else {
"***".to_owned()
};
return Err(D::Error::custom(format!(
"Duplicate user token found: `{}` for users `{}` and `{}`. User tokens must be \
unique.",
user.token, existing_name, user.name
"Duplicate user token found: `{redacted}` for users `{}` and `{}`. User tokens \
must be unique.",
existing_name, user.name
)));
}
@ -41,7 +51,9 @@ where
impl UserConfig {
pub fn get_user(&self, token: &str) -> Option<&User> {
self.user_configs.iter().find(|u| u.token == token)
self.user_configs
.iter()
.find(|u| u.token.as_bytes().ct_eq(token.as_bytes()).into())
}
}

View file

@ -5,7 +5,7 @@ use axum::{
http::StatusCode,
response::{IntoResponse, Response},
};
use log::{debug, error};
use log::{debug, error, warn};
use serde::Serialize;
use thiserror::Error;
use ts_rs::TS;
@ -29,6 +29,9 @@ pub enum SyncServerError {
#[error("Permission denied error: {0}")]
PermissionDeniedError(#[source] anyhow::Error),
#[error("Too many requests: {0}")]
TooManyRequests(#[source] anyhow::Error),
}
impl SyncServerError {
@ -39,7 +42,8 @@ impl SyncServerError {
| Self::ServerError(error)
| Self::NotFound(error)
| Self::Unauthenticated(error)
| Self::PermissionDeniedError(error) => error.into(),
| Self::PermissionDeniedError(error)
| Self::TooManyRequests(error) => error.into(),
}
}
}
@ -69,7 +73,22 @@ impl Display for SerializedError {
impl IntoResponse for SyncServerError {
fn into_response(self) -> Response {
let body = Json(self.serialize());
let serialized = self.serialize();
match &self {
Self::InitError(_) | Self::ServerError(_) => {
error!("{serialized}");
}
Self::ClientError(_) | Self::NotFound(_) => {
warn!("{serialized}");
}
Self::TooManyRequests(_) => {
warn!("{serialized}");
}
Self::Unauthenticated(_) | Self::PermissionDeniedError(_) => {}
}
let body = Json(serialized);
match self {
Self::InitError(_) | Self::ServerError(_) => {
@ -79,6 +98,9 @@ impl IntoResponse for SyncServerError {
Self::NotFound(_) => (StatusCode::NOT_FOUND, body).into_response(),
Self::Unauthenticated(_) => (StatusCode::UNAUTHORIZED, body).into_response(),
Self::PermissionDeniedError(_) => (StatusCode::FORBIDDEN, body).into_response(),
Self::TooManyRequests(_) => {
(StatusCode::TOO_MANY_REQUESTS, body).into_response()
}
}
}
}
@ -102,6 +124,7 @@ impl From<&anyhow::Error> for SerializedError {
SyncServerError::NotFound(_) => "NotFound",
SyncServerError::Unauthenticated(_) => "Unauthenticated",
SyncServerError::PermissionDeniedError(_) => "PermissionDeniedError",
SyncServerError::TooManyRequests(_) => "TooManyRequests",
},
),
message: error.to_string(),
@ -139,3 +162,18 @@ pub fn permission_denied_error(error: anyhow::Error) -> SyncServerError {
debug!("Permission denied: {error:?}");
SyncServerError::PermissionDeniedError(error)
}
pub fn too_many_requests_error(error: anyhow::Error) -> SyncServerError {
debug!("Too many requests: {error:?}");
SyncServerError::TooManyRequests(error)
}
/// Maps a `create_write_transaction` error to 429 if the database is busy,
/// or 500 for all other failures.
pub fn write_transaction_error(error: anyhow::Error) -> SyncServerError {
if error.downcast_ref::<crate::app_state::database::WriteBusyError>().is_some() {
too_many_requests_error(error)
} else {
server_error(error)
}
}

View file

@ -9,7 +9,7 @@ use axum_extra::{
TypedHeader,
headers::{Authorization, authorization::Bearer},
};
use log::info;
use log::{debug, info};
use crate::{
app_state::{AppState, database::models::VaultId},
@ -21,10 +21,12 @@ use crate::{
pub async fn auth_middleware(
State(state): State<AppState>,
Path(path_params): Path<HashMap<String, String>>,
TypedHeader(auth_header): TypedHeader<Authorization<Bearer>>,
auth_header: Option<TypedHeader<Authorization<Bearer>>>,
mut req: Request,
next: Next,
) -> Result<Response, SyncServerError> {
let auth_header = auth_header
.ok_or_else(|| unauthenticated_error(anyhow::anyhow!("Missing Authorization header")))?;
let token = auth_header.token().trim();
let vault_id = normalize_string(
path_params
@ -51,8 +53,8 @@ pub fn auth(state: &AppState, token: &str, vault_id: &VaultId) -> Result<User, S
VaultAccess::AllowAccessToAll => true,
VaultAccess::AllowList(AllowListedVaults { ref allowed }) => allowed.contains(vault_id),
} {
info!(
"User `{}` is authenticated and is authorised to access to vault `{vault_id}`",
debug!(
"User `{}` is authenticated and is authorised to access vault `{vault_id}`",
user.name
);

View file

@ -0,0 +1,72 @@
use std::sync::{
Arc,
atomic::{AtomicU64, Ordering},
};
use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
/// Simple token-bucket rate limiter that refills every second.
#[derive(Clone, Debug)]
pub struct RateLimiter {
inner: Arc<TokenBucket>,
}
#[derive(Debug)]
struct TokenBucket {
tokens: AtomicU64,
max_tokens: u64,
}
impl RateLimiter {
/// Create a new rate limiter. Spawns a background task that refills tokens
/// every second.
///
/// # Panics
///
/// Panics if `max_per_second` is 0.
pub fn new(max_per_second: u64) -> Self {
assert!(
max_per_second > 0,
"max_per_second must be > 0 (use 0 in config to disable rate limiting entirely)"
);
let bucket = Arc::new(TokenBucket {
tokens: AtomicU64::new(max_per_second),
max_tokens: max_per_second,
});
let bucket_clone = bucket.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(1));
loop {
interval.tick().await;
bucket_clone
.tokens
.store(bucket_clone.max_tokens, Ordering::Release);
}
});
Self { inner: bucket }
}
fn try_acquire(&self) -> bool {
self.inner
.tokens
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
if current > 0 { Some(current - 1) } else { None }
})
.is_ok()
}
}
pub async fn rate_limit_middleware(
axum::extract::State(limiter): axum::extract::State<RateLimiter>,
req: Request,
next: Next,
) -> Result<Response, StatusCode> {
if limiter.try_acquire() {
Ok(next.run(req).await)
} else {
Err(StatusCode::TOO_MANY_REQUESTS)
}
}

View file

@ -4,21 +4,18 @@ use reconcile_text::NumberOrText;
use serde::{self, Deserialize};
use ts_rs::TS;
use crate::app_state::database::models::{DocumentId, VaultUpdateId};
use crate::app_state::database::models::VaultUpdateId;
#[derive(TS, Debug, TryFromMultipart)]
#[ts(export)]
pub struct CreateDocumentVersion {
/// The client can decide the document id (if it wishes to) in order
/// to help with syncing. If the client does not provide a document id,
/// the server will generate one. If the client provides a document id
/// it must not already exist in the database.
pub document_id: Option<DocumentId>,
pub relative_path: String,
#[ts(as = "Vec<u8>")]
#[form_data(limit = "unlimited")]
pub content: FieldData<Bytes>,
pub idempotency_key: Option<String>,
}
#[derive(Debug, TryFromMultipart)]
@ -34,7 +31,7 @@ pub struct UpdateBinaryDocumentVersion {
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct UpdateTextDocumentVersion {
#[ts(as = "i32")]
#[ts(type = "number")]
pub parent_version_id: VaultUpdateId,
pub relative_path: String,
@ -43,9 +40,5 @@ pub struct UpdateTextDocumentVersion {
pub content: Vec<NumberOrText>,
}
#[derive(TS, Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct DeleteDocumentVersion {
pub relative_path: String,
}
#[derive(Debug, Deserialize)]
pub struct DeleteDocumentVersion {}

View file

@ -1,25 +1,31 @@
use crate::app_state::database::models::VaultId;
use crate::{app_state::database::Transaction, utils::dedup_paths::dedup_paths};
use anyhow::Result;
use log::{debug, info};
use crate::utils::dedup_paths::dedup_paths;
use anyhow::{Result, bail};
use log::info;
use sqlx::sqlite::SqliteConnection;
pub async fn find_first_available_path(
vault_id: &VaultId,
sanitized_relative_path: &str,
database: &crate::app_state::database::Database,
transaction: &mut Transaction<'_>,
connection: &mut SqliteConnection,
) -> Result<String> {
info!("Finding first available path for `{sanitized_relative_path}` in vault `{vault_id}`");
for candidate in dedup_paths(sanitized_relative_path) {
debug!("Checking candidate path for deconflicting names: `{candidate}`");
if database
.get_latest_document_by_path(vault_id, &candidate, Some(transaction))
.get_latest_non_deleted_document_by_path(vault_id, &candidate, Some(connection))
.await?
.is_none()
{
info!("Selected available path: `{candidate}`");
return Ok(candidate);
}
info!(
"Finding first available path for `{sanitized_relative_path}` in vault `{vault_id}` as `{candidate}` is already taken"
);
}
unreachable!("dedup_paths produces infinite paths");