Add proper shutdown, rate limits, config validation, cors config, fix dangling cursors, cache regex, merge created texts

This commit is contained in:
Andras Schmelczer 2026-03-28 09:49:46 +00:00
parent 4763bc9d04
commit e15b0f9903
28 changed files with 1277 additions and 464 deletions

View file

@ -42,7 +42,9 @@ impl Cursors {
) {
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
let all_device_cursors = vault_to_cursors.entry(vault_id).or_insert_with(Vec::new);
let all_device_cursors = vault_to_cursors
.entry(vault_id.clone())
.or_insert_with(Vec::new);
all_device_cursors.retain(|c| &c.client_cursors.device_id != device_id);
all_device_cursors.push(ClientCursorsWithTimeToLive::new(ClientCursors {
@ -52,7 +54,7 @@ impl Cursors {
}));
drop(vault_to_cursors); // Explicitly drop the lock before broadcasting to avoid deadlock
self.broadcast_cursors().await;
self.broadcast_cursors_for_vault(&vault_id).await;
}
pub async fn get_cursors(&self, vault_id: &VaultId) -> Vec<ClientCursors> {
@ -69,45 +71,83 @@ impl Cursors {
.unwrap_or_default()
}
pub fn start_background_task(self) {
pub fn start_background_task(self, mut shutdown: tokio::sync::watch::Receiver<()>) {
tokio::spawn(async move {
loop {
self.remove_expired_cursors().await;
tokio::time::sleep(Duration::from_secs(1)).await;
tokio::select! {
() = tokio::time::sleep(Duration::from_secs(1)) => {
self.remove_expired_cursors().await;
}
Ok(()) = shutdown.changed() => break,
}
}
});
}
async fn remove_expired_cursors(&self) {
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
let changed_vaults: Vec<VaultId> = {
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
for (_vault_id, cursors) in vault_to_cursors.iter_mut() {
cursors.retain(|cursor| !cursor.is_expired(self.config.cursor_timeout));
let mut changed = Vec::new();
for (vault_id, cursors) in vault_to_cursors.iter_mut() {
let before = cursors.len();
cursors.retain(|cursor| !cursor.is_expired(self.config.cursor_timeout));
if cursors.len() != before {
changed.push(vault_id.clone());
}
}
// Remove empty vault entries to prevent unbounded growth
vault_to_cursors.retain(|_, cursors| !cursors.is_empty());
changed
};
for vault_id in &changed_vaults {
self.broadcast_cursors_for_vault(vault_id).await;
}
}
async fn broadcast_cursors(&self) {
let vault_to_cursors = self.vault_to_cursors.lock().await;
async fn broadcast_cursors_for_vault(&self, vault_id: &VaultId) {
let client_cursors: Vec<ClientCursors> = {
let vault_to_cursors = self.vault_to_cursors.lock().await;
vault_to_cursors
.get(vault_id)
.map(|cursors| cursors.iter().map(|c| c.client_cursors.clone()).collect())
.unwrap_or_default()
};
for (vault_id, cursors) in vault_to_cursors.iter() {
self.broadcasts
.send_document_update(
vault_id.clone(),
WebSocketServerMessageWithOrigin::new(WebSocketServerMessage::CursorPositions(
CursorPositionFromServer {
clients: cursors.iter().map(|c| c.client_cursors.clone()).collect(),
},
)),
)
.await;
}
self.broadcasts
.send_document_update(
vault_id.clone(),
WebSocketServerMessageWithOrigin::new(WebSocketServerMessage::CursorPositions(
CursorPositionFromServer {
clients: client_cursors,
},
)),
)
.await;
}
pub async fn remove_cursors_of_device(&self, vault_id: &str, device_id: &str) {
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
pub async fn remove_cursors_of_device(&self, vault_id: &VaultId, device_id: &DeviceId) {
let changed = {
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
if let Some(cursors) = vault_to_cursors.get_mut(vault_id) {
cursors.retain(|c| c.client_cursors.device_id != device_id);
if let Some(cursors) = vault_to_cursors.get_mut(vault_id) {
let before = cursors.len();
cursors.retain(|c| c.client_cursors.device_id != *device_id);
let changed = cursors.len() != before;
if cursors.is_empty() {
vault_to_cursors.remove(vault_id);
}
changed
} else {
false
}
};
if changed {
self.broadcast_cursors_for_vault(vault_id).await;
}
}
}

View file

@ -9,8 +9,17 @@ use models::{
use sqlx::{ConnectOptions, sqlite::SqliteConnectOptions, types::chrono::Utc};
pub mod models;
use sqlx::{Pool, Sqlite, sqlite::SqlitePoolOptions};
use tokio::sync::Mutex;
/// Sentinel error indicating the SQLite database is busy (SQLITE_BUSY).
/// Handlers can downcast to this to return 429 instead of 500.
#[derive(Debug, thiserror::Error)]
#[error("Database is busy")]
pub struct WriteBusyError;
use sqlx::{
Pool, Sqlite, pool::PoolConnection, sqlite::SqliteConnection, sqlite::SqlitePoolOptions,
};
use tokio::sync::{Mutex, OnceCell};
use tokio::time::Instant;
use uuid::fmt::Hyphenated;
@ -19,33 +28,154 @@ use super::websocket::{
models::{WebSocketServerMessage, WebSocketServerMessageWithOrigin, WebSocketVaultUpdate},
};
use crate::config::database_config::DatabaseConfig;
use crate::consts::IDLE_POOL_TIMEOUT;
#[derive(Clone)]
struct PoolWithTimestamp {
pool: Pool<Sqlite>,
last_accessed: Instant,
/// Holds separate reader and writer pools for a single vault.
/// The writer pool has exactly 1 connection so writes never compete
/// with reads for pool slots.
#[derive(Debug, Clone)]
struct VaultPools {
reader: Pool<Sqlite>,
writer: Pool<Sqlite>,
}
impl std::fmt::Debug for PoolWithTimestamp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PoolWithTimestamp")
.field("pool", &"Pool<Sqlite>")
.field("last_accessed", &self.last_accessed)
.finish()
}
#[derive(Debug)]
struct VaultPool {
cell: Arc<OnceCell<VaultPools>>,
last_accessed: Mutex<Instant>,
}
#[derive(Clone, Debug)]
pub struct Database {
config: DatabaseConfig,
broadcasts: Broadcasts,
connection_pools: Arc<Mutex<HashMap<VaultId, PoolWithTimestamp>>>,
connection_pools: Arc<Mutex<HashMap<VaultId, Arc<VaultPool>>>>,
/// Per-vault write serialization. SQLite allows only one writer at a
/// time; `BEGIN IMMEDIATE` on a second connection blocks until the first
/// commits (up to `busy_timeout`). Under concurrent load the blocked
/// connections consume the pool, starving even read-only requests.
/// This mutex moves the wait from the SQLite layer (where it holds a
/// pool connection) to the Tokio layer (where it holds nothing).
write_locks: Arc<Mutex<HashMap<VaultId, Arc<tokio::sync::Mutex<()>>>>>,
}
pub type Transaction<'a> = sqlx::Transaction<'a, Sqlite>;
/// A write transaction backed by a raw `BEGIN IMMEDIATE` instead of sqlx's
/// savepoint-based `Transaction`. This avoids the savepoint mismatch caused
/// by the old `END; BEGIN IMMEDIATE;` workaround.
///
/// Holds an `OwnedMutexGuard` that serializes write transactions per vault
/// at the application level (see `Database::write_locks`). The guard is
/// released when the transaction is committed, rolled back, or dropped.
pub struct WriteTransaction {
conn: Option<PoolConnection<Sqlite>>,
_write_guard: tokio::sync::OwnedMutexGuard<()>,
}
impl WriteTransaction {
async fn new(pool: &Pool<Sqlite>, write_guard: tokio::sync::OwnedMutexGuard<()>) -> Result<Self> {
let mut conn = pool
.acquire()
.await
.context("Cannot acquire connection for write transaction")?;
if let Err(e) = sqlx::query("BEGIN IMMEDIATE")
.execute(&mut *conn)
.await
{
let is_busy = match &e {
sqlx::Error::Database(db_err) => {
// SQLITE_BUSY base code is 5. Extended codes share base 5.
let busy_by_code = db_err.code().is_some_and(|c| {
c.parse::<u32>().is_ok_and(|n| n & 0xFF == 5)
});
busy_by_code || db_err.message().contains("database is locked")
}
_ => false,
};
if is_busy {
return Err(WriteBusyError.into());
}
return Err(e).context("Cannot begin immediate transaction");
}
Ok(Self { conn: Some(conn), _write_guard: write_guard })
}
pub async fn commit(mut self) -> Result<()> {
if let Some(mut conn) = self.conn.take() {
sqlx::query("COMMIT")
.execute(&mut *conn)
.await
.context("Failed to commit transaction")?;
}
Ok(())
}
pub async fn rollback(mut self) -> Result<()> {
if let Some(mut conn) = self.conn.take() {
sqlx::query("ROLLBACK")
.execute(&mut *conn)
.await
.context("Failed to rollback transaction")?;
}
Ok(())
}
}
impl Drop for WriteTransaction {
fn drop(&mut self) {
if self.conn.is_some() {
// The connection is returned to the pool with an open transaction.
// The pool's `before_acquire` hook issues a ROLLBACK before
// handing it to the next consumer, so no async work is needed
// here. If the pool is being shut down, SQLite itself rolls back
// uncommitted transactions when the connection closes.
log::warn!("WriteTransaction dropped without commit or rollback");
}
}
}
impl std::ops::Deref for WriteTransaction {
type Target = SqliteConnection;
fn deref(&self) -> &Self::Target {
self.conn
.as_ref()
.expect("BUG: WriteTransaction dereferenced after being consumed")
.deref()
}
}
impl std::ops::DerefMut for WriteTransaction {
fn deref_mut(&mut self) -> &mut Self::Target {
self.conn
.as_mut()
.expect("BUG: WriteTransaction dereferenced after being consumed")
.deref_mut()
}
}
/// Ensure the connection has no leftover open transaction (e.g. from a
/// `WriteTransaction` that was dropped without commit/rollback). ROLLBACK
/// is a harmless no-op if no transaction is active.
fn rollback_before_acquire(
conn: &mut SqliteConnection,
_meta: sqlx::pool::PoolConnectionMetadata,
) -> futures::future::BoxFuture<'_, Result<bool, sqlx::Error>> {
Box::pin(async move {
if let Err(e) = sqlx::query("ROLLBACK").execute(&mut *conn).await {
// "cannot rollback - no transaction is active" is the common
// case (connection returned cleanly). Only unexpected errors
// deserve attention.
log::debug!("before_acquire ROLLBACK failed: {e}");
}
Ok(true)
})
}
impl Database {
pub async fn try_new(config: &DatabaseConfig, broadcasts: &Broadcasts) -> Result<Self> {
pub async fn try_new(
config: &DatabaseConfig,
broadcasts: &Broadcasts,
shutdown: tokio::sync::watch::Receiver<()>,
) -> Result<Self> {
tokio::fs::create_dir_all(&config.databases_directory_path)
.await
.with_context(|| {
@ -70,24 +200,29 @@ impl Database {
.trim_end_matches(".sqlite")
.to_owned();
let pool = Self::create_vault_database(config, &vault).await?;
Self::validate_vault_id(&vault)?;
let pools = Self::create_vault_database(config, &vault).await?;
let cell = Arc::new(OnceCell::new());
cell.set(pools).expect("cell is new");
connection_pools.insert(
vault.clone(),
PoolWithTimestamp {
pool,
last_accessed: Instant::now(),
},
Arc::new(VaultPool {
cell,
last_accessed: Mutex::new(Instant::now()),
}),
);
}
info!("Database migrations applied");
let database = Self {
config: config.clone(),
connection_pools: Arc::new(Mutex::new(connection_pools)),
broadcasts: broadcasts.clone(),
write_locks: Arc::new(Mutex::new(HashMap::new())),
};
// Start background task to cleanup idle connection pools
database.start_idle_pool_cleanup();
database.start_idle_pool_cleanup(shutdown);
Ok(database)
}
@ -95,92 +230,167 @@ impl Database {
async fn create_vault_database(
config: &DatabaseConfig,
vault: &VaultId,
) -> Result<Pool<Sqlite>> {
) -> Result<VaultPools> {
let file_name = config
.databases_directory_path
.join(format!("{vault}.sqlite"));
let connection_options = SqliteConnectOptions::new()
// Database-level PRAGMAs (auto_vacuum, journal_mode) require a write
// lock and persist across connections. Set them once with a dedicated
// init connection so pool connections never need the write lock just to
// open.
let init_options = SqliteConnectOptions::new()
.filename(file_name.clone())
.create_if_missing(true)
.auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Full)
.busy_timeout(Duration::from_secs(3600))
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
.log_slow_statements(log::LevelFilter::Warn, Duration::from_secs(30));
.auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental)
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal);
let pool = SqlitePoolOptions::new()
// Run migrations on a dedicated connection, NOT through the pool.
// The pool's `before_acquire` hook issues ROLLBACK on every checkout,
// which can roll back the migration's bookkeeping transaction (the
// _sqlx_migrations INSERT) while the DDL (ALTER TABLE) has already
// auto-committed — leaving the migration in a dirty state.
//
// Uses `run_direct` instead of `run` because `run` takes
// `impl Acquire<'_>`, whose lifetime bound prevents the enclosing
// future from satisfying the `Send` requirement of axum handlers.
let mut init_conn = sqlx::SqliteConnection::connect_with(&init_options).await?;
sqlx::migrate!("src/app_state/database/migrations")
.run_direct(&mut init_conn)
.await
.context("Cannot run pending migrations")?;
drop(init_conn);
// Per-connection PRAGMAs shared by both reader and writer pools.
// journal_mode = WAL is a no-op on an already-WAL database.
let base_options = SqliteConnectOptions::new()
.filename(file_name.clone())
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
.busy_timeout(Duration::from_secs(30))
.log_slow_statements(log::LevelFilter::Warn, Duration::from_secs(30))
// In WAL mode, NORMAL is safe: data survives OS crashes, only the
// last transaction can be lost on power failure. The default FULL
// forces an extra fsync() per commit, roughly halving write throughput.
.pragma("synchronous", "NORMAL")
// 16 MB page cache per connection (negative = KiB). Reduces disk
// reads for the latest_document_versions GROUP BY view.
.pragma("cache_size", "-16384")
// Memory-mapped I/O avoids read() syscalls. SQLite falls back to
// regular I/O for writes and beyond the mapped region. 256 MB is
// conservative; the OS handles actual memory pressure.
.pragma("mmap_size", "268435456")
// Keep temp tables and sort spillovers in memory instead of temp files.
.pragma("temp_store", "MEMORY")
// Cap WAL file growth at 64 MB. Without this, the WAL can grow
// unbounded during heavy write bursts (e.g. E2E tests with many
// concurrent clients). SQLite truncates to this size on checkpoint.
.pragma("journal_size_limit", "67108864");
// Reader pool: multiple connections for concurrent reads.
let reader = SqlitePoolOptions::new()
.max_connections(config.max_connections_per_vault)
.acquire_slow_threshold(Duration::from_secs(30))
.test_before_acquire(true)
.connect_with(connection_options)
// Disabled: the health-check query is subject to busy_timeout
// and blocks all connection checkouts when a write is active,
// starving the pool for up to 30s even for simple reads.
// The before_acquire ROLLBACK hook is sufficient for cleanup.
.test_before_acquire(false)
.before_acquire(rollback_before_acquire)
.connect_with(base_options.clone())
.await
.with_context(|| format!("Cannot open database at `{}`", file_name.display()))?;
.with_context(|| format!("Cannot open reader pool at `{}`", file_name.display()))?;
Self::run_migrations(&pool).await?;
// Writer pool: exactly 1 connection, dedicated to writes.
// Since the Tokio mutex already serializes writers per vault, this
// single connection is never contended. Separating it from the
// reader pool ensures writes never compete with reads for pool slots.
let writer = SqlitePoolOptions::new()
.max_connections(1)
.acquire_slow_threshold(Duration::from_secs(30))
.test_before_acquire(false)
.before_acquire(rollback_before_acquire)
.connect_with(base_options)
.await
.with_context(|| format!("Cannot open writer pool at `{}`", file_name.display()))?;
Ok(pool)
Ok(VaultPools { reader, writer })
}
async fn run_migrations(pool: &Pool<Sqlite>) -> Result<()> {
sqlx::migrate!("src/app_state/database/migrations")
.run(pool)
.await
.context("Cannot check for pending migrations")
}
async fn get_connection_pool(&self, vault: &VaultId) -> Result<Pool<Sqlite>> {
let mut pools = self.connection_pools.lock().await;
if !pools.contains_key(vault) {
let pool = Self::create_vault_database(&self.config, vault).await?;
pools.insert(
vault.clone(),
PoolWithTimestamp {
pool,
last_accessed: Instant::now(),
},
fn validate_vault_id(vault: &VaultId) -> Result<()> {
if vault.is_empty() {
anyhow::bail!("Vault ID must not be empty");
}
if vault.contains('/')
|| vault.contains('\\')
|| vault.contains("..")
|| vault.contains('\0')
{
anyhow::bail!(
"Invalid vault ID: must not contain path separators, '..', or null bytes"
);
}
let pool_with_timestamp = pools
.get_mut(vault)
.expect("Pool was just inserted or already exists");
// Update last accessed time
pool_with_timestamp.last_accessed = Instant::now();
Ok(pool_with_timestamp.pool.clone())
Ok(())
}
/// Attempting to write from this transaction might result in a
/// database locked error. Use this transaction for read-only operations.
pub async fn create_readonly_transaction(
&self,
vault: &VaultId,
) -> Result<Transaction<'static>> {
self.get_connection_pool(vault)
.await?
.begin()
.await
.context("Cannot create transaction")
}
async fn get_vault_pools(&self, vault: &VaultId) -> Result<VaultPools> {
Self::validate_vault_id(vault)?;
pub async fn create_write_transaction(&self, vault: &VaultId) -> Result<Transaction<'static>> {
let mut transaction = self.create_readonly_transaction(vault).await?;
// Get or create the VaultPool entry. The global lock is held only
// long enough for a HashMap lookup/insert — never across
// create_vault_database.
let vault_pool = {
let mut pools = self.connection_pools.lock().await;
pools
.entry(vault.clone())
.or_insert_with(|| {
Arc::new(VaultPool {
cell: Arc::new(OnceCell::new()),
last_accessed: Mutex::new(Instant::now()),
})
})
.clone()
};
// sqlx doesn't support immediate transactions for sqlite: https://github.com/launchbadge/sqlx/issues/481
sqlx::query!("END; BEGIN IMMEDIATE;")
.execute(&mut *transaction)
// OnceCell::get_or_try_init guarantees exactly-once
// initialization: concurrent callers for the same vault wait
// here; callers for other vaults are not blocked.
let config = self.config.clone();
let vault_clone = vault.clone();
let pools = vault_pool
.cell
.get_or_try_init(|| async {
Self::create_vault_database(&config, &vault_clone).await
})
.await?;
Ok(transaction)
*vault_pool.last_accessed.lock().await = Instant::now();
Ok(pools.clone())
}
/// Return the reader pool for read-only queries.
async fn get_connection_pool(&self, vault: &VaultId) -> Result<Pool<Sqlite>> {
Ok(self.get_vault_pools(vault).await?.reader)
}
pub async fn create_write_transaction(&self, vault: &VaultId) -> Result<WriteTransaction> {
let write_lock = {
let mut locks = self.write_locks.lock().await;
locks
.entry(vault.clone())
.or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
.clone()
};
let write_guard = write_lock.lock_owned().await;
let pools = self.get_vault_pools(vault).await?;
WriteTransaction::new(&pools.writer, write_guard).await
}
/// Return the latest state of all documents in the vault
pub async fn get_latest_documents(
&self,
vault: &VaultId,
transaction: Option<&mut Transaction<'_>>,
connection: Option<&mut SqliteConnection>,
) -> Result<Vec<DocumentVersionWithoutContent>> {
let query = sqlx::query!(
r#"
@ -198,8 +408,8 @@ impl Database {
"#,
);
if let Some(transaction) = transaction {
query.fetch_all(&mut **transaction).await
if let Some(conn) = connection {
query.fetch_all(&mut *conn).await
} else {
query
.fetch_all(&self.get_connection_pool(vault).await?)
@ -216,9 +426,7 @@ impl Database {
is_deleted: row.is_deleted,
user_id: row.user_id,
device_id: row.device_id,
content_size: row
.content_size
.expect("Content size can't be null but sqlx can't infer it"),
content_size: row.content_size.unwrap_or(0),
})
.collect()
})
@ -230,7 +438,7 @@ impl Database {
&self,
vault: &VaultId,
vault_update_id: VaultUpdateId,
transaction: Option<&mut Transaction<'_>>,
connection: Option<&mut SqliteConnection>,
) -> Result<Vec<DocumentVersionWithoutContent>> {
let query = sqlx::query!(
r#"
@ -250,8 +458,8 @@ impl Database {
vault_update_id
);
if let Some(transaction) = transaction {
query.fetch_all(&mut **transaction).await
if let Some(conn) = connection {
query.fetch_all(&mut *conn).await
} else {
query
.fetch_all(&self.get_connection_pool(vault).await?)
@ -270,9 +478,7 @@ impl Database {
is_deleted: row.is_deleted,
user_id: row.user_id,
device_id: row.device_id,
content_size: row
.content_size
.expect("Content size can't be null but sqlx can't infer it"),
content_size: row.content_size.unwrap_or(0),
})
.collect()
})
@ -281,7 +487,7 @@ impl Database {
pub async fn get_max_update_id_in_vault(
&self,
vault: &VaultId,
transaction: Option<&mut Transaction<'_>>,
connection: Option<&mut SqliteConnection>,
) -> Result<i64> {
let query = sqlx::query!(
r#"
@ -290,8 +496,8 @@ impl Database {
"#,
);
if let Some(transaction) = transaction {
query.fetch_one(&mut **transaction).await
if let Some(conn) = connection {
query.fetch_one(&mut *conn).await
} else {
query
.fetch_one(&self.get_connection_pool(vault).await?)
@ -301,11 +507,11 @@ impl Database {
.context("Cannot fetch max update id in vault")
}
pub async fn get_latest_document_by_path(
pub async fn get_latest_non_deleted_document_by_path(
&self,
vault: &VaultId,
relative_path: &str,
transaction: Option<&mut Transaction<'_>>,
connection: Option<&mut SqliteConnection>,
) -> Result<Option<StoredDocumentVersion>> {
let query = sqlx::query_as!(
StoredDocumentVersion,
@ -330,8 +536,8 @@ impl Database {
relative_path
);
if let Some(transaction) = transaction {
query.fetch_optional(&mut **transaction).await
if let Some(conn) = connection {
query.fetch_optional(&mut *conn).await
} else {
query
.fetch_optional(&self.get_connection_pool(vault).await?)
@ -344,7 +550,7 @@ impl Database {
&self,
vault: &VaultId,
document_id: &DocumentId,
transaction: Option<&mut Transaction<'_>>,
connection: Option<&mut SqliteConnection>,
) -> Result<Option<StoredDocumentVersion>> {
let document_id = document_id.as_hyphenated();
let query = sqlx::query_as!(
@ -366,8 +572,8 @@ impl Database {
document_id
);
if let Some(transaction) = transaction {
query.fetch_optional(&mut **transaction).await
if let Some(conn) = connection {
query.fetch_optional(&mut *conn).await
} else {
query
.fetch_optional(&self.get_connection_pool(vault).await?)
@ -380,7 +586,7 @@ impl Database {
&self,
vault: &VaultId,
vault_update_id: VaultUpdateId,
transaction: Option<&mut Transaction<'_>>,
connection: Option<&mut SqliteConnection>,
) -> Result<Option<StoredDocumentVersion>> {
let query = sqlx::query_as!(
StoredDocumentVersion,
@ -400,8 +606,8 @@ impl Database {
vault_update_id
);
if let Some(transaction) = transaction {
query.fetch_optional(&mut **transaction).await
if let Some(conn) = connection {
query.fetch_optional(&mut *conn).await
} else {
query
.fetch_optional(&self.get_connection_pool(vault).await?)
@ -415,7 +621,7 @@ impl Database {
&self,
vault_id: &VaultId,
version: &StoredDocumentVersion,
transaction: Option<Transaction<'_>>,
transaction: Option<WriteTransaction>,
) -> Result<()> {
let document_id = version.document_id.as_hyphenated();
let query = sqlx::query!(
@ -428,9 +634,10 @@ impl Database {
content,
is_deleted,
user_id,
device_id
device_id,
has_been_merged
)
values (?, ?, ?, ?, ?, ?, ?, ?)
values (?, ?, ?, ?, ?, ?, ?, ?, ?)
"#,
version.vault_update_id,
document_id,
@ -439,7 +646,8 @@ impl Database {
version.content,
version.is_deleted,
version.user_id,
version.device_id
version.device_id,
version.has_been_merged
);
if let Some(mut transaction) = transaction {
@ -477,38 +685,66 @@ impl Database {
/// Cleanup idle connection pools that haven't been accessed in more than 5 minutes
async fn cleanup_idle_pools(&self) {
let mut pools = self.connection_pools.lock().await;
let now = Instant::now();
let idle_timeout = Duration::from_secs(5 * 60); // 5 minutes
// Collect idle vaults and remove them from the map while holding
// the lock briefly. Close pools OUTSIDE the lock so that
// pool.close().await doesn't block other get_connection_pool calls.
let idle_pools: Vec<(VaultId, Arc<VaultPool>)> = {
let mut pools = self.connection_pools.lock().await;
let now = Instant::now();
// Collect vaults to remove
let vaults_to_remove: Vec<VaultId> = pools
.iter()
.filter(|(_, pool_with_timestamp)| {
now.duration_since(pool_with_timestamp.last_accessed) > idle_timeout
})
.map(|(vault_id, _)| vault_id.clone())
.collect();
let vaults_to_remove: Vec<VaultId> = pools
.iter()
.filter(|(_, vp)| {
// If the lock is contested, the pool is actively used — not idle.
let Ok(last) = vp.last_accessed.try_lock() else {
return false;
};
now.duration_since(*last) > IDLE_POOL_TIMEOUT
})
.map(|(vault_id, _)| vault_id.clone())
.collect();
// Close and remove idle pools
for vault_id in &vaults_to_remove {
if let Some(pool_with_timestamp) = pools.remove(vault_id) {
info!("Closing idle database connection pool for vault `{vault_id}`");
pool_with_timestamp.pool.close().await;
vaults_to_remove
.into_iter()
.filter_map(|id| pools.remove(&id).map(|vp| (id, vp)))
.collect()
};
for (vault_id, vault_pool) in idle_pools {
if let Some(pools) = vault_pool.cell.get() {
// Checkpoint the WAL before closing to reclaim disk space
// and ensure the next open doesn't need a large WAL replay.
// TRUNCATE mode resets the WAL file to zero bytes.
if let Err(e) = sqlx::query("PRAGMA wal_checkpoint(TRUNCATE)")
.execute(&pools.writer)
.await
{
log::warn!("WAL checkpoint failed for vault `{vault_id}`: {e}");
}
info!("Closing idle database connection pools for vault `{vault_id}`");
pools.reader.close().await;
pools.writer.close().await;
}
}
}
/// Start a background task that periodically cleans up idle connection pools
fn start_idle_pool_cleanup(&self) {
fn start_idle_pool_cleanup(&self, mut shutdown: tokio::sync::watch::Receiver<()>) {
let database = self.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60)); // Check every minute
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
interval.tick().await;
database.cleanup_idle_pools().await;
tokio::select! {
_ = interval.tick() => {
database.cleanup_idle_pools().await;
}
_ = shutdown.changed() => {
info!("Idle pool cleanup task shutting down");
break;
}
}
}
});
}

View file

@ -22,7 +22,6 @@ pub struct StoredDocumentVersion {
pub device_id: DeviceId,
#[allow(dead_code)] // This is for manual analysis
pub has_been_merged: bool,
pub idempotency_key: Option<String>,
}
impl PartialEq<Self> for StoredDocumentVersion {

View file

@ -1,35 +1,52 @@
use std::{collections::HashMap, sync::Arc};
use anyhow::Context;
use log::{debug, warn};
use tokio::sync::{Mutex, broadcast};
use super::models::WebSocketServerMessageWithOrigin;
use crate::{
app_state::database::models::VaultId, config::server_config::ServerConfig, errors::server_error,
};
use crate::{app_state::database::models::VaultId, config::server_config::ServerConfig};
#[derive(Debug, Clone)]
pub struct Broadcasts {
max_clients_per_vault: usize,
broadcast_channel_capacity: usize,
tx: Arc<Mutex<HashMap<VaultId, broadcast::Sender<WebSocketServerMessageWithOrigin>>>>,
}
type TxMap = HashMap<VaultId, broadcast::Sender<WebSocketServerMessageWithOrigin>>;
impl Broadcasts {
pub fn new(server_config: &ServerConfig) -> Self {
Self {
max_clients_per_vault: server_config.max_clients_per_vault,
broadcast_channel_capacity: server_config.broadcast_channel_capacity,
tx: Arc::new(Mutex::new(HashMap::new())),
}
}
/// Remove senders for vaults with no active receivers
fn prune_inactive_vaults(tx_map: &mut TxMap) {
tx_map.retain(|_, sender| sender.receiver_count() > 0);
}
pub async fn get_receiver(
&self,
vault: VaultId,
) -> broadcast::Receiver<WebSocketServerMessageWithOrigin> {
let tx = self.get_or_create(vault).await;
max_clients: usize,
) -> Result<broadcast::Receiver<WebSocketServerMessageWithOrigin>, crate::errors::SyncServerError>
{
let mut tx_map = self.tx.lock().await;
Self::prune_inactive_vaults(&mut tx_map);
tx.subscribe()
let sender = tx_map
.entry(vault)
.or_insert_with(|| broadcast::channel(self.broadcast_channel_capacity).0);
if sender.receiver_count() >= max_clients {
return Err(crate::errors::client_error(anyhow::anyhow!(
"Vault has reached the maximum number of clients ({max_clients})"
)));
}
Ok(sender.subscribe())
}
/// Notify all clients (who are subscribed to the vault) about an update.
@ -39,31 +56,20 @@ impl Broadcasts {
vault: VaultId,
document: WebSocketServerMessageWithOrigin,
) {
let tx = self.get_or_create(vault.clone()).await;
let mut tx_map = self.tx.lock().await;
Self::prune_inactive_vaults(&mut tx_map);
if tx.receiver_count() == 0 {
let sender = tx_map
.entry(vault.clone())
.or_insert_with(|| broadcast::channel(self.broadcast_channel_capacity).0);
if sender.receiver_count() == 0 {
debug!("Skipping broadcast, no clients connected for vault `{vault}`");
return;
}
let result = tx
.send(document)
.context("Cannot broadcast server message to websocket listeners")
.map_err(server_error);
if result.is_err() {
warn!("Failed to send message: {result:?}");
if let Err(e) = sender.send(document) {
warn!("Failed to broadcast to vault `{vault}`: {e}");
}
}
async fn get_or_create(
&self,
vault: VaultId,
) -> broadcast::Sender<WebSocketServerMessageWithOrigin> {
let mut tx = self.tx.lock().await;
tx.entry(vault)
.or_insert_with(|| broadcast::channel(self.max_clients_per_vault).0.clone())
.clone()
}
}