Improve diff

This commit is contained in:
Andras Schmelczer 2026-05-09 16:27:48 +01:00
parent 792f57dc7e
commit e5373ab2bb
23 changed files with 312 additions and 220 deletions

View file

@ -85,7 +85,7 @@ describe("File operations", () => {
const result = await ops.create("a", new Uint8Array());
assertSetContainsExactly(fs.names, "a");
assert.equal(result.actualPath, "a");
assert.equal(result, "a");
});
it("create throws FileAlreadyExistsError when the path is occupied", async () => {
@ -109,7 +109,7 @@ describe("File operations", () => {
const result = await ops.move("a", "b");
assertSetContainsExactly(fs.names, "b");
assert.equal(result.actualPath, "b");
assert.equal(result, "b");
});
it("move with same source and target is a no-op", async () => {
@ -119,7 +119,7 @@ describe("File operations", () => {
const result = await ops.move("a", "a");
assertSetContainsExactly(fs.names, "a");
assert.equal(result.actualPath, "a");
assert.equal(result, "a");
});
it("move throws FileAlreadyExistsError when the target is occupied", async () => {

View file

@ -11,16 +11,6 @@ import { FileNotFoundError } from "../errors/file-not-found-error";
import { FileAlreadyExistsError } from "../errors/file-already-exists-error";
import type { ExpectedFsEvents } from "../sync-operations/expected-fs-events";
/**
* Outcome of a `move`/`create`. `actualPath` is where the file ended up;
* with the conflict-path machinery removed it is always equal to the
* requested path. The shape is preserved so callers don't all need to
* change.
*/
export interface FileOpResult {
actualPath: RelativePath;
}
export class FileOperations {
private readonly fs: SafeFileSystemOperations;
@ -68,7 +58,7 @@ export class FileOperations {
public async create(
path: RelativePath,
newContent: Uint8Array
): Promise<FileOpResult> {
): Promise<RelativePath> {
if (await this.fs.exists(path)) {
throw new FileAlreadyExistsError(
`Refusing to create '${path}': file already exists`,
@ -84,7 +74,7 @@ export class FileOperations {
this.expectedFsEvents.unexpectCreate(path);
throw e;
}
return { actualPath: path };
return path;
}
/**
@ -220,9 +210,9 @@ export class FileOperations {
public async move(
oldPath: RelativePath,
newPath: RelativePath
): Promise<FileOpResult> {
): Promise<RelativePath> {
if (oldPath === newPath) {
return { actualPath: oldPath };
return oldPath;
}
if (await this.fs.exists(newPath)) {
@ -241,7 +231,7 @@ export class FileOperations {
throw e;
}
await this.deletingEmptyParentDirectoriesOfDeletedFile(oldPath);
return { actualPath: newPath };
return newPath;
}
private async deletingEmptyParentDirectoriesOfDeletedFile(

View file

@ -1103,7 +1103,7 @@ export class Syncer {
remoteHash,
localPath: target
});
const result = await this.operations.create(
const createdPath = await this.operations.create(
target,
remoteContent
);
@ -1112,7 +1112,7 @@ export class Syncer {
);
localPath =
liveRecord === undefined
? result.actualPath
? createdPath
: liveRecord.localPath;
await this.updateCache(
remoteVersion.vaultUpdateId,

View file

@ -21,9 +21,10 @@ cargo test --verbose
if [[ "$FIX_MODE" == true ]]; then
cargo clippy --all-targets --all-features --fix --allow-dirty --allow-staged
cargo clippy --all-targets --all-features -- -D warnings
cargo fmt --all
else
cargo clippy --all-targets --all-features
cargo clippy --all-targets --all-features -- -D warnings
cargo fmt --all -- --check
fi

View file

@ -2181,7 +2181,6 @@ dependencies = [
"log",
"rand 0.9.0",
"reconcile-text",
"regex",
"sanitize-filename",
"serde",
"serde_json",

View file

@ -26,7 +26,6 @@ sqlx = { version = "0.8.6", features = ["sqlite", "runtime-tokio", "uuid", "chro
chrono = { version = "0.4.41", features = ["serde"] }
rand = "0.9.0"
sanitize-filename = "0.6.0"
regex = "1.12.2"
clap = { version = "4.5.38", features = ["derive"] }
futures = "0.3.31"
serde_yaml = "0.9.34"
@ -49,16 +48,19 @@ rust_2018_idioms = { level = "warn", priority = -1 }
missing_debug_implementations = "warn"
[lints.clippy]
arithmetic_side_effects = "deny"
await_holding_lock = "warn"
dbg_macro = "warn"
disallowed_macros = { level = "deny", priority = 1 }
empty_enums = "warn"
enum_glob_use = "warn"
expect_used = "deny"
exit = "warn"
filter_map_next = "warn"
fn_params_excessive_bools = "warn"
if_let_mutex = "warn"
imprecise_flops = "warn"
indexing_slicing = "deny"
inefficient_to_string = "warn"
linkedlist = "warn"
lossy_float_literal = "warn"
@ -68,13 +70,19 @@ mem_forget = "warn"
needless_borrow = "warn"
needless_continue = "warn"
option_option = "warn"
panic = "deny"
panic_in_result_fn = "deny"
rest_pat_in_fully_bound_structs = "warn"
str_to_string = "warn"
suboptimal_flops = "warn"
todo = "warn"
todo = "deny"
uninlined_format_args = "warn"
unimplemented = "deny"
unreachable = "deny"
unnested_or_patterns = "warn"
unused_self = "warn"
unwrap_in_result = "deny"
unwrap_used = "deny"
verbose_file_reads = "warn"
large_stack_arrays = { level = "allow", priority = 1 } # https://github.com/rust-lang/rust-clippy/issues/13774
@ -88,7 +96,7 @@ single_call_fn = { level = "allow", priority = 1 }
similar_names = { level = "allow", priority = 1 }
missing_docs_in_private_items = { level = "allow", priority = 1 }
pedantic = { level = "warn", priority = 0 }
pedantic = { level = "warn", priority = -1 }
[package.metadata.cargo-machete]
ignored = ["humantime-serde"] # only used in serde macro

View file

@ -15,6 +15,7 @@ use super::{
};
use crate::{
app_state::websocket::models::DocumentWithCursors, config::database_config::DatabaseConfig,
errors::SyncServerError,
};
#[derive(Clone, Debug)]
@ -39,7 +40,7 @@ impl Cursors {
user_name: String,
device_id: &DeviceId,
document_to_cursors: Vec<DocumentWithCursors>,
) {
) -> Result<(), SyncServerError> {
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
let all_device_cursors = vault_to_cursors
@ -54,7 +55,7 @@ impl Cursors {
}));
drop(vault_to_cursors); // Explicitly drop the lock before broadcasting to avoid deadlock
self.broadcast_cursors_for_vault(&vault_id).await;
self.broadcast_cursors_for_vault(&vault_id).await
}
pub async fn get_cursors(&self, vault_id: &VaultId) -> Vec<ClientCursors> {
@ -76,15 +77,17 @@ impl Cursors {
loop {
tokio::select! {
() = tokio::time::sleep(Duration::from_secs(1)) => {
self.remove_expired_cursors().await;
self.remove_expired_cursors().await?;
}
Ok(()) = shutdown.changed() => break,
}
}
Ok::<(), SyncServerError>(())
});
}
async fn remove_expired_cursors(&self) {
async fn remove_expired_cursors(&self) -> Result<(), SyncServerError> {
let changed_vaults: Vec<VaultId> = {
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
@ -104,11 +107,13 @@ impl Cursors {
};
for vault_id in &changed_vaults {
self.broadcast_cursors_for_vault(vault_id).await;
self.broadcast_cursors_for_vault(vault_id).await?;
}
Ok(())
}
async fn broadcast_cursors_for_vault(&self, vault_id: &VaultId) {
async fn broadcast_cursors_for_vault(&self, vault_id: &VaultId) -> Result<(), SyncServerError> {
let client_cursors: Vec<ClientCursors> = {
let vault_to_cursors = self.vault_to_cursors.lock().await;
vault_to_cursors
@ -124,10 +129,14 @@ impl Cursors {
clients: client_cursors,
},
)),
);
)
}
pub async fn remove_cursors_of_device(&self, vault_id: &VaultId, device_id: &DeviceId) {
pub async fn remove_cursors_of_device(
&self,
vault_id: &VaultId,
device_id: &DeviceId,
) -> Result<(), SyncServerError> {
let changed = {
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
@ -145,8 +154,9 @@ impl Cursors {
};
if changed {
self.broadcast_cursors_for_vault(vault_id).await;
self.broadcast_cursors_for_vault(vault_id).await?;
}
Ok(())
}
}

View file

@ -5,7 +5,7 @@ use std::{
sync::atomic::{AtomicU64, Ordering},
};
use anyhow::{Context as _, Result};
use anyhow::{Context as _, Result, anyhow};
use log::info;
use models::{
DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId, VaultUpdateId,
@ -132,6 +132,12 @@ impl WriteTransaction {
}
Ok(())
}
pub fn connection_mut(&mut self) -> Result<&mut SqliteConnection> {
self.conn
.as_deref_mut()
.context("WriteTransaction already consumed")
}
}
impl Drop for WriteTransaction {
@ -147,25 +153,6 @@ impl Drop for WriteTransaction {
}
}
impl std::ops::Deref for WriteTransaction {
type Target = SqliteConnection;
fn deref(&self) -> &Self::Target {
self.conn
.as_ref()
.expect("BUG: WriteTransaction dereferenced after being consumed")
.deref()
}
}
impl std::ops::DerefMut for WriteTransaction {
fn deref_mut(&mut self) -> &mut Self::Target {
self.conn
.as_mut()
.expect("BUG: WriteTransaction dereferenced after being consumed")
.deref_mut()
}
}
/// Ensure the connection has no leftover open transaction (e.g. from a
/// `WriteTransaction` that was dropped without commit/rollback). ROLLBACK
/// is a harmless no-op if no transaction is active.
@ -797,7 +784,7 @@ impl Database {
let _send_guard = self.broadcasts.acquire_send_lock(vault_id).await;
query
.execute(&mut *transaction)
.execute(transaction.connection_mut()?)
.await
.context("Cannot insert document version")?;
@ -821,7 +808,8 @@ impl Database {
} else {
WebSocketServerMessageWithOrigin::with_origin(version.device_id.clone(), envelope)
};
self.broadcasts.send_document_update(vault_id, with_origin);
self.broadcasts
.send_document_update(vault_id, with_origin)?;
Ok(())
}

View file

@ -7,7 +7,11 @@ use log::{debug, info, warn};
use tokio::sync::{Mutex, broadcast};
use super::models::{WebSocketServerMessage, WebSocketServerMessageWithOrigin};
use crate::{app_state::database::models::VaultId, config::server_config::ServerConfig};
use crate::{
app_state::database::models::VaultId,
config::server_config::ServerConfig,
errors::{SyncServerError, client_error, server_error},
};
#[derive(Debug, Clone)]
pub struct Broadcasts {
@ -60,30 +64,31 @@ impl Broadcasts {
pub fn get_receiver(
&self,
vault: VaultId,
vault: &VaultId,
max_clients: usize,
) -> Result<broadcast::Receiver<WebSocketServerMessageWithOrigin>, crate::errors::SyncServerError>
{
) -> Result<broadcast::Receiver<WebSocketServerMessageWithOrigin>, SyncServerError> {
let mut tx_map = self
.tx
.lock()
.expect("broadcasts.tx mutex poisoned — a previous holder panicked");
.map_err(|_| server_error(anyhow::anyhow!("broadcasts.tx mutex poisoned")))?;
let count_before_prune = tx_map
.get(&vault)
.get(vault)
.map_or(0, tokio::sync::broadcast::Sender::receiver_count);
let pruned = Self::prune_inactive_vaults(&mut tx_map);
let pruned_self = pruned.contains(&vault);
let pruned_self = pruned
.iter()
.any(|pruned_vault| pruned_vault.as_str() == vault);
let sender = tx_map
.entry(vault.clone())
.entry(vault.to_owned())
.or_insert_with(|| broadcast::channel(self.broadcast_channel_capacity).0);
// Hold the lock across the count check *and* the subscribe so the
// `max_clients` cap is atomic: two concurrent callers can't both
// observe `receiver_count() < max_clients` and both subscribe.
if sender.receiver_count() >= max_clients {
return Err(crate::errors::client_error(anyhow::anyhow!(
return Err(client_error(anyhow::anyhow!(
"Vault has reached the maximum number of clients ({max_clients})"
)));
}
@ -100,8 +105,13 @@ impl Broadcasts {
/// Notify all clients (who are subscribed to the vault) about an update.
/// Synchronous: safe to invoke from a handler between `commit()` and
/// function return without worrying about task cancellation dropping
/// the broadcast mid-flight. Failures are logged, never propagated.
pub fn send_document_update(&self, vault: VaultId, document: WebSocketServerMessageWithOrigin) {
/// the broadcast mid-flight. Mutex poison is returned; send failures
/// are logged because they can happen when receivers disconnect.
pub fn send_document_update(
&self,
vault: &str,
document: WebSocketServerMessageWithOrigin,
) -> Result<(), SyncServerError> {
let vault_update_id = match &document.message {
WebSocketServerMessage::VaultUpdate(u) => Some(u.document.vault_update_id),
WebSocketServerMessage::CursorPositions(_) => None,
@ -110,18 +120,21 @@ impl Broadcasts {
WebSocketServerMessage::VaultUpdate(u) => Some(u.document.is_deleted),
WebSocketServerMessage::CursorPositions(_) => None,
};
let mut tx_map = self
.tx
.lock()
.expect("broadcasts.tx mutex poisoned — a previous holder panicked");
let mut tx_map = self.tx.lock().map_err(|_| {
server_error(anyhow::anyhow!(
"broadcasts.tx mutex poisoned; skipping document update broadcast"
))
})?;
let count_before_prune = tx_map
.get(&vault)
.get(vault)
.map_or(0, tokio::sync::broadcast::Sender::receiver_count);
let pruned = Self::prune_inactive_vaults(&mut tx_map);
let pruned_self = pruned.contains(&vault);
let pruned_self = pruned
.iter()
.any(|pruned_vault| pruned_vault.as_str() == vault);
let sender = tx_map
.entry(vault.clone())
.entry(vault.to_owned())
.or_insert_with(|| broadcast::channel(self.broadcast_channel_capacity).0);
let count_before_send = sender.receiver_count();
@ -131,7 +144,7 @@ impl Broadcasts {
"[BCAST] send_document_update vault={vault} vuid={vault_update_id:?} is_deleted={is_deleted:?} count_before_prune={count_before_prune} pruned_self={pruned_self} count_before_send=0 SKIPPED"
);
debug!("Skipping broadcast, no clients connected for vault `{vault}`");
return;
return Ok(());
}
let send_result = sender.send(document);
@ -143,5 +156,6 @@ impl Broadcasts {
"[BCAST] send_document_update vault={vault} vuid={vault_update_id:?} is_deleted={is_deleted:?} count_before_prune={count_before_prune} pruned_self={pruned_self} count_before_send={count_before_send} FAILED err={e}"
),
}
Ok(())
}
}

View file

@ -23,9 +23,10 @@ impl ColorWhen {
impl std::fmt::Display for ColorWhen {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.to_possible_value()
.expect("no values are skipped")
.get_name()
.fmt(f)
f.write_str(match self {
Self::Always => "always",
Self::Auto => "auto",
Self::Never => "never",
})
}
}

View file

@ -71,6 +71,10 @@ impl ServerConfig {
self.max_pending_websocket_connections > 0,
"max_pending_websocket_connections must be greater than 0"
);
ensure!(
self.rate_limit_per_user_per_second != Some(0),
"rate_limit_per_user_per_second must be greater than 0 when set (use null to disable rate limiting)"
);
Ok(())
}

View file

@ -20,15 +20,7 @@ where
let mut user_token_map = BiHashMap::new();
for user in &users {
if let Some(existing_name) = user_token_map.get_by_right(&user.token) {
let redacted = if user.token.len() > 6 {
format!(
"{}...{}",
&user.token[..3],
&user.token[user.token.len() - 3..]
)
} else {
"***".to_owned()
};
let redacted = redact_token(&user.token);
return Err(D::Error::custom(format!(
"Duplicate user token found: `{redacted}` for users `{}` and `{}`. User tokens \
must be unique.",
@ -49,6 +41,23 @@ where
Ok(users)
}
fn redact_token(token: &str) -> String {
if token.chars().count() <= 6 {
return "***".to_owned();
}
let prefix = token.chars().take(3).collect::<String>();
let suffix = token
.chars()
.rev()
.take(3)
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect::<String>();
format!("{prefix}...{suffix}")
}
impl UserConfig {
pub fn get_user(&self, token: &str) -> Option<&User> {
self.user_configs

View file

@ -1,3 +1,19 @@
#![cfg_attr(
test,
allow(
clippy::arithmetic_side_effects,
clippy::expect_used,
clippy::indexing_slicing,
clippy::panic,
clippy::panic_in_result_fn,
clippy::todo,
clippy::unimplemented,
clippy::unreachable,
clippy::unwrap_in_result,
clippy::unwrap_used
)
)]
mod app_state;
mod cli;
mod config;

View file

@ -71,7 +71,13 @@ pub async fn create_server(config: Config) -> Result<()> {
let app = app
.layer(DefaultBodyLimit::disable())
.layer(RequestBodyLimitLayer::new(
app_state.config.server.max_body_size_mb * 1024 * 1024,
app_state
.config
.server
.max_body_size_mb
.checked_mul(1024)
.and_then(|kb| kb.checked_mul(1024))
.context("max_body_size_mb is too large")?,
))
.layer(TimeoutLayer::new(server_config.response_timeout))
.layer(cors_layer)
@ -104,7 +110,7 @@ pub async fn create_server(config: Config) -> Result<()> {
fn build_cors_layer(server_config: &ServerConfig) -> Result<CorsLayer> {
let origins = &server_config.allowed_origins;
let cors = if origins.len() == 1 && origins[0] == "*" {
let cors = if origins.len() == 1 && origins.first().is_some_and(|origin| origin == "*") {
info!("CORS: allowing all origins");
let header: HeaderValue = "*"
.parse()

View file

@ -60,7 +60,7 @@ pub async fn create_document(
.get_latest_non_deleted_document_by_path(
&vault_id,
&sanitized_relative_path,
Some(&mut *transaction),
Some(transaction.connection_mut().map_err(server_error)?),
)
.await
.map_err(server_error)?;
@ -129,7 +129,7 @@ pub async fn create_document(
&device_id.0,
request.last_seen_vault_update_id,
&new_content,
Some(&mut *transaction),
Some(transaction.connection_mut().map_err(server_error)?),
)
.await
.map_err(server_error)?
@ -157,7 +157,10 @@ pub async fn create_document(
let last_update_id = state
.database
.get_max_update_id_in_vault(&vault_id, Some(&mut *transaction))
.get_max_update_id_in_vault(
&vault_id,
Some(transaction.connection_mut().map_err(server_error)?),
)
.await
.map_err(server_error)?;
@ -176,7 +179,9 @@ pub async fn create_document(
);
}
let new_vault_update_id = last_update_id + 1;
let new_vault_update_id = last_update_id
.checked_add(1)
.ok_or_else(|| server_error(anyhow::anyhow!("Vault update id overflow")))?;
let new_version = StoredDocumentVersion {
vault_update_id: new_vault_update_id,
creation_vault_update_id: new_vault_update_id,

View file

@ -48,13 +48,20 @@ pub async fn delete_document(
let last_update_id = state
.database
.get_max_update_id_in_vault(&vault_id, Some(&mut transaction))
.get_max_update_id_in_vault(
&vault_id,
Some(transaction.connection_mut().map_err(server_error)?),
)
.await
.map_err(server_error)?;
let latest_version = state
.database
.get_latest_document(&vault_id, &document_id, Some(&mut transaction))
.get_latest_document(
&vault_id,
&document_id,
Some(transaction.connection_mut().map_err(server_error)?),
)
.await
.map_err(server_error)?;
@ -80,7 +87,9 @@ pub async fn delete_document(
return Ok(Json(latest_version.into()));
}
let new_vault_update_id = last_update_id + 1;
let new_vault_update_id = last_update_id
.checked_add(1)
.ok_or_else(|| server_error(anyhow!("Vault update id overflow")))?;
let latest_relative_path = latest_version.relative_path;
let latest_content = latest_version.content;
let creation_vault_update_id = latest_version.creation_vault_update_id;

View file

@ -32,26 +32,23 @@ struct BucketState {
impl RateLimiter {
/// Create a new per-user rate limiter.
///
/// # Panics
///
/// Panics if `max_per_second` is 0.
pub fn new(max_per_second: u64) -> Self {
assert!(
max_per_second > 0,
"max_per_second must be > 0 (set rate_limit_per_user_per_second to null in config to disable)"
);
Self {
max_per_second,
buckets: Arc::new(Mutex::new(HashMap::new())),
}
}
fn get_or_create_bucket(&self, token: &str) -> Arc<TokenBucket> {
self.buckets
fn get_or_create_bucket(
&self,
token: &str,
) -> std::result::Result<Arc<TokenBucket>, StatusCode> {
let mut buckets = self
.buckets
.lock()
.expect("rate limiter lock poisoned")
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(buckets
.entry(token.to_owned())
.or_insert_with(|| {
Arc::new(TokenBucket {
@ -62,23 +59,26 @@ impl RateLimiter {
max_tokens: self.max_per_second,
})
})
.clone()
.clone())
}
}
impl TokenBucket {
fn try_acquire(&self) -> bool {
let mut state = self.state.lock().expect("token bucket lock poisoned");
fn try_acquire(&self) -> std::result::Result<bool, StatusCode> {
let mut state = self
.state
.lock()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let now = Instant::now();
if now.duration_since(state.last_refill).as_secs() >= 1 {
state.tokens = self.max_tokens;
state.last_refill = now;
}
if state.tokens > 0 {
state.tokens -= 1;
true
state.tokens = state.tokens.saturating_sub(1);
Ok(true)
} else {
false
Ok(false)
}
}
}
@ -88,13 +88,13 @@ pub async fn rate_limit_middleware(
auth_header: Option<TypedHeader<Authorization<Bearer>>>,
req: Request,
next: Next,
) -> Result<Response, StatusCode> {
) -> std::result::Result<Response, StatusCode> {
let Some(TypedHeader(auth)) = auth_header else {
return Ok(next.run(req).await);
};
let bucket = limiter.get_or_create_bucket(auth.token());
if bucket.try_acquire() {
let bucket = limiter.get_or_create_bucket(auth.token())?;
if bucket.try_acquire()? {
Ok(next.run(req).await)
} else {
Err(StatusCode::TOO_MANY_REQUESTS)

View file

@ -27,7 +27,7 @@ use crate::{
},
server::requests::UpdateBinaryDocumentVersion,
utils::{
find_first_available_path::find_first_available_path, is_binary::is_binary,
find_first_available_path::find_first_available_path, is_binary::as_non_binary_text,
is_file_type_mergable::is_file_type_mergable, normalize::normalize,
sanitize_path::sanitize_path,
},
@ -173,13 +173,20 @@ pub async fn update_document(
let last_update_id = state
.database
.get_max_update_id_in_vault(&vault_id, Some(&mut transaction))
.get_max_update_id_in_vault(
&vault_id,
Some(transaction.connection_mut().map_err(server_error)?),
)
.await
.map_err(server_error)?;
let latest_version = state
.database
.get_latest_document(&vault_id, &document_id, Some(&mut transaction))
.get_latest_document(
&vault_id,
&document_id,
Some(transaction.connection_mut().map_err(server_error)?),
)
.await
.map_err(server_error)?
.map_or_else(
@ -225,64 +232,56 @@ pub async fn update_document(
)));
}
// For mergability, use whichever path the new version will live at — the
// requested rename target if the client sent one, otherwise the existing
// server-side path.
// For mergability, use whichever path the new version will live at:
// - the requested rename target if the client sent one
// - otherwise the existing server-side path.
let mergable_check_path = sanitized_relative_path
.as_deref()
.unwrap_or(&latest_version.relative_path);
let are_all_participants_mergable = is_file_type_mergable(
let mergeable_texts = if is_file_type_mergable(
mergable_check_path,
&state.config.server.mergeable_file_extensions,
) && !is_binary(&parent_content)
&& !is_binary(&latest_version.content)
&& !is_binary(&content);
let (merged_content, is_different_from_request_content) = if are_all_participants_mergable {
info!("Merging changes for document `{document_id}` in vault `{vault_id}`");
let parent_text = str::from_utf8(&parent_content)
.context("Parent document content is not valid UTF-8")
.map_err(client_error)?;
let latest_text = str::from_utf8(&latest_version.content)
.context("Latest version content is not valid UTF-8")
.map_err(client_error)?;
let new_text = str::from_utf8(&content)
.context("New content is not valid UTF-8")
.map_err(client_error)?;
let parent_owned = parent_text.to_owned();
let latest_owned = latest_text.to_owned();
let new_owned = new_text.to_owned();
let content_clone = content.clone();
let (merged, is_different) = tokio::task::spawn_blocking(move || {
let merged = reconcile(
&parent_owned,
&latest_owned.into(),
&new_owned.into(),
&*BuiltinTokenizer::Word,
)
.apply()
.text()
.into_bytes();
let is_different = merged != content_clone;
(merged, is_different)
})
.await
.map_err(|e| server_error(anyhow::anyhow!("Reconcile task failed: {e}")))?;
(merged, is_different)
) {
as_non_binary_texts(&parent_content, &latest_version.content, &content)
} else {
(content, false) // false means that the client doesn't need to refetch the file as we can ensure the remote and local versions are the same as LWW is the merging method for binary files
None
};
let are_all_participants_mergable = mergeable_texts.is_some();
// Rename resolution: only apply the client's rename if (a) the client
// requested one (`sanitized_relative_path` is `Some`) and (b) the
// document's path hasn't changed since this client's parent version.
// If the parent and latest paths differ, another client already renamed
// the document — keep the latest path (first rename wins). Content
// changes from both clients are still merged correctly via the 3-way
// reconcile above, independent of which rename wins. A missing
// relative_path means "keep current path" (content-only edit).
let (merged_content, is_same_as_request) =
if let Some((parent_text, latest_text, new_text)) = mergeable_texts {
info!("Merging changes for document `{document_id}` in vault `{vault_id}`");
let parent_owned = parent_text.to_owned();
let latest_owned = latest_text.to_owned();
let new_owned = new_text.to_owned();
let content_clone = content.clone();
let merged = tokio::task::spawn_blocking(move || {
let merged = reconcile(
&parent_owned,
&latest_owned.into(),
&new_owned.into(),
&*BuiltinTokenizer::Word,
)
.apply()
.text()
.into_bytes();
merged
})
.await
.map_err(|e| server_error(anyhow::anyhow!("Reconcile task failed: {e}")))?;
let is_same = merged == content_clone;
(merged, is_same)
} else {
(content, true) // true means that the client doesn't need to refetch the file as we can ensure the remote and local versions are the same as LWW is the merging method for binary files
};
// First rename wins: apply the client's rename only if the doc's path
// hasn't changed since its parent version. Content from both clients
// still merges via the 3-way reconcile above
let new_relative_path = match sanitized_relative_path.as_deref() {
Some(requested)
if parent_relative_path == latest_version.relative_path
@ -306,7 +305,9 @@ pub async fn update_document(
let new_version = StoredDocumentVersion {
document_id,
vault_update_id: last_update_id + 1,
vault_update_id: last_update_id
.checked_add(1)
.ok_or_else(|| server_error(anyhow!("Vault update id overflow")))?,
creation_vault_update_id: latest_version.creation_vault_update_id,
relative_path: new_relative_path,
content: merged_content,
@ -314,7 +315,7 @@ pub async fn update_document(
is_deleted: false,
user_id: user.name,
device_id: device_id.0,
has_been_merged: are_all_participants_mergable && is_different_from_request_content,
has_been_merged: are_all_participants_mergable && !is_same_as_request,
};
state
@ -323,9 +324,21 @@ pub async fn update_document(
.await
.map_err(server_error)?;
Ok(Json(if is_different_from_request_content {
DocumentUpdateResponse::MergingUpdate(new_version.into())
} else {
Ok(Json(if is_same_as_request {
DocumentUpdateResponse::FastForwardUpdate(new_version.into())
} else {
DocumentUpdateResponse::MergingUpdate(new_version.into())
}))
}
fn as_non_binary_texts<'a>(
parent_content: &'a [u8],
latest_content: &'a [u8],
new_content: &'a [u8],
) -> Option<(&'a str, &'a str, &'a str)> {
Some((
as_non_binary_text(parent_content)?,
as_non_binary_text(latest_content)?,
as_non_binary_text(new_content)?,
))
}

View file

@ -306,7 +306,7 @@ async fn websocket(
&device_id,
docs,
)
.await;
.await?;
}
}
}
@ -351,7 +351,7 @@ async fn websocket(
state
.cursors
.remove_cursors_of_device(&vault_id, &authed_handshake.handshake.device_id)
.await;
.await?;
match &result {
Ok(()) => {

View file

@ -1,10 +1,3 @@
use std::sync::LazyLock;
use regex::Regex;
static DEDUP_SUFFIX_REGEX: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r" \((\d+)\)$").expect("invalid regex"));
pub fn dedup_paths(path: &str) -> impl Iterator<Item = String> {
let mut path_parts = path.split('/').collect::<Vec<_>>();
let file_name = path_parts
@ -24,29 +17,19 @@ pub fn dedup_paths(path: &str) -> impl Iterator<Item = String> {
let (stem, extension) = if is_simple_dotfile {
(file_name.clone(), String::new())
} else {
// Regular file or dotfile with extension
let name_parts = file_name.rsplitn(2, '.').collect::<Vec<_>>();
let mut reverse_parts = name_parts.into_iter().rev();
match (reverse_parts.next(), reverse_parts.next()) {
(Some(stem), maybe_extension) => (
stem.to_owned(),
maybe_extension
.map(|ext| format!(".{ext}"))
.unwrap_or_default(),
),
_ => unreachable!("Path must have at least one part"),
match file_name.rsplit_once('.') {
Some((stem, extension)) => (stem.to_owned(), format!(".{extension}")),
None => (file_name.clone(), String::new()),
}
};
let start_number = DEDUP_SUFFIX_REGEX
.captures(&stem)
.and_then(|caps| caps.get(1))
.and_then(|m| m.as_str().parse::<u32>().ok())
.unwrap_or(0);
let (clean_stem, start_number) = strip_dedup_suffix(&stem);
let clean_stem = clean_stem.to_owned();
let clean_stem = DEDUP_SUFFIX_REGEX.replace(&stem, "").to_string();
(start_number..).map(move |dedup_number| {
std::iter::successors(Some(start_number), |dedup_number| {
dedup_number.checked_add(1)
})
.map(move |dedup_number| {
if dedup_number == 0 {
format!("{directory}{clean_stem}{extension}")
} else {
@ -55,6 +38,20 @@ pub fn dedup_paths(path: &str) -> impl Iterator<Item = String> {
})
}
fn strip_dedup_suffix(stem: &str) -> (&str, u64) {
let Some(without_closing_paren) = stem.strip_suffix(')') else {
return (stem, 0);
};
let Some((clean_stem, number)) = without_closing_paren.rsplit_once(" (") else {
return (stem, 0);
};
if number.is_empty() || !number.chars().all(|c| c.is_ascii_digit()) {
return (stem, 0);
}
(clean_stem, number.parse::<u64>().unwrap_or(0))
}
#[cfg(test)]
mod test {
use super::*;
@ -103,7 +100,7 @@ mod test {
}
#[test]
fn test_regex_capturing_group() {
fn test_dedup_suffix_parsing() {
// Single digit in parentheses
let mut deduped = dedup_paths("document (5).md");
assert_eq!(deduped.next(), Some("document (5).md".to_owned()));

View file

@ -1,20 +1,23 @@
use crate::app_state::database::models::VaultId;
use crate::app_state::database::{WriteTransaction, models::VaultId};
use crate::utils::dedup_paths::dedup_paths;
use anyhow::Result;
use anyhow::{Result, anyhow};
use log::{debug, info};
use sqlx::sqlite::SqliteConnection;
pub async fn find_first_available_path(
vault_id: &VaultId,
sanitized_relative_path: &str,
database: &crate::app_state::database::Database,
connection: &mut SqliteConnection,
transaction: &mut WriteTransaction,
) -> Result<String> {
info!("Finding first available path for `{sanitized_relative_path}` in vault `{vault_id}`");
for candidate in dedup_paths(sanitized_relative_path) {
debug!("Checking candidate path for deconflicting names: `{candidate}`");
if database
.get_latest_non_deleted_document_by_path(vault_id, &candidate, Some(connection))
.get_latest_non_deleted_document_by_path(
vault_id,
&candidate,
Some(transaction.connection_mut()?),
)
.await?
.is_none()
{
@ -27,5 +30,7 @@ pub async fn find_first_available_path(
);
}
unreachable!("dedup_paths produces infinite paths");
Err(anyhow!(
"No available path candidates produced for `{sanitized_relative_path}` in vault `{vault_id}`"
))
}

View file

@ -1,16 +1,22 @@
/// Heuristically determine if the given data is a binary or a text file's
/// content.
/// Return the given data as UTF-8 text if it is not considered binary.
///
/// Only text inputs can be reconciled using the crate's functions.
#[must_use]
pub fn is_binary(data: &[u8]) -> bool {
pub fn as_non_binary_text(data: &[u8]) -> Option<&str> {
if data.contains(&0) {
// Even though the NUL character is valid in UTF-8, it's highly suspicious in
// human-readable text.
return true;
return None;
}
std::str::from_utf8(data).is_err()
std::str::from_utf8(data).ok()
}
/// Heuristically determine if the given data is a binary or a text file's
/// content.
#[must_use]
pub fn is_binary(data: &[u8]) -> bool {
as_non_binary_text(data).is_none()
}
#[cfg(test)]
@ -23,4 +29,11 @@ mod tests {
assert!(is_binary(&[0, 12]));
assert!(!is_binary(b"hello"));
}
#[test]
fn test_as_non_binary_text() {
assert_eq!(as_non_binary_text(b"hello"), Some("hello"));
assert_eq!(as_non_binary_text(&[0, 12]), None);
assert_eq!(as_non_binary_text(&[0xff]), None);
}
}

View file

@ -52,14 +52,14 @@ impl RotatingFileWriter {
/// Parse timestamp from log filename and return as `SystemTime`
fn parse_log_timestamp(filename: &str, file_prefix: &str) -> Option<SystemTime> {
// Expected format: {prefix}.{timestamp}.log where timestamp is %Y-%m-%d_%H-%M-%S
let prefix_len = file_prefix.len() + 1; // +1 for the dot
let prefix_len = file_prefix.len().checked_add(1)?; // +1 for the dot
let timestamp_str = filename.get(prefix_len..filename.len().checked_sub(4)?)?;
let dt = NaiveDateTime::parse_from_str(timestamp_str, "%Y-%m-%d_%H-%M-%S").ok()?;
let timestamp = dt.and_utc();
let secs: u64 = timestamp.timestamp().try_into().ok()?;
Some(UNIX_EPOCH + Duration::from_secs(secs))
UNIX_EPOCH.checked_add(Duration::from_secs(secs))
}
fn find_latest_log_file(directory: &Path, file_prefix: &str) -> Option<String> {
@ -86,7 +86,9 @@ impl RotatingFileWriter {
Self::find_latest_log_file(directory, file_prefix)
.and_then(|filename| Self::parse_log_timestamp(&filename, file_prefix))
.map_or_else(SystemTime::now, |last_rotation| {
last_rotation + rotation_duration
last_rotation
.checked_add(rotation_duration)
.unwrap_or_else(SystemTime::now)
})
}
@ -136,7 +138,9 @@ impl RotatingFileWriter {
.open(&filepath)?;
inner.current_file = Some(file);
inner.next_rotation_time = SystemTime::now() + inner.rotation_duration;
inner.next_rotation_time = SystemTime::now()
.checked_add(inner.rotation_duration)
.unwrap_or_else(SystemTime::now);
Ok(())
}