Improve diff
This commit is contained in:
parent
792f57dc7e
commit
e5373ab2bb
23 changed files with 312 additions and 220 deletions
|
|
@ -85,7 +85,7 @@ describe("File operations", () => {
|
||||||
const result = await ops.create("a", new Uint8Array());
|
const result = await ops.create("a", new Uint8Array());
|
||||||
|
|
||||||
assertSetContainsExactly(fs.names, "a");
|
assertSetContainsExactly(fs.names, "a");
|
||||||
assert.equal(result.actualPath, "a");
|
assert.equal(result, "a");
|
||||||
});
|
});
|
||||||
|
|
||||||
it("create throws FileAlreadyExistsError when the path is occupied", async () => {
|
it("create throws FileAlreadyExistsError when the path is occupied", async () => {
|
||||||
|
|
@ -109,7 +109,7 @@ describe("File operations", () => {
|
||||||
|
|
||||||
const result = await ops.move("a", "b");
|
const result = await ops.move("a", "b");
|
||||||
assertSetContainsExactly(fs.names, "b");
|
assertSetContainsExactly(fs.names, "b");
|
||||||
assert.equal(result.actualPath, "b");
|
assert.equal(result, "b");
|
||||||
});
|
});
|
||||||
|
|
||||||
it("move with same source and target is a no-op", async () => {
|
it("move with same source and target is a no-op", async () => {
|
||||||
|
|
@ -119,7 +119,7 @@ describe("File operations", () => {
|
||||||
const result = await ops.move("a", "a");
|
const result = await ops.move("a", "a");
|
||||||
|
|
||||||
assertSetContainsExactly(fs.names, "a");
|
assertSetContainsExactly(fs.names, "a");
|
||||||
assert.equal(result.actualPath, "a");
|
assert.equal(result, "a");
|
||||||
});
|
});
|
||||||
|
|
||||||
it("move throws FileAlreadyExistsError when the target is occupied", async () => {
|
it("move throws FileAlreadyExistsError when the target is occupied", async () => {
|
||||||
|
|
|
||||||
|
|
@ -11,16 +11,6 @@ import { FileNotFoundError } from "../errors/file-not-found-error";
|
||||||
import { FileAlreadyExistsError } from "../errors/file-already-exists-error";
|
import { FileAlreadyExistsError } from "../errors/file-already-exists-error";
|
||||||
import type { ExpectedFsEvents } from "../sync-operations/expected-fs-events";
|
import type { ExpectedFsEvents } from "../sync-operations/expected-fs-events";
|
||||||
|
|
||||||
/**
|
|
||||||
* Outcome of a `move`/`create`. `actualPath` is where the file ended up;
|
|
||||||
* with the conflict-path machinery removed it is always equal to the
|
|
||||||
* requested path. The shape is preserved so callers don't all need to
|
|
||||||
* change.
|
|
||||||
*/
|
|
||||||
export interface FileOpResult {
|
|
||||||
actualPath: RelativePath;
|
|
||||||
}
|
|
||||||
|
|
||||||
export class FileOperations {
|
export class FileOperations {
|
||||||
private readonly fs: SafeFileSystemOperations;
|
private readonly fs: SafeFileSystemOperations;
|
||||||
|
|
||||||
|
|
@ -68,7 +58,7 @@ export class FileOperations {
|
||||||
public async create(
|
public async create(
|
||||||
path: RelativePath,
|
path: RelativePath,
|
||||||
newContent: Uint8Array
|
newContent: Uint8Array
|
||||||
): Promise<FileOpResult> {
|
): Promise<RelativePath> {
|
||||||
if (await this.fs.exists(path)) {
|
if (await this.fs.exists(path)) {
|
||||||
throw new FileAlreadyExistsError(
|
throw new FileAlreadyExistsError(
|
||||||
`Refusing to create '${path}': file already exists`,
|
`Refusing to create '${path}': file already exists`,
|
||||||
|
|
@ -84,7 +74,7 @@ export class FileOperations {
|
||||||
this.expectedFsEvents.unexpectCreate(path);
|
this.expectedFsEvents.unexpectCreate(path);
|
||||||
throw e;
|
throw e;
|
||||||
}
|
}
|
||||||
return { actualPath: path };
|
return path;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -220,9 +210,9 @@ export class FileOperations {
|
||||||
public async move(
|
public async move(
|
||||||
oldPath: RelativePath,
|
oldPath: RelativePath,
|
||||||
newPath: RelativePath
|
newPath: RelativePath
|
||||||
): Promise<FileOpResult> {
|
): Promise<RelativePath> {
|
||||||
if (oldPath === newPath) {
|
if (oldPath === newPath) {
|
||||||
return { actualPath: oldPath };
|
return oldPath;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (await this.fs.exists(newPath)) {
|
if (await this.fs.exists(newPath)) {
|
||||||
|
|
@ -241,7 +231,7 @@ export class FileOperations {
|
||||||
throw e;
|
throw e;
|
||||||
}
|
}
|
||||||
await this.deletingEmptyParentDirectoriesOfDeletedFile(oldPath);
|
await this.deletingEmptyParentDirectoriesOfDeletedFile(oldPath);
|
||||||
return { actualPath: newPath };
|
return newPath;
|
||||||
}
|
}
|
||||||
|
|
||||||
private async deletingEmptyParentDirectoriesOfDeletedFile(
|
private async deletingEmptyParentDirectoriesOfDeletedFile(
|
||||||
|
|
|
||||||
|
|
@ -1103,7 +1103,7 @@ export class Syncer {
|
||||||
remoteHash,
|
remoteHash,
|
||||||
localPath: target
|
localPath: target
|
||||||
});
|
});
|
||||||
const result = await this.operations.create(
|
const createdPath = await this.operations.create(
|
||||||
target,
|
target,
|
||||||
remoteContent
|
remoteContent
|
||||||
);
|
);
|
||||||
|
|
@ -1112,7 +1112,7 @@ export class Syncer {
|
||||||
);
|
);
|
||||||
localPath =
|
localPath =
|
||||||
liveRecord === undefined
|
liveRecord === undefined
|
||||||
? result.actualPath
|
? createdPath
|
||||||
: liveRecord.localPath;
|
: liveRecord.localPath;
|
||||||
await this.updateCache(
|
await this.updateCache(
|
||||||
remoteVersion.vaultUpdateId,
|
remoteVersion.vaultUpdateId,
|
||||||
|
|
|
||||||
|
|
@ -21,9 +21,10 @@ cargo test --verbose
|
||||||
|
|
||||||
if [[ "$FIX_MODE" == true ]]; then
|
if [[ "$FIX_MODE" == true ]]; then
|
||||||
cargo clippy --all-targets --all-features --fix --allow-dirty --allow-staged
|
cargo clippy --all-targets --all-features --fix --allow-dirty --allow-staged
|
||||||
|
cargo clippy --all-targets --all-features -- -D warnings
|
||||||
cargo fmt --all
|
cargo fmt --all
|
||||||
else
|
else
|
||||||
cargo clippy --all-targets --all-features
|
cargo clippy --all-targets --all-features -- -D warnings
|
||||||
cargo fmt --all -- --check
|
cargo fmt --all -- --check
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
|
||||||
1
sync-server/Cargo.lock
generated
1
sync-server/Cargo.lock
generated
|
|
@ -2181,7 +2181,6 @@ dependencies = [
|
||||||
"log",
|
"log",
|
||||||
"rand 0.9.0",
|
"rand 0.9.0",
|
||||||
"reconcile-text",
|
"reconcile-text",
|
||||||
"regex",
|
|
||||||
"sanitize-filename",
|
"sanitize-filename",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,6 @@ sqlx = { version = "0.8.6", features = ["sqlite", "runtime-tokio", "uuid", "chro
|
||||||
chrono = { version = "0.4.41", features = ["serde"] }
|
chrono = { version = "0.4.41", features = ["serde"] }
|
||||||
rand = "0.9.0"
|
rand = "0.9.0"
|
||||||
sanitize-filename = "0.6.0"
|
sanitize-filename = "0.6.0"
|
||||||
regex = "1.12.2"
|
|
||||||
clap = { version = "4.5.38", features = ["derive"] }
|
clap = { version = "4.5.38", features = ["derive"] }
|
||||||
futures = "0.3.31"
|
futures = "0.3.31"
|
||||||
serde_yaml = "0.9.34"
|
serde_yaml = "0.9.34"
|
||||||
|
|
@ -49,16 +48,19 @@ rust_2018_idioms = { level = "warn", priority = -1 }
|
||||||
missing_debug_implementations = "warn"
|
missing_debug_implementations = "warn"
|
||||||
|
|
||||||
[lints.clippy]
|
[lints.clippy]
|
||||||
|
arithmetic_side_effects = "deny"
|
||||||
await_holding_lock = "warn"
|
await_holding_lock = "warn"
|
||||||
dbg_macro = "warn"
|
dbg_macro = "warn"
|
||||||
disallowed_macros = { level = "deny", priority = 1 }
|
disallowed_macros = { level = "deny", priority = 1 }
|
||||||
empty_enums = "warn"
|
empty_enums = "warn"
|
||||||
enum_glob_use = "warn"
|
enum_glob_use = "warn"
|
||||||
|
expect_used = "deny"
|
||||||
exit = "warn"
|
exit = "warn"
|
||||||
filter_map_next = "warn"
|
filter_map_next = "warn"
|
||||||
fn_params_excessive_bools = "warn"
|
fn_params_excessive_bools = "warn"
|
||||||
if_let_mutex = "warn"
|
if_let_mutex = "warn"
|
||||||
imprecise_flops = "warn"
|
imprecise_flops = "warn"
|
||||||
|
indexing_slicing = "deny"
|
||||||
inefficient_to_string = "warn"
|
inefficient_to_string = "warn"
|
||||||
linkedlist = "warn"
|
linkedlist = "warn"
|
||||||
lossy_float_literal = "warn"
|
lossy_float_literal = "warn"
|
||||||
|
|
@ -68,13 +70,19 @@ mem_forget = "warn"
|
||||||
needless_borrow = "warn"
|
needless_borrow = "warn"
|
||||||
needless_continue = "warn"
|
needless_continue = "warn"
|
||||||
option_option = "warn"
|
option_option = "warn"
|
||||||
|
panic = "deny"
|
||||||
|
panic_in_result_fn = "deny"
|
||||||
rest_pat_in_fully_bound_structs = "warn"
|
rest_pat_in_fully_bound_structs = "warn"
|
||||||
str_to_string = "warn"
|
str_to_string = "warn"
|
||||||
suboptimal_flops = "warn"
|
suboptimal_flops = "warn"
|
||||||
todo = "warn"
|
todo = "deny"
|
||||||
uninlined_format_args = "warn"
|
uninlined_format_args = "warn"
|
||||||
|
unimplemented = "deny"
|
||||||
|
unreachable = "deny"
|
||||||
unnested_or_patterns = "warn"
|
unnested_or_patterns = "warn"
|
||||||
unused_self = "warn"
|
unused_self = "warn"
|
||||||
|
unwrap_in_result = "deny"
|
||||||
|
unwrap_used = "deny"
|
||||||
verbose_file_reads = "warn"
|
verbose_file_reads = "warn"
|
||||||
|
|
||||||
large_stack_arrays = { level = "allow", priority = 1 } # https://github.com/rust-lang/rust-clippy/issues/13774
|
large_stack_arrays = { level = "allow", priority = 1 } # https://github.com/rust-lang/rust-clippy/issues/13774
|
||||||
|
|
@ -88,7 +96,7 @@ single_call_fn = { level = "allow", priority = 1 }
|
||||||
similar_names = { level = "allow", priority = 1 }
|
similar_names = { level = "allow", priority = 1 }
|
||||||
missing_docs_in_private_items = { level = "allow", priority = 1 }
|
missing_docs_in_private_items = { level = "allow", priority = 1 }
|
||||||
|
|
||||||
pedantic = { level = "warn", priority = 0 }
|
pedantic = { level = "warn", priority = -1 }
|
||||||
|
|
||||||
[package.metadata.cargo-machete]
|
[package.metadata.cargo-machete]
|
||||||
ignored = ["humantime-serde"] # only used in serde macro
|
ignored = ["humantime-serde"] # only used in serde macro
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ use super::{
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
app_state::websocket::models::DocumentWithCursors, config::database_config::DatabaseConfig,
|
app_state::websocket::models::DocumentWithCursors, config::database_config::DatabaseConfig,
|
||||||
|
errors::SyncServerError,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
|
|
@ -39,7 +40,7 @@ impl Cursors {
|
||||||
user_name: String,
|
user_name: String,
|
||||||
device_id: &DeviceId,
|
device_id: &DeviceId,
|
||||||
document_to_cursors: Vec<DocumentWithCursors>,
|
document_to_cursors: Vec<DocumentWithCursors>,
|
||||||
) {
|
) -> Result<(), SyncServerError> {
|
||||||
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
|
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||||
|
|
||||||
let all_device_cursors = vault_to_cursors
|
let all_device_cursors = vault_to_cursors
|
||||||
|
|
@ -54,7 +55,7 @@ impl Cursors {
|
||||||
}));
|
}));
|
||||||
|
|
||||||
drop(vault_to_cursors); // Explicitly drop the lock before broadcasting to avoid deadlock
|
drop(vault_to_cursors); // Explicitly drop the lock before broadcasting to avoid deadlock
|
||||||
self.broadcast_cursors_for_vault(&vault_id).await;
|
self.broadcast_cursors_for_vault(&vault_id).await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_cursors(&self, vault_id: &VaultId) -> Vec<ClientCursors> {
|
pub async fn get_cursors(&self, vault_id: &VaultId) -> Vec<ClientCursors> {
|
||||||
|
|
@ -76,15 +77,17 @@ impl Cursors {
|
||||||
loop {
|
loop {
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
() = tokio::time::sleep(Duration::from_secs(1)) => {
|
() = tokio::time::sleep(Duration::from_secs(1)) => {
|
||||||
self.remove_expired_cursors().await;
|
self.remove_expired_cursors().await?;
|
||||||
}
|
}
|
||||||
Ok(()) = shutdown.changed() => break,
|
Ok(()) = shutdown.changed() => break,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok::<(), SyncServerError>(())
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn remove_expired_cursors(&self) {
|
async fn remove_expired_cursors(&self) -> Result<(), SyncServerError> {
|
||||||
let changed_vaults: Vec<VaultId> = {
|
let changed_vaults: Vec<VaultId> = {
|
||||||
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
|
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||||
|
|
||||||
|
|
@ -104,11 +107,13 @@ impl Cursors {
|
||||||
};
|
};
|
||||||
|
|
||||||
for vault_id in &changed_vaults {
|
for vault_id in &changed_vaults {
|
||||||
self.broadcast_cursors_for_vault(vault_id).await;
|
self.broadcast_cursors_for_vault(vault_id).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn broadcast_cursors_for_vault(&self, vault_id: &VaultId) {
|
async fn broadcast_cursors_for_vault(&self, vault_id: &VaultId) -> Result<(), SyncServerError> {
|
||||||
let client_cursors: Vec<ClientCursors> = {
|
let client_cursors: Vec<ClientCursors> = {
|
||||||
let vault_to_cursors = self.vault_to_cursors.lock().await;
|
let vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||||
vault_to_cursors
|
vault_to_cursors
|
||||||
|
|
@ -124,10 +129,14 @@ impl Cursors {
|
||||||
clients: client_cursors,
|
clients: client_cursors,
|
||||||
},
|
},
|
||||||
)),
|
)),
|
||||||
);
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn remove_cursors_of_device(&self, vault_id: &VaultId, device_id: &DeviceId) {
|
pub async fn remove_cursors_of_device(
|
||||||
|
&self,
|
||||||
|
vault_id: &VaultId,
|
||||||
|
device_id: &DeviceId,
|
||||||
|
) -> Result<(), SyncServerError> {
|
||||||
let changed = {
|
let changed = {
|
||||||
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
|
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||||
|
|
||||||
|
|
@ -145,8 +154,9 @@ impl Cursors {
|
||||||
};
|
};
|
||||||
|
|
||||||
if changed {
|
if changed {
|
||||||
self.broadcast_cursors_for_vault(vault_id).await;
|
self.broadcast_cursors_for_vault(vault_id).await?;
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ use std::{
|
||||||
sync::atomic::{AtomicU64, Ordering},
|
sync::atomic::{AtomicU64, Ordering},
|
||||||
};
|
};
|
||||||
|
|
||||||
use anyhow::{Context as _, Result};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
use log::info;
|
use log::info;
|
||||||
use models::{
|
use models::{
|
||||||
DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId, VaultUpdateId,
|
DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId, VaultUpdateId,
|
||||||
|
|
@ -132,6 +132,12 @@ impl WriteTransaction {
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn connection_mut(&mut self) -> Result<&mut SqliteConnection> {
|
||||||
|
self.conn
|
||||||
|
.as_deref_mut()
|
||||||
|
.context("WriteTransaction already consumed")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for WriteTransaction {
|
impl Drop for WriteTransaction {
|
||||||
|
|
@ -147,25 +153,6 @@ impl Drop for WriteTransaction {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::ops::Deref for WriteTransaction {
|
|
||||||
type Target = SqliteConnection;
|
|
||||||
fn deref(&self) -> &Self::Target {
|
|
||||||
self.conn
|
|
||||||
.as_ref()
|
|
||||||
.expect("BUG: WriteTransaction dereferenced after being consumed")
|
|
||||||
.deref()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::ops::DerefMut for WriteTransaction {
|
|
||||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
|
||||||
self.conn
|
|
||||||
.as_mut()
|
|
||||||
.expect("BUG: WriteTransaction dereferenced after being consumed")
|
|
||||||
.deref_mut()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Ensure the connection has no leftover open transaction (e.g. from a
|
/// Ensure the connection has no leftover open transaction (e.g. from a
|
||||||
/// `WriteTransaction` that was dropped without commit/rollback). ROLLBACK
|
/// `WriteTransaction` that was dropped without commit/rollback). ROLLBACK
|
||||||
/// is a harmless no-op if no transaction is active.
|
/// is a harmless no-op if no transaction is active.
|
||||||
|
|
@ -797,7 +784,7 @@ impl Database {
|
||||||
let _send_guard = self.broadcasts.acquire_send_lock(vault_id).await;
|
let _send_guard = self.broadcasts.acquire_send_lock(vault_id).await;
|
||||||
|
|
||||||
query
|
query
|
||||||
.execute(&mut *transaction)
|
.execute(transaction.connection_mut()?)
|
||||||
.await
|
.await
|
||||||
.context("Cannot insert document version")?;
|
.context("Cannot insert document version")?;
|
||||||
|
|
||||||
|
|
@ -821,7 +808,8 @@ impl Database {
|
||||||
} else {
|
} else {
|
||||||
WebSocketServerMessageWithOrigin::with_origin(version.device_id.clone(), envelope)
|
WebSocketServerMessageWithOrigin::with_origin(version.device_id.clone(), envelope)
|
||||||
};
|
};
|
||||||
self.broadcasts.send_document_update(vault_id, with_origin);
|
self.broadcasts
|
||||||
|
.send_document_update(vault_id, with_origin)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,11 @@ use log::{debug, info, warn};
|
||||||
use tokio::sync::{Mutex, broadcast};
|
use tokio::sync::{Mutex, broadcast};
|
||||||
|
|
||||||
use super::models::{WebSocketServerMessage, WebSocketServerMessageWithOrigin};
|
use super::models::{WebSocketServerMessage, WebSocketServerMessageWithOrigin};
|
||||||
use crate::{app_state::database::models::VaultId, config::server_config::ServerConfig};
|
use crate::{
|
||||||
|
app_state::database::models::VaultId,
|
||||||
|
config::server_config::ServerConfig,
|
||||||
|
errors::{SyncServerError, client_error, server_error},
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Broadcasts {
|
pub struct Broadcasts {
|
||||||
|
|
@ -60,30 +64,31 @@ impl Broadcasts {
|
||||||
|
|
||||||
pub fn get_receiver(
|
pub fn get_receiver(
|
||||||
&self,
|
&self,
|
||||||
vault: VaultId,
|
vault: &VaultId,
|
||||||
max_clients: usize,
|
max_clients: usize,
|
||||||
) -> Result<broadcast::Receiver<WebSocketServerMessageWithOrigin>, crate::errors::SyncServerError>
|
) -> Result<broadcast::Receiver<WebSocketServerMessageWithOrigin>, SyncServerError> {
|
||||||
{
|
|
||||||
let mut tx_map = self
|
let mut tx_map = self
|
||||||
.tx
|
.tx
|
||||||
.lock()
|
.lock()
|
||||||
.expect("broadcasts.tx mutex poisoned — a previous holder panicked");
|
.map_err(|_| server_error(anyhow::anyhow!("broadcasts.tx mutex poisoned")))?;
|
||||||
|
|
||||||
let count_before_prune = tx_map
|
let count_before_prune = tx_map
|
||||||
.get(&vault)
|
.get(vault)
|
||||||
.map_or(0, tokio::sync::broadcast::Sender::receiver_count);
|
.map_or(0, tokio::sync::broadcast::Sender::receiver_count);
|
||||||
let pruned = Self::prune_inactive_vaults(&mut tx_map);
|
let pruned = Self::prune_inactive_vaults(&mut tx_map);
|
||||||
let pruned_self = pruned.contains(&vault);
|
let pruned_self = pruned
|
||||||
|
.iter()
|
||||||
|
.any(|pruned_vault| pruned_vault.as_str() == vault);
|
||||||
|
|
||||||
let sender = tx_map
|
let sender = tx_map
|
||||||
.entry(vault.clone())
|
.entry(vault.to_owned())
|
||||||
.or_insert_with(|| broadcast::channel(self.broadcast_channel_capacity).0);
|
.or_insert_with(|| broadcast::channel(self.broadcast_channel_capacity).0);
|
||||||
|
|
||||||
// Hold the lock across the count check *and* the subscribe so the
|
// Hold the lock across the count check *and* the subscribe so the
|
||||||
// `max_clients` cap is atomic: two concurrent callers can't both
|
// `max_clients` cap is atomic: two concurrent callers can't both
|
||||||
// observe `receiver_count() < max_clients` and both subscribe.
|
// observe `receiver_count() < max_clients` and both subscribe.
|
||||||
if sender.receiver_count() >= max_clients {
|
if sender.receiver_count() >= max_clients {
|
||||||
return Err(crate::errors::client_error(anyhow::anyhow!(
|
return Err(client_error(anyhow::anyhow!(
|
||||||
"Vault has reached the maximum number of clients ({max_clients})"
|
"Vault has reached the maximum number of clients ({max_clients})"
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
@ -100,8 +105,13 @@ impl Broadcasts {
|
||||||
/// Notify all clients (who are subscribed to the vault) about an update.
|
/// Notify all clients (who are subscribed to the vault) about an update.
|
||||||
/// Synchronous: safe to invoke from a handler between `commit()` and
|
/// Synchronous: safe to invoke from a handler between `commit()` and
|
||||||
/// function return without worrying about task cancellation dropping
|
/// function return without worrying about task cancellation dropping
|
||||||
/// the broadcast mid-flight. Failures are logged, never propagated.
|
/// the broadcast mid-flight. Mutex poison is returned; send failures
|
||||||
pub fn send_document_update(&self, vault: VaultId, document: WebSocketServerMessageWithOrigin) {
|
/// are logged because they can happen when receivers disconnect.
|
||||||
|
pub fn send_document_update(
|
||||||
|
&self,
|
||||||
|
vault: &str,
|
||||||
|
document: WebSocketServerMessageWithOrigin,
|
||||||
|
) -> Result<(), SyncServerError> {
|
||||||
let vault_update_id = match &document.message {
|
let vault_update_id = match &document.message {
|
||||||
WebSocketServerMessage::VaultUpdate(u) => Some(u.document.vault_update_id),
|
WebSocketServerMessage::VaultUpdate(u) => Some(u.document.vault_update_id),
|
||||||
WebSocketServerMessage::CursorPositions(_) => None,
|
WebSocketServerMessage::CursorPositions(_) => None,
|
||||||
|
|
@ -110,18 +120,21 @@ impl Broadcasts {
|
||||||
WebSocketServerMessage::VaultUpdate(u) => Some(u.document.is_deleted),
|
WebSocketServerMessage::VaultUpdate(u) => Some(u.document.is_deleted),
|
||||||
WebSocketServerMessage::CursorPositions(_) => None,
|
WebSocketServerMessage::CursorPositions(_) => None,
|
||||||
};
|
};
|
||||||
let mut tx_map = self
|
let mut tx_map = self.tx.lock().map_err(|_| {
|
||||||
.tx
|
server_error(anyhow::anyhow!(
|
||||||
.lock()
|
"broadcasts.tx mutex poisoned; skipping document update broadcast"
|
||||||
.expect("broadcasts.tx mutex poisoned — a previous holder panicked");
|
))
|
||||||
|
})?;
|
||||||
let count_before_prune = tx_map
|
let count_before_prune = tx_map
|
||||||
.get(&vault)
|
.get(vault)
|
||||||
.map_or(0, tokio::sync::broadcast::Sender::receiver_count);
|
.map_or(0, tokio::sync::broadcast::Sender::receiver_count);
|
||||||
let pruned = Self::prune_inactive_vaults(&mut tx_map);
|
let pruned = Self::prune_inactive_vaults(&mut tx_map);
|
||||||
let pruned_self = pruned.contains(&vault);
|
let pruned_self = pruned
|
||||||
|
.iter()
|
||||||
|
.any(|pruned_vault| pruned_vault.as_str() == vault);
|
||||||
|
|
||||||
let sender = tx_map
|
let sender = tx_map
|
||||||
.entry(vault.clone())
|
.entry(vault.to_owned())
|
||||||
.or_insert_with(|| broadcast::channel(self.broadcast_channel_capacity).0);
|
.or_insert_with(|| broadcast::channel(self.broadcast_channel_capacity).0);
|
||||||
|
|
||||||
let count_before_send = sender.receiver_count();
|
let count_before_send = sender.receiver_count();
|
||||||
|
|
@ -131,7 +144,7 @@ impl Broadcasts {
|
||||||
"[BCAST] send_document_update vault={vault} vuid={vault_update_id:?} is_deleted={is_deleted:?} count_before_prune={count_before_prune} pruned_self={pruned_self} count_before_send=0 SKIPPED"
|
"[BCAST] send_document_update vault={vault} vuid={vault_update_id:?} is_deleted={is_deleted:?} count_before_prune={count_before_prune} pruned_self={pruned_self} count_before_send=0 SKIPPED"
|
||||||
);
|
);
|
||||||
debug!("Skipping broadcast, no clients connected for vault `{vault}`");
|
debug!("Skipping broadcast, no clients connected for vault `{vault}`");
|
||||||
return;
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let send_result = sender.send(document);
|
let send_result = sender.send(document);
|
||||||
|
|
@ -143,5 +156,6 @@ impl Broadcasts {
|
||||||
"[BCAST] send_document_update vault={vault} vuid={vault_update_id:?} is_deleted={is_deleted:?} count_before_prune={count_before_prune} pruned_self={pruned_self} count_before_send={count_before_send} FAILED err={e}"
|
"[BCAST] send_document_update vault={vault} vuid={vault_update_id:?} is_deleted={is_deleted:?} count_before_prune={count_before_prune} pruned_self={pruned_self} count_before_send={count_before_send} FAILED err={e}"
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -23,9 +23,10 @@ impl ColorWhen {
|
||||||
|
|
||||||
impl std::fmt::Display for ColorWhen {
|
impl std::fmt::Display for ColorWhen {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
self.to_possible_value()
|
f.write_str(match self {
|
||||||
.expect("no values are skipped")
|
Self::Always => "always",
|
||||||
.get_name()
|
Self::Auto => "auto",
|
||||||
.fmt(f)
|
Self::Never => "never",
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -71,6 +71,10 @@ impl ServerConfig {
|
||||||
self.max_pending_websocket_connections > 0,
|
self.max_pending_websocket_connections > 0,
|
||||||
"max_pending_websocket_connections must be greater than 0"
|
"max_pending_websocket_connections must be greater than 0"
|
||||||
);
|
);
|
||||||
|
ensure!(
|
||||||
|
self.rate_limit_per_user_per_second != Some(0),
|
||||||
|
"rate_limit_per_user_per_second must be greater than 0 when set (use null to disable rate limiting)"
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -20,15 +20,7 @@ where
|
||||||
let mut user_token_map = BiHashMap::new();
|
let mut user_token_map = BiHashMap::new();
|
||||||
for user in &users {
|
for user in &users {
|
||||||
if let Some(existing_name) = user_token_map.get_by_right(&user.token) {
|
if let Some(existing_name) = user_token_map.get_by_right(&user.token) {
|
||||||
let redacted = if user.token.len() > 6 {
|
let redacted = redact_token(&user.token);
|
||||||
format!(
|
|
||||||
"{}...{}",
|
|
||||||
&user.token[..3],
|
|
||||||
&user.token[user.token.len() - 3..]
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
"***".to_owned()
|
|
||||||
};
|
|
||||||
return Err(D::Error::custom(format!(
|
return Err(D::Error::custom(format!(
|
||||||
"Duplicate user token found: `{redacted}` for users `{}` and `{}`. User tokens \
|
"Duplicate user token found: `{redacted}` for users `{}` and `{}`. User tokens \
|
||||||
must be unique.",
|
must be unique.",
|
||||||
|
|
@ -49,6 +41,23 @@ where
|
||||||
Ok(users)
|
Ok(users)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn redact_token(token: &str) -> String {
|
||||||
|
if token.chars().count() <= 6 {
|
||||||
|
return "***".to_owned();
|
||||||
|
}
|
||||||
|
|
||||||
|
let prefix = token.chars().take(3).collect::<String>();
|
||||||
|
let suffix = token
|
||||||
|
.chars()
|
||||||
|
.rev()
|
||||||
|
.take(3)
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.into_iter()
|
||||||
|
.rev()
|
||||||
|
.collect::<String>();
|
||||||
|
format!("{prefix}...{suffix}")
|
||||||
|
}
|
||||||
|
|
||||||
impl UserConfig {
|
impl UserConfig {
|
||||||
pub fn get_user(&self, token: &str) -> Option<&User> {
|
pub fn get_user(&self, token: &str) -> Option<&User> {
|
||||||
self.user_configs
|
self.user_configs
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,19 @@
|
||||||
|
#![cfg_attr(
|
||||||
|
test,
|
||||||
|
allow(
|
||||||
|
clippy::arithmetic_side_effects,
|
||||||
|
clippy::expect_used,
|
||||||
|
clippy::indexing_slicing,
|
||||||
|
clippy::panic,
|
||||||
|
clippy::panic_in_result_fn,
|
||||||
|
clippy::todo,
|
||||||
|
clippy::unimplemented,
|
||||||
|
clippy::unreachable,
|
||||||
|
clippy::unwrap_in_result,
|
||||||
|
clippy::unwrap_used
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
|
||||||
mod app_state;
|
mod app_state;
|
||||||
mod cli;
|
mod cli;
|
||||||
mod config;
|
mod config;
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,13 @@ pub async fn create_server(config: Config) -> Result<()> {
|
||||||
let app = app
|
let app = app
|
||||||
.layer(DefaultBodyLimit::disable())
|
.layer(DefaultBodyLimit::disable())
|
||||||
.layer(RequestBodyLimitLayer::new(
|
.layer(RequestBodyLimitLayer::new(
|
||||||
app_state.config.server.max_body_size_mb * 1024 * 1024,
|
app_state
|
||||||
|
.config
|
||||||
|
.server
|
||||||
|
.max_body_size_mb
|
||||||
|
.checked_mul(1024)
|
||||||
|
.and_then(|kb| kb.checked_mul(1024))
|
||||||
|
.context("max_body_size_mb is too large")?,
|
||||||
))
|
))
|
||||||
.layer(TimeoutLayer::new(server_config.response_timeout))
|
.layer(TimeoutLayer::new(server_config.response_timeout))
|
||||||
.layer(cors_layer)
|
.layer(cors_layer)
|
||||||
|
|
@ -104,7 +110,7 @@ pub async fn create_server(config: Config) -> Result<()> {
|
||||||
fn build_cors_layer(server_config: &ServerConfig) -> Result<CorsLayer> {
|
fn build_cors_layer(server_config: &ServerConfig) -> Result<CorsLayer> {
|
||||||
let origins = &server_config.allowed_origins;
|
let origins = &server_config.allowed_origins;
|
||||||
|
|
||||||
let cors = if origins.len() == 1 && origins[0] == "*" {
|
let cors = if origins.len() == 1 && origins.first().is_some_and(|origin| origin == "*") {
|
||||||
info!("CORS: allowing all origins");
|
info!("CORS: allowing all origins");
|
||||||
let header: HeaderValue = "*"
|
let header: HeaderValue = "*"
|
||||||
.parse()
|
.parse()
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,7 @@ pub async fn create_document(
|
||||||
.get_latest_non_deleted_document_by_path(
|
.get_latest_non_deleted_document_by_path(
|
||||||
&vault_id,
|
&vault_id,
|
||||||
&sanitized_relative_path,
|
&sanitized_relative_path,
|
||||||
Some(&mut *transaction),
|
Some(transaction.connection_mut().map_err(server_error)?),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(server_error)?;
|
.map_err(server_error)?;
|
||||||
|
|
@ -129,7 +129,7 @@ pub async fn create_document(
|
||||||
&device_id.0,
|
&device_id.0,
|
||||||
request.last_seen_vault_update_id,
|
request.last_seen_vault_update_id,
|
||||||
&new_content,
|
&new_content,
|
||||||
Some(&mut *transaction),
|
Some(transaction.connection_mut().map_err(server_error)?),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(server_error)?
|
.map_err(server_error)?
|
||||||
|
|
@ -157,7 +157,10 @@ pub async fn create_document(
|
||||||
|
|
||||||
let last_update_id = state
|
let last_update_id = state
|
||||||
.database
|
.database
|
||||||
.get_max_update_id_in_vault(&vault_id, Some(&mut *transaction))
|
.get_max_update_id_in_vault(
|
||||||
|
&vault_id,
|
||||||
|
Some(transaction.connection_mut().map_err(server_error)?),
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(server_error)?;
|
.map_err(server_error)?;
|
||||||
|
|
||||||
|
|
@ -176,7 +179,9 @@ pub async fn create_document(
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let new_vault_update_id = last_update_id + 1;
|
let new_vault_update_id = last_update_id
|
||||||
|
.checked_add(1)
|
||||||
|
.ok_or_else(|| server_error(anyhow::anyhow!("Vault update id overflow")))?;
|
||||||
let new_version = StoredDocumentVersion {
|
let new_version = StoredDocumentVersion {
|
||||||
vault_update_id: new_vault_update_id,
|
vault_update_id: new_vault_update_id,
|
||||||
creation_vault_update_id: new_vault_update_id,
|
creation_vault_update_id: new_vault_update_id,
|
||||||
|
|
|
||||||
|
|
@ -48,13 +48,20 @@ pub async fn delete_document(
|
||||||
|
|
||||||
let last_update_id = state
|
let last_update_id = state
|
||||||
.database
|
.database
|
||||||
.get_max_update_id_in_vault(&vault_id, Some(&mut transaction))
|
.get_max_update_id_in_vault(
|
||||||
|
&vault_id,
|
||||||
|
Some(transaction.connection_mut().map_err(server_error)?),
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(server_error)?;
|
.map_err(server_error)?;
|
||||||
|
|
||||||
let latest_version = state
|
let latest_version = state
|
||||||
.database
|
.database
|
||||||
.get_latest_document(&vault_id, &document_id, Some(&mut transaction))
|
.get_latest_document(
|
||||||
|
&vault_id,
|
||||||
|
&document_id,
|
||||||
|
Some(transaction.connection_mut().map_err(server_error)?),
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(server_error)?;
|
.map_err(server_error)?;
|
||||||
|
|
||||||
|
|
@ -80,7 +87,9 @@ pub async fn delete_document(
|
||||||
return Ok(Json(latest_version.into()));
|
return Ok(Json(latest_version.into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
let new_vault_update_id = last_update_id + 1;
|
let new_vault_update_id = last_update_id
|
||||||
|
.checked_add(1)
|
||||||
|
.ok_or_else(|| server_error(anyhow!("Vault update id overflow")))?;
|
||||||
let latest_relative_path = latest_version.relative_path;
|
let latest_relative_path = latest_version.relative_path;
|
||||||
let latest_content = latest_version.content;
|
let latest_content = latest_version.content;
|
||||||
let creation_vault_update_id = latest_version.creation_vault_update_id;
|
let creation_vault_update_id = latest_version.creation_vault_update_id;
|
||||||
|
|
|
||||||
|
|
@ -32,26 +32,23 @@ struct BucketState {
|
||||||
|
|
||||||
impl RateLimiter {
|
impl RateLimiter {
|
||||||
/// Create a new per-user rate limiter.
|
/// Create a new per-user rate limiter.
|
||||||
///
|
|
||||||
/// # Panics
|
|
||||||
///
|
|
||||||
/// Panics if `max_per_second` is 0.
|
|
||||||
pub fn new(max_per_second: u64) -> Self {
|
pub fn new(max_per_second: u64) -> Self {
|
||||||
assert!(
|
|
||||||
max_per_second > 0,
|
|
||||||
"max_per_second must be > 0 (set rate_limit_per_user_per_second to null in config to disable)"
|
|
||||||
);
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
max_per_second,
|
max_per_second,
|
||||||
buckets: Arc::new(Mutex::new(HashMap::new())),
|
buckets: Arc::new(Mutex::new(HashMap::new())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_or_create_bucket(&self, token: &str) -> Arc<TokenBucket> {
|
fn get_or_create_bucket(
|
||||||
self.buckets
|
&self,
|
||||||
|
token: &str,
|
||||||
|
) -> std::result::Result<Arc<TokenBucket>, StatusCode> {
|
||||||
|
let mut buckets = self
|
||||||
|
.buckets
|
||||||
.lock()
|
.lock()
|
||||||
.expect("rate limiter lock poisoned")
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||||
|
|
||||||
|
Ok(buckets
|
||||||
.entry(token.to_owned())
|
.entry(token.to_owned())
|
||||||
.or_insert_with(|| {
|
.or_insert_with(|| {
|
||||||
Arc::new(TokenBucket {
|
Arc::new(TokenBucket {
|
||||||
|
|
@ -62,23 +59,26 @@ impl RateLimiter {
|
||||||
max_tokens: self.max_per_second,
|
max_tokens: self.max_per_second,
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.clone()
|
.clone())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TokenBucket {
|
impl TokenBucket {
|
||||||
fn try_acquire(&self) -> bool {
|
fn try_acquire(&self) -> std::result::Result<bool, StatusCode> {
|
||||||
let mut state = self.state.lock().expect("token bucket lock poisoned");
|
let mut state = self
|
||||||
|
.state
|
||||||
|
.lock()
|
||||||
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
if now.duration_since(state.last_refill).as_secs() >= 1 {
|
if now.duration_since(state.last_refill).as_secs() >= 1 {
|
||||||
state.tokens = self.max_tokens;
|
state.tokens = self.max_tokens;
|
||||||
state.last_refill = now;
|
state.last_refill = now;
|
||||||
}
|
}
|
||||||
if state.tokens > 0 {
|
if state.tokens > 0 {
|
||||||
state.tokens -= 1;
|
state.tokens = state.tokens.saturating_sub(1);
|
||||||
true
|
Ok(true)
|
||||||
} else {
|
} else {
|
||||||
false
|
Ok(false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -88,13 +88,13 @@ pub async fn rate_limit_middleware(
|
||||||
auth_header: Option<TypedHeader<Authorization<Bearer>>>,
|
auth_header: Option<TypedHeader<Authorization<Bearer>>>,
|
||||||
req: Request,
|
req: Request,
|
||||||
next: Next,
|
next: Next,
|
||||||
) -> Result<Response, StatusCode> {
|
) -> std::result::Result<Response, StatusCode> {
|
||||||
let Some(TypedHeader(auth)) = auth_header else {
|
let Some(TypedHeader(auth)) = auth_header else {
|
||||||
return Ok(next.run(req).await);
|
return Ok(next.run(req).await);
|
||||||
};
|
};
|
||||||
|
|
||||||
let bucket = limiter.get_or_create_bucket(auth.token());
|
let bucket = limiter.get_or_create_bucket(auth.token())?;
|
||||||
if bucket.try_acquire() {
|
if bucket.try_acquire()? {
|
||||||
Ok(next.run(req).await)
|
Ok(next.run(req).await)
|
||||||
} else {
|
} else {
|
||||||
Err(StatusCode::TOO_MANY_REQUESTS)
|
Err(StatusCode::TOO_MANY_REQUESTS)
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ use crate::{
|
||||||
},
|
},
|
||||||
server::requests::UpdateBinaryDocumentVersion,
|
server::requests::UpdateBinaryDocumentVersion,
|
||||||
utils::{
|
utils::{
|
||||||
find_first_available_path::find_first_available_path, is_binary::is_binary,
|
find_first_available_path::find_first_available_path, is_binary::as_non_binary_text,
|
||||||
is_file_type_mergable::is_file_type_mergable, normalize::normalize,
|
is_file_type_mergable::is_file_type_mergable, normalize::normalize,
|
||||||
sanitize_path::sanitize_path,
|
sanitize_path::sanitize_path,
|
||||||
},
|
},
|
||||||
|
|
@ -173,13 +173,20 @@ pub async fn update_document(
|
||||||
|
|
||||||
let last_update_id = state
|
let last_update_id = state
|
||||||
.database
|
.database
|
||||||
.get_max_update_id_in_vault(&vault_id, Some(&mut transaction))
|
.get_max_update_id_in_vault(
|
||||||
|
&vault_id,
|
||||||
|
Some(transaction.connection_mut().map_err(server_error)?),
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(server_error)?;
|
.map_err(server_error)?;
|
||||||
|
|
||||||
let latest_version = state
|
let latest_version = state
|
||||||
.database
|
.database
|
||||||
.get_latest_document(&vault_id, &document_id, Some(&mut transaction))
|
.get_latest_document(
|
||||||
|
&vault_id,
|
||||||
|
&document_id,
|
||||||
|
Some(transaction.connection_mut().map_err(server_error)?),
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(server_error)?
|
.map_err(server_error)?
|
||||||
.map_or_else(
|
.map_or_else(
|
||||||
|
|
@ -225,64 +232,56 @@ pub async fn update_document(
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// For mergability, use whichever path the new version will live at — the
|
// For mergability, use whichever path the new version will live at:
|
||||||
// requested rename target if the client sent one, otherwise the existing
|
// - the requested rename target if the client sent one
|
||||||
// server-side path.
|
// - otherwise the existing server-side path.
|
||||||
let mergable_check_path = sanitized_relative_path
|
let mergable_check_path = sanitized_relative_path
|
||||||
.as_deref()
|
.as_deref()
|
||||||
.unwrap_or(&latest_version.relative_path);
|
.unwrap_or(&latest_version.relative_path);
|
||||||
let are_all_participants_mergable = is_file_type_mergable(
|
|
||||||
|
let mergeable_texts = if is_file_type_mergable(
|
||||||
mergable_check_path,
|
mergable_check_path,
|
||||||
&state.config.server.mergeable_file_extensions,
|
&state.config.server.mergeable_file_extensions,
|
||||||
) && !is_binary(&parent_content)
|
) {
|
||||||
&& !is_binary(&latest_version.content)
|
as_non_binary_texts(&parent_content, &latest_version.content, &content)
|
||||||
&& !is_binary(&content);
|
|
||||||
|
|
||||||
let (merged_content, is_different_from_request_content) = if are_all_participants_mergable {
|
|
||||||
info!("Merging changes for document `{document_id}` in vault `{vault_id}`");
|
|
||||||
let parent_text = str::from_utf8(&parent_content)
|
|
||||||
.context("Parent document content is not valid UTF-8")
|
|
||||||
.map_err(client_error)?;
|
|
||||||
let latest_text = str::from_utf8(&latest_version.content)
|
|
||||||
.context("Latest version content is not valid UTF-8")
|
|
||||||
.map_err(client_error)?;
|
|
||||||
let new_text = str::from_utf8(&content)
|
|
||||||
.context("New content is not valid UTF-8")
|
|
||||||
.map_err(client_error)?;
|
|
||||||
let parent_owned = parent_text.to_owned();
|
|
||||||
let latest_owned = latest_text.to_owned();
|
|
||||||
let new_owned = new_text.to_owned();
|
|
||||||
let content_clone = content.clone();
|
|
||||||
|
|
||||||
let (merged, is_different) = tokio::task::spawn_blocking(move || {
|
|
||||||
let merged = reconcile(
|
|
||||||
&parent_owned,
|
|
||||||
&latest_owned.into(),
|
|
||||||
&new_owned.into(),
|
|
||||||
&*BuiltinTokenizer::Word,
|
|
||||||
)
|
|
||||||
.apply()
|
|
||||||
.text()
|
|
||||||
.into_bytes();
|
|
||||||
let is_different = merged != content_clone;
|
|
||||||
(merged, is_different)
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
.map_err(|e| server_error(anyhow::anyhow!("Reconcile task failed: {e}")))?;
|
|
||||||
|
|
||||||
(merged, is_different)
|
|
||||||
} else {
|
} else {
|
||||||
(content, false) // false means that the client doesn't need to refetch the file as we can ensure the remote and local versions are the same as LWW is the merging method for binary files
|
None
|
||||||
};
|
};
|
||||||
|
let are_all_participants_mergable = mergeable_texts.is_some();
|
||||||
|
|
||||||
// Rename resolution: only apply the client's rename if (a) the client
|
let (merged_content, is_same_as_request) =
|
||||||
// requested one (`sanitized_relative_path` is `Some`) and (b) the
|
if let Some((parent_text, latest_text, new_text)) = mergeable_texts {
|
||||||
// document's path hasn't changed since this client's parent version.
|
info!("Merging changes for document `{document_id}` in vault `{vault_id}`");
|
||||||
// If the parent and latest paths differ, another client already renamed
|
|
||||||
// the document — keep the latest path (first rename wins). Content
|
let parent_owned = parent_text.to_owned();
|
||||||
// changes from both clients are still merged correctly via the 3-way
|
let latest_owned = latest_text.to_owned();
|
||||||
// reconcile above, independent of which rename wins. A missing
|
let new_owned = new_text.to_owned();
|
||||||
// relative_path means "keep current path" (content-only edit).
|
let content_clone = content.clone();
|
||||||
|
|
||||||
|
let merged = tokio::task::spawn_blocking(move || {
|
||||||
|
let merged = reconcile(
|
||||||
|
&parent_owned,
|
||||||
|
&latest_owned.into(),
|
||||||
|
&new_owned.into(),
|
||||||
|
&*BuiltinTokenizer::Word,
|
||||||
|
)
|
||||||
|
.apply()
|
||||||
|
.text()
|
||||||
|
.into_bytes();
|
||||||
|
merged
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.map_err(|e| server_error(anyhow::anyhow!("Reconcile task failed: {e}")))?;
|
||||||
|
|
||||||
|
let is_same = merged == content_clone;
|
||||||
|
(merged, is_same)
|
||||||
|
} else {
|
||||||
|
(content, true) // true means that the client doesn't need to refetch the file as we can ensure the remote and local versions are the same as LWW is the merging method for binary files
|
||||||
|
};
|
||||||
|
|
||||||
|
// First rename wins: apply the client's rename only if the doc's path
|
||||||
|
// hasn't changed since its parent version. Content from both clients
|
||||||
|
// still merges via the 3-way reconcile above
|
||||||
let new_relative_path = match sanitized_relative_path.as_deref() {
|
let new_relative_path = match sanitized_relative_path.as_deref() {
|
||||||
Some(requested)
|
Some(requested)
|
||||||
if parent_relative_path == latest_version.relative_path
|
if parent_relative_path == latest_version.relative_path
|
||||||
|
|
@ -306,7 +305,9 @@ pub async fn update_document(
|
||||||
|
|
||||||
let new_version = StoredDocumentVersion {
|
let new_version = StoredDocumentVersion {
|
||||||
document_id,
|
document_id,
|
||||||
vault_update_id: last_update_id + 1,
|
vault_update_id: last_update_id
|
||||||
|
.checked_add(1)
|
||||||
|
.ok_or_else(|| server_error(anyhow!("Vault update id overflow")))?,
|
||||||
creation_vault_update_id: latest_version.creation_vault_update_id,
|
creation_vault_update_id: latest_version.creation_vault_update_id,
|
||||||
relative_path: new_relative_path,
|
relative_path: new_relative_path,
|
||||||
content: merged_content,
|
content: merged_content,
|
||||||
|
|
@ -314,7 +315,7 @@ pub async fn update_document(
|
||||||
is_deleted: false,
|
is_deleted: false,
|
||||||
user_id: user.name,
|
user_id: user.name,
|
||||||
device_id: device_id.0,
|
device_id: device_id.0,
|
||||||
has_been_merged: are_all_participants_mergable && is_different_from_request_content,
|
has_been_merged: are_all_participants_mergable && !is_same_as_request,
|
||||||
};
|
};
|
||||||
|
|
||||||
state
|
state
|
||||||
|
|
@ -323,9 +324,21 @@ pub async fn update_document(
|
||||||
.await
|
.await
|
||||||
.map_err(server_error)?;
|
.map_err(server_error)?;
|
||||||
|
|
||||||
Ok(Json(if is_different_from_request_content {
|
Ok(Json(if is_same_as_request {
|
||||||
DocumentUpdateResponse::MergingUpdate(new_version.into())
|
|
||||||
} else {
|
|
||||||
DocumentUpdateResponse::FastForwardUpdate(new_version.into())
|
DocumentUpdateResponse::FastForwardUpdate(new_version.into())
|
||||||
|
} else {
|
||||||
|
DocumentUpdateResponse::MergingUpdate(new_version.into())
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn as_non_binary_texts<'a>(
|
||||||
|
parent_content: &'a [u8],
|
||||||
|
latest_content: &'a [u8],
|
||||||
|
new_content: &'a [u8],
|
||||||
|
) -> Option<(&'a str, &'a str, &'a str)> {
|
||||||
|
Some((
|
||||||
|
as_non_binary_text(parent_content)?,
|
||||||
|
as_non_binary_text(latest_content)?,
|
||||||
|
as_non_binary_text(new_content)?,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -306,7 +306,7 @@ async fn websocket(
|
||||||
&device_id,
|
&device_id,
|
||||||
docs,
|
docs,
|
||||||
)
|
)
|
||||||
.await;
|
.await?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -351,7 +351,7 @@ async fn websocket(
|
||||||
state
|
state
|
||||||
.cursors
|
.cursors
|
||||||
.remove_cursors_of_device(&vault_id, &authed_handshake.handshake.device_id)
|
.remove_cursors_of_device(&vault_id, &authed_handshake.handshake.device_id)
|
||||||
.await;
|
.await?;
|
||||||
|
|
||||||
match &result {
|
match &result {
|
||||||
Ok(()) => {
|
Ok(()) => {
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,3 @@
|
||||||
use std::sync::LazyLock;
|
|
||||||
|
|
||||||
use regex::Regex;
|
|
||||||
|
|
||||||
static DEDUP_SUFFIX_REGEX: LazyLock<Regex> =
|
|
||||||
LazyLock::new(|| Regex::new(r" \((\d+)\)$").expect("invalid regex"));
|
|
||||||
|
|
||||||
pub fn dedup_paths(path: &str) -> impl Iterator<Item = String> {
|
pub fn dedup_paths(path: &str) -> impl Iterator<Item = String> {
|
||||||
let mut path_parts = path.split('/').collect::<Vec<_>>();
|
let mut path_parts = path.split('/').collect::<Vec<_>>();
|
||||||
let file_name = path_parts
|
let file_name = path_parts
|
||||||
|
|
@ -24,29 +17,19 @@ pub fn dedup_paths(path: &str) -> impl Iterator<Item = String> {
|
||||||
let (stem, extension) = if is_simple_dotfile {
|
let (stem, extension) = if is_simple_dotfile {
|
||||||
(file_name.clone(), String::new())
|
(file_name.clone(), String::new())
|
||||||
} else {
|
} else {
|
||||||
// Regular file or dotfile with extension
|
match file_name.rsplit_once('.') {
|
||||||
let name_parts = file_name.rsplitn(2, '.').collect::<Vec<_>>();
|
Some((stem, extension)) => (stem.to_owned(), format!(".{extension}")),
|
||||||
let mut reverse_parts = name_parts.into_iter().rev();
|
None => (file_name.clone(), String::new()),
|
||||||
match (reverse_parts.next(), reverse_parts.next()) {
|
|
||||||
(Some(stem), maybe_extension) => (
|
|
||||||
stem.to_owned(),
|
|
||||||
maybe_extension
|
|
||||||
.map(|ext| format!(".{ext}"))
|
|
||||||
.unwrap_or_default(),
|
|
||||||
),
|
|
||||||
_ => unreachable!("Path must have at least one part"),
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let start_number = DEDUP_SUFFIX_REGEX
|
let (clean_stem, start_number) = strip_dedup_suffix(&stem);
|
||||||
.captures(&stem)
|
let clean_stem = clean_stem.to_owned();
|
||||||
.and_then(|caps| caps.get(1))
|
|
||||||
.and_then(|m| m.as_str().parse::<u32>().ok())
|
|
||||||
.unwrap_or(0);
|
|
||||||
|
|
||||||
let clean_stem = DEDUP_SUFFIX_REGEX.replace(&stem, "").to_string();
|
std::iter::successors(Some(start_number), |dedup_number| {
|
||||||
|
dedup_number.checked_add(1)
|
||||||
(start_number..).map(move |dedup_number| {
|
})
|
||||||
|
.map(move |dedup_number| {
|
||||||
if dedup_number == 0 {
|
if dedup_number == 0 {
|
||||||
format!("{directory}{clean_stem}{extension}")
|
format!("{directory}{clean_stem}{extension}")
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -55,6 +38,20 @@ pub fn dedup_paths(path: &str) -> impl Iterator<Item = String> {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn strip_dedup_suffix(stem: &str) -> (&str, u64) {
|
||||||
|
let Some(without_closing_paren) = stem.strip_suffix(')') else {
|
||||||
|
return (stem, 0);
|
||||||
|
};
|
||||||
|
let Some((clean_stem, number)) = without_closing_paren.rsplit_once(" (") else {
|
||||||
|
return (stem, 0);
|
||||||
|
};
|
||||||
|
if number.is_empty() || !number.chars().all(|c| c.is_ascii_digit()) {
|
||||||
|
return (stem, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
(clean_stem, number.parse::<u64>().unwrap_or(0))
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod test {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
@ -103,7 +100,7 @@ mod test {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_regex_capturing_group() {
|
fn test_dedup_suffix_parsing() {
|
||||||
// Single digit in parentheses
|
// Single digit in parentheses
|
||||||
let mut deduped = dedup_paths("document (5).md");
|
let mut deduped = dedup_paths("document (5).md");
|
||||||
assert_eq!(deduped.next(), Some("document (5).md".to_owned()));
|
assert_eq!(deduped.next(), Some("document (5).md".to_owned()));
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,23 @@
|
||||||
use crate::app_state::database::models::VaultId;
|
use crate::app_state::database::{WriteTransaction, models::VaultId};
|
||||||
use crate::utils::dedup_paths::dedup_paths;
|
use crate::utils::dedup_paths::dedup_paths;
|
||||||
use anyhow::Result;
|
use anyhow::{Result, anyhow};
|
||||||
use log::{debug, info};
|
use log::{debug, info};
|
||||||
use sqlx::sqlite::SqliteConnection;
|
|
||||||
|
|
||||||
pub async fn find_first_available_path(
|
pub async fn find_first_available_path(
|
||||||
vault_id: &VaultId,
|
vault_id: &VaultId,
|
||||||
sanitized_relative_path: &str,
|
sanitized_relative_path: &str,
|
||||||
database: &crate::app_state::database::Database,
|
database: &crate::app_state::database::Database,
|
||||||
connection: &mut SqliteConnection,
|
transaction: &mut WriteTransaction,
|
||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
info!("Finding first available path for `{sanitized_relative_path}` in vault `{vault_id}`");
|
info!("Finding first available path for `{sanitized_relative_path}` in vault `{vault_id}`");
|
||||||
for candidate in dedup_paths(sanitized_relative_path) {
|
for candidate in dedup_paths(sanitized_relative_path) {
|
||||||
debug!("Checking candidate path for deconflicting names: `{candidate}`");
|
debug!("Checking candidate path for deconflicting names: `{candidate}`");
|
||||||
if database
|
if database
|
||||||
.get_latest_non_deleted_document_by_path(vault_id, &candidate, Some(connection))
|
.get_latest_non_deleted_document_by_path(
|
||||||
|
vault_id,
|
||||||
|
&candidate,
|
||||||
|
Some(transaction.connection_mut()?),
|
||||||
|
)
|
||||||
.await?
|
.await?
|
||||||
.is_none()
|
.is_none()
|
||||||
{
|
{
|
||||||
|
|
@ -27,5 +30,7 @@ pub async fn find_first_available_path(
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
unreachable!("dedup_paths produces infinite paths");
|
Err(anyhow!(
|
||||||
|
"No available path candidates produced for `{sanitized_relative_path}` in vault `{vault_id}`"
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,22 @@
|
||||||
/// Heuristically determine if the given data is a binary or a text file's
|
/// Return the given data as UTF-8 text if it is not considered binary.
|
||||||
/// content.
|
|
||||||
///
|
///
|
||||||
/// Only text inputs can be reconciled using the crate's functions.
|
/// Only text inputs can be reconciled using the crate's functions.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn is_binary(data: &[u8]) -> bool {
|
pub fn as_non_binary_text(data: &[u8]) -> Option<&str> {
|
||||||
if data.contains(&0) {
|
if data.contains(&0) {
|
||||||
// Even though the NUL character is valid in UTF-8, it's highly suspicious in
|
// Even though the NUL character is valid in UTF-8, it's highly suspicious in
|
||||||
// human-readable text.
|
// human-readable text.
|
||||||
return true;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::str::from_utf8(data).is_err()
|
std::str::from_utf8(data).ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Heuristically determine if the given data is a binary or a text file's
|
||||||
|
/// content.
|
||||||
|
#[must_use]
|
||||||
|
pub fn is_binary(data: &[u8]) -> bool {
|
||||||
|
as_non_binary_text(data).is_none()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
@ -23,4 +29,11 @@ mod tests {
|
||||||
assert!(is_binary(&[0, 12]));
|
assert!(is_binary(&[0, 12]));
|
||||||
assert!(!is_binary(b"hello"));
|
assert!(!is_binary(b"hello"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_as_non_binary_text() {
|
||||||
|
assert_eq!(as_non_binary_text(b"hello"), Some("hello"));
|
||||||
|
assert_eq!(as_non_binary_text(&[0, 12]), None);
|
||||||
|
assert_eq!(as_non_binary_text(&[0xff]), None);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -52,14 +52,14 @@ impl RotatingFileWriter {
|
||||||
/// Parse timestamp from log filename and return as `SystemTime`
|
/// Parse timestamp from log filename and return as `SystemTime`
|
||||||
fn parse_log_timestamp(filename: &str, file_prefix: &str) -> Option<SystemTime> {
|
fn parse_log_timestamp(filename: &str, file_prefix: &str) -> Option<SystemTime> {
|
||||||
// Expected format: {prefix}.{timestamp}.log where timestamp is %Y-%m-%d_%H-%M-%S
|
// Expected format: {prefix}.{timestamp}.log where timestamp is %Y-%m-%d_%H-%M-%S
|
||||||
let prefix_len = file_prefix.len() + 1; // +1 for the dot
|
let prefix_len = file_prefix.len().checked_add(1)?; // +1 for the dot
|
||||||
let timestamp_str = filename.get(prefix_len..filename.len().checked_sub(4)?)?;
|
let timestamp_str = filename.get(prefix_len..filename.len().checked_sub(4)?)?;
|
||||||
|
|
||||||
let dt = NaiveDateTime::parse_from_str(timestamp_str, "%Y-%m-%d_%H-%M-%S").ok()?;
|
let dt = NaiveDateTime::parse_from_str(timestamp_str, "%Y-%m-%d_%H-%M-%S").ok()?;
|
||||||
let timestamp = dt.and_utc();
|
let timestamp = dt.and_utc();
|
||||||
let secs: u64 = timestamp.timestamp().try_into().ok()?;
|
let secs: u64 = timestamp.timestamp().try_into().ok()?;
|
||||||
|
|
||||||
Some(UNIX_EPOCH + Duration::from_secs(secs))
|
UNIX_EPOCH.checked_add(Duration::from_secs(secs))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn find_latest_log_file(directory: &Path, file_prefix: &str) -> Option<String> {
|
fn find_latest_log_file(directory: &Path, file_prefix: &str) -> Option<String> {
|
||||||
|
|
@ -86,7 +86,9 @@ impl RotatingFileWriter {
|
||||||
Self::find_latest_log_file(directory, file_prefix)
|
Self::find_latest_log_file(directory, file_prefix)
|
||||||
.and_then(|filename| Self::parse_log_timestamp(&filename, file_prefix))
|
.and_then(|filename| Self::parse_log_timestamp(&filename, file_prefix))
|
||||||
.map_or_else(SystemTime::now, |last_rotation| {
|
.map_or_else(SystemTime::now, |last_rotation| {
|
||||||
last_rotation + rotation_duration
|
last_rotation
|
||||||
|
.checked_add(rotation_duration)
|
||||||
|
.unwrap_or_else(SystemTime::now)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -136,7 +138,9 @@ impl RotatingFileWriter {
|
||||||
.open(&filepath)?;
|
.open(&filepath)?;
|
||||||
|
|
||||||
inner.current_file = Some(file);
|
inner.current_file = Some(file);
|
||||||
inner.next_rotation_time = SystemTime::now() + inner.rotation_duration;
|
inner.next_rotation_time = SystemTime::now()
|
||||||
|
.checked_add(inner.rotation_duration)
|
||||||
|
.unwrap_or_else(SystemTime::now);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue