diff --git a/sync-server/src/app_state.rs b/sync-server/src/app_state.rs index 2019e08e..1bd3222e 100644 --- a/sync-server/src/app_state.rs +++ b/sync-server/src/app_state.rs @@ -2,6 +2,8 @@ pub mod cursors; pub mod database; pub mod websocket; +use std::sync::{Arc, atomic::AtomicUsize}; + use anyhow::Result; use cursors::Cursors; use database::Database; @@ -15,21 +17,42 @@ pub struct AppState { pub database: Database, pub cursors: Cursors, pub broadcasts: Broadcasts, + /// Tracks WebSocket connections that have upgraded but not yet completed + /// the authentication handshake + pub pending_ws_connections: Arc, + /// Send on this channel to stop background tasks (cursor cleanup, + /// idle-pool cleanup) + shutdown_tx: Arc>, } impl AppState { pub async fn try_new(config: Config) -> Result { + let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(()); + let broadcasts = Broadcasts::new(&config.server); - let database = Database::try_new(&config.database, &broadcasts).await?; + let database = + Database::try_new(&config.database, &broadcasts, shutdown_rx.clone()).await?; let cursors: Cursors = Cursors::new(&config.database, &broadcasts); - Cursors::start_background_task(cursors.clone()); + Cursors::start_background_task(cursors.clone(), shutdown_rx); Ok(Self { config, database, cursors, broadcasts, + pending_ws_connections: Arc::new(AtomicUsize::new(0)), + shutdown_tx: Arc::new(shutdown_tx), }) } + + /// Signal all background tasks (idle pool cleanup, cursor cleanup) to stop + pub fn shutdown(&self) { + let _ = self.shutdown_tx.send(()); + } + + /// Get a receiver to be notified when shutdown is triggered + pub fn subscribe_shutdown(&self) -> tokio::sync::watch::Receiver<()> { + self.shutdown_tx.subscribe() + } } diff --git a/sync-server/src/app_state/database.rs b/sync-server/src/app_state/database.rs index 75ce6df4..28acde41 100644 --- a/sync-server/src/app_state/database.rs +++ b/sync-server/src/app_state/database.rs @@ -1,16 +1,29 @@ use core::time::Duration; -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::HashMap, + sync::Arc, + sync::atomic::{AtomicU64, Ordering}, +}; use anyhow::{Context as _, Result}; use log::info; use models::{ DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId, VaultUpdateId, }; -use sqlx::{ConnectOptions, sqlite::SqliteConnectOptions, types::chrono::Utc}; +use sqlx::{ConnectOptions, Connection, sqlite::SqliteConnectOptions, types::chrono::Utc}; pub mod models; -use sqlx::{Pool, Sqlite, sqlite::SqlitePoolOptions}; -use tokio::sync::Mutex; + +/// Sentinel error indicating the `SQLite` database is busy (`SQLITE_BUSY`). +/// Handlers can downcast to this to return 429 instead of 500. +#[derive(Debug, thiserror::Error)] +#[error("Database is busy")] +pub struct WriteBusyError; + +use sqlx::{ + Pool, Sqlite, pool::PoolConnection, sqlite::SqliteConnection, sqlite::SqlitePoolOptions, +}; +use tokio::sync::{Mutex, OnceCell}; use tokio::time::Instant; use uuid::fmt::Hyphenated; @@ -19,33 +32,200 @@ use super::websocket::{ models::{WebSocketServerMessage, WebSocketServerMessageWithOrigin, WebSocketVaultUpdate}, }; use crate::config::database_config::DatabaseConfig; +use crate::consts::IDLE_POOL_TIMEOUT; -#[derive(Clone)] -struct PoolWithTimestamp { - pool: Pool, - last_accessed: Instant, +/// Holds separate reader and writer pools for a single vault. +/// The writer pool has exactly 1 connection so writes never compete +/// with reads for pool slots. +#[derive(Debug, Clone)] +struct VaultPools { + reader: Pool, + writer: Pool, } -impl std::fmt::Debug for PoolWithTimestamp { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("PoolWithTimestamp") - .field("pool", &"Pool") - .field("last_accessed", &self.last_accessed) - .finish() - } +#[derive(Debug)] +struct VaultPool { + cell: Arc>, + /// Monotonic timestamp in milliseconds (from `Instant::now()` at server start) + last_accessed_ms: AtomicU64, } #[derive(Clone, Debug)] pub struct Database { config: DatabaseConfig, broadcasts: Broadcasts, - connection_pools: Arc>>, + connection_pools: Arc>>>, + /// Per-vault write serialization. `SQLite` allows only one writer at a + /// time; `BEGIN IMMEDIATE` on a second connection blocks until the first + /// commits (up to `busy_timeout`). Under concurrent load the blocked + /// connections consume the pool, starving even read-only requests. + /// This mutex moves the wait from the `SQLite` layer (where it holds a + /// pool connection) to the Tokio layer (where it holds nothing). + write_locks: Arc>>>>, + /// Monotonic epoch for lock-free `last_accessed_ms` timestamps + epoch: Instant, } -pub type Transaction<'a> = sqlx::Transaction<'a, Sqlite>; +/// A write transaction backed by a raw `BEGIN IMMEDIATE` instead of sqlx's +/// savepoint-based `Transaction`. This avoids the savepoint mismatch caused +/// by the old `END; BEGIN IMMEDIATE;` workaround. +/// +/// Holds an `OwnedMutexGuard` that serializes write transactions per vault +/// at the application level (see `Database::write_locks`). The guard is +/// released when the transaction is committed, rolled back, or dropped. +pub struct WriteTransaction { + conn: Option>, + _write_guard: tokio::sync::OwnedMutexGuard<()>, +} + +impl WriteTransaction { + async fn new( + pool: &Pool, + write_guard: tokio::sync::OwnedMutexGuard<()>, + ) -> Result { + let mut conn = pool + .acquire() + .await + .context("Cannot acquire connection for write transaction")?; + if let Err(e) = sqlx::query("BEGIN IMMEDIATE").execute(&mut *conn).await { + let is_busy = match &e { + sqlx::Error::Database(db_err) => { + // SQLITE_BUSY base code is 5. Extended codes share base 5. + let busy_by_code = db_err + .code() + .is_some_and(|c| c.parse::().is_ok_and(|n| n & 0xFF == 5)); + busy_by_code || db_err.message().contains("database is locked") + } + _ => false, + }; + if is_busy { + return Err(WriteBusyError.into()); + } + return Err(e).context("Cannot begin immediate transaction"); + } + Ok(Self { + conn: Some(conn), + _write_guard: write_guard, + }) + } + + pub async fn commit(mut self) -> Result<()> { + if let Some(mut conn) = self.conn.take() { + sqlx::query("COMMIT") + .execute(&mut *conn) + .await + .context("Failed to commit transaction")?; + } + Ok(()) + } + + pub async fn rollback(mut self) -> Result<()> { + if let Some(mut conn) = self.conn.take() { + sqlx::query("ROLLBACK") + .execute(&mut *conn) + .await + .context("Failed to rollback transaction")?; + } + Ok(()) + } +} + +impl Drop for WriteTransaction { + fn drop(&mut self) { + if self.conn.is_some() { + // The connection is returned to the pool with an open transaction. + // The pool's `before_acquire` hook issues a ROLLBACK before + // handing it to the next consumer, so no async work is needed + // here. If the pool is being shut down, SQLite itself rolls back + // uncommitted transactions when the connection closes. + log::warn!("WriteTransaction dropped without commit or rollback"); + } + } +} + +impl std::ops::Deref for WriteTransaction { + type Target = SqliteConnection; + fn deref(&self) -> &Self::Target { + self.conn + .as_ref() + .expect("BUG: WriteTransaction dereferenced after being consumed") + .deref() + } +} + +impl std::ops::DerefMut for WriteTransaction { + fn deref_mut(&mut self) -> &mut Self::Target { + self.conn + .as_mut() + .expect("BUG: WriteTransaction dereferenced after being consumed") + .deref_mut() + } +} + +/// Ensure the connection has no leftover open transaction (e.g. from a +/// `WriteTransaction` that was dropped without commit/rollback). ROLLBACK +/// is a harmless no-op if no transaction is active. +fn rollback_before_acquire( + conn: &mut SqliteConnection, + _meta: sqlx::pool::PoolConnectionMetadata, +) -> futures::future::BoxFuture<'_, Result> { + Box::pin(async move { + if let Err(e) = sqlx::query("ROLLBACK").execute(&mut *conn).await { + // "cannot rollback - no transaction is active" is the common + // case (connection returned cleanly). Only unexpected errors + // deserve attention. + log::debug!("before_acquire ROLLBACK failed: {e}"); + } + Ok(true) + }) +} impl Database { - pub async fn try_new(config: &DatabaseConfig, broadcasts: &Broadcasts) -> Result { + fn now_ms(&self) -> u64 { + self.epoch.elapsed().as_millis() as u64 + } + + /// Lists all vault IDs that exist on disk (have a `.sqlite` file). + pub async fn list_vaults(&self) -> Result> { + let mut vaults = Vec::new(); + let mut entries = tokio::fs::read_dir(&self.config.databases_directory_path) + .await + .context("Failed to read databases directory")?; + while let Some(entry) = entries.next_entry().await? { + let name = entry.file_name().to_string_lossy().to_string(); + if let Some(vault) = name.strip_suffix(".sqlite") { + vaults.push(vault.to_owned()); + } + } + vaults.sort(); + Ok(vaults) + } + + pub async fn get_vault_stats(&self, vault: &VaultId) -> Result { + let pool = self.get_connection_pool(vault).await?; + let row = sqlx::query!( + r#" + SELECT + (SELECT MIN(updated_date) FROM documents) + AS "created_at: chrono::DateTime", + (SELECT COUNT(DISTINCT document_id) FROM latest_document_versions + WHERE is_deleted = false) + AS "document_count!: u32" + "#, + ) + .fetch_one(&pool) + .await?; + Ok(models::VaultStats { + created_at: row.created_at, + document_count: row.document_count, + }) + } + + pub async fn try_new( + config: &DatabaseConfig, + broadcasts: &Broadcasts, + shutdown: tokio::sync::watch::Receiver<()>, + ) -> Result { tokio::fs::create_dir_all(&config.databases_directory_path) .await .with_context(|| { @@ -70,122 +250,207 @@ impl Database { .trim_end_matches(".sqlite") .to_owned(); - let pool = Self::create_vault_database(config, &vault).await?; + Self::validate_vault_id(&vault)?; + + let pools = Self::create_vault_database(config, &vault).await?; + let cell = Arc::new(OnceCell::new()); + cell.set(pools).expect("cell is new"); connection_pools.insert( vault.clone(), - PoolWithTimestamp { - pool, - last_accessed: Instant::now(), - }, + Arc::new(VaultPool { + cell, + last_accessed_ms: AtomicU64::new(0), + }), ); } + info!("Database migrations applied"); let database = Self { config: config.clone(), connection_pools: Arc::new(Mutex::new(connection_pools)), broadcasts: broadcasts.clone(), + write_locks: Arc::new(Mutex::new(HashMap::new())), + epoch: Instant::now(), }; - // Start background task to cleanup idle connection pools - database.start_idle_pool_cleanup(); + database.start_idle_pool_cleanup(shutdown); Ok(database) } - async fn create_vault_database( - config: &DatabaseConfig, - vault: &VaultId, - ) -> Result> { + async fn create_vault_database(config: &DatabaseConfig, vault: &VaultId) -> Result { let file_name = config .databases_directory_path .join(format!("{vault}.sqlite")); - let connection_options = SqliteConnectOptions::new() + // Database-level PRAGMAs (auto_vacuum, journal_mode) require a write + // lock and persist across connections. Set them once with a dedicated + // init connection so pool connections never need the write lock just to + // open. + let init_options = SqliteConnectOptions::new() .filename(file_name.clone()) .create_if_missing(true) - .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Full) - .busy_timeout(Duration::from_secs(3600)) - .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal) - .log_slow_statements(log::LevelFilter::Warn, Duration::from_secs(30)); + .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental) + .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal); - let pool = SqlitePoolOptions::new() + // Run migrations on a dedicated connection, NOT through the pool. + // The pool's `before_acquire` hook issues ROLLBACK on every checkout, + // which can roll back the migration's bookkeeping transaction (the + // _sqlx_migrations INSERT) while the DDL (ALTER TABLE) has already + // auto-committed — leaving the migration in a dirty state. + // + // Uses `run_direct` instead of `run` because `run` takes + // `impl Acquire<'_>`, whose lifetime bound prevents the enclosing + // future from satisfying the `Send` requirement of axum handlers. + let mut init_conn = sqlx::SqliteConnection::connect_with(&init_options).await?; + sqlx::migrate!("src/app_state/database/migrations") + .run_direct(&mut init_conn) + .await + .context("Cannot run pending migrations")?; + drop(init_conn); + + // Per-connection PRAGMAs shared by both reader and writer pools. + // journal_mode = WAL is a no-op on an already-WAL database. + let base_options = SqliteConnectOptions::new() + .filename(file_name.clone()) + .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal) + .busy_timeout(Duration::from_secs(30)) + .log_slow_statements(log::LevelFilter::Warn, Duration::from_secs(30)) + // In WAL mode, NORMAL is safe: data survives OS crashes, only the + // last transaction can be lost on power failure. The default FULL + // forces an extra fsync() per commit, roughly halving write throughput. + .pragma("synchronous", "NORMAL") + // 16 MB page cache per connection (negative = KiB). Reduces disk + // reads for the latest_document_versions GROUP BY view. + .pragma("cache_size", "-16384") + // Memory-mapped I/O avoids read() syscalls. SQLite falls back to + // regular I/O for writes and beyond the mapped region. 256 MB is + // conservative; the OS handles actual memory pressure. + .pragma("mmap_size", "268435456") + // Keep temp tables and sort spillovers in memory instead of temp files. + .pragma("temp_store", "MEMORY") + // Cap WAL file growth at 64 MB. Without this, the WAL can grow + // unbounded during heavy write bursts (e.g. E2E tests with many + // concurrent clients). SQLite truncates to this size on checkpoint. + .pragma("journal_size_limit", "67108864"); + + // Reader pool: multiple connections for concurrent reads. + let reader = SqlitePoolOptions::new() .max_connections(config.max_connections_per_vault) .acquire_slow_threshold(Duration::from_secs(30)) - .test_before_acquire(true) - .connect_with(connection_options) + // Disabled: the health-check query is subject to busy_timeout + // and blocks all connection checkouts when a write is active, + // starving the pool for up to 30s even for simple reads. + // The before_acquire ROLLBACK hook is sufficient for cleanup. + .test_before_acquire(false) + .before_acquire(rollback_before_acquire) + .connect_with(base_options.clone()) .await - .with_context(|| format!("Cannot open database at `{}`", file_name.display()))?; + .with_context(|| format!("Cannot open reader pool at `{}`", file_name.display()))?; - Self::run_migrations(&pool).await?; + // Writer pool: exactly 1 connection, dedicated to writes. + // Since the Tokio mutex already serializes writers per vault, this + // single connection is never contended. Separating it from the + // reader pool ensures writes never compete with reads for pool slots. + let writer = SqlitePoolOptions::new() + .max_connections(1) + .acquire_slow_threshold(Duration::from_secs(30)) + .test_before_acquire(false) + .before_acquire(rollback_before_acquire) + .connect_with(base_options) + .await + .with_context(|| format!("Cannot open writer pool at `{}`", file_name.display()))?; - Ok(pool) + Ok(VaultPools { reader, writer }) } - async fn run_migrations(pool: &Pool) -> Result<()> { - sqlx::migrate!("src/app_state/database/migrations") - .run(pool) - .await - .context("Cannot check for pending migrations") - } - - async fn get_connection_pool(&self, vault: &VaultId) -> Result> { - let mut pools = self.connection_pools.lock().await; - - if !pools.contains_key(vault) { - let pool = Self::create_vault_database(&self.config, vault).await?; - pools.insert( - vault.clone(), - PoolWithTimestamp { - pool, - last_accessed: Instant::now(), - }, + fn validate_vault_id(vault: &VaultId) -> Result<()> { + if vault.is_empty() { + anyhow::bail!("Vault ID must not be empty"); + } + if vault.contains('/') + || vault.contains('\\') + || vault.contains("..") + || vault.contains('\0') + { + anyhow::bail!( + "Invalid vault ID: must not contain path separators, '..', or null bytes" ); } - - let pool_with_timestamp = pools - .get_mut(vault) - .expect("Pool was just inserted or already exists"); - - // Update last accessed time - pool_with_timestamp.last_accessed = Instant::now(); - - Ok(pool_with_timestamp.pool.clone()) + Ok(()) } - /// Attempting to write from this transaction might result in a - /// database locked error. Use this transaction for read-only operations. - pub async fn create_readonly_transaction( - &self, - vault: &VaultId, - ) -> Result> { - self.get_connection_pool(vault) - .await? - .begin() - .await - .context("Cannot create transaction") - } + async fn get_vault_pools(&self, vault: &VaultId) -> Result { + Self::validate_vault_id(vault)?; - pub async fn create_write_transaction(&self, vault: &VaultId) -> Result> { - let mut transaction = self.create_readonly_transaction(vault).await?; + // Get or create the VaultPool entry. The global lock is held only + // long enough for a HashMap lookup/insert — never across + // create_vault_database. + let vault_pool = { + let mut pools = self.connection_pools.lock().await; + pools + .entry(vault.clone()) + .or_insert_with(|| { + Arc::new(VaultPool { + cell: Arc::new(OnceCell::new()), + last_accessed_ms: AtomicU64::new(self.now_ms()), + }) + }) + .clone() + }; - // sqlx doesn't support immediate transactions for sqlite: https://github.com/launchbadge/sqlx/issues/481 - sqlx::query!("END; BEGIN IMMEDIATE;") - .execute(&mut *transaction) + // OnceCell::get_or_try_init guarantees exactly-once + // initialization: concurrent callers for the same vault wait + // here; callers for other vaults are not blocked. + let config = self.config.clone(); + let vault_clone = vault.clone(); + let pools = vault_pool + .cell + .get_or_try_init(|| async { Self::create_vault_database(&config, &vault_clone).await }) .await?; - Ok(transaction) + vault_pool + .last_accessed_ms + .store(self.now_ms(), Ordering::Relaxed); + Ok(pools.clone()) } - /// Return the latest state of all documents in the vault + /// Return the reader pool for read-only queries. + async fn get_connection_pool(&self, vault: &VaultId) -> Result> { + Ok(self.get_vault_pools(vault).await?.reader) + } + + pub async fn create_write_transaction(&self, vault: &VaultId) -> Result { + let write_lock = { + let mut locks = self.write_locks.lock().await; + locks + .entry(vault.clone()) + .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(()))) + .clone() + }; + let write_guard = write_lock.lock_owned().await; + let pools = self.get_vault_pools(vault).await?; + WriteTransaction::new(&pools.writer, write_guard).await + } + + /// Return the latest state of all documents in the vault, optionally + /// bounded above by `up_to_vault_update_id` so that the result is a + /// stable snapshot at exactly that cursor (commits past the cursor + /// will be delivered separately via the broadcast channel). pub async fn get_latest_documents( &self, vault: &VaultId, - transaction: Option<&mut Transaction<'_>>, + up_to_vault_update_id: Option, + connection: Option<&mut SqliteConnection>, ) -> Result> { + // `i64::MAX` makes the upper bound a no-op for callers that don't + // care about an exact snapshot (they pass `None`). + let upper = up_to_vault_update_id.unwrap_or(i64::MAX); let query = sqlx::query!( r#" select vault_update_id, + creation_vault_update_id, document_id as "document_id: Hyphenated", relative_path, updated_date as "updated_date: chrono::DateTime", @@ -194,12 +459,14 @@ impl Database { device_id, length(content) as "content_size: u64" from latest_document_versions + where vault_update_id <= ? order by vault_update_id "#, + upper, ); - if let Some(transaction) = transaction { - query.fetch_all(&mut **transaction).await + if let Some(conn) = connection { + query.fetch_all(&mut *conn).await } else { query .fetch_all(&self.get_connection_pool(vault).await?) @@ -216,42 +483,72 @@ impl Database { is_deleted: row.is_deleted, user_id: row.user_id, device_id: row.device_id, - content_size: row - .content_size - .expect("Content size can't be null but sqlx can't infer it"), + content_size: row.content_size.unwrap_or(0), + is_new_file: row.creation_vault_update_id == row.vault_update_id, }) .collect() }) } /// Return the latest state of all documents (including deleted) in the - /// vault which have changed since the given update id + /// vault which have changed since the given update id, bounded above + /// by `up_to_vault_update_id` so the catch-up result is a stable + /// snapshot at exactly that cursor. Commits past the cursor will be + /// delivered separately via the broadcast channel. pub async fn get_latest_documents_since( &self, vault: &VaultId, vault_update_id: VaultUpdateId, - transaction: Option<&mut Transaction<'_>>, + up_to_vault_update_id: Option, + connection: Option<&mut SqliteConnection>, ) -> Result> { + // `i64::MAX` makes the upper bound a no-op for callers that don't + // care about an exact snapshot (they pass `None`). + let upper = up_to_vault_update_id.unwrap_or(i64::MAX); + // Compute "latest version as of `upper`" per document — NOT + // global latest. The `latest_document_versions` view is keyed + // on global max, so a write that commits between the catch-up's + // cursor capture (under broadcast send-lock) and this query + // (which runs after drop-lock) would expose a `vault_update_id + // > cursor` row that the cursor filter then drops, removing + // the doc from the catch-up entirely. The post-cursor live + // broadcast then carries `is_new_file = false` (per real-time + // semantics it's an update of a previously-existing version), + // and the receiving client — which has no record of the doc — + // ignores it as stale, stranding the doc forever. Computing + // the snapshot from the documents table directly with the + // upper bound applied at the GROUP BY layer keeps the + // catch-up self-contained at exactly the cursor. let query = sqlx::query!( r#" select - vault_update_id, - document_id as "document_id: Hyphenated", - relative_path, - updated_date as "updated_date: chrono::DateTime", - is_deleted, - user_id, - device_id, - length(content) as "content_size: u64" - from latest_document_versions - where vault_update_id > ? - order by vault_update_id + d.vault_update_id, + d.creation_vault_update_id, + d.document_id as "document_id: Hyphenated", + d.relative_path, + d.updated_date as "updated_date: chrono::DateTime", + d.is_deleted, + d.user_id, + d.device_id, + length(d.content) as "content_size: u64" + from documents d + inner join ( + select document_id, max(vault_update_id) as max_vid + from documents + where vault_update_id <= ? + group by document_id + ) latest_at_cursor + on d.document_id = latest_at_cursor.document_id + and d.vault_update_id = latest_at_cursor.max_vid + where d.vault_update_id > ? + order by d.vault_update_id "#, - vault_update_id + upper, + vault_update_id, ); - if let Some(transaction) = transaction { - query.fetch_all(&mut **transaction).await + if let Some(conn) = connection { + query.fetch_all(&mut *conn).await } else { query .fetch_all(&self.get_connection_pool(vault).await?) @@ -270,9 +567,18 @@ impl Database { is_deleted: row.is_deleted, user_id: row.user_id, device_id: row.device_id, - content_size: row - .content_size - .expect("Content size can't be null but sqlx can't infer it"), + content_size: row.content_size.unwrap_or(0), + // For catch-up streams, "new file" means "new to this + // recipient" — the doc was created past the recipient's + // watermark. The catch-up only carries the doc's + // *latest* version (not its full history), so using + // `creation == latest` instead would mis-flag every + // doc that was created and then updated before the + // client reconnected, and the client's + // `processRemoteChange` would drop it as "stale + // RemoteChange for untracked, non-new document", + // silently leaking docs to clients catching up. + is_new_file: row.creation_vault_update_id > vault_update_id, }) .collect() }) @@ -281,7 +587,7 @@ impl Database { pub async fn get_max_update_id_in_vault( &self, vault: &VaultId, - transaction: Option<&mut Transaction<'_>>, + connection: Option<&mut SqliteConnection>, ) -> Result { let query = sqlx::query!( r#" @@ -290,8 +596,8 @@ impl Database { "#, ); - if let Some(transaction) = transaction { - query.fetch_one(&mut **transaction).await + if let Some(conn) = connection { + query.fetch_one(&mut *conn).await } else { query .fetch_one(&self.get_connection_pool(vault).await?) @@ -301,17 +607,18 @@ impl Database { .context("Cannot fetch max update id in vault") } - pub async fn get_latest_document_by_path( + pub async fn get_latest_non_deleted_document_by_path( &self, vault: &VaultId, relative_path: &str, - transaction: Option<&mut Transaction<'_>>, + connection: Option<&mut SqliteConnection>, ) -> Result> { let query = sqlx::query_as!( StoredDocumentVersion, r#" select vault_update_id, + creation_vault_update_id, document_id as "document_id: Hyphenated", relative_path, updated_date as "updated_date: chrono::DateTime", @@ -330,8 +637,8 @@ impl Database { relative_path ); - if let Some(transaction) = transaction { - query.fetch_optional(&mut **transaction).await + if let Some(conn) = connection { + query.fetch_optional(&mut *conn).await } else { query .fetch_optional(&self.get_connection_pool(vault).await?) @@ -340,11 +647,79 @@ impl Database { .context("Cannot fetch latest document version") } + /// Find a doc whose CREATE was authored by this device with + /// matching content, and whose creation the requesting client + /// hasn't observed yet (`creation_vault_update_id > last_seen`). + /// Used by `create_document` to recover from a "lost create" + /// race: this device's create response was discarded mid-flight, + /// so the retry comes in as a brand-new create — possibly at a + /// renamed path. Binding the retry to the existing doc avoids + /// duplicating the content under a deconflicted path. + /// + /// Matches against the doc's CREATION version (not the latest) + /// because a same-path concurrent create from another agent may + /// have merged into our doc since: the latest version's content + /// is the merge result, not what we originally sent. Joining on + /// `creation_vault_update_id` recovers the original bytes. + /// + /// The `device_id` + `creation > last_seen` combination scopes + /// the dedup to "we genuinely lost track of our own create"; + /// another agent's same-content doc won't match because of + /// `device_id`, and a doc this client already saw won't match + /// because of the watermark check. + pub async fn find_unseen_lost_create_by_device_and_content( + &self, + vault: &VaultId, + device_id: &str, + last_seen_vault_update_id: VaultUpdateId, + content: &[u8], + connection: Option<&mut SqliteConnection>, + ) -> Result> { + let query = sqlx::query_as!( + StoredDocumentVersion, + r#" + select + lv.vault_update_id, + lv.creation_vault_update_id, + lv.document_id as "document_id: Hyphenated", + lv.relative_path, + lv.updated_date as "updated_date: chrono::DateTime", + lv.content, + lv.is_deleted, + lv.user_id, + lv.device_id, + lv.has_been_merged + from latest_document_versions lv + inner join documents creation + on creation.document_id = lv.document_id + and creation.vault_update_id = lv.creation_vault_update_id + where creation.device_id = ? + and creation.content = ? + and lv.is_deleted = false + and lv.creation_vault_update_id > ? + order by lv.creation_vault_update_id desc + limit 1 + "#, + device_id, + content, + last_seen_vault_update_id, + ); + + if let Some(conn) = connection { + query.fetch_optional(&mut *conn).await + } else { + query + .fetch_optional(&self.get_connection_pool(vault).await?) + .await + } + .context("Cannot fetch lost-create candidate") + } + pub async fn get_latest_document( &self, vault: &VaultId, document_id: &DocumentId, - transaction: Option<&mut Transaction<'_>>, + connection: Option<&mut SqliteConnection>, ) -> Result> { let document_id = document_id.as_hyphenated(); let query = sqlx::query_as!( @@ -352,6 +727,7 @@ impl Database { r#" select vault_update_id, + creation_vault_update_id, document_id as "document_id: Hyphenated", relative_path, updated_date as "updated_date: chrono::DateTime", @@ -366,8 +742,8 @@ impl Database { document_id ); - if let Some(transaction) = transaction { - query.fetch_optional(&mut **transaction).await + if let Some(conn) = connection { + query.fetch_optional(&mut *conn).await } else { query .fetch_optional(&self.get_connection_pool(vault).await?) @@ -380,13 +756,14 @@ impl Database { &self, vault: &VaultId, vault_update_id: VaultUpdateId, - transaction: Option<&mut Transaction<'_>>, + connection: Option<&mut SqliteConnection>, ) -> Result> { let query = sqlx::query_as!( StoredDocumentVersion, r#" select vault_update_id, + creation_vault_update_id, document_id as "document_id: Hyphenated", relative_path, updated_date as "updated_date: chrono::DateTime", @@ -400,8 +777,8 @@ impl Database { vault_update_id ); - if let Some(transaction) = transaction { - query.fetch_optional(&mut **transaction).await + if let Some(conn) = connection { + query.fetch_optional(&mut *conn).await } else { query .fetch_optional(&self.get_connection_pool(vault).await?) @@ -410,105 +787,307 @@ impl Database { .context("Cannot fetch document version") } - // inserting the document must be the last step of the transaction if there's one + // inserting the document must be the last step of the transaction pub async fn insert_document_version( &self, vault_id: &VaultId, version: &StoredDocumentVersion, - transaction: Option>, + mut transaction: WriteTransaction, ) -> Result<()> { let document_id = version.document_id.as_hyphenated(); let query = sqlx::query!( r#" insert into documents ( vault_update_id, + creation_vault_update_id, document_id, relative_path, updated_date, content, is_deleted, user_id, - device_id + device_id, + has_been_merged ) - values (?, ?, ?, ?, ?, ?, ?, ?) + values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) "#, version.vault_update_id, + version.creation_vault_update_id, document_id, version.relative_path, version.updated_date, version.content, version.is_deleted, version.user_id, - version.device_id + version.device_id, + version.has_been_merged ); - if let Some(mut transaction) = transaction { - query - .execute(&mut *transaction) - .await - .context("Cannot insert document version")?; + // Acquire the broadcast send lock before the insert so that + // broadcasts are serialized in vault_update_id order even after + // the write transaction (and its per-vault lock) is released. + let _send_guard = self.broadcasts.acquire_send_lock(vault_id).await; - transaction - .commit() - .await - .context("Failed to commit transaction")?; + query + .execute(&mut *transaction) + .await + .context("Cannot insert document version")?; + + transaction + .commit() + .await + .context("Failed to commit transaction")?; + + // For non-delete writes the originating device already has + // authoritative state from its HTTP response, so we tag the + // broadcast with `origin_device_id` and the send task in + // `websocket.rs` filters it out for that device. Deletes are + // delivered to *every* connected client including the author — + // the originator only removes the document from its sync queue + // once it receives this receipt. + let envelope = WebSocketServerMessage::VaultUpdate(WebSocketVaultUpdate { + document: version.clone().into(), + }); + let with_origin = if version.is_deleted { + WebSocketServerMessageWithOrigin::new(envelope) } else { - query - .execute(&self.get_connection_pool(vault_id).await?) - .await - .context("Cannot insert document version")?; - } - + WebSocketServerMessageWithOrigin::with_origin(version.device_id.clone(), envelope) + }; self.broadcasts - .send_document_update( - vault_id.clone(), - WebSocketServerMessageWithOrigin::with_origin( - version.device_id.clone(), - WebSocketServerMessage::VaultUpdate(WebSocketVaultUpdate { - documents: vec![version.clone().into()], - is_initial_sync: false, - }), - ), - ) - .await; + .send_document_update(vault_id.clone(), with_origin); Ok(()) } + /// Return all versions (without content) of a specific document, ordered by `vault_update_id` + pub async fn get_document_versions( + &self, + vault: &VaultId, + document_id: &DocumentId, + connection: Option<&mut SqliteConnection>, + ) -> Result> { + let document_id = document_id.as_hyphenated(); + let query = sqlx::query!( + r#" + select + vault_update_id, + creation_vault_update_id, + document_id as "document_id: Hyphenated", + relative_path, + updated_date as "updated_date: chrono::DateTime", + is_deleted, + user_id, + device_id, + length(content) as "content_size: u64" + from documents + where document_id = ? + order by vault_update_id + "#, + document_id, + ); + + if let Some(conn) = connection { + query.fetch_all(&mut *conn).await + } else { + query + .fetch_all(&self.get_connection_pool(vault).await?) + .await + } + .with_context(|| format!("Cannot fetch document versions for document `{document_id}`")) + .map(|rows| { + rows.into_iter() + .map(|row| DocumentVersionWithoutContent { + vault_update_id: row.vault_update_id, + document_id: row.document_id.into(), + relative_path: row.relative_path, + updated_date: row.updated_date, + is_deleted: row.is_deleted, + user_id: row.user_id, + device_id: row.device_id, + content_size: row.content_size.unwrap_or(0), + is_new_file: row.creation_vault_update_id == row.vault_update_id, + }) + .collect() + }) + } + + /// Return all versions across all documents, paginated, ordered by `vault_update_id` DESC + pub async fn get_vault_history( + &self, + vault: &VaultId, + limit: i64, + before_update_id: Option, + connection: Option<&mut SqliteConnection>, + ) -> Result> { + let map_row = |row: models::VaultHistoryRow| DocumentVersionWithoutContent { + vault_update_id: row.vault_update_id, + document_id: row.document_id, + relative_path: row.relative_path, + updated_date: row.updated_date, + is_deleted: row.is_deleted, + user_id: row.user_id, + device_id: row.device_id, + content_size: row.content_size.unwrap_or(0), + is_new_file: row.creation_vault_update_id == row.vault_update_id, + }; + + if let Some(before) = before_update_id { + let query = sqlx::query_as!( + models::VaultHistoryRow, + r#" + select + vault_update_id, + creation_vault_update_id, + document_id as "document_id: Hyphenated", + relative_path, + updated_date as "updated_date: chrono::DateTime", + is_deleted, + user_id, + device_id, + length(content) as "content_size: u64" + from documents + where vault_update_id < ? + order by vault_update_id desc + limit ? + "#, + before, + limit, + ); + + let rows = if let Some(conn) = connection { + query.fetch_all(&mut *conn).await + } else { + query + .fetch_all(&self.get_connection_pool(vault).await?) + .await + } + .context("Cannot fetch vault history")?; + + Ok(rows.into_iter().map(map_row).collect()) + } else { + let query = sqlx::query_as!( + models::VaultHistoryRow, + r#" + select + vault_update_id, + creation_vault_update_id, + document_id as "document_id: Hyphenated", + relative_path, + updated_date as "updated_date: chrono::DateTime", + is_deleted, + user_id, + device_id, + length(content) as "content_size: u64" + from documents + order by vault_update_id desc + limit ? + "#, + limit, + ); + + let rows = if let Some(conn) = connection { + query.fetch_all(&mut *conn).await + } else { + query + .fetch_all(&self.get_connection_pool(vault).await?) + .await + } + .context("Cannot fetch vault history")?; + + Ok(rows.into_iter().map(map_row).collect()) + } + } + /// Cleanup idle connection pools that haven't been accessed in more than 5 minutes async fn cleanup_idle_pools(&self) { - let mut pools = self.connection_pools.lock().await; - let now = Instant::now(); - let idle_timeout = Duration::from_secs(5 * 60); // 5 minutes + // Collect idle vaults and remove them from the map while holding + // the lock briefly. Close pools OUTSIDE the lock so that + // pool.close().await doesn't block other get_connection_pool calls. + let idle_pools: Vec<(VaultId, Arc)> = { + let mut pools = self.connection_pools.lock().await; + let now_ms = self.now_ms(); + let idle_threshold_ms = IDLE_POOL_TIMEOUT.as_millis() as u64; - // Collect vaults to remove - let vaults_to_remove: Vec = pools - .iter() - .filter(|(_, pool_with_timestamp)| { - now.duration_since(pool_with_timestamp.last_accessed) > idle_timeout + let vaults_to_remove: Vec = pools + .iter() + .filter(|(_, vp)| { + let last = vp.last_accessed_ms.load(Ordering::Relaxed); + now_ms.saturating_sub(last) > idle_threshold_ms + }) + .map(|(vault_id, _)| vault_id.clone()) + .collect(); + + vaults_to_remove + .into_iter() + .filter_map(|id| pools.remove(&id).map(|vp| (id, vp))) + .collect() + }; + + // Close pools concurrently so cleanup doesn't serialise across vaults + let closures: Vec<_> = idle_pools + .into_iter() + .filter_map(|(vault_id, vault_pool)| { + vault_pool + .cell + .get() + .cloned() + .map(|pools| (vault_id, pools)) }) - .map(|(vault_id, _)| vault_id.clone()) .collect(); - // Close and remove idle pools - for vault_id in &vaults_to_remove { - if let Some(pool_with_timestamp) = pools.remove(vault_id) { - info!("Closing idle database connection pool for vault `{vault_id}`"); - pool_with_timestamp.pool.close().await; - } + let handles: Vec<_> = closures + .into_iter() + .map(|(vault_id, pools)| { + tokio::spawn(async move { + // Checkpoint the WAL before closing to reclaim disk space. + // Run on the blocking pool so disk I/O doesn't starve the runtime + let writer_clone = pools.writer.clone(); + let ckpt_result = tokio::task::spawn_blocking(move || { + futures::executor::block_on( + sqlx::query("PRAGMA wal_checkpoint(TRUNCATE)").execute(&writer_clone), + ) + }) + .await; + + match ckpt_result { + Ok(Err(e)) => { + log::warn!("WAL checkpoint failed for vault `{vault_id}`: {e}"); + } + Err(e) => { + log::warn!("WAL checkpoint task panicked for vault `{vault_id}`: {e}"); + } + _ => {} + } + + info!("Closing idle database connection pools for vault `{vault_id}`"); + pools.reader.close().await; + pools.writer.close().await; + }) + }) + .collect(); + + for handle in handles { + let _ = handle.await; } } /// Start a background task that periodically cleans up idle connection pools - fn start_idle_pool_cleanup(&self) { + fn start_idle_pool_cleanup(&self, mut shutdown: tokio::sync::watch::Receiver<()>) { let database = self.clone(); tokio::spawn(async move { let mut interval = tokio::time::interval(Duration::from_secs(60)); // Check every minute interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); loop { - interval.tick().await; - database.cleanup_idle_pools().await; + tokio::select! { + _ = interval.tick() => { + database.cleanup_idle_pools().await; + } + _ = shutdown.changed() => { + info!("Idle pool cleanup task shutting down"); + break; + } + } } }); } diff --git a/sync-server/src/app_state/database/migrations/20260314000000_add_idempotency_key.sql b/sync-server/src/app_state/database/migrations/20260314000000_add_idempotency_key.sql new file mode 100644 index 00000000..f3ee8dd3 --- /dev/null +++ b/sync-server/src/app_state/database/migrations/20260314000000_add_idempotency_key.sql @@ -0,0 +1,2 @@ +CREATE INDEX IF NOT EXISTS idx_documents_document_id +ON documents (document_id, vault_update_id); diff --git a/sync-server/src/app_state/database/migrations/20260421000000_add_creation_vault_update_id.sql b/sync-server/src/app_state/database/migrations/20260421000000_add_creation_vault_update_id.sql new file mode 100644 index 00000000..40dc85fb --- /dev/null +++ b/sync-server/src/app_state/database/migrations/20260421000000_add_creation_vault_update_id.sql @@ -0,0 +1,20 @@ +ALTER TABLE documents ADD COLUMN creation_vault_update_id INTEGER NOT NULL DEFAULT 0; + +UPDATE documents +SET creation_vault_update_id = ( + SELECT MIN(d2.vault_update_id) + FROM documents d2 + WHERE d2.document_id = documents.document_id +); + +DROP VIEW latest_document_versions; + +CREATE VIEW IF NOT EXISTS latest_document_versions AS --recreate view as it now includes one more field +SELECT d.* +FROM documents d +INNER JOIN ( + SELECT MAX(vault_update_id) AS max_version_id + FROM documents + GROUP BY document_id +) max_versions +ON d.vault_update_id = max_versions.max_version_id; diff --git a/sync-server/src/app_state/database/models.rs b/sync-server/src/app_state/database/models.rs index a216125a..89867067 100644 --- a/sync-server/src/app_state/database/models.rs +++ b/sync-server/src/app_state/database/models.rs @@ -13,6 +13,7 @@ pub type DeviceId = String; #[derive(Debug, Clone)] pub struct StoredDocumentVersion { pub vault_update_id: VaultUpdateId, + pub creation_vault_update_id: VaultUpdateId, pub document_id: DocumentId, pub relative_path: String, pub updated_date: DateTime, @@ -33,7 +34,7 @@ impl PartialEq for StoredDocumentVersion { #[derive(TS, Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] pub struct DocumentVersionWithoutContent { - #[ts(as = "i32")] + #[ts(type = "number")] pub vault_update_id: VaultUpdateId, pub document_id: DocumentId, @@ -43,12 +44,16 @@ pub struct DocumentVersionWithoutContent { pub user_id: UserId, pub device_id: DeviceId, - #[ts(as = "i32")] + #[ts(type = "number")] pub content_size: u64, + + /// True iff this is the first version of the document + pub is_new_file: bool, } impl From for DocumentVersionWithoutContent { fn from(value: StoredDocumentVersion) -> Self { + let is_new_file = value.creation_vault_update_id == value.vault_update_id; Self { vault_update_id: value.vault_update_id, document_id: value.document_id, @@ -58,6 +63,7 @@ impl From for DocumentVersionWithoutContent { user_id: value.user_id, device_id: value.device_id, content_size: value.content.len() as u64, + is_new_file, } } } @@ -65,7 +71,7 @@ impl From for DocumentVersionWithoutContent { #[derive(TS, Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] pub struct DocumentVersion { - #[ts(as = "i32")] + #[ts(type = "number")] pub vault_update_id: VaultUpdateId, pub document_id: DocumentId, @@ -77,6 +83,25 @@ pub struct DocumentVersion { pub device_id: DeviceId, } +/// Row struct for vault history queries (used by `sqlx::query_as!`) +#[derive(Debug)] +pub struct VaultHistoryRow { + pub vault_update_id: VaultUpdateId, + pub creation_vault_update_id: VaultUpdateId, + pub document_id: DocumentId, + pub relative_path: String, + pub updated_date: DateTime, + pub is_deleted: bool, + pub user_id: String, + pub device_id: String, + pub content_size: Option, +} + +pub struct VaultStats { + pub created_at: Option>, + pub document_count: u32, +} + impl From for DocumentVersion { fn from(value: StoredDocumentVersion) -> Self { Self {