Improve diff

This commit is contained in:
Andras Schmelczer 2026-05-09 16:27:48 +01:00
parent 792f57dc7e
commit e5373ab2bb
23 changed files with 312 additions and 220 deletions

View file

@ -85,7 +85,7 @@ describe("File operations", () => {
const result = await ops.create("a", new Uint8Array()); const result = await ops.create("a", new Uint8Array());
assertSetContainsExactly(fs.names, "a"); assertSetContainsExactly(fs.names, "a");
assert.equal(result.actualPath, "a"); assert.equal(result, "a");
}); });
it("create throws FileAlreadyExistsError when the path is occupied", async () => { it("create throws FileAlreadyExistsError when the path is occupied", async () => {
@ -109,7 +109,7 @@ describe("File operations", () => {
const result = await ops.move("a", "b"); const result = await ops.move("a", "b");
assertSetContainsExactly(fs.names, "b"); assertSetContainsExactly(fs.names, "b");
assert.equal(result.actualPath, "b"); assert.equal(result, "b");
}); });
it("move with same source and target is a no-op", async () => { it("move with same source and target is a no-op", async () => {
@ -119,7 +119,7 @@ describe("File operations", () => {
const result = await ops.move("a", "a"); const result = await ops.move("a", "a");
assertSetContainsExactly(fs.names, "a"); assertSetContainsExactly(fs.names, "a");
assert.equal(result.actualPath, "a"); assert.equal(result, "a");
}); });
it("move throws FileAlreadyExistsError when the target is occupied", async () => { it("move throws FileAlreadyExistsError when the target is occupied", async () => {

View file

@ -11,16 +11,6 @@ import { FileNotFoundError } from "../errors/file-not-found-error";
import { FileAlreadyExistsError } from "../errors/file-already-exists-error"; import { FileAlreadyExistsError } from "../errors/file-already-exists-error";
import type { ExpectedFsEvents } from "../sync-operations/expected-fs-events"; import type { ExpectedFsEvents } from "../sync-operations/expected-fs-events";
/**
* Outcome of a `move`/`create`. `actualPath` is where the file ended up;
* with the conflict-path machinery removed it is always equal to the
* requested path. The shape is preserved so callers don't all need to
* change.
*/
export interface FileOpResult {
actualPath: RelativePath;
}
export class FileOperations { export class FileOperations {
private readonly fs: SafeFileSystemOperations; private readonly fs: SafeFileSystemOperations;
@ -68,7 +58,7 @@ export class FileOperations {
public async create( public async create(
path: RelativePath, path: RelativePath,
newContent: Uint8Array newContent: Uint8Array
): Promise<FileOpResult> { ): Promise<RelativePath> {
if (await this.fs.exists(path)) { if (await this.fs.exists(path)) {
throw new FileAlreadyExistsError( throw new FileAlreadyExistsError(
`Refusing to create '${path}': file already exists`, `Refusing to create '${path}': file already exists`,
@ -84,7 +74,7 @@ export class FileOperations {
this.expectedFsEvents.unexpectCreate(path); this.expectedFsEvents.unexpectCreate(path);
throw e; throw e;
} }
return { actualPath: path }; return path;
} }
/** /**
@ -220,9 +210,9 @@ export class FileOperations {
public async move( public async move(
oldPath: RelativePath, oldPath: RelativePath,
newPath: RelativePath newPath: RelativePath
): Promise<FileOpResult> { ): Promise<RelativePath> {
if (oldPath === newPath) { if (oldPath === newPath) {
return { actualPath: oldPath }; return oldPath;
} }
if (await this.fs.exists(newPath)) { if (await this.fs.exists(newPath)) {
@ -241,7 +231,7 @@ export class FileOperations {
throw e; throw e;
} }
await this.deletingEmptyParentDirectoriesOfDeletedFile(oldPath); await this.deletingEmptyParentDirectoriesOfDeletedFile(oldPath);
return { actualPath: newPath }; return newPath;
} }
private async deletingEmptyParentDirectoriesOfDeletedFile( private async deletingEmptyParentDirectoriesOfDeletedFile(

View file

@ -1103,7 +1103,7 @@ export class Syncer {
remoteHash, remoteHash,
localPath: target localPath: target
}); });
const result = await this.operations.create( const createdPath = await this.operations.create(
target, target,
remoteContent remoteContent
); );
@ -1112,7 +1112,7 @@ export class Syncer {
); );
localPath = localPath =
liveRecord === undefined liveRecord === undefined
? result.actualPath ? createdPath
: liveRecord.localPath; : liveRecord.localPath;
await this.updateCache( await this.updateCache(
remoteVersion.vaultUpdateId, remoteVersion.vaultUpdateId,

View file

@ -21,9 +21,10 @@ cargo test --verbose
if [[ "$FIX_MODE" == true ]]; then if [[ "$FIX_MODE" == true ]]; then
cargo clippy --all-targets --all-features --fix --allow-dirty --allow-staged cargo clippy --all-targets --all-features --fix --allow-dirty --allow-staged
cargo clippy --all-targets --all-features -- -D warnings
cargo fmt --all cargo fmt --all
else else
cargo clippy --all-targets --all-features cargo clippy --all-targets --all-features -- -D warnings
cargo fmt --all -- --check cargo fmt --all -- --check
fi fi

View file

@ -2181,7 +2181,6 @@ dependencies = [
"log", "log",
"rand 0.9.0", "rand 0.9.0",
"reconcile-text", "reconcile-text",
"regex",
"sanitize-filename", "sanitize-filename",
"serde", "serde",
"serde_json", "serde_json",

View file

@ -26,7 +26,6 @@ sqlx = { version = "0.8.6", features = ["sqlite", "runtime-tokio", "uuid", "chro
chrono = { version = "0.4.41", features = ["serde"] } chrono = { version = "0.4.41", features = ["serde"] }
rand = "0.9.0" rand = "0.9.0"
sanitize-filename = "0.6.0" sanitize-filename = "0.6.0"
regex = "1.12.2"
clap = { version = "4.5.38", features = ["derive"] } clap = { version = "4.5.38", features = ["derive"] }
futures = "0.3.31" futures = "0.3.31"
serde_yaml = "0.9.34" serde_yaml = "0.9.34"
@ -49,16 +48,19 @@ rust_2018_idioms = { level = "warn", priority = -1 }
missing_debug_implementations = "warn" missing_debug_implementations = "warn"
[lints.clippy] [lints.clippy]
arithmetic_side_effects = "deny"
await_holding_lock = "warn" await_holding_lock = "warn"
dbg_macro = "warn" dbg_macro = "warn"
disallowed_macros = { level = "deny", priority = 1 } disallowed_macros = { level = "deny", priority = 1 }
empty_enums = "warn" empty_enums = "warn"
enum_glob_use = "warn" enum_glob_use = "warn"
expect_used = "deny"
exit = "warn" exit = "warn"
filter_map_next = "warn" filter_map_next = "warn"
fn_params_excessive_bools = "warn" fn_params_excessive_bools = "warn"
if_let_mutex = "warn" if_let_mutex = "warn"
imprecise_flops = "warn" imprecise_flops = "warn"
indexing_slicing = "deny"
inefficient_to_string = "warn" inefficient_to_string = "warn"
linkedlist = "warn" linkedlist = "warn"
lossy_float_literal = "warn" lossy_float_literal = "warn"
@ -68,13 +70,19 @@ mem_forget = "warn"
needless_borrow = "warn" needless_borrow = "warn"
needless_continue = "warn" needless_continue = "warn"
option_option = "warn" option_option = "warn"
panic = "deny"
panic_in_result_fn = "deny"
rest_pat_in_fully_bound_structs = "warn" rest_pat_in_fully_bound_structs = "warn"
str_to_string = "warn" str_to_string = "warn"
suboptimal_flops = "warn" suboptimal_flops = "warn"
todo = "warn" todo = "deny"
uninlined_format_args = "warn" uninlined_format_args = "warn"
unimplemented = "deny"
unreachable = "deny"
unnested_or_patterns = "warn" unnested_or_patterns = "warn"
unused_self = "warn" unused_self = "warn"
unwrap_in_result = "deny"
unwrap_used = "deny"
verbose_file_reads = "warn" verbose_file_reads = "warn"
large_stack_arrays = { level = "allow", priority = 1 } # https://github.com/rust-lang/rust-clippy/issues/13774 large_stack_arrays = { level = "allow", priority = 1 } # https://github.com/rust-lang/rust-clippy/issues/13774
@ -88,7 +96,7 @@ single_call_fn = { level = "allow", priority = 1 }
similar_names = { level = "allow", priority = 1 } similar_names = { level = "allow", priority = 1 }
missing_docs_in_private_items = { level = "allow", priority = 1 } missing_docs_in_private_items = { level = "allow", priority = 1 }
pedantic = { level = "warn", priority = 0 } pedantic = { level = "warn", priority = -1 }
[package.metadata.cargo-machete] [package.metadata.cargo-machete]
ignored = ["humantime-serde"] # only used in serde macro ignored = ["humantime-serde"] # only used in serde macro

View file

@ -15,6 +15,7 @@ use super::{
}; };
use crate::{ use crate::{
app_state::websocket::models::DocumentWithCursors, config::database_config::DatabaseConfig, app_state::websocket::models::DocumentWithCursors, config::database_config::DatabaseConfig,
errors::SyncServerError,
}; };
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -39,7 +40,7 @@ impl Cursors {
user_name: String, user_name: String,
device_id: &DeviceId, device_id: &DeviceId,
document_to_cursors: Vec<DocumentWithCursors>, document_to_cursors: Vec<DocumentWithCursors>,
) { ) -> Result<(), SyncServerError> {
let mut vault_to_cursors = self.vault_to_cursors.lock().await; let mut vault_to_cursors = self.vault_to_cursors.lock().await;
let all_device_cursors = vault_to_cursors let all_device_cursors = vault_to_cursors
@ -54,7 +55,7 @@ impl Cursors {
})); }));
drop(vault_to_cursors); // Explicitly drop the lock before broadcasting to avoid deadlock drop(vault_to_cursors); // Explicitly drop the lock before broadcasting to avoid deadlock
self.broadcast_cursors_for_vault(&vault_id).await; self.broadcast_cursors_for_vault(&vault_id).await
} }
pub async fn get_cursors(&self, vault_id: &VaultId) -> Vec<ClientCursors> { pub async fn get_cursors(&self, vault_id: &VaultId) -> Vec<ClientCursors> {
@ -76,15 +77,17 @@ impl Cursors {
loop { loop {
tokio::select! { tokio::select! {
() = tokio::time::sleep(Duration::from_secs(1)) => { () = tokio::time::sleep(Duration::from_secs(1)) => {
self.remove_expired_cursors().await; self.remove_expired_cursors().await?;
} }
Ok(()) = shutdown.changed() => break, Ok(()) = shutdown.changed() => break,
} }
} }
Ok::<(), SyncServerError>(())
}); });
} }
async fn remove_expired_cursors(&self) { async fn remove_expired_cursors(&self) -> Result<(), SyncServerError> {
let changed_vaults: Vec<VaultId> = { let changed_vaults: Vec<VaultId> = {
let mut vault_to_cursors = self.vault_to_cursors.lock().await; let mut vault_to_cursors = self.vault_to_cursors.lock().await;
@ -104,11 +107,13 @@ impl Cursors {
}; };
for vault_id in &changed_vaults { for vault_id in &changed_vaults {
self.broadcast_cursors_for_vault(vault_id).await; self.broadcast_cursors_for_vault(vault_id).await?;
} }
Ok(())
} }
async fn broadcast_cursors_for_vault(&self, vault_id: &VaultId) { async fn broadcast_cursors_for_vault(&self, vault_id: &VaultId) -> Result<(), SyncServerError> {
let client_cursors: Vec<ClientCursors> = { let client_cursors: Vec<ClientCursors> = {
let vault_to_cursors = self.vault_to_cursors.lock().await; let vault_to_cursors = self.vault_to_cursors.lock().await;
vault_to_cursors vault_to_cursors
@ -124,10 +129,14 @@ impl Cursors {
clients: client_cursors, clients: client_cursors,
}, },
)), )),
); )
} }
pub async fn remove_cursors_of_device(&self, vault_id: &VaultId, device_id: &DeviceId) { pub async fn remove_cursors_of_device(
&self,
vault_id: &VaultId,
device_id: &DeviceId,
) -> Result<(), SyncServerError> {
let changed = { let changed = {
let mut vault_to_cursors = self.vault_to_cursors.lock().await; let mut vault_to_cursors = self.vault_to_cursors.lock().await;
@ -145,8 +154,9 @@ impl Cursors {
}; };
if changed { if changed {
self.broadcast_cursors_for_vault(vault_id).await; self.broadcast_cursors_for_vault(vault_id).await?;
} }
Ok(())
} }
} }

View file

@ -5,7 +5,7 @@ use std::{
sync::atomic::{AtomicU64, Ordering}, sync::atomic::{AtomicU64, Ordering},
}; };
use anyhow::{Context as _, Result}; use anyhow::{Context as _, Result, anyhow};
use log::info; use log::info;
use models::{ use models::{
DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId, VaultUpdateId, DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId, VaultUpdateId,
@ -132,6 +132,12 @@ impl WriteTransaction {
} }
Ok(()) Ok(())
} }
pub fn connection_mut(&mut self) -> Result<&mut SqliteConnection> {
self.conn
.as_deref_mut()
.context("WriteTransaction already consumed")
}
} }
impl Drop for WriteTransaction { impl Drop for WriteTransaction {
@ -147,25 +153,6 @@ impl Drop for WriteTransaction {
} }
} }
impl std::ops::Deref for WriteTransaction {
type Target = SqliteConnection;
fn deref(&self) -> &Self::Target {
self.conn
.as_ref()
.expect("BUG: WriteTransaction dereferenced after being consumed")
.deref()
}
}
impl std::ops::DerefMut for WriteTransaction {
fn deref_mut(&mut self) -> &mut Self::Target {
self.conn
.as_mut()
.expect("BUG: WriteTransaction dereferenced after being consumed")
.deref_mut()
}
}
/// Ensure the connection has no leftover open transaction (e.g. from a /// Ensure the connection has no leftover open transaction (e.g. from a
/// `WriteTransaction` that was dropped without commit/rollback). ROLLBACK /// `WriteTransaction` that was dropped without commit/rollback). ROLLBACK
/// is a harmless no-op if no transaction is active. /// is a harmless no-op if no transaction is active.
@ -797,7 +784,7 @@ impl Database {
let _send_guard = self.broadcasts.acquire_send_lock(vault_id).await; let _send_guard = self.broadcasts.acquire_send_lock(vault_id).await;
query query
.execute(&mut *transaction) .execute(transaction.connection_mut()?)
.await .await
.context("Cannot insert document version")?; .context("Cannot insert document version")?;
@ -821,7 +808,8 @@ impl Database {
} else { } else {
WebSocketServerMessageWithOrigin::with_origin(version.device_id.clone(), envelope) WebSocketServerMessageWithOrigin::with_origin(version.device_id.clone(), envelope)
}; };
self.broadcasts.send_document_update(vault_id, with_origin); self.broadcasts
.send_document_update(vault_id, with_origin)?;
Ok(()) Ok(())
} }

View file

@ -7,7 +7,11 @@ use log::{debug, info, warn};
use tokio::sync::{Mutex, broadcast}; use tokio::sync::{Mutex, broadcast};
use super::models::{WebSocketServerMessage, WebSocketServerMessageWithOrigin}; use super::models::{WebSocketServerMessage, WebSocketServerMessageWithOrigin};
use crate::{app_state::database::models::VaultId, config::server_config::ServerConfig}; use crate::{
app_state::database::models::VaultId,
config::server_config::ServerConfig,
errors::{SyncServerError, client_error, server_error},
};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Broadcasts { pub struct Broadcasts {
@ -60,30 +64,31 @@ impl Broadcasts {
pub fn get_receiver( pub fn get_receiver(
&self, &self,
vault: VaultId, vault: &VaultId,
max_clients: usize, max_clients: usize,
) -> Result<broadcast::Receiver<WebSocketServerMessageWithOrigin>, crate::errors::SyncServerError> ) -> Result<broadcast::Receiver<WebSocketServerMessageWithOrigin>, SyncServerError> {
{
let mut tx_map = self let mut tx_map = self
.tx .tx
.lock() .lock()
.expect("broadcasts.tx mutex poisoned — a previous holder panicked"); .map_err(|_| server_error(anyhow::anyhow!("broadcasts.tx mutex poisoned")))?;
let count_before_prune = tx_map let count_before_prune = tx_map
.get(&vault) .get(vault)
.map_or(0, tokio::sync::broadcast::Sender::receiver_count); .map_or(0, tokio::sync::broadcast::Sender::receiver_count);
let pruned = Self::prune_inactive_vaults(&mut tx_map); let pruned = Self::prune_inactive_vaults(&mut tx_map);
let pruned_self = pruned.contains(&vault); let pruned_self = pruned
.iter()
.any(|pruned_vault| pruned_vault.as_str() == vault);
let sender = tx_map let sender = tx_map
.entry(vault.clone()) .entry(vault.to_owned())
.or_insert_with(|| broadcast::channel(self.broadcast_channel_capacity).0); .or_insert_with(|| broadcast::channel(self.broadcast_channel_capacity).0);
// Hold the lock across the count check *and* the subscribe so the // Hold the lock across the count check *and* the subscribe so the
// `max_clients` cap is atomic: two concurrent callers can't both // `max_clients` cap is atomic: two concurrent callers can't both
// observe `receiver_count() < max_clients` and both subscribe. // observe `receiver_count() < max_clients` and both subscribe.
if sender.receiver_count() >= max_clients { if sender.receiver_count() >= max_clients {
return Err(crate::errors::client_error(anyhow::anyhow!( return Err(client_error(anyhow::anyhow!(
"Vault has reached the maximum number of clients ({max_clients})" "Vault has reached the maximum number of clients ({max_clients})"
))); )));
} }
@ -100,8 +105,13 @@ impl Broadcasts {
/// Notify all clients (who are subscribed to the vault) about an update. /// Notify all clients (who are subscribed to the vault) about an update.
/// Synchronous: safe to invoke from a handler between `commit()` and /// Synchronous: safe to invoke from a handler between `commit()` and
/// function return without worrying about task cancellation dropping /// function return without worrying about task cancellation dropping
/// the broadcast mid-flight. Failures are logged, never propagated. /// the broadcast mid-flight. Mutex poison is returned; send failures
pub fn send_document_update(&self, vault: VaultId, document: WebSocketServerMessageWithOrigin) { /// are logged because they can happen when receivers disconnect.
pub fn send_document_update(
&self,
vault: &str,
document: WebSocketServerMessageWithOrigin,
) -> Result<(), SyncServerError> {
let vault_update_id = match &document.message { let vault_update_id = match &document.message {
WebSocketServerMessage::VaultUpdate(u) => Some(u.document.vault_update_id), WebSocketServerMessage::VaultUpdate(u) => Some(u.document.vault_update_id),
WebSocketServerMessage::CursorPositions(_) => None, WebSocketServerMessage::CursorPositions(_) => None,
@ -110,18 +120,21 @@ impl Broadcasts {
WebSocketServerMessage::VaultUpdate(u) => Some(u.document.is_deleted), WebSocketServerMessage::VaultUpdate(u) => Some(u.document.is_deleted),
WebSocketServerMessage::CursorPositions(_) => None, WebSocketServerMessage::CursorPositions(_) => None,
}; };
let mut tx_map = self let mut tx_map = self.tx.lock().map_err(|_| {
.tx server_error(anyhow::anyhow!(
.lock() "broadcasts.tx mutex poisoned; skipping document update broadcast"
.expect("broadcasts.tx mutex poisoned — a previous holder panicked"); ))
})?;
let count_before_prune = tx_map let count_before_prune = tx_map
.get(&vault) .get(vault)
.map_or(0, tokio::sync::broadcast::Sender::receiver_count); .map_or(0, tokio::sync::broadcast::Sender::receiver_count);
let pruned = Self::prune_inactive_vaults(&mut tx_map); let pruned = Self::prune_inactive_vaults(&mut tx_map);
let pruned_self = pruned.contains(&vault); let pruned_self = pruned
.iter()
.any(|pruned_vault| pruned_vault.as_str() == vault);
let sender = tx_map let sender = tx_map
.entry(vault.clone()) .entry(vault.to_owned())
.or_insert_with(|| broadcast::channel(self.broadcast_channel_capacity).0); .or_insert_with(|| broadcast::channel(self.broadcast_channel_capacity).0);
let count_before_send = sender.receiver_count(); let count_before_send = sender.receiver_count();
@ -131,7 +144,7 @@ impl Broadcasts {
"[BCAST] send_document_update vault={vault} vuid={vault_update_id:?} is_deleted={is_deleted:?} count_before_prune={count_before_prune} pruned_self={pruned_self} count_before_send=0 SKIPPED" "[BCAST] send_document_update vault={vault} vuid={vault_update_id:?} is_deleted={is_deleted:?} count_before_prune={count_before_prune} pruned_self={pruned_self} count_before_send=0 SKIPPED"
); );
debug!("Skipping broadcast, no clients connected for vault `{vault}`"); debug!("Skipping broadcast, no clients connected for vault `{vault}`");
return; return Ok(());
} }
let send_result = sender.send(document); let send_result = sender.send(document);
@ -143,5 +156,6 @@ impl Broadcasts {
"[BCAST] send_document_update vault={vault} vuid={vault_update_id:?} is_deleted={is_deleted:?} count_before_prune={count_before_prune} pruned_self={pruned_self} count_before_send={count_before_send} FAILED err={e}" "[BCAST] send_document_update vault={vault} vuid={vault_update_id:?} is_deleted={is_deleted:?} count_before_prune={count_before_prune} pruned_self={pruned_self} count_before_send={count_before_send} FAILED err={e}"
), ),
} }
Ok(())
} }
} }

View file

@ -23,9 +23,10 @@ impl ColorWhen {
impl std::fmt::Display for ColorWhen { impl std::fmt::Display for ColorWhen {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.to_possible_value() f.write_str(match self {
.expect("no values are skipped") Self::Always => "always",
.get_name() Self::Auto => "auto",
.fmt(f) Self::Never => "never",
})
} }
} }

View file

@ -71,6 +71,10 @@ impl ServerConfig {
self.max_pending_websocket_connections > 0, self.max_pending_websocket_connections > 0,
"max_pending_websocket_connections must be greater than 0" "max_pending_websocket_connections must be greater than 0"
); );
ensure!(
self.rate_limit_per_user_per_second != Some(0),
"rate_limit_per_user_per_second must be greater than 0 when set (use null to disable rate limiting)"
);
Ok(()) Ok(())
} }

View file

@ -20,15 +20,7 @@ where
let mut user_token_map = BiHashMap::new(); let mut user_token_map = BiHashMap::new();
for user in &users { for user in &users {
if let Some(existing_name) = user_token_map.get_by_right(&user.token) { if let Some(existing_name) = user_token_map.get_by_right(&user.token) {
let redacted = if user.token.len() > 6 { let redacted = redact_token(&user.token);
format!(
"{}...{}",
&user.token[..3],
&user.token[user.token.len() - 3..]
)
} else {
"***".to_owned()
};
return Err(D::Error::custom(format!( return Err(D::Error::custom(format!(
"Duplicate user token found: `{redacted}` for users `{}` and `{}`. User tokens \ "Duplicate user token found: `{redacted}` for users `{}` and `{}`. User tokens \
must be unique.", must be unique.",
@ -49,6 +41,23 @@ where
Ok(users) Ok(users)
} }
fn redact_token(token: &str) -> String {
if token.chars().count() <= 6 {
return "***".to_owned();
}
let prefix = token.chars().take(3).collect::<String>();
let suffix = token
.chars()
.rev()
.take(3)
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect::<String>();
format!("{prefix}...{suffix}")
}
impl UserConfig { impl UserConfig {
pub fn get_user(&self, token: &str) -> Option<&User> { pub fn get_user(&self, token: &str) -> Option<&User> {
self.user_configs self.user_configs

View file

@ -1,3 +1,19 @@
#![cfg_attr(
test,
allow(
clippy::arithmetic_side_effects,
clippy::expect_used,
clippy::indexing_slicing,
clippy::panic,
clippy::panic_in_result_fn,
clippy::todo,
clippy::unimplemented,
clippy::unreachable,
clippy::unwrap_in_result,
clippy::unwrap_used
)
)]
mod app_state; mod app_state;
mod cli; mod cli;
mod config; mod config;

View file

@ -71,7 +71,13 @@ pub async fn create_server(config: Config) -> Result<()> {
let app = app let app = app
.layer(DefaultBodyLimit::disable()) .layer(DefaultBodyLimit::disable())
.layer(RequestBodyLimitLayer::new( .layer(RequestBodyLimitLayer::new(
app_state.config.server.max_body_size_mb * 1024 * 1024, app_state
.config
.server
.max_body_size_mb
.checked_mul(1024)
.and_then(|kb| kb.checked_mul(1024))
.context("max_body_size_mb is too large")?,
)) ))
.layer(TimeoutLayer::new(server_config.response_timeout)) .layer(TimeoutLayer::new(server_config.response_timeout))
.layer(cors_layer) .layer(cors_layer)
@ -104,7 +110,7 @@ pub async fn create_server(config: Config) -> Result<()> {
fn build_cors_layer(server_config: &ServerConfig) -> Result<CorsLayer> { fn build_cors_layer(server_config: &ServerConfig) -> Result<CorsLayer> {
let origins = &server_config.allowed_origins; let origins = &server_config.allowed_origins;
let cors = if origins.len() == 1 && origins[0] == "*" { let cors = if origins.len() == 1 && origins.first().is_some_and(|origin| origin == "*") {
info!("CORS: allowing all origins"); info!("CORS: allowing all origins");
let header: HeaderValue = "*" let header: HeaderValue = "*"
.parse() .parse()

View file

@ -60,7 +60,7 @@ pub async fn create_document(
.get_latest_non_deleted_document_by_path( .get_latest_non_deleted_document_by_path(
&vault_id, &vault_id,
&sanitized_relative_path, &sanitized_relative_path,
Some(&mut *transaction), Some(transaction.connection_mut().map_err(server_error)?),
) )
.await .await
.map_err(server_error)?; .map_err(server_error)?;
@ -129,7 +129,7 @@ pub async fn create_document(
&device_id.0, &device_id.0,
request.last_seen_vault_update_id, request.last_seen_vault_update_id,
&new_content, &new_content,
Some(&mut *transaction), Some(transaction.connection_mut().map_err(server_error)?),
) )
.await .await
.map_err(server_error)? .map_err(server_error)?
@ -157,7 +157,10 @@ pub async fn create_document(
let last_update_id = state let last_update_id = state
.database .database
.get_max_update_id_in_vault(&vault_id, Some(&mut *transaction)) .get_max_update_id_in_vault(
&vault_id,
Some(transaction.connection_mut().map_err(server_error)?),
)
.await .await
.map_err(server_error)?; .map_err(server_error)?;
@ -176,7 +179,9 @@ pub async fn create_document(
); );
} }
let new_vault_update_id = last_update_id + 1; let new_vault_update_id = last_update_id
.checked_add(1)
.ok_or_else(|| server_error(anyhow::anyhow!("Vault update id overflow")))?;
let new_version = StoredDocumentVersion { let new_version = StoredDocumentVersion {
vault_update_id: new_vault_update_id, vault_update_id: new_vault_update_id,
creation_vault_update_id: new_vault_update_id, creation_vault_update_id: new_vault_update_id,

View file

@ -48,13 +48,20 @@ pub async fn delete_document(
let last_update_id = state let last_update_id = state
.database .database
.get_max_update_id_in_vault(&vault_id, Some(&mut transaction)) .get_max_update_id_in_vault(
&vault_id,
Some(transaction.connection_mut().map_err(server_error)?),
)
.await .await
.map_err(server_error)?; .map_err(server_error)?;
let latest_version = state let latest_version = state
.database .database
.get_latest_document(&vault_id, &document_id, Some(&mut transaction)) .get_latest_document(
&vault_id,
&document_id,
Some(transaction.connection_mut().map_err(server_error)?),
)
.await .await
.map_err(server_error)?; .map_err(server_error)?;
@ -80,7 +87,9 @@ pub async fn delete_document(
return Ok(Json(latest_version.into())); return Ok(Json(latest_version.into()));
} }
let new_vault_update_id = last_update_id + 1; let new_vault_update_id = last_update_id
.checked_add(1)
.ok_or_else(|| server_error(anyhow!("Vault update id overflow")))?;
let latest_relative_path = latest_version.relative_path; let latest_relative_path = latest_version.relative_path;
let latest_content = latest_version.content; let latest_content = latest_version.content;
let creation_vault_update_id = latest_version.creation_vault_update_id; let creation_vault_update_id = latest_version.creation_vault_update_id;

View file

@ -32,26 +32,23 @@ struct BucketState {
impl RateLimiter { impl RateLimiter {
/// Create a new per-user rate limiter. /// Create a new per-user rate limiter.
///
/// # Panics
///
/// Panics if `max_per_second` is 0.
pub fn new(max_per_second: u64) -> Self { pub fn new(max_per_second: u64) -> Self {
assert!(
max_per_second > 0,
"max_per_second must be > 0 (set rate_limit_per_user_per_second to null in config to disable)"
);
Self { Self {
max_per_second, max_per_second,
buckets: Arc::new(Mutex::new(HashMap::new())), buckets: Arc::new(Mutex::new(HashMap::new())),
} }
} }
fn get_or_create_bucket(&self, token: &str) -> Arc<TokenBucket> { fn get_or_create_bucket(
self.buckets &self,
token: &str,
) -> std::result::Result<Arc<TokenBucket>, StatusCode> {
let mut buckets = self
.buckets
.lock() .lock()
.expect("rate limiter lock poisoned") .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(buckets
.entry(token.to_owned()) .entry(token.to_owned())
.or_insert_with(|| { .or_insert_with(|| {
Arc::new(TokenBucket { Arc::new(TokenBucket {
@ -62,23 +59,26 @@ impl RateLimiter {
max_tokens: self.max_per_second, max_tokens: self.max_per_second,
}) })
}) })
.clone() .clone())
} }
} }
impl TokenBucket { impl TokenBucket {
fn try_acquire(&self) -> bool { fn try_acquire(&self) -> std::result::Result<bool, StatusCode> {
let mut state = self.state.lock().expect("token bucket lock poisoned"); let mut state = self
.state
.lock()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let now = Instant::now(); let now = Instant::now();
if now.duration_since(state.last_refill).as_secs() >= 1 { if now.duration_since(state.last_refill).as_secs() >= 1 {
state.tokens = self.max_tokens; state.tokens = self.max_tokens;
state.last_refill = now; state.last_refill = now;
} }
if state.tokens > 0 { if state.tokens > 0 {
state.tokens -= 1; state.tokens = state.tokens.saturating_sub(1);
true Ok(true)
} else { } else {
false Ok(false)
} }
} }
} }
@ -88,13 +88,13 @@ pub async fn rate_limit_middleware(
auth_header: Option<TypedHeader<Authorization<Bearer>>>, auth_header: Option<TypedHeader<Authorization<Bearer>>>,
req: Request, req: Request,
next: Next, next: Next,
) -> Result<Response, StatusCode> { ) -> std::result::Result<Response, StatusCode> {
let Some(TypedHeader(auth)) = auth_header else { let Some(TypedHeader(auth)) = auth_header else {
return Ok(next.run(req).await); return Ok(next.run(req).await);
}; };
let bucket = limiter.get_or_create_bucket(auth.token()); let bucket = limiter.get_or_create_bucket(auth.token())?;
if bucket.try_acquire() { if bucket.try_acquire()? {
Ok(next.run(req).await) Ok(next.run(req).await)
} else { } else {
Err(StatusCode::TOO_MANY_REQUESTS) Err(StatusCode::TOO_MANY_REQUESTS)

View file

@ -27,7 +27,7 @@ use crate::{
}, },
server::requests::UpdateBinaryDocumentVersion, server::requests::UpdateBinaryDocumentVersion,
utils::{ utils::{
find_first_available_path::find_first_available_path, is_binary::is_binary, find_first_available_path::find_first_available_path, is_binary::as_non_binary_text,
is_file_type_mergable::is_file_type_mergable, normalize::normalize, is_file_type_mergable::is_file_type_mergable, normalize::normalize,
sanitize_path::sanitize_path, sanitize_path::sanitize_path,
}, },
@ -173,13 +173,20 @@ pub async fn update_document(
let last_update_id = state let last_update_id = state
.database .database
.get_max_update_id_in_vault(&vault_id, Some(&mut transaction)) .get_max_update_id_in_vault(
&vault_id,
Some(transaction.connection_mut().map_err(server_error)?),
)
.await .await
.map_err(server_error)?; .map_err(server_error)?;
let latest_version = state let latest_version = state
.database .database
.get_latest_document(&vault_id, &document_id, Some(&mut transaction)) .get_latest_document(
&vault_id,
&document_id,
Some(transaction.connection_mut().map_err(server_error)?),
)
.await .await
.map_err(server_error)? .map_err(server_error)?
.map_or_else( .map_or_else(
@ -225,64 +232,56 @@ pub async fn update_document(
))); )));
} }
// For mergability, use whichever path the new version will live at — the // For mergability, use whichever path the new version will live at:
// requested rename target if the client sent one, otherwise the existing // - the requested rename target if the client sent one
// server-side path. // - otherwise the existing server-side path.
let mergable_check_path = sanitized_relative_path let mergable_check_path = sanitized_relative_path
.as_deref() .as_deref()
.unwrap_or(&latest_version.relative_path); .unwrap_or(&latest_version.relative_path);
let are_all_participants_mergable = is_file_type_mergable(
let mergeable_texts = if is_file_type_mergable(
mergable_check_path, mergable_check_path,
&state.config.server.mergeable_file_extensions, &state.config.server.mergeable_file_extensions,
) && !is_binary(&parent_content) ) {
&& !is_binary(&latest_version.content) as_non_binary_texts(&parent_content, &latest_version.content, &content)
&& !is_binary(&content);
let (merged_content, is_different_from_request_content) = if are_all_participants_mergable {
info!("Merging changes for document `{document_id}` in vault `{vault_id}`");
let parent_text = str::from_utf8(&parent_content)
.context("Parent document content is not valid UTF-8")
.map_err(client_error)?;
let latest_text = str::from_utf8(&latest_version.content)
.context("Latest version content is not valid UTF-8")
.map_err(client_error)?;
let new_text = str::from_utf8(&content)
.context("New content is not valid UTF-8")
.map_err(client_error)?;
let parent_owned = parent_text.to_owned();
let latest_owned = latest_text.to_owned();
let new_owned = new_text.to_owned();
let content_clone = content.clone();
let (merged, is_different) = tokio::task::spawn_blocking(move || {
let merged = reconcile(
&parent_owned,
&latest_owned.into(),
&new_owned.into(),
&*BuiltinTokenizer::Word,
)
.apply()
.text()
.into_bytes();
let is_different = merged != content_clone;
(merged, is_different)
})
.await
.map_err(|e| server_error(anyhow::anyhow!("Reconcile task failed: {e}")))?;
(merged, is_different)
} else { } else {
(content, false) // false means that the client doesn't need to refetch the file as we can ensure the remote and local versions are the same as LWW is the merging method for binary files None
}; };
let are_all_participants_mergable = mergeable_texts.is_some();
// Rename resolution: only apply the client's rename if (a) the client let (merged_content, is_same_as_request) =
// requested one (`sanitized_relative_path` is `Some`) and (b) the if let Some((parent_text, latest_text, new_text)) = mergeable_texts {
// document's path hasn't changed since this client's parent version. info!("Merging changes for document `{document_id}` in vault `{vault_id}`");
// If the parent and latest paths differ, another client already renamed
// the document — keep the latest path (first rename wins). Content let parent_owned = parent_text.to_owned();
// changes from both clients are still merged correctly via the 3-way let latest_owned = latest_text.to_owned();
// reconcile above, independent of which rename wins. A missing let new_owned = new_text.to_owned();
// relative_path means "keep current path" (content-only edit). let content_clone = content.clone();
let merged = tokio::task::spawn_blocking(move || {
let merged = reconcile(
&parent_owned,
&latest_owned.into(),
&new_owned.into(),
&*BuiltinTokenizer::Word,
)
.apply()
.text()
.into_bytes();
merged
})
.await
.map_err(|e| server_error(anyhow::anyhow!("Reconcile task failed: {e}")))?;
let is_same = merged == content_clone;
(merged, is_same)
} else {
(content, true) // true means that the client doesn't need to refetch the file as we can ensure the remote and local versions are the same as LWW is the merging method for binary files
};
// First rename wins: apply the client's rename only if the doc's path
// hasn't changed since its parent version. Content from both clients
// still merges via the 3-way reconcile above
let new_relative_path = match sanitized_relative_path.as_deref() { let new_relative_path = match sanitized_relative_path.as_deref() {
Some(requested) Some(requested)
if parent_relative_path == latest_version.relative_path if parent_relative_path == latest_version.relative_path
@ -306,7 +305,9 @@ pub async fn update_document(
let new_version = StoredDocumentVersion { let new_version = StoredDocumentVersion {
document_id, document_id,
vault_update_id: last_update_id + 1, vault_update_id: last_update_id
.checked_add(1)
.ok_or_else(|| server_error(anyhow!("Vault update id overflow")))?,
creation_vault_update_id: latest_version.creation_vault_update_id, creation_vault_update_id: latest_version.creation_vault_update_id,
relative_path: new_relative_path, relative_path: new_relative_path,
content: merged_content, content: merged_content,
@ -314,7 +315,7 @@ pub async fn update_document(
is_deleted: false, is_deleted: false,
user_id: user.name, user_id: user.name,
device_id: device_id.0, device_id: device_id.0,
has_been_merged: are_all_participants_mergable && is_different_from_request_content, has_been_merged: are_all_participants_mergable && !is_same_as_request,
}; };
state state
@ -323,9 +324,21 @@ pub async fn update_document(
.await .await
.map_err(server_error)?; .map_err(server_error)?;
Ok(Json(if is_different_from_request_content { Ok(Json(if is_same_as_request {
DocumentUpdateResponse::MergingUpdate(new_version.into())
} else {
DocumentUpdateResponse::FastForwardUpdate(new_version.into()) DocumentUpdateResponse::FastForwardUpdate(new_version.into())
} else {
DocumentUpdateResponse::MergingUpdate(new_version.into())
})) }))
} }
fn as_non_binary_texts<'a>(
parent_content: &'a [u8],
latest_content: &'a [u8],
new_content: &'a [u8],
) -> Option<(&'a str, &'a str, &'a str)> {
Some((
as_non_binary_text(parent_content)?,
as_non_binary_text(latest_content)?,
as_non_binary_text(new_content)?,
))
}

View file

@ -306,7 +306,7 @@ async fn websocket(
&device_id, &device_id,
docs, docs,
) )
.await; .await?;
} }
} }
} }
@ -351,7 +351,7 @@ async fn websocket(
state state
.cursors .cursors
.remove_cursors_of_device(&vault_id, &authed_handshake.handshake.device_id) .remove_cursors_of_device(&vault_id, &authed_handshake.handshake.device_id)
.await; .await?;
match &result { match &result {
Ok(()) => { Ok(()) => {

View file

@ -1,10 +1,3 @@
use std::sync::LazyLock;
use regex::Regex;
static DEDUP_SUFFIX_REGEX: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r" \((\d+)\)$").expect("invalid regex"));
pub fn dedup_paths(path: &str) -> impl Iterator<Item = String> { pub fn dedup_paths(path: &str) -> impl Iterator<Item = String> {
let mut path_parts = path.split('/').collect::<Vec<_>>(); let mut path_parts = path.split('/').collect::<Vec<_>>();
let file_name = path_parts let file_name = path_parts
@ -24,29 +17,19 @@ pub fn dedup_paths(path: &str) -> impl Iterator<Item = String> {
let (stem, extension) = if is_simple_dotfile { let (stem, extension) = if is_simple_dotfile {
(file_name.clone(), String::new()) (file_name.clone(), String::new())
} else { } else {
// Regular file or dotfile with extension match file_name.rsplit_once('.') {
let name_parts = file_name.rsplitn(2, '.').collect::<Vec<_>>(); Some((stem, extension)) => (stem.to_owned(), format!(".{extension}")),
let mut reverse_parts = name_parts.into_iter().rev(); None => (file_name.clone(), String::new()),
match (reverse_parts.next(), reverse_parts.next()) {
(Some(stem), maybe_extension) => (
stem.to_owned(),
maybe_extension
.map(|ext| format!(".{ext}"))
.unwrap_or_default(),
),
_ => unreachable!("Path must have at least one part"),
} }
}; };
let start_number = DEDUP_SUFFIX_REGEX let (clean_stem, start_number) = strip_dedup_suffix(&stem);
.captures(&stem) let clean_stem = clean_stem.to_owned();
.and_then(|caps| caps.get(1))
.and_then(|m| m.as_str().parse::<u32>().ok())
.unwrap_or(0);
let clean_stem = DEDUP_SUFFIX_REGEX.replace(&stem, "").to_string(); std::iter::successors(Some(start_number), |dedup_number| {
dedup_number.checked_add(1)
(start_number..).map(move |dedup_number| { })
.map(move |dedup_number| {
if dedup_number == 0 { if dedup_number == 0 {
format!("{directory}{clean_stem}{extension}") format!("{directory}{clean_stem}{extension}")
} else { } else {
@ -55,6 +38,20 @@ pub fn dedup_paths(path: &str) -> impl Iterator<Item = String> {
}) })
} }
fn strip_dedup_suffix(stem: &str) -> (&str, u64) {
let Some(without_closing_paren) = stem.strip_suffix(')') else {
return (stem, 0);
};
let Some((clean_stem, number)) = without_closing_paren.rsplit_once(" (") else {
return (stem, 0);
};
if number.is_empty() || !number.chars().all(|c| c.is_ascii_digit()) {
return (stem, 0);
}
(clean_stem, number.parse::<u64>().unwrap_or(0))
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::*; use super::*;
@ -103,7 +100,7 @@ mod test {
} }
#[test] #[test]
fn test_regex_capturing_group() { fn test_dedup_suffix_parsing() {
// Single digit in parentheses // Single digit in parentheses
let mut deduped = dedup_paths("document (5).md"); let mut deduped = dedup_paths("document (5).md");
assert_eq!(deduped.next(), Some("document (5).md".to_owned())); assert_eq!(deduped.next(), Some("document (5).md".to_owned()));

View file

@ -1,20 +1,23 @@
use crate::app_state::database::models::VaultId; use crate::app_state::database::{WriteTransaction, models::VaultId};
use crate::utils::dedup_paths::dedup_paths; use crate::utils::dedup_paths::dedup_paths;
use anyhow::Result; use anyhow::{Result, anyhow};
use log::{debug, info}; use log::{debug, info};
use sqlx::sqlite::SqliteConnection;
pub async fn find_first_available_path( pub async fn find_first_available_path(
vault_id: &VaultId, vault_id: &VaultId,
sanitized_relative_path: &str, sanitized_relative_path: &str,
database: &crate::app_state::database::Database, database: &crate::app_state::database::Database,
connection: &mut SqliteConnection, transaction: &mut WriteTransaction,
) -> Result<String> { ) -> Result<String> {
info!("Finding first available path for `{sanitized_relative_path}` in vault `{vault_id}`"); info!("Finding first available path for `{sanitized_relative_path}` in vault `{vault_id}`");
for candidate in dedup_paths(sanitized_relative_path) { for candidate in dedup_paths(sanitized_relative_path) {
debug!("Checking candidate path for deconflicting names: `{candidate}`"); debug!("Checking candidate path for deconflicting names: `{candidate}`");
if database if database
.get_latest_non_deleted_document_by_path(vault_id, &candidate, Some(connection)) .get_latest_non_deleted_document_by_path(
vault_id,
&candidate,
Some(transaction.connection_mut()?),
)
.await? .await?
.is_none() .is_none()
{ {
@ -27,5 +30,7 @@ pub async fn find_first_available_path(
); );
} }
unreachable!("dedup_paths produces infinite paths"); Err(anyhow!(
"No available path candidates produced for `{sanitized_relative_path}` in vault `{vault_id}`"
))
} }

View file

@ -1,16 +1,22 @@
/// Heuristically determine if the given data is a binary or a text file's /// Return the given data as UTF-8 text if it is not considered binary.
/// content.
/// ///
/// Only text inputs can be reconciled using the crate's functions. /// Only text inputs can be reconciled using the crate's functions.
#[must_use] #[must_use]
pub fn is_binary(data: &[u8]) -> bool { pub fn as_non_binary_text(data: &[u8]) -> Option<&str> {
if data.contains(&0) { if data.contains(&0) {
// Even though the NUL character is valid in UTF-8, it's highly suspicious in // Even though the NUL character is valid in UTF-8, it's highly suspicious in
// human-readable text. // human-readable text.
return true; return None;
} }
std::str::from_utf8(data).is_err() std::str::from_utf8(data).ok()
}
/// Heuristically determine if the given data is a binary or a text file's
/// content.
#[must_use]
pub fn is_binary(data: &[u8]) -> bool {
as_non_binary_text(data).is_none()
} }
#[cfg(test)] #[cfg(test)]
@ -23,4 +29,11 @@ mod tests {
assert!(is_binary(&[0, 12])); assert!(is_binary(&[0, 12]));
assert!(!is_binary(b"hello")); assert!(!is_binary(b"hello"));
} }
#[test]
fn test_as_non_binary_text() {
assert_eq!(as_non_binary_text(b"hello"), Some("hello"));
assert_eq!(as_non_binary_text(&[0, 12]), None);
assert_eq!(as_non_binary_text(&[0xff]), None);
}
} }

View file

@ -52,14 +52,14 @@ impl RotatingFileWriter {
/// Parse timestamp from log filename and return as `SystemTime` /// Parse timestamp from log filename and return as `SystemTime`
fn parse_log_timestamp(filename: &str, file_prefix: &str) -> Option<SystemTime> { fn parse_log_timestamp(filename: &str, file_prefix: &str) -> Option<SystemTime> {
// Expected format: {prefix}.{timestamp}.log where timestamp is %Y-%m-%d_%H-%M-%S // Expected format: {prefix}.{timestamp}.log where timestamp is %Y-%m-%d_%H-%M-%S
let prefix_len = file_prefix.len() + 1; // +1 for the dot let prefix_len = file_prefix.len().checked_add(1)?; // +1 for the dot
let timestamp_str = filename.get(prefix_len..filename.len().checked_sub(4)?)?; let timestamp_str = filename.get(prefix_len..filename.len().checked_sub(4)?)?;
let dt = NaiveDateTime::parse_from_str(timestamp_str, "%Y-%m-%d_%H-%M-%S").ok()?; let dt = NaiveDateTime::parse_from_str(timestamp_str, "%Y-%m-%d_%H-%M-%S").ok()?;
let timestamp = dt.and_utc(); let timestamp = dt.and_utc();
let secs: u64 = timestamp.timestamp().try_into().ok()?; let secs: u64 = timestamp.timestamp().try_into().ok()?;
Some(UNIX_EPOCH + Duration::from_secs(secs)) UNIX_EPOCH.checked_add(Duration::from_secs(secs))
} }
fn find_latest_log_file(directory: &Path, file_prefix: &str) -> Option<String> { fn find_latest_log_file(directory: &Path, file_prefix: &str) -> Option<String> {
@ -86,7 +86,9 @@ impl RotatingFileWriter {
Self::find_latest_log_file(directory, file_prefix) Self::find_latest_log_file(directory, file_prefix)
.and_then(|filename| Self::parse_log_timestamp(&filename, file_prefix)) .and_then(|filename| Self::parse_log_timestamp(&filename, file_prefix))
.map_or_else(SystemTime::now, |last_rotation| { .map_or_else(SystemTime::now, |last_rotation| {
last_rotation + rotation_duration last_rotation
.checked_add(rotation_duration)
.unwrap_or_else(SystemTime::now)
}) })
} }
@ -136,7 +138,9 @@ impl RotatingFileWriter {
.open(&filepath)?; .open(&filepath)?;
inner.current_file = Some(file); inner.current_file = Some(file);
inner.next_rotation_time = SystemTime::now() + inner.rotation_duration; inner.next_rotation_time = SystemTime::now()
.checked_add(inner.rotation_duration)
.unwrap_or_else(SystemTime::now);
Ok(()) Ok(())
} }