Add proper shutdown, rate limits, config validation, cors config, fix dangling cursors, cache regex, merge created texts
This commit is contained in:
parent
4763bc9d04
commit
e15b0f9903
28 changed files with 1277 additions and 464 deletions
|
|
@ -30,8 +30,11 @@ fi
|
|||
which cargo-machete || cargo install cargo-machete
|
||||
cargo machete --with-metadata
|
||||
|
||||
cd ..
|
||||
scripts/update-api-types.sh # this will dirty up the git state if not up-to-date
|
||||
|
||||
echo "Running checks in frontend"
|
||||
cd ../frontend
|
||||
cd frontend
|
||||
|
||||
if [[ "$FIX_MODE" == true ]]; then
|
||||
npm install
|
||||
|
|
@ -57,6 +60,4 @@ if [[ "$FIX_MODE" == false ]] && [[ $(git status --porcelain) ]]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
cd ..
|
||||
|
||||
echo "Success"
|
||||
|
|
|
|||
|
|
@ -2,14 +2,14 @@
|
|||
|
||||
set -e
|
||||
|
||||
SERVER_URL="http://localhost:3000"
|
||||
SERVER_URL="http://localhost:3010"
|
||||
MAX_RETRIES=30
|
||||
RETRY_INTERVAL_IN_SECONDS=5
|
||||
|
||||
echo "Waiting for $SERVER_URL to become available..."
|
||||
count=0
|
||||
while [ $count -lt $MAX_RETRIES ]; do
|
||||
if curl -s -f -o /dev/null $SERVER_URL; then
|
||||
if curl -s -o /dev/null $SERVER_URL; then
|
||||
echo "$SERVER_URL is now available!"
|
||||
break
|
||||
fi
|
||||
|
|
|
|||
112
sync-server/Cargo.lock
generated
112
sync-server/Cargo.lock
generated
|
|
@ -337,10 +337,11 @@ checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b"
|
|||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.2.2"
|
||||
version = "1.2.57"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f34d93e62b03caf570cccc334cbc6c2fceca82f39211051345108adcba3eebdc"
|
||||
checksum = "7a0dd1ca384932ff3641c8718a02769f1698e7563dc6974ffd03346116310423"
|
||||
dependencies = [
|
||||
"find-msvc-tools",
|
||||
"shlex",
|
||||
]
|
||||
|
||||
|
|
@ -624,6 +625,12 @@ version = "2.2.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4"
|
||||
|
||||
[[package]]
|
||||
name = "find-msvc-tools"
|
||||
version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582"
|
||||
|
||||
[[package]]
|
||||
name = "flume"
|
||||
version = "0.11.1"
|
||||
|
|
@ -1272,6 +1279,16 @@ version = "0.3.17"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
|
||||
|
||||
[[package]]
|
||||
name = "mime_guess"
|
||||
version = "2.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
|
||||
dependencies = [
|
||||
"mime",
|
||||
"unicase",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "miniz_oxide"
|
||||
version = "0.8.0"
|
||||
|
|
@ -1582,12 +1599,12 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "reconcile-text"
|
||||
version = "0.8.0"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "599cf9539996a2a19e501110404c59ba62f4974009f8fb864a8b7151c15ee5a5"
|
||||
checksum = "52e0cf361887ea64c479ca871c1170dda761f84e122f2616b5579906a38d7557"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"thiserror 2.0.17",
|
||||
"thiserror 2.0.18",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1648,6 +1665,40 @@ dependencies = [
|
|||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust-embed"
|
||||
version = "8.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "04113cb9355a377d83f06ef1f0a45b8ab8cd7d8b1288160717d66df5c7988d27"
|
||||
dependencies = [
|
||||
"rust-embed-impl",
|
||||
"rust-embed-utils",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust-embed-impl"
|
||||
version = "8.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "da0902e4c7c8e997159ab384e6d0fc91c221375f6894346ae107f47dd0f3ccaa"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rust-embed-utils",
|
||||
"syn 2.0.90",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust-embed-utils"
|
||||
version = "8.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5bcdef0be6fe7f6fa333b1073c949729274b05f123a0ad7efcb8efd878e5c3b1"
|
||||
dependencies = [
|
||||
"sha2",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc-demangle"
|
||||
version = "0.1.24"
|
||||
|
|
@ -1679,6 +1730,15 @@ version = "1.0.18"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
|
||||
|
||||
[[package]]
|
||||
name = "same-file"
|
||||
version = "1.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
|
||||
dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sanitize-filename"
|
||||
version = "0.6.0"
|
||||
|
|
@ -1916,7 +1976,7 @@ dependencies = [
|
|||
"serde_json",
|
||||
"sha2",
|
||||
"smallvec",
|
||||
"thiserror 2.0.17",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tracing",
|
||||
|
|
@ -2000,7 +2060,7 @@ dependencies = [
|
|||
"smallvec",
|
||||
"sqlx-core",
|
||||
"stringprep",
|
||||
"thiserror 2.0.17",
|
||||
"thiserror 2.0.18",
|
||||
"tracing",
|
||||
"uuid",
|
||||
"whoami",
|
||||
|
|
@ -2039,7 +2099,7 @@ dependencies = [
|
|||
"smallvec",
|
||||
"sqlx-core",
|
||||
"stringprep",
|
||||
"thiserror 2.0.17",
|
||||
"thiserror 2.0.18",
|
||||
"tracing",
|
||||
"uuid",
|
||||
"whoami",
|
||||
|
|
@ -2065,7 +2125,7 @@ dependencies = [
|
|||
"serde",
|
||||
"serde_urlencoded",
|
||||
"sqlx-core",
|
||||
"thiserror 2.0.17",
|
||||
"thiserror 2.0.18",
|
||||
"tracing",
|
||||
"url",
|
||||
"uuid",
|
||||
|
|
@ -2136,15 +2196,18 @@ dependencies = [
|
|||
"futures",
|
||||
"humantime-serde",
|
||||
"log",
|
||||
"mime_guess",
|
||||
"rand 0.9.0",
|
||||
"reconcile-text",
|
||||
"regex",
|
||||
"rust-embed",
|
||||
"sanitize-filename",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_yaml",
|
||||
"sqlx",
|
||||
"thiserror 2.0.17",
|
||||
"subtle",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
|
|
@ -2203,11 +2266,11 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "2.0.17"
|
||||
version = "2.0.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8"
|
||||
checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4"
|
||||
dependencies = [
|
||||
"thiserror-impl 2.0.17",
|
||||
"thiserror-impl 2.0.18",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -2223,9 +2286,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "2.0.17"
|
||||
version = "2.0.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913"
|
||||
checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
|
@ -2276,7 +2339,6 @@ dependencies = [
|
|||
"bytes",
|
||||
"libc",
|
||||
"mio",
|
||||
"parking_lot",
|
||||
"pin-project-lite",
|
||||
"signal-hook-registry",
|
||||
"socket2",
|
||||
|
|
@ -2434,7 +2496,7 @@ checksum = "e640d9b0964e9d39df633548591090ab92f7a4567bc31d3891af23471a3365c6"
|
|||
dependencies = [
|
||||
"chrono",
|
||||
"lazy_static",
|
||||
"thiserror 2.0.17",
|
||||
"thiserror 2.0.18",
|
||||
"ts-rs-macros",
|
||||
"uuid",
|
||||
]
|
||||
|
|
@ -2481,6 +2543,12 @@ version = "0.10.4"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f720def6ce1ee2fc44d40ac9ed6d3a59c361c80a75a7aa8e75bb9baed31cf2ea"
|
||||
|
||||
[[package]]
|
||||
name = "unicase"
|
||||
version = "2.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-bidi"
|
||||
version = "0.3.17"
|
||||
|
|
@ -2577,6 +2645,16 @@ version = "0.9.5"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
|
||||
|
||||
[[package]]
|
||||
name = "walkdir"
|
||||
version = "2.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b"
|
||||
dependencies = [
|
||||
"same-file",
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.11.0+wasi-snapshot-preview1"
|
||||
|
|
|
|||
|
|
@ -33,7 +33,10 @@ serde_json = "1.0.140"
|
|||
bimap = "0.6.3"
|
||||
ts-rs = { version = "10.1", features = ["uuid-impl", "chrono-impl"] }
|
||||
base64 = "0.22.1"
|
||||
reconcile-text = { version = "0.8.0", features = ["serde"] }
|
||||
reconcile-text = { version = "0.11.0", features = ["serde"] }
|
||||
rust-embed = "8.5"
|
||||
mime_guess = "2.0"
|
||||
subtle = "2.6.1"
|
||||
|
||||
[profile.release]
|
||||
codegen-units = 1
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
database:
|
||||
databases_directory_path: databases
|
||||
max_connections_per_vault: 12
|
||||
max_connections_per_vault: 8
|
||||
cursor_timeout: 1m
|
||||
server:
|
||||
host: 0.0.0.0
|
||||
port: 3000
|
||||
port: 3010
|
||||
max_body_size_mb: 512
|
||||
max_clients_per_vault: 256
|
||||
broadcast_channel_capacity: 1024
|
||||
response_timeout: 30m
|
||||
mergeable_file_extensions:
|
||||
- md
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ pub mod cursors;
|
|||
pub mod database;
|
||||
pub mod websocket;
|
||||
|
||||
use std::sync::{Arc, atomic::AtomicUsize};
|
||||
|
||||
use anyhow::Result;
|
||||
use cursors::Cursors;
|
||||
use database::Database;
|
||||
|
|
@ -15,21 +17,42 @@ pub struct AppState {
|
|||
pub database: Database,
|
||||
pub cursors: Cursors,
|
||||
pub broadcasts: Broadcasts,
|
||||
/// Tracks WebSocket connections that have upgraded but not yet completed
|
||||
/// the authentication handshake
|
||||
pub pending_ws_connections: Arc<AtomicUsize>,
|
||||
/// Send on this channel to stop background tasks (cursor cleanup,
|
||||
/// idle-pool cleanup)
|
||||
shutdown_tx: Arc<tokio::sync::watch::Sender<()>>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub async fn try_new(config: Config) -> Result<Self> {
|
||||
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(());
|
||||
|
||||
let broadcasts = Broadcasts::new(&config.server);
|
||||
let database = Database::try_new(&config.database, &broadcasts).await?;
|
||||
let database =
|
||||
Database::try_new(&config.database, &broadcasts, shutdown_rx.clone()).await?;
|
||||
let cursors: Cursors = Cursors::new(&config.database, &broadcasts);
|
||||
|
||||
Cursors::start_background_task(cursors.clone());
|
||||
Cursors::start_background_task(cursors.clone(), shutdown_rx);
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
database,
|
||||
cursors,
|
||||
broadcasts,
|
||||
pending_ws_connections: Arc::new(AtomicUsize::new(0)),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
})
|
||||
}
|
||||
|
||||
/// Signal all background tasks (idle pool cleanup, cursor cleanup) to stop
|
||||
pub fn shutdown(&self) {
|
||||
let _ = self.shutdown_tx.send(());
|
||||
}
|
||||
|
||||
/// Get a receiver to be notified when shutdown is triggered
|
||||
pub fn subscribe_shutdown(&self) -> tokio::sync::watch::Receiver<()> {
|
||||
self.shutdown_tx.subscribe()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,7 +42,9 @@ impl Cursors {
|
|||
) {
|
||||
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||
|
||||
let all_device_cursors = vault_to_cursors.entry(vault_id).or_insert_with(Vec::new);
|
||||
let all_device_cursors = vault_to_cursors
|
||||
.entry(vault_id.clone())
|
||||
.or_insert_with(Vec::new);
|
||||
|
||||
all_device_cursors.retain(|c| &c.client_cursors.device_id != device_id);
|
||||
all_device_cursors.push(ClientCursorsWithTimeToLive::new(ClientCursors {
|
||||
|
|
@ -52,7 +54,7 @@ impl Cursors {
|
|||
}));
|
||||
|
||||
drop(vault_to_cursors); // Explicitly drop the lock before broadcasting to avoid deadlock
|
||||
self.broadcast_cursors().await;
|
||||
self.broadcast_cursors_for_vault(&vault_id).await;
|
||||
}
|
||||
|
||||
pub async fn get_cursors(&self, vault_id: &VaultId) -> Vec<ClientCursors> {
|
||||
|
|
@ -69,45 +71,83 @@ impl Cursors {
|
|||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn start_background_task(self) {
|
||||
pub fn start_background_task(self, mut shutdown: tokio::sync::watch::Receiver<()>) {
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
self.remove_expired_cursors().await;
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
tokio::select! {
|
||||
() = tokio::time::sleep(Duration::from_secs(1)) => {
|
||||
self.remove_expired_cursors().await;
|
||||
}
|
||||
Ok(()) = shutdown.changed() => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn remove_expired_cursors(&self) {
|
||||
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||
let changed_vaults: Vec<VaultId> = {
|
||||
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||
|
||||
for (_vault_id, cursors) in vault_to_cursors.iter_mut() {
|
||||
cursors.retain(|cursor| !cursor.is_expired(self.config.cursor_timeout));
|
||||
let mut changed = Vec::new();
|
||||
for (vault_id, cursors) in vault_to_cursors.iter_mut() {
|
||||
let before = cursors.len();
|
||||
cursors.retain(|cursor| !cursor.is_expired(self.config.cursor_timeout));
|
||||
if cursors.len() != before {
|
||||
changed.push(vault_id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Remove empty vault entries to prevent unbounded growth
|
||||
vault_to_cursors.retain(|_, cursors| !cursors.is_empty());
|
||||
|
||||
changed
|
||||
};
|
||||
|
||||
for vault_id in &changed_vaults {
|
||||
self.broadcast_cursors_for_vault(vault_id).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn broadcast_cursors(&self) {
|
||||
let vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||
async fn broadcast_cursors_for_vault(&self, vault_id: &VaultId) {
|
||||
let client_cursors: Vec<ClientCursors> = {
|
||||
let vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||
vault_to_cursors
|
||||
.get(vault_id)
|
||||
.map(|cursors| cursors.iter().map(|c| c.client_cursors.clone()).collect())
|
||||
.unwrap_or_default()
|
||||
};
|
||||
|
||||
for (vault_id, cursors) in vault_to_cursors.iter() {
|
||||
self.broadcasts
|
||||
.send_document_update(
|
||||
vault_id.clone(),
|
||||
WebSocketServerMessageWithOrigin::new(WebSocketServerMessage::CursorPositions(
|
||||
CursorPositionFromServer {
|
||||
clients: cursors.iter().map(|c| c.client_cursors.clone()).collect(),
|
||||
},
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
self.broadcasts
|
||||
.send_document_update(
|
||||
vault_id.clone(),
|
||||
WebSocketServerMessageWithOrigin::new(WebSocketServerMessage::CursorPositions(
|
||||
CursorPositionFromServer {
|
||||
clients: client_cursors,
|
||||
},
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn remove_cursors_of_device(&self, vault_id: &str, device_id: &str) {
|
||||
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||
pub async fn remove_cursors_of_device(&self, vault_id: &VaultId, device_id: &DeviceId) {
|
||||
let changed = {
|
||||
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||
|
||||
if let Some(cursors) = vault_to_cursors.get_mut(vault_id) {
|
||||
cursors.retain(|c| c.client_cursors.device_id != device_id);
|
||||
if let Some(cursors) = vault_to_cursors.get_mut(vault_id) {
|
||||
let before = cursors.len();
|
||||
cursors.retain(|c| c.client_cursors.device_id != *device_id);
|
||||
let changed = cursors.len() != before;
|
||||
if cursors.is_empty() {
|
||||
vault_to_cursors.remove(vault_id);
|
||||
}
|
||||
changed
|
||||
} else {
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
if changed {
|
||||
self.broadcast_cursors_for_vault(vault_id).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,8 +9,17 @@ use models::{
|
|||
use sqlx::{ConnectOptions, sqlite::SqliteConnectOptions, types::chrono::Utc};
|
||||
|
||||
pub mod models;
|
||||
use sqlx::{Pool, Sqlite, sqlite::SqlitePoolOptions};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
/// Sentinel error indicating the SQLite database is busy (SQLITE_BUSY).
|
||||
/// Handlers can downcast to this to return 429 instead of 500.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Database is busy")]
|
||||
pub struct WriteBusyError;
|
||||
|
||||
use sqlx::{
|
||||
Pool, Sqlite, pool::PoolConnection, sqlite::SqliteConnection, sqlite::SqlitePoolOptions,
|
||||
};
|
||||
use tokio::sync::{Mutex, OnceCell};
|
||||
use tokio::time::Instant;
|
||||
use uuid::fmt::Hyphenated;
|
||||
|
||||
|
|
@ -19,33 +28,154 @@ use super::websocket::{
|
|||
models::{WebSocketServerMessage, WebSocketServerMessageWithOrigin, WebSocketVaultUpdate},
|
||||
};
|
||||
use crate::config::database_config::DatabaseConfig;
|
||||
use crate::consts::IDLE_POOL_TIMEOUT;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct PoolWithTimestamp {
|
||||
pool: Pool<Sqlite>,
|
||||
last_accessed: Instant,
|
||||
/// Holds separate reader and writer pools for a single vault.
|
||||
/// The writer pool has exactly 1 connection so writes never compete
|
||||
/// with reads for pool slots.
|
||||
#[derive(Debug, Clone)]
|
||||
struct VaultPools {
|
||||
reader: Pool<Sqlite>,
|
||||
writer: Pool<Sqlite>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for PoolWithTimestamp {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PoolWithTimestamp")
|
||||
.field("pool", &"Pool<Sqlite>")
|
||||
.field("last_accessed", &self.last_accessed)
|
||||
.finish()
|
||||
}
|
||||
#[derive(Debug)]
|
||||
struct VaultPool {
|
||||
cell: Arc<OnceCell<VaultPools>>,
|
||||
last_accessed: Mutex<Instant>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Database {
|
||||
config: DatabaseConfig,
|
||||
broadcasts: Broadcasts,
|
||||
connection_pools: Arc<Mutex<HashMap<VaultId, PoolWithTimestamp>>>,
|
||||
connection_pools: Arc<Mutex<HashMap<VaultId, Arc<VaultPool>>>>,
|
||||
/// Per-vault write serialization. SQLite allows only one writer at a
|
||||
/// time; `BEGIN IMMEDIATE` on a second connection blocks until the first
|
||||
/// commits (up to `busy_timeout`). Under concurrent load the blocked
|
||||
/// connections consume the pool, starving even read-only requests.
|
||||
/// This mutex moves the wait from the SQLite layer (where it holds a
|
||||
/// pool connection) to the Tokio layer (where it holds nothing).
|
||||
write_locks: Arc<Mutex<HashMap<VaultId, Arc<tokio::sync::Mutex<()>>>>>,
|
||||
}
|
||||
|
||||
pub type Transaction<'a> = sqlx::Transaction<'a, Sqlite>;
|
||||
/// A write transaction backed by a raw `BEGIN IMMEDIATE` instead of sqlx's
|
||||
/// savepoint-based `Transaction`. This avoids the savepoint mismatch caused
|
||||
/// by the old `END; BEGIN IMMEDIATE;` workaround.
|
||||
///
|
||||
/// Holds an `OwnedMutexGuard` that serializes write transactions per vault
|
||||
/// at the application level (see `Database::write_locks`). The guard is
|
||||
/// released when the transaction is committed, rolled back, or dropped.
|
||||
pub struct WriteTransaction {
|
||||
conn: Option<PoolConnection<Sqlite>>,
|
||||
_write_guard: tokio::sync::OwnedMutexGuard<()>,
|
||||
}
|
||||
|
||||
impl WriteTransaction {
|
||||
async fn new(pool: &Pool<Sqlite>, write_guard: tokio::sync::OwnedMutexGuard<()>) -> Result<Self> {
|
||||
let mut conn = pool
|
||||
.acquire()
|
||||
.await
|
||||
.context("Cannot acquire connection for write transaction")?;
|
||||
if let Err(e) = sqlx::query("BEGIN IMMEDIATE")
|
||||
.execute(&mut *conn)
|
||||
.await
|
||||
{
|
||||
let is_busy = match &e {
|
||||
sqlx::Error::Database(db_err) => {
|
||||
// SQLITE_BUSY base code is 5. Extended codes share base 5.
|
||||
let busy_by_code = db_err.code().is_some_and(|c| {
|
||||
c.parse::<u32>().is_ok_and(|n| n & 0xFF == 5)
|
||||
});
|
||||
busy_by_code || db_err.message().contains("database is locked")
|
||||
}
|
||||
_ => false,
|
||||
};
|
||||
if is_busy {
|
||||
return Err(WriteBusyError.into());
|
||||
}
|
||||
return Err(e).context("Cannot begin immediate transaction");
|
||||
}
|
||||
Ok(Self { conn: Some(conn), _write_guard: write_guard })
|
||||
}
|
||||
|
||||
pub async fn commit(mut self) -> Result<()> {
|
||||
if let Some(mut conn) = self.conn.take() {
|
||||
sqlx::query("COMMIT")
|
||||
.execute(&mut *conn)
|
||||
.await
|
||||
.context("Failed to commit transaction")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn rollback(mut self) -> Result<()> {
|
||||
if let Some(mut conn) = self.conn.take() {
|
||||
sqlx::query("ROLLBACK")
|
||||
.execute(&mut *conn)
|
||||
.await
|
||||
.context("Failed to rollback transaction")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WriteTransaction {
|
||||
fn drop(&mut self) {
|
||||
if self.conn.is_some() {
|
||||
// The connection is returned to the pool with an open transaction.
|
||||
// The pool's `before_acquire` hook issues a ROLLBACK before
|
||||
// handing it to the next consumer, so no async work is needed
|
||||
// here. If the pool is being shut down, SQLite itself rolls back
|
||||
// uncommitted transactions when the connection closes.
|
||||
log::warn!("WriteTransaction dropped without commit or rollback");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for WriteTransaction {
|
||||
type Target = SqliteConnection;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.conn
|
||||
.as_ref()
|
||||
.expect("BUG: WriteTransaction dereferenced after being consumed")
|
||||
.deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::DerefMut for WriteTransaction {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.conn
|
||||
.as_mut()
|
||||
.expect("BUG: WriteTransaction dereferenced after being consumed")
|
||||
.deref_mut()
|
||||
}
|
||||
}
|
||||
|
||||
/// Ensure the connection has no leftover open transaction (e.g. from a
|
||||
/// `WriteTransaction` that was dropped without commit/rollback). ROLLBACK
|
||||
/// is a harmless no-op if no transaction is active.
|
||||
fn rollback_before_acquire(
|
||||
conn: &mut SqliteConnection,
|
||||
_meta: sqlx::pool::PoolConnectionMetadata,
|
||||
) -> futures::future::BoxFuture<'_, Result<bool, sqlx::Error>> {
|
||||
Box::pin(async move {
|
||||
if let Err(e) = sqlx::query("ROLLBACK").execute(&mut *conn).await {
|
||||
// "cannot rollback - no transaction is active" is the common
|
||||
// case (connection returned cleanly). Only unexpected errors
|
||||
// deserve attention.
|
||||
log::debug!("before_acquire ROLLBACK failed: {e}");
|
||||
}
|
||||
Ok(true)
|
||||
})
|
||||
}
|
||||
|
||||
impl Database {
|
||||
pub async fn try_new(config: &DatabaseConfig, broadcasts: &Broadcasts) -> Result<Self> {
|
||||
pub async fn try_new(
|
||||
config: &DatabaseConfig,
|
||||
broadcasts: &Broadcasts,
|
||||
shutdown: tokio::sync::watch::Receiver<()>,
|
||||
) -> Result<Self> {
|
||||
tokio::fs::create_dir_all(&config.databases_directory_path)
|
||||
.await
|
||||
.with_context(|| {
|
||||
|
|
@ -70,24 +200,29 @@ impl Database {
|
|||
.trim_end_matches(".sqlite")
|
||||
.to_owned();
|
||||
|
||||
let pool = Self::create_vault_database(config, &vault).await?;
|
||||
Self::validate_vault_id(&vault)?;
|
||||
|
||||
let pools = Self::create_vault_database(config, &vault).await?;
|
||||
let cell = Arc::new(OnceCell::new());
|
||||
cell.set(pools).expect("cell is new");
|
||||
connection_pools.insert(
|
||||
vault.clone(),
|
||||
PoolWithTimestamp {
|
||||
pool,
|
||||
last_accessed: Instant::now(),
|
||||
},
|
||||
Arc::new(VaultPool {
|
||||
cell,
|
||||
last_accessed: Mutex::new(Instant::now()),
|
||||
}),
|
||||
);
|
||||
}
|
||||
info!("Database migrations applied");
|
||||
|
||||
let database = Self {
|
||||
config: config.clone(),
|
||||
connection_pools: Arc::new(Mutex::new(connection_pools)),
|
||||
broadcasts: broadcasts.clone(),
|
||||
write_locks: Arc::new(Mutex::new(HashMap::new())),
|
||||
};
|
||||
|
||||
// Start background task to cleanup idle connection pools
|
||||
database.start_idle_pool_cleanup();
|
||||
database.start_idle_pool_cleanup(shutdown);
|
||||
|
||||
Ok(database)
|
||||
}
|
||||
|
|
@ -95,92 +230,167 @@ impl Database {
|
|||
async fn create_vault_database(
|
||||
config: &DatabaseConfig,
|
||||
vault: &VaultId,
|
||||
) -> Result<Pool<Sqlite>> {
|
||||
) -> Result<VaultPools> {
|
||||
let file_name = config
|
||||
.databases_directory_path
|
||||
.join(format!("{vault}.sqlite"));
|
||||
|
||||
let connection_options = SqliteConnectOptions::new()
|
||||
// Database-level PRAGMAs (auto_vacuum, journal_mode) require a write
|
||||
// lock and persist across connections. Set them once with a dedicated
|
||||
// init connection so pool connections never need the write lock just to
|
||||
// open.
|
||||
let init_options = SqliteConnectOptions::new()
|
||||
.filename(file_name.clone())
|
||||
.create_if_missing(true)
|
||||
.auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Full)
|
||||
.busy_timeout(Duration::from_secs(3600))
|
||||
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
|
||||
.log_slow_statements(log::LevelFilter::Warn, Duration::from_secs(30));
|
||||
.auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental)
|
||||
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal);
|
||||
|
||||
let pool = SqlitePoolOptions::new()
|
||||
// Run migrations on a dedicated connection, NOT through the pool.
|
||||
// The pool's `before_acquire` hook issues ROLLBACK on every checkout,
|
||||
// which can roll back the migration's bookkeeping transaction (the
|
||||
// _sqlx_migrations INSERT) while the DDL (ALTER TABLE) has already
|
||||
// auto-committed — leaving the migration in a dirty state.
|
||||
//
|
||||
// Uses `run_direct` instead of `run` because `run` takes
|
||||
// `impl Acquire<'_>`, whose lifetime bound prevents the enclosing
|
||||
// future from satisfying the `Send` requirement of axum handlers.
|
||||
let mut init_conn = sqlx::SqliteConnection::connect_with(&init_options).await?;
|
||||
sqlx::migrate!("src/app_state/database/migrations")
|
||||
.run_direct(&mut init_conn)
|
||||
.await
|
||||
.context("Cannot run pending migrations")?;
|
||||
drop(init_conn);
|
||||
|
||||
// Per-connection PRAGMAs shared by both reader and writer pools.
|
||||
// journal_mode = WAL is a no-op on an already-WAL database.
|
||||
let base_options = SqliteConnectOptions::new()
|
||||
.filename(file_name.clone())
|
||||
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
|
||||
.busy_timeout(Duration::from_secs(30))
|
||||
.log_slow_statements(log::LevelFilter::Warn, Duration::from_secs(30))
|
||||
// In WAL mode, NORMAL is safe: data survives OS crashes, only the
|
||||
// last transaction can be lost on power failure. The default FULL
|
||||
// forces an extra fsync() per commit, roughly halving write throughput.
|
||||
.pragma("synchronous", "NORMAL")
|
||||
// 16 MB page cache per connection (negative = KiB). Reduces disk
|
||||
// reads for the latest_document_versions GROUP BY view.
|
||||
.pragma("cache_size", "-16384")
|
||||
// Memory-mapped I/O avoids read() syscalls. SQLite falls back to
|
||||
// regular I/O for writes and beyond the mapped region. 256 MB is
|
||||
// conservative; the OS handles actual memory pressure.
|
||||
.pragma("mmap_size", "268435456")
|
||||
// Keep temp tables and sort spillovers in memory instead of temp files.
|
||||
.pragma("temp_store", "MEMORY")
|
||||
// Cap WAL file growth at 64 MB. Without this, the WAL can grow
|
||||
// unbounded during heavy write bursts (e.g. E2E tests with many
|
||||
// concurrent clients). SQLite truncates to this size on checkpoint.
|
||||
.pragma("journal_size_limit", "67108864");
|
||||
|
||||
// Reader pool: multiple connections for concurrent reads.
|
||||
let reader = SqlitePoolOptions::new()
|
||||
.max_connections(config.max_connections_per_vault)
|
||||
.acquire_slow_threshold(Duration::from_secs(30))
|
||||
.test_before_acquire(true)
|
||||
.connect_with(connection_options)
|
||||
// Disabled: the health-check query is subject to busy_timeout
|
||||
// and blocks all connection checkouts when a write is active,
|
||||
// starving the pool for up to 30s even for simple reads.
|
||||
// The before_acquire ROLLBACK hook is sufficient for cleanup.
|
||||
.test_before_acquire(false)
|
||||
.before_acquire(rollback_before_acquire)
|
||||
.connect_with(base_options.clone())
|
||||
.await
|
||||
.with_context(|| format!("Cannot open database at `{}`", file_name.display()))?;
|
||||
.with_context(|| format!("Cannot open reader pool at `{}`", file_name.display()))?;
|
||||
|
||||
Self::run_migrations(&pool).await?;
|
||||
// Writer pool: exactly 1 connection, dedicated to writes.
|
||||
// Since the Tokio mutex already serializes writers per vault, this
|
||||
// single connection is never contended. Separating it from the
|
||||
// reader pool ensures writes never compete with reads for pool slots.
|
||||
let writer = SqlitePoolOptions::new()
|
||||
.max_connections(1)
|
||||
.acquire_slow_threshold(Duration::from_secs(30))
|
||||
.test_before_acquire(false)
|
||||
.before_acquire(rollback_before_acquire)
|
||||
.connect_with(base_options)
|
||||
.await
|
||||
.with_context(|| format!("Cannot open writer pool at `{}`", file_name.display()))?;
|
||||
|
||||
Ok(pool)
|
||||
Ok(VaultPools { reader, writer })
|
||||
}
|
||||
|
||||
async fn run_migrations(pool: &Pool<Sqlite>) -> Result<()> {
|
||||
sqlx::migrate!("src/app_state/database/migrations")
|
||||
.run(pool)
|
||||
.await
|
||||
.context("Cannot check for pending migrations")
|
||||
}
|
||||
|
||||
async fn get_connection_pool(&self, vault: &VaultId) -> Result<Pool<Sqlite>> {
|
||||
let mut pools = self.connection_pools.lock().await;
|
||||
|
||||
if !pools.contains_key(vault) {
|
||||
let pool = Self::create_vault_database(&self.config, vault).await?;
|
||||
pools.insert(
|
||||
vault.clone(),
|
||||
PoolWithTimestamp {
|
||||
pool,
|
||||
last_accessed: Instant::now(),
|
||||
},
|
||||
fn validate_vault_id(vault: &VaultId) -> Result<()> {
|
||||
if vault.is_empty() {
|
||||
anyhow::bail!("Vault ID must not be empty");
|
||||
}
|
||||
if vault.contains('/')
|
||||
|| vault.contains('\\')
|
||||
|| vault.contains("..")
|
||||
|| vault.contains('\0')
|
||||
{
|
||||
anyhow::bail!(
|
||||
"Invalid vault ID: must not contain path separators, '..', or null bytes"
|
||||
);
|
||||
}
|
||||
|
||||
let pool_with_timestamp = pools
|
||||
.get_mut(vault)
|
||||
.expect("Pool was just inserted or already exists");
|
||||
|
||||
// Update last accessed time
|
||||
pool_with_timestamp.last_accessed = Instant::now();
|
||||
|
||||
Ok(pool_with_timestamp.pool.clone())
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Attempting to write from this transaction might result in a
|
||||
/// database locked error. Use this transaction for read-only operations.
|
||||
pub async fn create_readonly_transaction(
|
||||
&self,
|
||||
vault: &VaultId,
|
||||
) -> Result<Transaction<'static>> {
|
||||
self.get_connection_pool(vault)
|
||||
.await?
|
||||
.begin()
|
||||
.await
|
||||
.context("Cannot create transaction")
|
||||
}
|
||||
async fn get_vault_pools(&self, vault: &VaultId) -> Result<VaultPools> {
|
||||
Self::validate_vault_id(vault)?;
|
||||
|
||||
pub async fn create_write_transaction(&self, vault: &VaultId) -> Result<Transaction<'static>> {
|
||||
let mut transaction = self.create_readonly_transaction(vault).await?;
|
||||
// Get or create the VaultPool entry. The global lock is held only
|
||||
// long enough for a HashMap lookup/insert — never across
|
||||
// create_vault_database.
|
||||
let vault_pool = {
|
||||
let mut pools = self.connection_pools.lock().await;
|
||||
pools
|
||||
.entry(vault.clone())
|
||||
.or_insert_with(|| {
|
||||
Arc::new(VaultPool {
|
||||
cell: Arc::new(OnceCell::new()),
|
||||
last_accessed: Mutex::new(Instant::now()),
|
||||
})
|
||||
})
|
||||
.clone()
|
||||
};
|
||||
|
||||
// sqlx doesn't support immediate transactions for sqlite: https://github.com/launchbadge/sqlx/issues/481
|
||||
sqlx::query!("END; BEGIN IMMEDIATE;")
|
||||
.execute(&mut *transaction)
|
||||
// OnceCell::get_or_try_init guarantees exactly-once
|
||||
// initialization: concurrent callers for the same vault wait
|
||||
// here; callers for other vaults are not blocked.
|
||||
let config = self.config.clone();
|
||||
let vault_clone = vault.clone();
|
||||
let pools = vault_pool
|
||||
.cell
|
||||
.get_or_try_init(|| async {
|
||||
Self::create_vault_database(&config, &vault_clone).await
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(transaction)
|
||||
*vault_pool.last_accessed.lock().await = Instant::now();
|
||||
Ok(pools.clone())
|
||||
}
|
||||
|
||||
/// Return the reader pool for read-only queries.
|
||||
async fn get_connection_pool(&self, vault: &VaultId) -> Result<Pool<Sqlite>> {
|
||||
Ok(self.get_vault_pools(vault).await?.reader)
|
||||
}
|
||||
|
||||
pub async fn create_write_transaction(&self, vault: &VaultId) -> Result<WriteTransaction> {
|
||||
let write_lock = {
|
||||
let mut locks = self.write_locks.lock().await;
|
||||
locks
|
||||
.entry(vault.clone())
|
||||
.or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
|
||||
.clone()
|
||||
};
|
||||
let write_guard = write_lock.lock_owned().await;
|
||||
let pools = self.get_vault_pools(vault).await?;
|
||||
WriteTransaction::new(&pools.writer, write_guard).await
|
||||
}
|
||||
|
||||
/// Return the latest state of all documents in the vault
|
||||
pub async fn get_latest_documents(
|
||||
&self,
|
||||
vault: &VaultId,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
connection: Option<&mut SqliteConnection>,
|
||||
) -> Result<Vec<DocumentVersionWithoutContent>> {
|
||||
let query = sqlx::query!(
|
||||
r#"
|
||||
|
|
@ -198,8 +408,8 @@ impl Database {
|
|||
"#,
|
||||
);
|
||||
|
||||
if let Some(transaction) = transaction {
|
||||
query.fetch_all(&mut **transaction).await
|
||||
if let Some(conn) = connection {
|
||||
query.fetch_all(&mut *conn).await
|
||||
} else {
|
||||
query
|
||||
.fetch_all(&self.get_connection_pool(vault).await?)
|
||||
|
|
@ -216,9 +426,7 @@ impl Database {
|
|||
is_deleted: row.is_deleted,
|
||||
user_id: row.user_id,
|
||||
device_id: row.device_id,
|
||||
content_size: row
|
||||
.content_size
|
||||
.expect("Content size can't be null but sqlx can't infer it"),
|
||||
content_size: row.content_size.unwrap_or(0),
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
|
|
@ -230,7 +438,7 @@ impl Database {
|
|||
&self,
|
||||
vault: &VaultId,
|
||||
vault_update_id: VaultUpdateId,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
connection: Option<&mut SqliteConnection>,
|
||||
) -> Result<Vec<DocumentVersionWithoutContent>> {
|
||||
let query = sqlx::query!(
|
||||
r#"
|
||||
|
|
@ -250,8 +458,8 @@ impl Database {
|
|||
vault_update_id
|
||||
);
|
||||
|
||||
if let Some(transaction) = transaction {
|
||||
query.fetch_all(&mut **transaction).await
|
||||
if let Some(conn) = connection {
|
||||
query.fetch_all(&mut *conn).await
|
||||
} else {
|
||||
query
|
||||
.fetch_all(&self.get_connection_pool(vault).await?)
|
||||
|
|
@ -270,9 +478,7 @@ impl Database {
|
|||
is_deleted: row.is_deleted,
|
||||
user_id: row.user_id,
|
||||
device_id: row.device_id,
|
||||
content_size: row
|
||||
.content_size
|
||||
.expect("Content size can't be null but sqlx can't infer it"),
|
||||
content_size: row.content_size.unwrap_or(0),
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
|
|
@ -281,7 +487,7 @@ impl Database {
|
|||
pub async fn get_max_update_id_in_vault(
|
||||
&self,
|
||||
vault: &VaultId,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
connection: Option<&mut SqliteConnection>,
|
||||
) -> Result<i64> {
|
||||
let query = sqlx::query!(
|
||||
r#"
|
||||
|
|
@ -290,8 +496,8 @@ impl Database {
|
|||
"#,
|
||||
);
|
||||
|
||||
if let Some(transaction) = transaction {
|
||||
query.fetch_one(&mut **transaction).await
|
||||
if let Some(conn) = connection {
|
||||
query.fetch_one(&mut *conn).await
|
||||
} else {
|
||||
query
|
||||
.fetch_one(&self.get_connection_pool(vault).await?)
|
||||
|
|
@ -301,11 +507,11 @@ impl Database {
|
|||
.context("Cannot fetch max update id in vault")
|
||||
}
|
||||
|
||||
pub async fn get_latest_document_by_path(
|
||||
pub async fn get_latest_non_deleted_document_by_path(
|
||||
&self,
|
||||
vault: &VaultId,
|
||||
relative_path: &str,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
connection: Option<&mut SqliteConnection>,
|
||||
) -> Result<Option<StoredDocumentVersion>> {
|
||||
let query = sqlx::query_as!(
|
||||
StoredDocumentVersion,
|
||||
|
|
@ -330,8 +536,8 @@ impl Database {
|
|||
relative_path
|
||||
);
|
||||
|
||||
if let Some(transaction) = transaction {
|
||||
query.fetch_optional(&mut **transaction).await
|
||||
if let Some(conn) = connection {
|
||||
query.fetch_optional(&mut *conn).await
|
||||
} else {
|
||||
query
|
||||
.fetch_optional(&self.get_connection_pool(vault).await?)
|
||||
|
|
@ -344,7 +550,7 @@ impl Database {
|
|||
&self,
|
||||
vault: &VaultId,
|
||||
document_id: &DocumentId,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
connection: Option<&mut SqliteConnection>,
|
||||
) -> Result<Option<StoredDocumentVersion>> {
|
||||
let document_id = document_id.as_hyphenated();
|
||||
let query = sqlx::query_as!(
|
||||
|
|
@ -366,8 +572,8 @@ impl Database {
|
|||
document_id
|
||||
);
|
||||
|
||||
if let Some(transaction) = transaction {
|
||||
query.fetch_optional(&mut **transaction).await
|
||||
if let Some(conn) = connection {
|
||||
query.fetch_optional(&mut *conn).await
|
||||
} else {
|
||||
query
|
||||
.fetch_optional(&self.get_connection_pool(vault).await?)
|
||||
|
|
@ -380,7 +586,7 @@ impl Database {
|
|||
&self,
|
||||
vault: &VaultId,
|
||||
vault_update_id: VaultUpdateId,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
connection: Option<&mut SqliteConnection>,
|
||||
) -> Result<Option<StoredDocumentVersion>> {
|
||||
let query = sqlx::query_as!(
|
||||
StoredDocumentVersion,
|
||||
|
|
@ -400,8 +606,8 @@ impl Database {
|
|||
vault_update_id
|
||||
);
|
||||
|
||||
if let Some(transaction) = transaction {
|
||||
query.fetch_optional(&mut **transaction).await
|
||||
if let Some(conn) = connection {
|
||||
query.fetch_optional(&mut *conn).await
|
||||
} else {
|
||||
query
|
||||
.fetch_optional(&self.get_connection_pool(vault).await?)
|
||||
|
|
@ -415,7 +621,7 @@ impl Database {
|
|||
&self,
|
||||
vault_id: &VaultId,
|
||||
version: &StoredDocumentVersion,
|
||||
transaction: Option<Transaction<'_>>,
|
||||
transaction: Option<WriteTransaction>,
|
||||
) -> Result<()> {
|
||||
let document_id = version.document_id.as_hyphenated();
|
||||
let query = sqlx::query!(
|
||||
|
|
@ -428,9 +634,10 @@ impl Database {
|
|||
content,
|
||||
is_deleted,
|
||||
user_id,
|
||||
device_id
|
||||
device_id,
|
||||
has_been_merged
|
||||
)
|
||||
values (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
values (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"#,
|
||||
version.vault_update_id,
|
||||
document_id,
|
||||
|
|
@ -439,7 +646,8 @@ impl Database {
|
|||
version.content,
|
||||
version.is_deleted,
|
||||
version.user_id,
|
||||
version.device_id
|
||||
version.device_id,
|
||||
version.has_been_merged
|
||||
);
|
||||
|
||||
if let Some(mut transaction) = transaction {
|
||||
|
|
@ -477,38 +685,66 @@ impl Database {
|
|||
|
||||
/// Cleanup idle connection pools that haven't been accessed in more than 5 minutes
|
||||
async fn cleanup_idle_pools(&self) {
|
||||
let mut pools = self.connection_pools.lock().await;
|
||||
let now = Instant::now();
|
||||
let idle_timeout = Duration::from_secs(5 * 60); // 5 minutes
|
||||
// Collect idle vaults and remove them from the map while holding
|
||||
// the lock briefly. Close pools OUTSIDE the lock so that
|
||||
// pool.close().await doesn't block other get_connection_pool calls.
|
||||
let idle_pools: Vec<(VaultId, Arc<VaultPool>)> = {
|
||||
let mut pools = self.connection_pools.lock().await;
|
||||
let now = Instant::now();
|
||||
|
||||
// Collect vaults to remove
|
||||
let vaults_to_remove: Vec<VaultId> = pools
|
||||
.iter()
|
||||
.filter(|(_, pool_with_timestamp)| {
|
||||
now.duration_since(pool_with_timestamp.last_accessed) > idle_timeout
|
||||
})
|
||||
.map(|(vault_id, _)| vault_id.clone())
|
||||
.collect();
|
||||
let vaults_to_remove: Vec<VaultId> = pools
|
||||
.iter()
|
||||
.filter(|(_, vp)| {
|
||||
// If the lock is contested, the pool is actively used — not idle.
|
||||
let Ok(last) = vp.last_accessed.try_lock() else {
|
||||
return false;
|
||||
};
|
||||
now.duration_since(*last) > IDLE_POOL_TIMEOUT
|
||||
})
|
||||
.map(|(vault_id, _)| vault_id.clone())
|
||||
.collect();
|
||||
|
||||
// Close and remove idle pools
|
||||
for vault_id in &vaults_to_remove {
|
||||
if let Some(pool_with_timestamp) = pools.remove(vault_id) {
|
||||
info!("Closing idle database connection pool for vault `{vault_id}`");
|
||||
pool_with_timestamp.pool.close().await;
|
||||
vaults_to_remove
|
||||
.into_iter()
|
||||
.filter_map(|id| pools.remove(&id).map(|vp| (id, vp)))
|
||||
.collect()
|
||||
};
|
||||
|
||||
for (vault_id, vault_pool) in idle_pools {
|
||||
if let Some(pools) = vault_pool.cell.get() {
|
||||
// Checkpoint the WAL before closing to reclaim disk space
|
||||
// and ensure the next open doesn't need a large WAL replay.
|
||||
// TRUNCATE mode resets the WAL file to zero bytes.
|
||||
if let Err(e) = sqlx::query("PRAGMA wal_checkpoint(TRUNCATE)")
|
||||
.execute(&pools.writer)
|
||||
.await
|
||||
{
|
||||
log::warn!("WAL checkpoint failed for vault `{vault_id}`: {e}");
|
||||
}
|
||||
info!("Closing idle database connection pools for vault `{vault_id}`");
|
||||
pools.reader.close().await;
|
||||
pools.writer.close().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Start a background task that periodically cleans up idle connection pools
|
||||
fn start_idle_pool_cleanup(&self) {
|
||||
fn start_idle_pool_cleanup(&self, mut shutdown: tokio::sync::watch::Receiver<()>) {
|
||||
let database = self.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(60)); // Check every minute
|
||||
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
database.cleanup_idle_pools().await;
|
||||
tokio::select! {
|
||||
_ = interval.tick() => {
|
||||
database.cleanup_idle_pools().await;
|
||||
}
|
||||
_ = shutdown.changed() => {
|
||||
info!("Idle pool cleanup task shutting down");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ pub struct StoredDocumentVersion {
|
|||
pub device_id: DeviceId,
|
||||
#[allow(dead_code)] // This is for manual analysis
|
||||
pub has_been_merged: bool,
|
||||
pub idempotency_key: Option<String>,
|
||||
}
|
||||
|
||||
impl PartialEq<Self> for StoredDocumentVersion {
|
||||
|
|
|
|||
|
|
@ -1,35 +1,52 @@
|
|||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use anyhow::Context;
|
||||
use log::{debug, warn};
|
||||
use tokio::sync::{Mutex, broadcast};
|
||||
|
||||
use super::models::WebSocketServerMessageWithOrigin;
|
||||
use crate::{
|
||||
app_state::database::models::VaultId, config::server_config::ServerConfig, errors::server_error,
|
||||
};
|
||||
use crate::{app_state::database::models::VaultId, config::server_config::ServerConfig};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Broadcasts {
|
||||
max_clients_per_vault: usize,
|
||||
broadcast_channel_capacity: usize,
|
||||
tx: Arc<Mutex<HashMap<VaultId, broadcast::Sender<WebSocketServerMessageWithOrigin>>>>,
|
||||
}
|
||||
|
||||
type TxMap = HashMap<VaultId, broadcast::Sender<WebSocketServerMessageWithOrigin>>;
|
||||
|
||||
impl Broadcasts {
|
||||
pub fn new(server_config: &ServerConfig) -> Self {
|
||||
Self {
|
||||
max_clients_per_vault: server_config.max_clients_per_vault,
|
||||
broadcast_channel_capacity: server_config.broadcast_channel_capacity,
|
||||
tx: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove senders for vaults with no active receivers
|
||||
fn prune_inactive_vaults(tx_map: &mut TxMap) {
|
||||
tx_map.retain(|_, sender| sender.receiver_count() > 0);
|
||||
}
|
||||
|
||||
pub async fn get_receiver(
|
||||
&self,
|
||||
vault: VaultId,
|
||||
) -> broadcast::Receiver<WebSocketServerMessageWithOrigin> {
|
||||
let tx = self.get_or_create(vault).await;
|
||||
max_clients: usize,
|
||||
) -> Result<broadcast::Receiver<WebSocketServerMessageWithOrigin>, crate::errors::SyncServerError>
|
||||
{
|
||||
let mut tx_map = self.tx.lock().await;
|
||||
Self::prune_inactive_vaults(&mut tx_map);
|
||||
|
||||
tx.subscribe()
|
||||
let sender = tx_map
|
||||
.entry(vault)
|
||||
.or_insert_with(|| broadcast::channel(self.broadcast_channel_capacity).0);
|
||||
|
||||
if sender.receiver_count() >= max_clients {
|
||||
return Err(crate::errors::client_error(anyhow::anyhow!(
|
||||
"Vault has reached the maximum number of clients ({max_clients})"
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(sender.subscribe())
|
||||
}
|
||||
|
||||
/// Notify all clients (who are subscribed to the vault) about an update.
|
||||
|
|
@ -39,31 +56,20 @@ impl Broadcasts {
|
|||
vault: VaultId,
|
||||
document: WebSocketServerMessageWithOrigin,
|
||||
) {
|
||||
let tx = self.get_or_create(vault.clone()).await;
|
||||
let mut tx_map = self.tx.lock().await;
|
||||
Self::prune_inactive_vaults(&mut tx_map);
|
||||
|
||||
if tx.receiver_count() == 0 {
|
||||
let sender = tx_map
|
||||
.entry(vault.clone())
|
||||
.or_insert_with(|| broadcast::channel(self.broadcast_channel_capacity).0);
|
||||
|
||||
if sender.receiver_count() == 0 {
|
||||
debug!("Skipping broadcast, no clients connected for vault `{vault}`");
|
||||
return;
|
||||
}
|
||||
|
||||
let result = tx
|
||||
.send(document)
|
||||
.context("Cannot broadcast server message to websocket listeners")
|
||||
.map_err(server_error);
|
||||
|
||||
if result.is_err() {
|
||||
warn!("Failed to send message: {result:?}");
|
||||
if let Err(e) = sender.send(document) {
|
||||
warn!("Failed to broadcast to vault `{vault}`: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_or_create(
|
||||
&self,
|
||||
vault: VaultId,
|
||||
) -> broadcast::Sender<WebSocketServerMessageWithOrigin> {
|
||||
let mut tx = self.tx.lock().await;
|
||||
|
||||
tx.entry(vault)
|
||||
.or_insert_with(|| broadcast::channel(self.max_clients_per_vault).0.clone())
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,6 +27,19 @@ pub struct Config {
|
|||
}
|
||||
|
||||
impl Config {
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
self.server
|
||||
.validate()
|
||||
.context("Invalid server configuration")?;
|
||||
self.logging
|
||||
.validate()
|
||||
.context("Invalid logging configuration")?;
|
||||
self.database
|
||||
.validate()
|
||||
.context("Invalid database configuration")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn read_or_create(path: &Path) -> Result<Self> {
|
||||
let display_path = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use std::{path::PathBuf, time::Duration};
|
||||
|
||||
use anyhow::{Result, ensure};
|
||||
use log::debug;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
|
@ -34,6 +35,24 @@ fn default_cursor_timeout() -> Duration {
|
|||
DEFAULT_CURSOR_TIMEOUT
|
||||
}
|
||||
|
||||
impl DatabaseConfig {
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
ensure!(
|
||||
self.databases_directory_path.as_os_str().len() > 0,
|
||||
"databases_directory_path must not be empty"
|
||||
);
|
||||
ensure!(
|
||||
self.max_connections_per_vault > 0,
|
||||
"max_connections_per_vault must be greater than 0"
|
||||
);
|
||||
ensure!(
|
||||
!self.cursor_timeout.is_zero(),
|
||||
"cursor_timeout must be greater than 0"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DatabaseConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use std::time::Duration;
|
||||
|
||||
use anyhow::{Result, ensure};
|
||||
use log::debug;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
|
@ -20,6 +21,17 @@ pub struct LoggingConfig {
|
|||
pub log_level: LogLevel,
|
||||
}
|
||||
|
||||
impl LoggingConfig {
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
ensure!(
|
||||
!self.log_directory.is_empty(),
|
||||
"log_directory must not be an empty string"
|
||||
);
|
||||
ensure!(self.log_rotation > 0, "log_rotation must be greater than 0");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LoggingConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
use anyhow::{Result, ensure};
|
||||
use log::debug;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::consts::{
|
||||
DEFAULT_HOST, DEFAULT_MAX_BODY_SIZE_MB, DEFAULT_MAX_CLIENTS_PER_VAULT,
|
||||
DEFAULT_MERGEABLE_FILE_EXTENSIONS, DEFAULT_PORT, DEFAULT_RESPONSE_TIMEOUT_SECONDS,
|
||||
DEFAULT_ALLOWED_ORIGINS, DEFAULT_BROADCAST_CHANNEL_CAPACITY, DEFAULT_HOST,
|
||||
DEFAULT_MAX_BODY_SIZE_MB, DEFAULT_MAX_CLIENTS_PER_VAULT, DEFAULT_MAX_PENDING_WS_CONNECTIONS,
|
||||
DEFAULT_MERGEABLE_FILE_EXTENSIONS, DEFAULT_PORT, DEFAULT_RATE_LIMIT_PER_USER_PER_SECOND,
|
||||
DEFAULT_RESPONSE_TIMEOUT_SECONDS,
|
||||
};
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
|
||||
|
|
@ -21,11 +24,56 @@ pub struct ServerConfig {
|
|||
#[serde(default = "default_max_clients_per_vault")]
|
||||
pub max_clients_per_vault: usize,
|
||||
|
||||
#[serde(default = "default_broadcast_channel_capacity")]
|
||||
pub broadcast_channel_capacity: usize,
|
||||
|
||||
#[serde(default = "default_response_timeout", with = "humantime_serde")]
|
||||
pub response_timeout: Duration,
|
||||
|
||||
#[serde(default = "default_mergeable_file_extensions")]
|
||||
pub mergeable_file_extensions: Vec<String>,
|
||||
|
||||
/// Per-user maximum requests per second (keyed by bearer token).
|
||||
/// `None` disables rate limiting.
|
||||
#[serde(default = "DEFAULT_RATE_LIMIT_PER_USER_PER_SECOND")]
|
||||
pub rate_limit_per_user_per_second: Option<u64>,
|
||||
|
||||
/// Allowed CORS origins. Default: `["*"]` (allow all).
|
||||
#[serde(default = "default_allowed_origins")]
|
||||
pub allowed_origins: Vec<String>,
|
||||
|
||||
/// Maximum concurrent unauthenticated WebSocket connections waiting for
|
||||
/// handshake. Limits resource consumption from clients that connect but
|
||||
/// never authenticate.
|
||||
#[serde(default = "default_max_pending_websocket_connections")]
|
||||
pub max_pending_websocket_connections: usize,
|
||||
}
|
||||
|
||||
impl ServerConfig {
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
ensure!(
|
||||
self.response_timeout > 0,
|
||||
"response_timeout must be greater than 0"
|
||||
);
|
||||
ensure!(
|
||||
self.max_body_size_mb > 0,
|
||||
"max_body_size_mb must be greater than 0"
|
||||
);
|
||||
ensure!(
|
||||
self.max_clients_per_vault > 0,
|
||||
"max_clients_per_vault must be greater than 0"
|
||||
);
|
||||
ensure!(
|
||||
self.broadcast_channel_capacity > 0,
|
||||
"broadcast_channel_capacity must be greater than 0"
|
||||
);
|
||||
ensure!(
|
||||
self.max_pending_websocket_connections > 0,
|
||||
"max_pending_websocket_connections must be greater than 0"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn default_host() -> String {
|
||||
|
|
@ -48,6 +96,11 @@ fn default_max_clients_per_vault() -> usize {
|
|||
DEFAULT_MAX_CLIENTS_PER_VAULT
|
||||
}
|
||||
|
||||
fn default_broadcast_channel_capacity() -> usize {
|
||||
debug!("Using default broadcast channel capacity: {DEFAULT_BROADCAST_CHANNEL_CAPACITY}");
|
||||
DEFAULT_BROADCAST_CHANNEL_CAPACITY
|
||||
}
|
||||
|
||||
fn default_response_timeout() -> Duration {
|
||||
debug!("Using default response timeout: {DEFAULT_RESPONSE_TIMEOUT_SECONDS:?}");
|
||||
DEFAULT_RESPONSE_TIMEOUT_SECONDS
|
||||
|
|
@ -60,3 +113,21 @@ fn default_mergeable_file_extensions() -> Vec<String> {
|
|||
.map(|s| (*s).to_owned())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn DEFAULT_RATE_LIMIT_PER_USER_PER_SECOND() -> Option<u64> {
|
||||
debug!("Using default rate limit per second: {DEFAULT_RATE_LIMIT_PER_USER_PER_SECOND:?}");
|
||||
DEFAULT_RATE_LIMIT_PER_USER_PER_SECOND
|
||||
}
|
||||
|
||||
fn default_allowed_origins() -> Vec<String> {
|
||||
debug!("Using default allowed origins: {DEFAULT_ALLOWED_ORIGINS:?}");
|
||||
DEFAULT_ALLOWED_ORIGINS
|
||||
.iter()
|
||||
.map(|s| (*s).to_owned())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn default_max_pending_websocket_connections() -> usize {
|
||||
debug!("Using default max pending WebSocket connections: {DEFAULT_MAX_PENDING_WS_CONNECTIONS}");
|
||||
DEFAULT_MAX_PENDING_WS_CONNECTIONS
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,19 +5,31 @@ use crate::utils::log_level::LogLevel;
|
|||
pub const DEFAULT_CONFIG_PATH: &str = "config.yml";
|
||||
|
||||
pub const DEFAULT_DATABASES_DIRECTORY_PATH: &str = "databases";
|
||||
pub const DEFAULT_MAX_CONNECTIONS_PER_VAULT: u32 = 12;
|
||||
pub const DEFAULT_MAX_CONNECTIONS_PER_VAULT: u32 = 6;
|
||||
pub const DEFAULT_CURSOR_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
|
||||
pub const DEFAULT_HOST: &str = "127.0.0.1";
|
||||
pub const DEFAULT_PORT: u16 = 3000;
|
||||
pub const DEFAULT_MAX_BODY_SIZE_MB: usize = 4096;
|
||||
pub const DEFAULT_RESPONSE_TIMEOUT_SECONDS: Duration = Duration::from_secs(1800);
|
||||
pub const DEFAULT_RESPONSE_TIMEOUT_SECONDS: Duration = Duration::from_mins(30);
|
||||
pub const DEFAULT_MAX_CLIENTS_PER_VAULT: usize = 256;
|
||||
pub const DEFAULT_BROADCAST_CHANNEL_CAPACITY: usize = 4096;
|
||||
pub const DEFAULT_MAX_PENDING_WS_CONNECTIONS: usize = 128;
|
||||
|
||||
pub const DEFAULT_LOG_DIRECTORY: &str = "logs";
|
||||
pub const DEFAULT_LOG_ROTATION_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24); // 1 day
|
||||
pub const DEFAULT_LOG_ROTATION_INTERVAL: Duration = Duration::from_hours(24);
|
||||
pub const IDLE_POOL_TIMEOUT: Duration = Duration::from_mins(5);
|
||||
pub const GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
pub const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
pub const MAX_CURSOR_DOCUMENTS: usize = 1000;
|
||||
pub const MAX_CURSORS_PER_DOCUMENT: usize = 100;
|
||||
pub const MAX_RELATIVE_PATH_LEN: usize = 4096;
|
||||
|
||||
pub const DEFAULT_LOG_LEVEL: LogLevel = LogLevel::Info;
|
||||
|
||||
pub const DEFAULT_MERGEABLE_FILE_EXTENSIONS: &[&str] = &["md", "txt"];
|
||||
|
||||
pub const SUPPORTED_API_VERSION: u32 = 2;
|
||||
pub const DEFAULT_RATE_LIMIT_PER_USER_PER_SECOND: Option<u64> = None;
|
||||
pub const DEFAULT_ALLOWED_ORIGINS: &[&str] = &["*"];
|
||||
pub const SUPPORTED_API_VERSION: u32 = 3;
|
||||
|
|
|
|||
|
|
@ -41,11 +41,12 @@ async fn main() -> ExitCode {
|
|||
}
|
||||
};
|
||||
|
||||
let mut result = set_up_logging(&args, &config.logging);
|
||||
|
||||
if result.is_ok() {
|
||||
result = start_server(config).await;
|
||||
let result = async {
|
||||
config.validate().map_err(init_error)?;
|
||||
set_up_logging(&args, &config.logging)?;
|
||||
start_server(config).await
|
||||
}
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(()) => ExitCode::SUCCESS,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ mod fetch_latest_document_version;
|
|||
mod fetch_latest_documents;
|
||||
mod index;
|
||||
mod ping;
|
||||
mod rate_limit;
|
||||
mod requests;
|
||||
mod responses;
|
||||
mod update_document;
|
||||
|
|
@ -24,7 +25,7 @@ use axum::{
|
|||
routing::{IntoMakeService, delete, get, post, put},
|
||||
};
|
||||
use device_id_header::DEVICE_ID_HEADER_NAME;
|
||||
use log::info;
|
||||
use log::{info, warn};
|
||||
use tokio::signal;
|
||||
use tower_http::{
|
||||
LatencyUnit,
|
||||
|
|
@ -41,7 +42,7 @@ use tracing::{Level, info_span};
|
|||
use crate::{
|
||||
app_state::AppState,
|
||||
config::{Config, server_config::ServerConfig},
|
||||
errors::{client_error, not_found_error},
|
||||
consts::GRACEFUL_SHUTDOWN_TIMEOUT,
|
||||
};
|
||||
|
||||
pub async fn create_server(config: Config) -> Result<()> {
|
||||
|
|
@ -56,21 +57,26 @@ pub async fn create_server(config: Config) -> Result<()> {
|
|||
.route("/", get(index::index))
|
||||
.route("/vaults/:vault_id/ping", get(ping::ping))
|
||||
.route("/vaults/:vault_id/ws", get(websocket::websocket_handler))
|
||||
.fallback(index::spa_fallback);
|
||||
|
||||
let cors_layer = build_cors_layer(&server_config).context("Invalid CORS configuration")?;
|
||||
|
||||
if let Some(rate_limit) = server_config.rate_limit_per_user_per_second {
|
||||
info!("Rate limiting enabled: {rate_limit} requests/second per user");
|
||||
let limiter = rate_limit::RateLimiter::new(rate_limit);
|
||||
app = app.layer(middleware::from_fn_with_state(
|
||||
limiter,
|
||||
rate_limit::rate_limit_middleware,
|
||||
));
|
||||
}
|
||||
|
||||
let app = app
|
||||
.layer(DefaultBodyLimit::disable())
|
||||
.layer(RequestBodyLimitLayer::new(
|
||||
app_state.config.server.max_body_size_mb * 1024 * 1024,
|
||||
))
|
||||
.layer(TimeoutLayer::new(server_config.response_timeout))
|
||||
.layer(
|
||||
CorsLayer::new()
|
||||
.allow_origin("*".parse::<HeaderValue>().expect("Failed to parse origin"))
|
||||
.allow_headers([
|
||||
http::header::CONTENT_TYPE,
|
||||
http::header::AUTHORIZATION,
|
||||
DEVICE_ID_HEADER_NAME.clone(),
|
||||
])
|
||||
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE]),
|
||||
)
|
||||
.layer(cors_layer)
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(|request: &Request<_>| {
|
||||
|
|
@ -90,12 +96,39 @@ pub async fn create_server(config: Config) -> Result<()> {
|
|||
.on_eos(DefaultOnEos::new())
|
||||
.on_failure(DefaultOnFailure::new().level(Level::ERROR)),
|
||||
)
|
||||
.with_state(app_state)
|
||||
.fallback(handle_404)
|
||||
.fallback(handle_405)
|
||||
.with_state(app_state.clone())
|
||||
.into_make_service();
|
||||
|
||||
start_server(app, &server_config).await
|
||||
start_server(app, &server_config, app_state).await
|
||||
}
|
||||
|
||||
fn build_cors_layer(server_config: &ServerConfig) -> Result<CorsLayer> {
|
||||
let origins = &server_config.allowed_origins;
|
||||
|
||||
let cors = if origins.len() == 1 && origins[0] == "*" {
|
||||
info!("CORS: allowing all origins (wildcard)");
|
||||
let header: HeaderValue = "*"
|
||||
.parse()
|
||||
.context("Failed to parse wildcard CORS origin")?;
|
||||
CorsLayer::new().allow_origin(header)
|
||||
} else {
|
||||
let parsed: Vec<HeaderValue> = origins
|
||||
.iter()
|
||||
.map(|o| {
|
||||
o.parse::<HeaderValue>()
|
||||
.with_context(|| format!("Failed to parse CORS origin: `{o}`"))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
CorsLayer::new().allow_origin(parsed)
|
||||
};
|
||||
|
||||
Ok(cors
|
||||
.allow_headers([
|
||||
http::header::CONTENT_TYPE,
|
||||
http::header::AUTHORIZATION,
|
||||
DEVICE_ID_HEADER_NAME.clone(),
|
||||
])
|
||||
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE]))
|
||||
}
|
||||
|
||||
fn get_authed_routes(app_state: AppState) -> Router<AppState> {
|
||||
|
|
@ -135,7 +168,11 @@ fn get_authed_routes(app_state: AppState) -> Router<AppState> {
|
|||
.layer(middleware::from_fn_with_state(app_state, auth_middleware))
|
||||
}
|
||||
|
||||
async fn start_server(app: IntoMakeService<axum::Router>, config: &ServerConfig) -> Result<()> {
|
||||
async fn start_server(
|
||||
app: IntoMakeService<axum::Router>,
|
||||
config: &ServerConfig,
|
||||
app_state: AppState,
|
||||
) -> Result<()> {
|
||||
let address = format!("{}:{}", config.host, config.port);
|
||||
let listener = tokio::net::TcpListener::bind(address.clone())
|
||||
.await
|
||||
|
|
@ -148,26 +185,46 @@ async fn start_server(app: IntoMakeService<axum::Router>, config: &ServerConfig)
|
|||
.context("Failed to get local address")?
|
||||
);
|
||||
|
||||
axum::serve(listener, app)
|
||||
.with_graceful_shutdown(shutdown_signal())
|
||||
.tcp_nodelay(true)
|
||||
.await
|
||||
.context("Failed to start server")
|
||||
let mut shutdown_rx = app_state.subscribe_shutdown();
|
||||
|
||||
let server = axum::serve(listener, app)
|
||||
.with_graceful_shutdown(async move {
|
||||
shutdown_signal().await;
|
||||
app_state.shutdown();
|
||||
})
|
||||
.tcp_nodelay(true);
|
||||
|
||||
tokio::select! {
|
||||
result = server => result.context("Failed to start server"),
|
||||
() = async {
|
||||
let _ = shutdown_rx.changed().await;
|
||||
info!(
|
||||
"Shutdown signal received, waiting up to {}s for in-flight requests to complete...",
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT.as_secs()
|
||||
);
|
||||
tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT).await;
|
||||
warn!("Graceful shutdown timed out, forcing exit");
|
||||
} => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn shutdown_signal() {
|
||||
let ctrl_c = async {
|
||||
signal::ctrl_c()
|
||||
.await
|
||||
.expect("failed to install Ctrl+C handler");
|
||||
if let Err(e) = signal::ctrl_c().await {
|
||||
log::error!("Failed to install Ctrl+C handler: {e}");
|
||||
}
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
let terminate = async {
|
||||
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||
.expect("failed to install signal handler")
|
||||
.recv()
|
||||
.await;
|
||||
match signal::unix::signal(signal::unix::SignalKind::terminate()) {
|
||||
Ok(mut signal) => {
|
||||
signal.recv().await;
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Failed to install SIGTERM handler: {e}");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#[cfg(not(unix))]
|
||||
|
|
@ -178,11 +235,3 @@ async fn shutdown_signal() {
|
|||
() = terminate => {},
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_404() -> impl IntoResponse {
|
||||
not_found_error(anyhow!("Page not found"))
|
||||
}
|
||||
|
||||
async fn handle_405() -> impl IntoResponse {
|
||||
client_error(anyhow!("Method not allowed"))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
use anyhow::Context as _;
|
||||
use axum::{
|
||||
Extension, Json,
|
||||
extract::{Path, State},
|
||||
|
|
@ -5,18 +6,21 @@ use axum::{
|
|||
use axum_extra::TypedHeader;
|
||||
use axum_typed_multipart::TypedMultipart;
|
||||
use log::{debug, info};
|
||||
use reconcile_text::{BuiltinTokenizer, reconcile};
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::{device_id_header::DeviceIdHeader, requests::CreateDocumentVersion};
|
||||
use crate::{
|
||||
app_state::{
|
||||
AppState,
|
||||
database::models::{DocumentVersionWithoutContent, StoredDocumentVersion, VaultId},
|
||||
database::models::{StoredDocumentVersion, VaultId},
|
||||
},
|
||||
config::user_config::User,
|
||||
errors::{SyncServerError, client_error, server_error},
|
||||
errors::{SyncServerError, client_error, server_error, write_transaction_error},
|
||||
server::{responses::DocumentUpdateResponse, update_document},
|
||||
utils::{
|
||||
find_first_available_path::find_first_available_path, normalize::normalize,
|
||||
find_first_available_path::find_first_available_path, is_binary::is_binary,
|
||||
is_file_type_mergable::is_file_type_mergable, normalize::normalize,
|
||||
sanitize_path::sanitize_path,
|
||||
},
|
||||
};
|
||||
|
|
@ -30,48 +34,75 @@ pub struct CreateDocumentPathParams {
|
|||
/// Create a new document in case a document with the same doesn't exist
|
||||
/// already. If a document with the same path exists, a new version is created
|
||||
/// with their content merged.
|
||||
///
|
||||
/// Text content must be UTF-8 encoded. Clients are responsible for
|
||||
/// transcoding other encodings (e.g. UTF-16) to UTF-8 before sending.
|
||||
#[axum::debug_handler]
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn create_document(
|
||||
Path(CreateDocumentPathParams { vault_id }): Path<CreateDocumentPathParams>,
|
||||
Extension(user): Extension<User>,
|
||||
TypedHeader(device_id): TypedHeader<DeviceIdHeader>,
|
||||
State(state): State<AppState>,
|
||||
TypedMultipart(request): TypedMultipart<CreateDocumentVersion>,
|
||||
) -> Result<Json<DocumentVersionWithoutContent>, SyncServerError> {
|
||||
) -> Result<Json<DocumentUpdateResponse>, SyncServerError> {
|
||||
debug!("Creating document in vault `{vault_id}`");
|
||||
|
||||
let mut transaction = state
|
||||
.database
|
||||
.create_write_transaction(&vault_id)
|
||||
.await
|
||||
.map_err(server_error)?;
|
||||
.map_err(write_transaction_error)?;
|
||||
|
||||
let document_id = match request.document_id {
|
||||
Some(document_id) => {
|
||||
let existing_version = state
|
||||
.database
|
||||
.get_latest_document(&vault_id, &document_id, Some(&mut transaction))
|
||||
.await
|
||||
.map_err(server_error)?;
|
||||
let sanitized_relative_path = sanitize_path(&request.relative_path).map_err(client_error)?;
|
||||
let new_content = request.content.contents.to_vec();
|
||||
|
||||
if existing_version.is_some() {
|
||||
return Err(client_error(anyhow::anyhow!(
|
||||
"Document with the same ID `{document_id}` already exists"
|
||||
)));
|
||||
}
|
||||
|
||||
document_id
|
||||
}
|
||||
None => uuid::Uuid::new_v4(),
|
||||
};
|
||||
|
||||
let last_update_id = state
|
||||
let latest_version = state
|
||||
.database
|
||||
.get_max_update_id_in_vault(&vault_id, Some(&mut transaction))
|
||||
.get_latest_non_deleted_document_by_path(
|
||||
&vault_id,
|
||||
&sanitized_relative_path,
|
||||
Some(&mut *transaction),
|
||||
)
|
||||
.await
|
||||
.map_err(server_error)?;
|
||||
|
||||
if let Some(latest_version) = latest_version {
|
||||
let is_mergeable_text = is_file_type_mergable(
|
||||
&sanitized_relative_path,
|
||||
&state.config.server.mergeable_file_extensions,
|
||||
) && !is_binary(&latest_version.content)
|
||||
&& !is_binary(&new_content);
|
||||
|
||||
if is_mergeable_text || new_content == latest_version.content {
|
||||
return update_document::update_document(
|
||||
&sanitized_relative_path,
|
||||
Vec::new(),
|
||||
vault_id,
|
||||
latest_version.document_id,
|
||||
&request.relative_path,
|
||||
new_content,
|
||||
user,
|
||||
device_id,
|
||||
state,
|
||||
transaction,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// For non-mergeable (binary) files with different content, don't
|
||||
// merge, create a separate document at a deconflicted path so
|
||||
// neither client's data is silently overwritten.
|
||||
}
|
||||
|
||||
let document_id = uuid::Uuid::new_v4();
|
||||
|
||||
let last_update_id = state
|
||||
.database
|
||||
.get_max_update_id_in_vault(&vault_id, Some(&mut *transaction))
|
||||
.await
|
||||
.map_err(server_error)?;
|
||||
|
||||
let sanitized_relative_path = sanitize_path(&request.relative_path);
|
||||
let deduped_path = find_first_available_path(
|
||||
&vault_id,
|
||||
&sanitized_relative_path,
|
||||
|
|
@ -91,7 +122,7 @@ pub async fn create_document(
|
|||
vault_update_id: last_update_id + 1,
|
||||
document_id,
|
||||
relative_path: deduped_path,
|
||||
content: request.content.contents.to_vec(),
|
||||
content: new_content,
|
||||
updated_date: chrono::Utc::now(),
|
||||
is_deleted: false,
|
||||
user_id: user.name,
|
||||
|
|
@ -105,5 +136,7 @@ pub async fn create_document(
|
|||
.await
|
||||
.map_err(server_error)?;
|
||||
|
||||
Ok(Json(new_version.into()))
|
||||
Ok(Json(DocumentUpdateResponse::FastForwardUpdate(
|
||||
new_version.into(),
|
||||
)))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use anyhow::Context;
|
||||
use anyhow::{Context, anyhow};
|
||||
use axum::{
|
||||
Extension, Json,
|
||||
extract::{Path, State},
|
||||
|
|
@ -16,7 +16,7 @@ use crate::{
|
|||
},
|
||||
},
|
||||
config::user_config::User,
|
||||
errors::{SyncServerError, server_error},
|
||||
errors::{SyncServerError, client_error, not_found_error, server_error, write_transaction_error},
|
||||
utils::{normalize::normalize, sanitize_path::sanitize_path},
|
||||
};
|
||||
|
||||
|
|
@ -37,7 +37,7 @@ pub async fn delete_document(
|
|||
Extension(user): Extension<User>,
|
||||
TypedHeader(device_id): TypedHeader<DeviceIdHeader>,
|
||||
State(state): State<AppState>,
|
||||
Json(request): Json<DeleteDocumentVersion>,
|
||||
Json(_request): Json<DeleteDocumentVersion>,
|
||||
) -> Result<Json<DocumentVersionWithoutContent>, SyncServerError> {
|
||||
debug!("Deleting document `{document_id}` in vault `{vault_id}`");
|
||||
|
||||
|
|
@ -45,7 +45,7 @@ pub async fn delete_document(
|
|||
.database
|
||||
.create_write_transaction(&vault_id)
|
||||
.await
|
||||
.map_err(server_error)?;
|
||||
.map_err(write_transaction_error)?;
|
||||
|
||||
let last_update_id = state
|
||||
.database
|
||||
|
|
@ -77,7 +77,7 @@ pub async fn delete_document(
|
|||
let new_version = StoredDocumentVersion {
|
||||
vault_update_id: last_update_id + 1,
|
||||
document_id,
|
||||
relative_path: sanitize_path(&request.relative_path),
|
||||
relative_path: sanitize_path(&request.relative_path).map_err(client_error)?,
|
||||
content: latest_content, // copy the content from the latest version
|
||||
updated_date: chrono::Utc::now(),
|
||||
is_deleted: true,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ use crate::{
|
|||
AppState,
|
||||
database::models::{DocumentId, DocumentVersion, VaultId, VaultUpdateId},
|
||||
},
|
||||
errors::{SyncServerError, not_found_error, server_error},
|
||||
errors::{SyncServerError, client_error, not_found_error, server_error},
|
||||
utils::normalize::normalize,
|
||||
};
|
||||
|
||||
|
|
@ -52,7 +52,7 @@ pub async fn fetch_document_version(
|
|||
)?;
|
||||
|
||||
if result.document_id != document_id {
|
||||
return Err(not_found_error(anyhow!(
|
||||
return Err(client_error(anyhow!(
|
||||
"Document with document id `{document_id}` does not have a version with id \
|
||||
`{vault_update_id}`",
|
||||
)));
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ use crate::{
|
|||
AppState,
|
||||
database::models::{DocumentId, VaultId, VaultUpdateId},
|
||||
},
|
||||
errors::{SyncServerError, not_found_error, server_error},
|
||||
errors::{SyncServerError, client_error, not_found_error, server_error},
|
||||
utils::normalize::normalize,
|
||||
};
|
||||
|
||||
|
|
@ -52,7 +52,7 @@ pub async fn fetch_document_version_content(
|
|||
)?;
|
||||
|
||||
if result.document_id != document_id {
|
||||
return Err(not_found_error(anyhow!(
|
||||
return Err(client_error(anyhow!(
|
||||
"Document with document id `{document_id}` does not have a version with id \
|
||||
`{vault_update_id}`",
|
||||
)));
|
||||
|
|
|
|||
|
|
@ -1,25 +1,37 @@
|
|||
use std::sync::{
|
||||
Arc,
|
||||
atomic::{AtomicU64, Ordering},
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{Arc, Mutex},
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
|
||||
use axum_extra::{
|
||||
TypedHeader,
|
||||
headers::{Authorization, authorization::Bearer},
|
||||
};
|
||||
|
||||
/// Simple token-bucket rate limiter that refills every second.
|
||||
/// Per-user token-bucket rate limiter. Each bearer token gets its own bucket
|
||||
/// that refills to `max_per_second` tokens every second.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RateLimiter {
|
||||
inner: Arc<TokenBucket>,
|
||||
max_per_second: u64,
|
||||
buckets: Arc<Mutex<HashMap<String, Arc<TokenBucket>>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TokenBucket {
|
||||
tokens: AtomicU64,
|
||||
state: Mutex<BucketState>,
|
||||
max_tokens: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct BucketState {
|
||||
tokens: u64,
|
||||
last_refill: Instant,
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
/// Create a new rate limiter. Spawns a background task that refills tokens
|
||||
/// every second.
|
||||
/// Create a new per-user rate limiter.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
|
|
@ -27,44 +39,62 @@ impl RateLimiter {
|
|||
pub fn new(max_per_second: u64) -> Self {
|
||||
assert!(
|
||||
max_per_second > 0,
|
||||
"max_per_second must be > 0 (use 0 in config to disable rate limiting entirely)"
|
||||
"max_per_second must be > 0 (set rate_limit_per_user_per_second to null in config to disable)"
|
||||
);
|
||||
|
||||
let bucket = Arc::new(TokenBucket {
|
||||
tokens: AtomicU64::new(max_per_second),
|
||||
max_tokens: max_per_second,
|
||||
});
|
||||
|
||||
let bucket_clone = bucket.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(1));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
bucket_clone
|
||||
.tokens
|
||||
.store(bucket_clone.max_tokens, Ordering::Release);
|
||||
}
|
||||
});
|
||||
|
||||
Self { inner: bucket }
|
||||
Self {
|
||||
max_per_second,
|
||||
buckets: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
fn try_acquire(&self) -> bool {
|
||||
self.inner
|
||||
.tokens
|
||||
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
|
||||
if current > 0 { Some(current - 1) } else { None }
|
||||
fn get_or_create_bucket(&self, token: &str) -> Arc<TokenBucket> {
|
||||
self.buckets
|
||||
.lock()
|
||||
.expect("rate limiter lock poisoned")
|
||||
.entry(token.to_owned())
|
||||
.or_insert_with(|| {
|
||||
Arc::new(TokenBucket {
|
||||
state: Mutex::new(BucketState {
|
||||
tokens: self.max_per_second,
|
||||
last_refill: Instant::now(),
|
||||
}),
|
||||
max_tokens: self.max_per_second,
|
||||
})
|
||||
})
|
||||
.is_ok()
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenBucket {
|
||||
fn try_acquire(&self) -> bool {
|
||||
let mut state = self.state.lock().expect("token bucket lock poisoned");
|
||||
let now = Instant::now();
|
||||
if now.duration_since(state.last_refill).as_secs() >= 1 {
|
||||
state.tokens = self.max_tokens;
|
||||
state.last_refill = now;
|
||||
}
|
||||
if state.tokens > 0 {
|
||||
state.tokens -= 1;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn rate_limit_middleware(
|
||||
axum::extract::State(limiter): axum::extract::State<RateLimiter>,
|
||||
auth_header: Option<TypedHeader<Authorization<Bearer>>>,
|
||||
req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
if limiter.try_acquire() {
|
||||
let Some(TypedHeader(auth)) = auth_header else {
|
||||
return Ok(next.run(req).await);
|
||||
};
|
||||
|
||||
let bucket = limiter.get_or_create_bucket(auth.token());
|
||||
if bucket.try_acquire() {
|
||||
Ok(next.run(req).await)
|
||||
} else {
|
||||
Err(StatusCode::TOO_MANY_REQUESTS)
|
||||
|
|
|
|||
|
|
@ -14,8 +14,6 @@ pub struct CreateDocumentVersion {
|
|||
#[ts(as = "Vec<u8>")]
|
||||
#[form_data(limit = "unlimited")]
|
||||
pub content: FieldData<Bytes>,
|
||||
|
||||
pub idempotency_key: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, TryFromMultipart)]
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ use axum::{
|
|||
};
|
||||
use axum_extra::TypedHeader;
|
||||
use axum_typed_multipart::TypedMultipart;
|
||||
use futures::io::Write;
|
||||
use log::{debug, info};
|
||||
use reconcile_text::{BuiltinTokenizer, EditedText, reconcile};
|
||||
use serde::Deserialize;
|
||||
|
|
@ -16,10 +17,15 @@ use super::{
|
|||
use crate::{
|
||||
app_state::{
|
||||
AppState,
|
||||
database::models::{DocumentId, StoredDocumentVersion, VaultId, VaultUpdateId},
|
||||
database::{
|
||||
WriteTransaction,
|
||||
models::{DocumentId, StoredDocumentVersion, VaultId, VaultUpdateId},
|
||||
},
|
||||
},
|
||||
config::user_config::User,
|
||||
errors::{SyncServerError, client_error, not_found_error, server_error},
|
||||
errors::{
|
||||
SyncServerError, client_error, not_found_error, server_error, write_transaction_error,
|
||||
},
|
||||
server::requests::UpdateBinaryDocumentVersion,
|
||||
utils::{
|
||||
find_first_available_path::find_first_available_path, is_binary::is_binary,
|
||||
|
|
@ -46,18 +52,27 @@ pub async fn update_binary(
|
|||
State(state): State<AppState>,
|
||||
TypedMultipart(request): TypedMultipart<UpdateBinaryDocumentVersion>,
|
||||
) -> Result<Json<DocumentUpdateResponse>, SyncServerError> {
|
||||
let parent_document = get_parent_document(&state, &vault_id, request.parent_version_id).await?;
|
||||
let parent_document =
|
||||
get_parent_document(&state, &vault_id, &document_id, request.parent_version_id).await?;
|
||||
let content = request.content.contents.to_vec();
|
||||
|
||||
let mut transaction = state
|
||||
.database
|
||||
.create_write_transaction(&vault_id)
|
||||
.await
|
||||
.map_err(write_transaction_error)?;
|
||||
|
||||
update_document(
|
||||
parent_document,
|
||||
&parent_document.relative_path,
|
||||
parent_document.content,
|
||||
vault_id,
|
||||
document_id,
|
||||
&request.relative_path,
|
||||
content,
|
||||
user,
|
||||
device_id,
|
||||
state,
|
||||
&request.relative_path,
|
||||
content,
|
||||
transaction,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
|
@ -74,28 +89,36 @@ pub async fn update_text(
|
|||
State(state): State<AppState>,
|
||||
Json(request): Json<UpdateTextDocumentVersion>,
|
||||
) -> Result<Json<DocumentUpdateResponse>, SyncServerError> {
|
||||
let parent_document = get_parent_document(&state, &vault_id, request.parent_version_id).await?;
|
||||
let parent_document =
|
||||
get_parent_document(&state, &vault_id, &document_id, request.parent_version_id).await?;
|
||||
|
||||
let edited_text = EditedText::from_diff(
|
||||
str::from_utf8(&parent_document.content)
|
||||
.expect("parent must be valid UTF-8 because it's a text document"),
|
||||
request.content,
|
||||
&*BuiltinTokenizer::Word,
|
||||
)
|
||||
.context("Failed to apply given diff to parent document")
|
||||
.map_err(client_error)?;
|
||||
let parent_text = str::from_utf8(&parent_document.content)
|
||||
.context("Parent version contains binary content; use putBinary instead of putText")
|
||||
.map_err(client_error)?;
|
||||
|
||||
let edited_text = EditedText::from_diff(parent_text, request.content, &*BuiltinTokenizer::Word)
|
||||
.context("Failed to apply given diff to parent document")
|
||||
.map_err(client_error)?;
|
||||
|
||||
let content = edited_text.apply().text().into_bytes();
|
||||
|
||||
let mut transaction = state
|
||||
.database
|
||||
.create_write_transaction(&vault_id)
|
||||
.await
|
||||
.map_err(write_transaction_error)?;
|
||||
|
||||
update_document(
|
||||
parent_document,
|
||||
&parent_document.relative_path,
|
||||
parent_document.content,
|
||||
vault_id,
|
||||
document_id,
|
||||
&request.relative_path,
|
||||
content,
|
||||
user,
|
||||
device_id,
|
||||
state,
|
||||
&request.relative_path,
|
||||
content,
|
||||
transaction,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
|
@ -103,9 +126,10 @@ pub async fn update_text(
|
|||
async fn get_parent_document(
|
||||
state: &AppState,
|
||||
vault_id: &VaultId,
|
||||
document_id: &DocumentId,
|
||||
parent_version_id: VaultUpdateId,
|
||||
) -> Result<StoredDocumentVersion, SyncServerError> {
|
||||
state
|
||||
let parent = state
|
||||
.database
|
||||
.get_document_version(vault_id, parent_version_id, None)
|
||||
.await
|
||||
|
|
@ -117,29 +141,33 @@ async fn get_parent_document(
|
|||
)))
|
||||
},
|
||||
Ok,
|
||||
)
|
||||
)?;
|
||||
|
||||
if &parent.document_id != document_id {
|
||||
return Err(client_error(anyhow!(
|
||||
"Parent version `{parent_version_id}` does not belong to document `{document_id}`"
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(parent)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
|
||||
async fn update_document(
|
||||
parent_document: StoredDocumentVersion,
|
||||
pub async fn update_document(
|
||||
parent_relative_path: &str,
|
||||
parent_content: Vec<u8>,
|
||||
vault_id: VaultId,
|
||||
document_id: DocumentId,
|
||||
relative_path: &str,
|
||||
content: Vec<u8>,
|
||||
user: User,
|
||||
device_id: DeviceIdHeader,
|
||||
state: AppState,
|
||||
relative_path: &str,
|
||||
content: Vec<u8>,
|
||||
mut transaction: WriteTransaction,
|
||||
) -> Result<Json<DocumentUpdateResponse>, SyncServerError> {
|
||||
debug!("Updating document `{document_id}` in vault `{vault_id}`");
|
||||
|
||||
let sanitized_relative_path = sanitize_path(relative_path);
|
||||
|
||||
let mut transaction = state
|
||||
.database
|
||||
.create_write_transaction(&vault_id)
|
||||
.await
|
||||
.map_err(server_error)?;
|
||||
let sanitized_relative_path = sanitize_path(relative_path).map_err(client_error)?;
|
||||
|
||||
let last_update_id = state
|
||||
.database
|
||||
|
|
@ -195,35 +223,44 @@ async fn update_document(
|
|||
let are_all_participants_mergable = is_file_type_mergable(
|
||||
&sanitized_relative_path,
|
||||
&state.config.server.mergeable_file_extensions,
|
||||
) && !is_binary(&parent_document.content)
|
||||
) && !is_binary(&parent_content)
|
||||
&& !is_binary(&latest_version.content)
|
||||
&& !is_binary(&content);
|
||||
|
||||
let merged_content = if are_all_participants_mergable {
|
||||
let (merged_content, is_different_from_request_content) = if are_all_participants_mergable {
|
||||
info!("Merging changes for document `{document_id}` in vault `{vault_id}`");
|
||||
reconcile(
|
||||
str::from_utf8(&parent_document.content)
|
||||
.expect("parent must be valid UTF-8 because it's not binary"),
|
||||
&str::from_utf8(&latest_version.content)
|
||||
.expect("latest_version must be valid UTF-8 because it's not binary")
|
||||
.into(),
|
||||
&str::from_utf8(&content)
|
||||
.expect("content must be valid UTF-8 because it's not binary")
|
||||
.into(),
|
||||
let parent_text = str::from_utf8(&parent_content)
|
||||
.context("Parent document content is not valid UTF-8")
|
||||
.map_err(client_error)?;
|
||||
let latest_text = str::from_utf8(&latest_version.content)
|
||||
.context("Latest version content is not valid UTF-8")
|
||||
.map_err(client_error)?;
|
||||
let new_text = str::from_utf8(&content)
|
||||
.context("New content is not valid UTF-8")
|
||||
.map_err(client_error)?;
|
||||
let merged = reconcile(
|
||||
parent_text,
|
||||
&latest_text.into(),
|
||||
&new_text.into(),
|
||||
&*BuiltinTokenizer::Word,
|
||||
)
|
||||
.apply()
|
||||
.text()
|
||||
.into_bytes()
|
||||
.into_bytes();
|
||||
let is_different = merged != content;
|
||||
(merged, is_different)
|
||||
} else {
|
||||
content.clone()
|
||||
(content, false) // false means that the client doesn't need to refetch the file as we can ensure the remote and local versions are the same as LWW is the merging method for binary files
|
||||
};
|
||||
|
||||
let is_different_from_request_content = merged_content != content;
|
||||
|
||||
// We can only update the relative path if we're the first one to do so
|
||||
let new_relative_path = if parent_document.relative_path == latest_version.relative_path
|
||||
&& latest_version.relative_path != sanitized_relative_path
|
||||
// Rename resolution: only apply the client's rename if the document's path
|
||||
// hasn't changed since this client's parent version. Check the parent
|
||||
// version's path against the latest version's path. If they differ, another
|
||||
// client already renamed the document — keep the latest path (first rename
|
||||
// wins). Content changes from both clients are still merged correctly via
|
||||
// the 3-way reconcile above, independent of which rename wins.
|
||||
let new_relative_path = if parent_relative_path == latest_version.relative_path
|
||||
&& sanitized_relative_path != latest_version.relative_path
|
||||
{
|
||||
let new_path = find_first_available_path(
|
||||
&vault_id,
|
||||
|
|
|
|||
|
|
@ -6,10 +6,10 @@ use axum::{
|
|||
},
|
||||
response::Response,
|
||||
};
|
||||
use futures::sink::SinkExt;
|
||||
use futures::stream::StreamExt;
|
||||
use log::{debug, info};
|
||||
use log::{debug, info, warn};
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::{
|
||||
app_state::{
|
||||
AppState,
|
||||
|
|
@ -24,10 +24,26 @@ use crate::{
|
|||
},
|
||||
},
|
||||
},
|
||||
consts::{
|
||||
HANDSHAKE_TIMEOUT, MAX_CURSORS_PER_DOCUMENT, MAX_CURSOR_DOCUMENTS,
|
||||
MAX_RELATIVE_PATH_LEN,
|
||||
},
|
||||
errors::{SyncServerError, client_error, server_error},
|
||||
utils::normalize::normalize,
|
||||
};
|
||||
|
||||
/// Tracks a pending (not yet authenticated) WebSocket connection.
|
||||
/// Decrements the counter when dropped, ensuring cleanup even if
|
||||
/// the upgrade never completes or auth fails.
|
||||
struct PendingWsGuard(std::sync::Arc<std::sync::atomic::AtomicUsize>);
|
||||
|
||||
impl Drop for PendingWsGuard {
|
||||
fn drop(&mut self) {
|
||||
self.0
|
||||
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct WebSocketPathParams {
|
||||
#[serde(deserialize_with = "normalize")]
|
||||
|
|
@ -39,13 +55,31 @@ pub async fn websocket_handler(
|
|||
Path(WebSocketPathParams { vault_id }): Path<WebSocketPathParams>,
|
||||
State(state): State<AppState>,
|
||||
) -> Result<Response, SyncServerError> {
|
||||
Ok(ws.on_upgrade(move |socket| websocket_wrapped(state, socket, vault_id)))
|
||||
let current = state
|
||||
.pending_ws_connections
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
if current >= state.config.server.max_pending_websocket_connections {
|
||||
state
|
||||
.pending_ws_connections
|
||||
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
|
||||
return Err(client_error(anyhow::anyhow!(
|
||||
"Too many pending WebSocket connections"
|
||||
)));
|
||||
}
|
||||
|
||||
let guard = PendingWsGuard(state.pending_ws_connections.clone());
|
||||
Ok(ws.on_upgrade(move |socket| websocket_wrapped(state, socket, vault_id, guard)))
|
||||
}
|
||||
|
||||
async fn websocket_wrapped(state: AppState, stream: WebSocket, vault_id: VaultId) {
|
||||
async fn websocket_wrapped(
|
||||
state: AppState,
|
||||
stream: WebSocket,
|
||||
vault_id: VaultId,
|
||||
pending_guard: PendingWsGuard,
|
||||
) {
|
||||
info!("WebSocket connection opened on vault `{vault_id}`");
|
||||
|
||||
let result = websocket(state, stream, vault_id.clone()).await;
|
||||
let result = websocket(state, stream, vault_id.clone(), pending_guard).await;
|
||||
|
||||
if let Err(err) = result {
|
||||
debug!("WebSocket connection error on vault `{vault_id}`: {err}");
|
||||
|
|
@ -57,25 +91,53 @@ async fn websocket(
|
|||
state: AppState,
|
||||
stream: WebSocket,
|
||||
vault_id: VaultId,
|
||||
pending_guard: PendingWsGuard,
|
||||
) -> Result<(), SyncServerError> {
|
||||
let (mut sender, mut websocket_receiver) = stream.split();
|
||||
|
||||
let authed_handshake = get_authenticated_handshake(
|
||||
&state,
|
||||
&vault_id,
|
||||
websocket_receiver
|
||||
.next()
|
||||
.await
|
||||
.transpose()
|
||||
.unwrap_or_default(),
|
||||
)?;
|
||||
let handshake_msg = tokio::time::timeout(HANDSHAKE_TIMEOUT, websocket_receiver.next())
|
||||
.await
|
||||
.map_err(|_| client_error(anyhow::anyhow!("WebSocket handshake timed out")))?
|
||||
.transpose()
|
||||
.map_err(|e| client_error(anyhow::anyhow!("WebSocket error during handshake: {e}")))?;
|
||||
|
||||
let authed_handshake = get_authenticated_handshake(&state, &vault_id, handshake_msg)?;
|
||||
|
||||
info!(
|
||||
"WebSocket handshake successful for vault `{vault_id}` for `{}`",
|
||||
authed_handshake.handshake.device_id
|
||||
);
|
||||
|
||||
let mut broadcast_receiver = state.broadcasts.get_receiver(vault_id.clone()).await;
|
||||
// Auth complete — no longer a pending connection.
|
||||
drop(pending_guard);
|
||||
|
||||
let max_clients = state.config.server.max_clients_per_vault;
|
||||
let mut broadcast_receiver = match state
|
||||
.broadcasts
|
||||
.get_receiver(vault_id.clone(), max_clients)
|
||||
.await
|
||||
{
|
||||
Ok(receiver) => receiver,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"Vault `{vault_id}` has reached the maximum number of clients ({max_clients}), rejecting connection from `{}`",
|
||||
authed_handshake.handshake.device_id
|
||||
);
|
||||
if let Err(e) = sender
|
||||
.send(Message::Close(Some(axum::extract::ws::CloseFrame {
|
||||
code: 4000,
|
||||
reason: format!(
|
||||
"Vault has reached the maximum number of clients ({max_clients})"
|
||||
)
|
||||
.into(),
|
||||
})))
|
||||
.await
|
||||
{
|
||||
warn!("Failed to send WebSocket close frame: {e}");
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
send_update_over_websocket(
|
||||
&WebSocketServerMessage::VaultUpdate(WebSocketVaultUpdate {
|
||||
|
|
@ -101,24 +163,35 @@ async fn websocket(
|
|||
|
||||
let device_id = authed_handshake.handshake.device_id.clone();
|
||||
let mut send_task = tokio::spawn(async move {
|
||||
while let Ok(update) = broadcast_receiver.recv().await {
|
||||
if Some(&device_id) == update.origin_device_id.as_ref() {
|
||||
continue;
|
||||
}
|
||||
loop {
|
||||
match broadcast_receiver.recv().await {
|
||||
Ok(update) => {
|
||||
if Some(&device_id) == update.origin_device_id.as_ref() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let message = match update.message {
|
||||
WebSocketServerMessage::CursorPositions(CursorPositionFromServer { clients }) => {
|
||||
WebSocketServerMessage::CursorPositions(CursorPositionFromServer {
|
||||
clients: clients
|
||||
.into_iter()
|
||||
.filter(|client| client.device_id != device_id)
|
||||
.collect(),
|
||||
})
|
||||
let message = match update.message {
|
||||
WebSocketServerMessage::CursorPositions(CursorPositionFromServer {
|
||||
clients,
|
||||
}) => WebSocketServerMessage::CursorPositions(CursorPositionFromServer {
|
||||
clients: clients
|
||||
.into_iter()
|
||||
.filter(|client| client.device_id != device_id)
|
||||
.collect(),
|
||||
}),
|
||||
WebSocketServerMessage::VaultUpdate(_) => update.message,
|
||||
};
|
||||
|
||||
send_update_over_websocket(&message, &mut sender).await?;
|
||||
}
|
||||
WebSocketServerMessage::VaultUpdate(_) => update.message,
|
||||
};
|
||||
|
||||
send_update_over_websocket(&message, &mut sender).await?;
|
||||
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
|
||||
warn!(
|
||||
"WebSocket receiver lagged, dropped {n} messages — disconnecting client to force full resync"
|
||||
);
|
||||
break;
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
|
||||
}
|
||||
}
|
||||
|
||||
Ok::<(), SyncServerError>(())
|
||||
|
|
@ -128,26 +201,57 @@ async fn websocket(
|
|||
let vault_id_clone = vault_id.clone();
|
||||
let cursor_manager = state.cursors.clone();
|
||||
let mut receive_task = tokio::spawn(async move {
|
||||
while let Some(Ok(Message::Text(message))) = websocket_receiver.next().await {
|
||||
let message: WebSocketClientMessage = serde_json::from_str(&message)
|
||||
.context("Failed to parse WebSocket message from client")
|
||||
.map_err(server_error)?;
|
||||
while let Some(msg) = websocket_receiver.next().await {
|
||||
match msg {
|
||||
Ok(Message::Text(message)) => {
|
||||
let message: WebSocketClientMessage = serde_json::from_str(&message)
|
||||
.context("Failed to parse WebSocket message from client")
|
||||
.map_err(client_error)?;
|
||||
|
||||
match message {
|
||||
WebSocketClientMessage::Handshake(_) => {
|
||||
return Err(client_error(anyhow::anyhow!(
|
||||
"Unexpected handshake message"
|
||||
)));
|
||||
match message {
|
||||
WebSocketClientMessage::Handshake(_) => {
|
||||
return Err(client_error(anyhow::anyhow!(
|
||||
"Unexpected handshake message"
|
||||
)));
|
||||
}
|
||||
WebSocketClientMessage::CursorPositions(cursors) => {
|
||||
let docs = cursors.documents_with_cursors;
|
||||
if docs.len() > MAX_CURSOR_DOCUMENTS {
|
||||
warn!(
|
||||
"Cursor update rejected: {} documents exceeds limit of {MAX_CURSOR_DOCUMENTS}",
|
||||
docs.len()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let valid = docs.iter().all(|doc| {
|
||||
doc.cursors.len() <= MAX_CURSORS_PER_DOCUMENT
|
||||
&& doc.relative_path.len() <= MAX_RELATIVE_PATH_LEN
|
||||
});
|
||||
if !valid {
|
||||
warn!("Cursor update rejected: a document exceeds cursor or path length limits");
|
||||
continue;
|
||||
}
|
||||
|
||||
cursor_manager
|
||||
.update_cursors(
|
||||
vault_id_clone.clone(),
|
||||
authed_handshake.user.name.clone(),
|
||||
&device_id,
|
||||
docs,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
WebSocketClientMessage::CursorPositions(cursors) => {
|
||||
cursor_manager
|
||||
.update_cursors(
|
||||
vault_id_clone.clone(),
|
||||
authed_handshake.user.name.clone(),
|
||||
&device_id,
|
||||
cursors.documents_with_cursors,
|
||||
)
|
||||
.await;
|
||||
Ok(Message::Close(_)) => break,
|
||||
Ok(Message::Binary(_)) => {
|
||||
warn!("Received unexpected binary WebSocket message, ignoring");
|
||||
}
|
||||
Ok(_) => {} // Ping/Pong frames handled by axum
|
||||
Err(e) => {
|
||||
debug!("WebSocket receive error: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -155,38 +259,47 @@ async fn websocket(
|
|||
Ok::<(), SyncServerError>(())
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
_ = &mut send_task => receive_task.abort(),
|
||||
_ = &mut receive_task => send_task.abort(),
|
||||
let result: Result<(), SyncServerError> = tokio::select! {
|
||||
send_result = &mut send_task => {
|
||||
receive_task.abort();
|
||||
let _ = receive_task.await;
|
||||
match send_result {
|
||||
Err(e) => Err(server_error(
|
||||
anyhow::Error::from(e).context("WebSocket send task failed"),
|
||||
)),
|
||||
Ok(inner) => inner,
|
||||
}
|
||||
},
|
||||
receive_result = &mut receive_task => {
|
||||
send_task.abort();
|
||||
let _ = send_task.await;
|
||||
match receive_result {
|
||||
Err(e) => Err(server_error(
|
||||
anyhow::Error::from(e).context("WebSocket receive task failed"),
|
||||
)),
|
||||
Ok(inner) => inner,
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let result: Result<(), SyncServerError> = (async {
|
||||
send_task
|
||||
.await
|
||||
.context("WebSocket send task failed")
|
||||
.map_err(client_error)
|
||||
.and_then(|err| err)?;
|
||||
|
||||
receive_task
|
||||
.await
|
||||
.context("WebSocket receive task failed")
|
||||
.map_err(client_error)
|
||||
.and_then(|err| err)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await;
|
||||
|
||||
state
|
||||
.cursors
|
||||
.remove_cursors_of_device(&vault_id, &authed_handshake.handshake.device_id)
|
||||
.await;
|
||||
|
||||
if result.is_err() {
|
||||
info!(
|
||||
"WebSocket disconnected on vault `{vault_id}` for `{}`",
|
||||
authed_handshake.handshake.device_id
|
||||
);
|
||||
match &result {
|
||||
Ok(()) => {
|
||||
info!(
|
||||
"WebSocket disconnected on vault `{vault_id}` for `{}`",
|
||||
authed_handshake.handshake.device_id
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"WebSocket error on vault `{vault_id}` for `{}`: {err}",
|
||||
authed_handshake.handshake.device_id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
|
|
|
|||
|
|
@ -1,8 +1,17 @@
|
|||
use std::sync::LazyLock;
|
||||
|
||||
use regex::Regex;
|
||||
|
||||
static DEDUP_SUFFIX_REGEX: LazyLock<Regex> =
|
||||
LazyLock::new(|| Regex::new(r" \((\d+)\)$").expect("invalid regex"));
|
||||
|
||||
pub fn dedup_paths(path: &str) -> impl Iterator<Item = String> {
|
||||
let mut path_parts = path.split('/').collect::<Vec<_>>();
|
||||
let file_name = path_parts.pop().unwrap().to_owned();
|
||||
let file_name = path_parts
|
||||
.pop()
|
||||
.filter(|s| !s.is_empty())
|
||||
.unwrap_or(path)
|
||||
.to_owned();
|
||||
|
||||
let mut directory = path_parts.join("/");
|
||||
if !directory.is_empty() {
|
||||
|
|
@ -29,14 +38,13 @@ pub fn dedup_paths(path: &str) -> impl Iterator<Item = String> {
|
|||
}
|
||||
};
|
||||
|
||||
let regex = Regex::new(r" \((\d+)\)$").unwrap();
|
||||
let start_number = regex
|
||||
let start_number = DEDUP_SUFFIX_REGEX
|
||||
.captures(&stem)
|
||||
.and_then(|caps| caps.get(1))
|
||||
.and_then(|m| m.as_str().parse::<u32>().ok())
|
||||
.unwrap_or(0);
|
||||
|
||||
let clean_stem = regex.replace(&stem, "").to_string();
|
||||
let clean_stem = DEDUP_SUFFIX_REGEX.replace(&stem, "").to_string();
|
||||
|
||||
(start_number..).map(move |dedup_number| {
|
||||
if dedup_number == 0 {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use crate::app_state::database::models::VaultId;
|
||||
use crate::utils::dedup_paths::dedup_paths;
|
||||
use anyhow::{Result, bail};
|
||||
use log::info;
|
||||
use anyhow::Result;
|
||||
use log::{debug, info};
|
||||
use sqlx::sqlite::SqliteConnection;
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,17 @@
|
|||
use anyhow::{Result, ensure};
|
||||
|
||||
/// Sanitize the document's path to allow all clients to create the same path in
|
||||
/// their filesystem. If we didn't do this server-side, client's would need to
|
||||
/// deal with mapping invalid names to valid ones and then back.
|
||||
pub fn sanitize_path(path: &str) -> String {
|
||||
pub fn sanitize_path(path: &str) -> Result<String> {
|
||||
let options = sanitize_filename::Options {
|
||||
truncate: true,
|
||||
windows: true, // Windows is the lowest common denominator
|
||||
replacement: "",
|
||||
};
|
||||
|
||||
path.split('/')
|
||||
let result = path
|
||||
.split('/')
|
||||
.map(|part| {
|
||||
let proposal = sanitize_filename::sanitize_with_options(part, options.clone());
|
||||
if !part.is_empty() && proposal.is_empty() {
|
||||
|
|
@ -18,7 +21,10 @@ pub fn sanitize_path(path: &str) -> String {
|
|||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("/")
|
||||
.join("/");
|
||||
|
||||
ensure!(!result.is_empty(), "Relative path is empty after sanitization");
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -27,8 +33,32 @@ mod test {
|
|||
|
||||
#[test]
|
||||
fn test_sanitize_path() {
|
||||
assert_eq!(sanitize_path("/my/path/what?"), "/my/path/what");
|
||||
assert_eq!(sanitize_path("file (1).md"), "file (1).md");
|
||||
assert_eq!(sanitize_path("/my/path/\\\\:?"), "/my/path/_");
|
||||
assert_eq!(sanitize_path("/my/path/what?").unwrap(), "/my/path/what");
|
||||
assert_eq!(sanitize_path("file (1).md").unwrap(), "file (1).md");
|
||||
assert_eq!(sanitize_path("/my/path/\\\\:?").unwrap(), "/my/path/_");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_path_empty() {
|
||||
assert!(sanitize_path("").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_path_idempotent_simple() {
|
||||
let mut result = sanitize_path("notes/my file.md").unwrap();
|
||||
for _ in 0..5 {
|
||||
result = sanitize_path(&result).unwrap();
|
||||
}
|
||||
assert_eq!(result, "notes/my file.md");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_path_idempotent_special_chars() {
|
||||
let first = sanitize_path("/my/path/what?/file:name<>.md").unwrap();
|
||||
let mut result = first.clone();
|
||||
for _ in 0..5 {
|
||||
result = sanitize_path(&result).unwrap();
|
||||
}
|
||||
assert_eq!(result, first);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue