Add proper shutdown, rate limits, config validation, cors config, fix dangling cursors, cache regex, merge created texts
This commit is contained in:
parent
4763bc9d04
commit
e15b0f9903
28 changed files with 1277 additions and 464 deletions
|
|
@ -42,7 +42,9 @@ impl Cursors {
|
|||
) {
|
||||
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||
|
||||
let all_device_cursors = vault_to_cursors.entry(vault_id).or_insert_with(Vec::new);
|
||||
let all_device_cursors = vault_to_cursors
|
||||
.entry(vault_id.clone())
|
||||
.or_insert_with(Vec::new);
|
||||
|
||||
all_device_cursors.retain(|c| &c.client_cursors.device_id != device_id);
|
||||
all_device_cursors.push(ClientCursorsWithTimeToLive::new(ClientCursors {
|
||||
|
|
@ -52,7 +54,7 @@ impl Cursors {
|
|||
}));
|
||||
|
||||
drop(vault_to_cursors); // Explicitly drop the lock before broadcasting to avoid deadlock
|
||||
self.broadcast_cursors().await;
|
||||
self.broadcast_cursors_for_vault(&vault_id).await;
|
||||
}
|
||||
|
||||
pub async fn get_cursors(&self, vault_id: &VaultId) -> Vec<ClientCursors> {
|
||||
|
|
@ -69,45 +71,83 @@ impl Cursors {
|
|||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn start_background_task(self) {
|
||||
pub fn start_background_task(self, mut shutdown: tokio::sync::watch::Receiver<()>) {
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
self.remove_expired_cursors().await;
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
tokio::select! {
|
||||
() = tokio::time::sleep(Duration::from_secs(1)) => {
|
||||
self.remove_expired_cursors().await;
|
||||
}
|
||||
Ok(()) = shutdown.changed() => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn remove_expired_cursors(&self) {
|
||||
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||
let changed_vaults: Vec<VaultId> = {
|
||||
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||
|
||||
for (_vault_id, cursors) in vault_to_cursors.iter_mut() {
|
||||
cursors.retain(|cursor| !cursor.is_expired(self.config.cursor_timeout));
|
||||
let mut changed = Vec::new();
|
||||
for (vault_id, cursors) in vault_to_cursors.iter_mut() {
|
||||
let before = cursors.len();
|
||||
cursors.retain(|cursor| !cursor.is_expired(self.config.cursor_timeout));
|
||||
if cursors.len() != before {
|
||||
changed.push(vault_id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Remove empty vault entries to prevent unbounded growth
|
||||
vault_to_cursors.retain(|_, cursors| !cursors.is_empty());
|
||||
|
||||
changed
|
||||
};
|
||||
|
||||
for vault_id in &changed_vaults {
|
||||
self.broadcast_cursors_for_vault(vault_id).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn broadcast_cursors(&self) {
|
||||
let vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||
async fn broadcast_cursors_for_vault(&self, vault_id: &VaultId) {
|
||||
let client_cursors: Vec<ClientCursors> = {
|
||||
let vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||
vault_to_cursors
|
||||
.get(vault_id)
|
||||
.map(|cursors| cursors.iter().map(|c| c.client_cursors.clone()).collect())
|
||||
.unwrap_or_default()
|
||||
};
|
||||
|
||||
for (vault_id, cursors) in vault_to_cursors.iter() {
|
||||
self.broadcasts
|
||||
.send_document_update(
|
||||
vault_id.clone(),
|
||||
WebSocketServerMessageWithOrigin::new(WebSocketServerMessage::CursorPositions(
|
||||
CursorPositionFromServer {
|
||||
clients: cursors.iter().map(|c| c.client_cursors.clone()).collect(),
|
||||
},
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
self.broadcasts
|
||||
.send_document_update(
|
||||
vault_id.clone(),
|
||||
WebSocketServerMessageWithOrigin::new(WebSocketServerMessage::CursorPositions(
|
||||
CursorPositionFromServer {
|
||||
clients: client_cursors,
|
||||
},
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn remove_cursors_of_device(&self, vault_id: &str, device_id: &str) {
|
||||
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||
pub async fn remove_cursors_of_device(&self, vault_id: &VaultId, device_id: &DeviceId) {
|
||||
let changed = {
|
||||
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||
|
||||
if let Some(cursors) = vault_to_cursors.get_mut(vault_id) {
|
||||
cursors.retain(|c| c.client_cursors.device_id != device_id);
|
||||
if let Some(cursors) = vault_to_cursors.get_mut(vault_id) {
|
||||
let before = cursors.len();
|
||||
cursors.retain(|c| c.client_cursors.device_id != *device_id);
|
||||
let changed = cursors.len() != before;
|
||||
if cursors.is_empty() {
|
||||
vault_to_cursors.remove(vault_id);
|
||||
}
|
||||
changed
|
||||
} else {
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
if changed {
|
||||
self.broadcast_cursors_for_vault(vault_id).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,8 +9,17 @@ use models::{
|
|||
use sqlx::{ConnectOptions, sqlite::SqliteConnectOptions, types::chrono::Utc};
|
||||
|
||||
pub mod models;
|
||||
use sqlx::{Pool, Sqlite, sqlite::SqlitePoolOptions};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
/// Sentinel error indicating the SQLite database is busy (SQLITE_BUSY).
|
||||
/// Handlers can downcast to this to return 429 instead of 500.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Database is busy")]
|
||||
pub struct WriteBusyError;
|
||||
|
||||
use sqlx::{
|
||||
Pool, Sqlite, pool::PoolConnection, sqlite::SqliteConnection, sqlite::SqlitePoolOptions,
|
||||
};
|
||||
use tokio::sync::{Mutex, OnceCell};
|
||||
use tokio::time::Instant;
|
||||
use uuid::fmt::Hyphenated;
|
||||
|
||||
|
|
@ -19,33 +28,154 @@ use super::websocket::{
|
|||
models::{WebSocketServerMessage, WebSocketServerMessageWithOrigin, WebSocketVaultUpdate},
|
||||
};
|
||||
use crate::config::database_config::DatabaseConfig;
|
||||
use crate::consts::IDLE_POOL_TIMEOUT;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct PoolWithTimestamp {
|
||||
pool: Pool<Sqlite>,
|
||||
last_accessed: Instant,
|
||||
/// Holds separate reader and writer pools for a single vault.
|
||||
/// The writer pool has exactly 1 connection so writes never compete
|
||||
/// with reads for pool slots.
|
||||
#[derive(Debug, Clone)]
|
||||
struct VaultPools {
|
||||
reader: Pool<Sqlite>,
|
||||
writer: Pool<Sqlite>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for PoolWithTimestamp {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PoolWithTimestamp")
|
||||
.field("pool", &"Pool<Sqlite>")
|
||||
.field("last_accessed", &self.last_accessed)
|
||||
.finish()
|
||||
}
|
||||
#[derive(Debug)]
|
||||
struct VaultPool {
|
||||
cell: Arc<OnceCell<VaultPools>>,
|
||||
last_accessed: Mutex<Instant>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Database {
|
||||
config: DatabaseConfig,
|
||||
broadcasts: Broadcasts,
|
||||
connection_pools: Arc<Mutex<HashMap<VaultId, PoolWithTimestamp>>>,
|
||||
connection_pools: Arc<Mutex<HashMap<VaultId, Arc<VaultPool>>>>,
|
||||
/// Per-vault write serialization. SQLite allows only one writer at a
|
||||
/// time; `BEGIN IMMEDIATE` on a second connection blocks until the first
|
||||
/// commits (up to `busy_timeout`). Under concurrent load the blocked
|
||||
/// connections consume the pool, starving even read-only requests.
|
||||
/// This mutex moves the wait from the SQLite layer (where it holds a
|
||||
/// pool connection) to the Tokio layer (where it holds nothing).
|
||||
write_locks: Arc<Mutex<HashMap<VaultId, Arc<tokio::sync::Mutex<()>>>>>,
|
||||
}
|
||||
|
||||
pub type Transaction<'a> = sqlx::Transaction<'a, Sqlite>;
|
||||
/// A write transaction backed by a raw `BEGIN IMMEDIATE` instead of sqlx's
|
||||
/// savepoint-based `Transaction`. This avoids the savepoint mismatch caused
|
||||
/// by the old `END; BEGIN IMMEDIATE;` workaround.
|
||||
///
|
||||
/// Holds an `OwnedMutexGuard` that serializes write transactions per vault
|
||||
/// at the application level (see `Database::write_locks`). The guard is
|
||||
/// released when the transaction is committed, rolled back, or dropped.
|
||||
pub struct WriteTransaction {
|
||||
conn: Option<PoolConnection<Sqlite>>,
|
||||
_write_guard: tokio::sync::OwnedMutexGuard<()>,
|
||||
}
|
||||
|
||||
impl WriteTransaction {
|
||||
async fn new(pool: &Pool<Sqlite>, write_guard: tokio::sync::OwnedMutexGuard<()>) -> Result<Self> {
|
||||
let mut conn = pool
|
||||
.acquire()
|
||||
.await
|
||||
.context("Cannot acquire connection for write transaction")?;
|
||||
if let Err(e) = sqlx::query("BEGIN IMMEDIATE")
|
||||
.execute(&mut *conn)
|
||||
.await
|
||||
{
|
||||
let is_busy = match &e {
|
||||
sqlx::Error::Database(db_err) => {
|
||||
// SQLITE_BUSY base code is 5. Extended codes share base 5.
|
||||
let busy_by_code = db_err.code().is_some_and(|c| {
|
||||
c.parse::<u32>().is_ok_and(|n| n & 0xFF == 5)
|
||||
});
|
||||
busy_by_code || db_err.message().contains("database is locked")
|
||||
}
|
||||
_ => false,
|
||||
};
|
||||
if is_busy {
|
||||
return Err(WriteBusyError.into());
|
||||
}
|
||||
return Err(e).context("Cannot begin immediate transaction");
|
||||
}
|
||||
Ok(Self { conn: Some(conn), _write_guard: write_guard })
|
||||
}
|
||||
|
||||
pub async fn commit(mut self) -> Result<()> {
|
||||
if let Some(mut conn) = self.conn.take() {
|
||||
sqlx::query("COMMIT")
|
||||
.execute(&mut *conn)
|
||||
.await
|
||||
.context("Failed to commit transaction")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn rollback(mut self) -> Result<()> {
|
||||
if let Some(mut conn) = self.conn.take() {
|
||||
sqlx::query("ROLLBACK")
|
||||
.execute(&mut *conn)
|
||||
.await
|
||||
.context("Failed to rollback transaction")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WriteTransaction {
|
||||
fn drop(&mut self) {
|
||||
if self.conn.is_some() {
|
||||
// The connection is returned to the pool with an open transaction.
|
||||
// The pool's `before_acquire` hook issues a ROLLBACK before
|
||||
// handing it to the next consumer, so no async work is needed
|
||||
// here. If the pool is being shut down, SQLite itself rolls back
|
||||
// uncommitted transactions when the connection closes.
|
||||
log::warn!("WriteTransaction dropped without commit or rollback");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for WriteTransaction {
|
||||
type Target = SqliteConnection;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.conn
|
||||
.as_ref()
|
||||
.expect("BUG: WriteTransaction dereferenced after being consumed")
|
||||
.deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::DerefMut for WriteTransaction {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.conn
|
||||
.as_mut()
|
||||
.expect("BUG: WriteTransaction dereferenced after being consumed")
|
||||
.deref_mut()
|
||||
}
|
||||
}
|
||||
|
||||
/// Ensure the connection has no leftover open transaction (e.g. from a
|
||||
/// `WriteTransaction` that was dropped without commit/rollback). ROLLBACK
|
||||
/// is a harmless no-op if no transaction is active.
|
||||
fn rollback_before_acquire(
|
||||
conn: &mut SqliteConnection,
|
||||
_meta: sqlx::pool::PoolConnectionMetadata,
|
||||
) -> futures::future::BoxFuture<'_, Result<bool, sqlx::Error>> {
|
||||
Box::pin(async move {
|
||||
if let Err(e) = sqlx::query("ROLLBACK").execute(&mut *conn).await {
|
||||
// "cannot rollback - no transaction is active" is the common
|
||||
// case (connection returned cleanly). Only unexpected errors
|
||||
// deserve attention.
|
||||
log::debug!("before_acquire ROLLBACK failed: {e}");
|
||||
}
|
||||
Ok(true)
|
||||
})
|
||||
}
|
||||
|
||||
impl Database {
|
||||
pub async fn try_new(config: &DatabaseConfig, broadcasts: &Broadcasts) -> Result<Self> {
|
||||
pub async fn try_new(
|
||||
config: &DatabaseConfig,
|
||||
broadcasts: &Broadcasts,
|
||||
shutdown: tokio::sync::watch::Receiver<()>,
|
||||
) -> Result<Self> {
|
||||
tokio::fs::create_dir_all(&config.databases_directory_path)
|
||||
.await
|
||||
.with_context(|| {
|
||||
|
|
@ -70,24 +200,29 @@ impl Database {
|
|||
.trim_end_matches(".sqlite")
|
||||
.to_owned();
|
||||
|
||||
let pool = Self::create_vault_database(config, &vault).await?;
|
||||
Self::validate_vault_id(&vault)?;
|
||||
|
||||
let pools = Self::create_vault_database(config, &vault).await?;
|
||||
let cell = Arc::new(OnceCell::new());
|
||||
cell.set(pools).expect("cell is new");
|
||||
connection_pools.insert(
|
||||
vault.clone(),
|
||||
PoolWithTimestamp {
|
||||
pool,
|
||||
last_accessed: Instant::now(),
|
||||
},
|
||||
Arc::new(VaultPool {
|
||||
cell,
|
||||
last_accessed: Mutex::new(Instant::now()),
|
||||
}),
|
||||
);
|
||||
}
|
||||
info!("Database migrations applied");
|
||||
|
||||
let database = Self {
|
||||
config: config.clone(),
|
||||
connection_pools: Arc::new(Mutex::new(connection_pools)),
|
||||
broadcasts: broadcasts.clone(),
|
||||
write_locks: Arc::new(Mutex::new(HashMap::new())),
|
||||
};
|
||||
|
||||
// Start background task to cleanup idle connection pools
|
||||
database.start_idle_pool_cleanup();
|
||||
database.start_idle_pool_cleanup(shutdown);
|
||||
|
||||
Ok(database)
|
||||
}
|
||||
|
|
@ -95,92 +230,167 @@ impl Database {
|
|||
async fn create_vault_database(
|
||||
config: &DatabaseConfig,
|
||||
vault: &VaultId,
|
||||
) -> Result<Pool<Sqlite>> {
|
||||
) -> Result<VaultPools> {
|
||||
let file_name = config
|
||||
.databases_directory_path
|
||||
.join(format!("{vault}.sqlite"));
|
||||
|
||||
let connection_options = SqliteConnectOptions::new()
|
||||
// Database-level PRAGMAs (auto_vacuum, journal_mode) require a write
|
||||
// lock and persist across connections. Set them once with a dedicated
|
||||
// init connection so pool connections never need the write lock just to
|
||||
// open.
|
||||
let init_options = SqliteConnectOptions::new()
|
||||
.filename(file_name.clone())
|
||||
.create_if_missing(true)
|
||||
.auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Full)
|
||||
.busy_timeout(Duration::from_secs(3600))
|
||||
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
|
||||
.log_slow_statements(log::LevelFilter::Warn, Duration::from_secs(30));
|
||||
.auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental)
|
||||
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal);
|
||||
|
||||
let pool = SqlitePoolOptions::new()
|
||||
// Run migrations on a dedicated connection, NOT through the pool.
|
||||
// The pool's `before_acquire` hook issues ROLLBACK on every checkout,
|
||||
// which can roll back the migration's bookkeeping transaction (the
|
||||
// _sqlx_migrations INSERT) while the DDL (ALTER TABLE) has already
|
||||
// auto-committed — leaving the migration in a dirty state.
|
||||
//
|
||||
// Uses `run_direct` instead of `run` because `run` takes
|
||||
// `impl Acquire<'_>`, whose lifetime bound prevents the enclosing
|
||||
// future from satisfying the `Send` requirement of axum handlers.
|
||||
let mut init_conn = sqlx::SqliteConnection::connect_with(&init_options).await?;
|
||||
sqlx::migrate!("src/app_state/database/migrations")
|
||||
.run_direct(&mut init_conn)
|
||||
.await
|
||||
.context("Cannot run pending migrations")?;
|
||||
drop(init_conn);
|
||||
|
||||
// Per-connection PRAGMAs shared by both reader and writer pools.
|
||||
// journal_mode = WAL is a no-op on an already-WAL database.
|
||||
let base_options = SqliteConnectOptions::new()
|
||||
.filename(file_name.clone())
|
||||
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
|
||||
.busy_timeout(Duration::from_secs(30))
|
||||
.log_slow_statements(log::LevelFilter::Warn, Duration::from_secs(30))
|
||||
// In WAL mode, NORMAL is safe: data survives OS crashes, only the
|
||||
// last transaction can be lost on power failure. The default FULL
|
||||
// forces an extra fsync() per commit, roughly halving write throughput.
|
||||
.pragma("synchronous", "NORMAL")
|
||||
// 16 MB page cache per connection (negative = KiB). Reduces disk
|
||||
// reads for the latest_document_versions GROUP BY view.
|
||||
.pragma("cache_size", "-16384")
|
||||
// Memory-mapped I/O avoids read() syscalls. SQLite falls back to
|
||||
// regular I/O for writes and beyond the mapped region. 256 MB is
|
||||
// conservative; the OS handles actual memory pressure.
|
||||
.pragma("mmap_size", "268435456")
|
||||
// Keep temp tables and sort spillovers in memory instead of temp files.
|
||||
.pragma("temp_store", "MEMORY")
|
||||
// Cap WAL file growth at 64 MB. Without this, the WAL can grow
|
||||
// unbounded during heavy write bursts (e.g. E2E tests with many
|
||||
// concurrent clients). SQLite truncates to this size on checkpoint.
|
||||
.pragma("journal_size_limit", "67108864");
|
||||
|
||||
// Reader pool: multiple connections for concurrent reads.
|
||||
let reader = SqlitePoolOptions::new()
|
||||
.max_connections(config.max_connections_per_vault)
|
||||
.acquire_slow_threshold(Duration::from_secs(30))
|
||||
.test_before_acquire(true)
|
||||
.connect_with(connection_options)
|
||||
// Disabled: the health-check query is subject to busy_timeout
|
||||
// and blocks all connection checkouts when a write is active,
|
||||
// starving the pool for up to 30s even for simple reads.
|
||||
// The before_acquire ROLLBACK hook is sufficient for cleanup.
|
||||
.test_before_acquire(false)
|
||||
.before_acquire(rollback_before_acquire)
|
||||
.connect_with(base_options.clone())
|
||||
.await
|
||||
.with_context(|| format!("Cannot open database at `{}`", file_name.display()))?;
|
||||
.with_context(|| format!("Cannot open reader pool at `{}`", file_name.display()))?;
|
||||
|
||||
Self::run_migrations(&pool).await?;
|
||||
// Writer pool: exactly 1 connection, dedicated to writes.
|
||||
// Since the Tokio mutex already serializes writers per vault, this
|
||||
// single connection is never contended. Separating it from the
|
||||
// reader pool ensures writes never compete with reads for pool slots.
|
||||
let writer = SqlitePoolOptions::new()
|
||||
.max_connections(1)
|
||||
.acquire_slow_threshold(Duration::from_secs(30))
|
||||
.test_before_acquire(false)
|
||||
.before_acquire(rollback_before_acquire)
|
||||
.connect_with(base_options)
|
||||
.await
|
||||
.with_context(|| format!("Cannot open writer pool at `{}`", file_name.display()))?;
|
||||
|
||||
Ok(pool)
|
||||
Ok(VaultPools { reader, writer })
|
||||
}
|
||||
|
||||
async fn run_migrations(pool: &Pool<Sqlite>) -> Result<()> {
|
||||
sqlx::migrate!("src/app_state/database/migrations")
|
||||
.run(pool)
|
||||
.await
|
||||
.context("Cannot check for pending migrations")
|
||||
}
|
||||
|
||||
async fn get_connection_pool(&self, vault: &VaultId) -> Result<Pool<Sqlite>> {
|
||||
let mut pools = self.connection_pools.lock().await;
|
||||
|
||||
if !pools.contains_key(vault) {
|
||||
let pool = Self::create_vault_database(&self.config, vault).await?;
|
||||
pools.insert(
|
||||
vault.clone(),
|
||||
PoolWithTimestamp {
|
||||
pool,
|
||||
last_accessed: Instant::now(),
|
||||
},
|
||||
fn validate_vault_id(vault: &VaultId) -> Result<()> {
|
||||
if vault.is_empty() {
|
||||
anyhow::bail!("Vault ID must not be empty");
|
||||
}
|
||||
if vault.contains('/')
|
||||
|| vault.contains('\\')
|
||||
|| vault.contains("..")
|
||||
|| vault.contains('\0')
|
||||
{
|
||||
anyhow::bail!(
|
||||
"Invalid vault ID: must not contain path separators, '..', or null bytes"
|
||||
);
|
||||
}
|
||||
|
||||
let pool_with_timestamp = pools
|
||||
.get_mut(vault)
|
||||
.expect("Pool was just inserted or already exists");
|
||||
|
||||
// Update last accessed time
|
||||
pool_with_timestamp.last_accessed = Instant::now();
|
||||
|
||||
Ok(pool_with_timestamp.pool.clone())
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Attempting to write from this transaction might result in a
|
||||
/// database locked error. Use this transaction for read-only operations.
|
||||
pub async fn create_readonly_transaction(
|
||||
&self,
|
||||
vault: &VaultId,
|
||||
) -> Result<Transaction<'static>> {
|
||||
self.get_connection_pool(vault)
|
||||
.await?
|
||||
.begin()
|
||||
.await
|
||||
.context("Cannot create transaction")
|
||||
}
|
||||
async fn get_vault_pools(&self, vault: &VaultId) -> Result<VaultPools> {
|
||||
Self::validate_vault_id(vault)?;
|
||||
|
||||
pub async fn create_write_transaction(&self, vault: &VaultId) -> Result<Transaction<'static>> {
|
||||
let mut transaction = self.create_readonly_transaction(vault).await?;
|
||||
// Get or create the VaultPool entry. The global lock is held only
|
||||
// long enough for a HashMap lookup/insert — never across
|
||||
// create_vault_database.
|
||||
let vault_pool = {
|
||||
let mut pools = self.connection_pools.lock().await;
|
||||
pools
|
||||
.entry(vault.clone())
|
||||
.or_insert_with(|| {
|
||||
Arc::new(VaultPool {
|
||||
cell: Arc::new(OnceCell::new()),
|
||||
last_accessed: Mutex::new(Instant::now()),
|
||||
})
|
||||
})
|
||||
.clone()
|
||||
};
|
||||
|
||||
// sqlx doesn't support immediate transactions for sqlite: https://github.com/launchbadge/sqlx/issues/481
|
||||
sqlx::query!("END; BEGIN IMMEDIATE;")
|
||||
.execute(&mut *transaction)
|
||||
// OnceCell::get_or_try_init guarantees exactly-once
|
||||
// initialization: concurrent callers for the same vault wait
|
||||
// here; callers for other vaults are not blocked.
|
||||
let config = self.config.clone();
|
||||
let vault_clone = vault.clone();
|
||||
let pools = vault_pool
|
||||
.cell
|
||||
.get_or_try_init(|| async {
|
||||
Self::create_vault_database(&config, &vault_clone).await
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(transaction)
|
||||
*vault_pool.last_accessed.lock().await = Instant::now();
|
||||
Ok(pools.clone())
|
||||
}
|
||||
|
||||
/// Return the reader pool for read-only queries.
|
||||
async fn get_connection_pool(&self, vault: &VaultId) -> Result<Pool<Sqlite>> {
|
||||
Ok(self.get_vault_pools(vault).await?.reader)
|
||||
}
|
||||
|
||||
pub async fn create_write_transaction(&self, vault: &VaultId) -> Result<WriteTransaction> {
|
||||
let write_lock = {
|
||||
let mut locks = self.write_locks.lock().await;
|
||||
locks
|
||||
.entry(vault.clone())
|
||||
.or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
|
||||
.clone()
|
||||
};
|
||||
let write_guard = write_lock.lock_owned().await;
|
||||
let pools = self.get_vault_pools(vault).await?;
|
||||
WriteTransaction::new(&pools.writer, write_guard).await
|
||||
}
|
||||
|
||||
/// Return the latest state of all documents in the vault
|
||||
pub async fn get_latest_documents(
|
||||
&self,
|
||||
vault: &VaultId,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
connection: Option<&mut SqliteConnection>,
|
||||
) -> Result<Vec<DocumentVersionWithoutContent>> {
|
||||
let query = sqlx::query!(
|
||||
r#"
|
||||
|
|
@ -198,8 +408,8 @@ impl Database {
|
|||
"#,
|
||||
);
|
||||
|
||||
if let Some(transaction) = transaction {
|
||||
query.fetch_all(&mut **transaction).await
|
||||
if let Some(conn) = connection {
|
||||
query.fetch_all(&mut *conn).await
|
||||
} else {
|
||||
query
|
||||
.fetch_all(&self.get_connection_pool(vault).await?)
|
||||
|
|
@ -216,9 +426,7 @@ impl Database {
|
|||
is_deleted: row.is_deleted,
|
||||
user_id: row.user_id,
|
||||
device_id: row.device_id,
|
||||
content_size: row
|
||||
.content_size
|
||||
.expect("Content size can't be null but sqlx can't infer it"),
|
||||
content_size: row.content_size.unwrap_or(0),
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
|
|
@ -230,7 +438,7 @@ impl Database {
|
|||
&self,
|
||||
vault: &VaultId,
|
||||
vault_update_id: VaultUpdateId,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
connection: Option<&mut SqliteConnection>,
|
||||
) -> Result<Vec<DocumentVersionWithoutContent>> {
|
||||
let query = sqlx::query!(
|
||||
r#"
|
||||
|
|
@ -250,8 +458,8 @@ impl Database {
|
|||
vault_update_id
|
||||
);
|
||||
|
||||
if let Some(transaction) = transaction {
|
||||
query.fetch_all(&mut **transaction).await
|
||||
if let Some(conn) = connection {
|
||||
query.fetch_all(&mut *conn).await
|
||||
} else {
|
||||
query
|
||||
.fetch_all(&self.get_connection_pool(vault).await?)
|
||||
|
|
@ -270,9 +478,7 @@ impl Database {
|
|||
is_deleted: row.is_deleted,
|
||||
user_id: row.user_id,
|
||||
device_id: row.device_id,
|
||||
content_size: row
|
||||
.content_size
|
||||
.expect("Content size can't be null but sqlx can't infer it"),
|
||||
content_size: row.content_size.unwrap_or(0),
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
|
|
@ -281,7 +487,7 @@ impl Database {
|
|||
pub async fn get_max_update_id_in_vault(
|
||||
&self,
|
||||
vault: &VaultId,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
connection: Option<&mut SqliteConnection>,
|
||||
) -> Result<i64> {
|
||||
let query = sqlx::query!(
|
||||
r#"
|
||||
|
|
@ -290,8 +496,8 @@ impl Database {
|
|||
"#,
|
||||
);
|
||||
|
||||
if let Some(transaction) = transaction {
|
||||
query.fetch_one(&mut **transaction).await
|
||||
if let Some(conn) = connection {
|
||||
query.fetch_one(&mut *conn).await
|
||||
} else {
|
||||
query
|
||||
.fetch_one(&self.get_connection_pool(vault).await?)
|
||||
|
|
@ -301,11 +507,11 @@ impl Database {
|
|||
.context("Cannot fetch max update id in vault")
|
||||
}
|
||||
|
||||
pub async fn get_latest_document_by_path(
|
||||
pub async fn get_latest_non_deleted_document_by_path(
|
||||
&self,
|
||||
vault: &VaultId,
|
||||
relative_path: &str,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
connection: Option<&mut SqliteConnection>,
|
||||
) -> Result<Option<StoredDocumentVersion>> {
|
||||
let query = sqlx::query_as!(
|
||||
StoredDocumentVersion,
|
||||
|
|
@ -330,8 +536,8 @@ impl Database {
|
|||
relative_path
|
||||
);
|
||||
|
||||
if let Some(transaction) = transaction {
|
||||
query.fetch_optional(&mut **transaction).await
|
||||
if let Some(conn) = connection {
|
||||
query.fetch_optional(&mut *conn).await
|
||||
} else {
|
||||
query
|
||||
.fetch_optional(&self.get_connection_pool(vault).await?)
|
||||
|
|
@ -344,7 +550,7 @@ impl Database {
|
|||
&self,
|
||||
vault: &VaultId,
|
||||
document_id: &DocumentId,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
connection: Option<&mut SqliteConnection>,
|
||||
) -> Result<Option<StoredDocumentVersion>> {
|
||||
let document_id = document_id.as_hyphenated();
|
||||
let query = sqlx::query_as!(
|
||||
|
|
@ -366,8 +572,8 @@ impl Database {
|
|||
document_id
|
||||
);
|
||||
|
||||
if let Some(transaction) = transaction {
|
||||
query.fetch_optional(&mut **transaction).await
|
||||
if let Some(conn) = connection {
|
||||
query.fetch_optional(&mut *conn).await
|
||||
} else {
|
||||
query
|
||||
.fetch_optional(&self.get_connection_pool(vault).await?)
|
||||
|
|
@ -380,7 +586,7 @@ impl Database {
|
|||
&self,
|
||||
vault: &VaultId,
|
||||
vault_update_id: VaultUpdateId,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
connection: Option<&mut SqliteConnection>,
|
||||
) -> Result<Option<StoredDocumentVersion>> {
|
||||
let query = sqlx::query_as!(
|
||||
StoredDocumentVersion,
|
||||
|
|
@ -400,8 +606,8 @@ impl Database {
|
|||
vault_update_id
|
||||
);
|
||||
|
||||
if let Some(transaction) = transaction {
|
||||
query.fetch_optional(&mut **transaction).await
|
||||
if let Some(conn) = connection {
|
||||
query.fetch_optional(&mut *conn).await
|
||||
} else {
|
||||
query
|
||||
.fetch_optional(&self.get_connection_pool(vault).await?)
|
||||
|
|
@ -415,7 +621,7 @@ impl Database {
|
|||
&self,
|
||||
vault_id: &VaultId,
|
||||
version: &StoredDocumentVersion,
|
||||
transaction: Option<Transaction<'_>>,
|
||||
transaction: Option<WriteTransaction>,
|
||||
) -> Result<()> {
|
||||
let document_id = version.document_id.as_hyphenated();
|
||||
let query = sqlx::query!(
|
||||
|
|
@ -428,9 +634,10 @@ impl Database {
|
|||
content,
|
||||
is_deleted,
|
||||
user_id,
|
||||
device_id
|
||||
device_id,
|
||||
has_been_merged
|
||||
)
|
||||
values (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
values (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"#,
|
||||
version.vault_update_id,
|
||||
document_id,
|
||||
|
|
@ -439,7 +646,8 @@ impl Database {
|
|||
version.content,
|
||||
version.is_deleted,
|
||||
version.user_id,
|
||||
version.device_id
|
||||
version.device_id,
|
||||
version.has_been_merged
|
||||
);
|
||||
|
||||
if let Some(mut transaction) = transaction {
|
||||
|
|
@ -477,38 +685,66 @@ impl Database {
|
|||
|
||||
/// Cleanup idle connection pools that haven't been accessed in more than 5 minutes
|
||||
async fn cleanup_idle_pools(&self) {
|
||||
let mut pools = self.connection_pools.lock().await;
|
||||
let now = Instant::now();
|
||||
let idle_timeout = Duration::from_secs(5 * 60); // 5 minutes
|
||||
// Collect idle vaults and remove them from the map while holding
|
||||
// the lock briefly. Close pools OUTSIDE the lock so that
|
||||
// pool.close().await doesn't block other get_connection_pool calls.
|
||||
let idle_pools: Vec<(VaultId, Arc<VaultPool>)> = {
|
||||
let mut pools = self.connection_pools.lock().await;
|
||||
let now = Instant::now();
|
||||
|
||||
// Collect vaults to remove
|
||||
let vaults_to_remove: Vec<VaultId> = pools
|
||||
.iter()
|
||||
.filter(|(_, pool_with_timestamp)| {
|
||||
now.duration_since(pool_with_timestamp.last_accessed) > idle_timeout
|
||||
})
|
||||
.map(|(vault_id, _)| vault_id.clone())
|
||||
.collect();
|
||||
let vaults_to_remove: Vec<VaultId> = pools
|
||||
.iter()
|
||||
.filter(|(_, vp)| {
|
||||
// If the lock is contested, the pool is actively used — not idle.
|
||||
let Ok(last) = vp.last_accessed.try_lock() else {
|
||||
return false;
|
||||
};
|
||||
now.duration_since(*last) > IDLE_POOL_TIMEOUT
|
||||
})
|
||||
.map(|(vault_id, _)| vault_id.clone())
|
||||
.collect();
|
||||
|
||||
// Close and remove idle pools
|
||||
for vault_id in &vaults_to_remove {
|
||||
if let Some(pool_with_timestamp) = pools.remove(vault_id) {
|
||||
info!("Closing idle database connection pool for vault `{vault_id}`");
|
||||
pool_with_timestamp.pool.close().await;
|
||||
vaults_to_remove
|
||||
.into_iter()
|
||||
.filter_map(|id| pools.remove(&id).map(|vp| (id, vp)))
|
||||
.collect()
|
||||
};
|
||||
|
||||
for (vault_id, vault_pool) in idle_pools {
|
||||
if let Some(pools) = vault_pool.cell.get() {
|
||||
// Checkpoint the WAL before closing to reclaim disk space
|
||||
// and ensure the next open doesn't need a large WAL replay.
|
||||
// TRUNCATE mode resets the WAL file to zero bytes.
|
||||
if let Err(e) = sqlx::query("PRAGMA wal_checkpoint(TRUNCATE)")
|
||||
.execute(&pools.writer)
|
||||
.await
|
||||
{
|
||||
log::warn!("WAL checkpoint failed for vault `{vault_id}`: {e}");
|
||||
}
|
||||
info!("Closing idle database connection pools for vault `{vault_id}`");
|
||||
pools.reader.close().await;
|
||||
pools.writer.close().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Start a background task that periodically cleans up idle connection pools
|
||||
fn start_idle_pool_cleanup(&self) {
|
||||
fn start_idle_pool_cleanup(&self, mut shutdown: tokio::sync::watch::Receiver<()>) {
|
||||
let database = self.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(60)); // Check every minute
|
||||
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
database.cleanup_idle_pools().await;
|
||||
tokio::select! {
|
||||
_ = interval.tick() => {
|
||||
database.cleanup_idle_pools().await;
|
||||
}
|
||||
_ = shutdown.changed() => {
|
||||
info!("Idle pool cleanup task shutting down");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ pub struct StoredDocumentVersion {
|
|||
pub device_id: DeviceId,
|
||||
#[allow(dead_code)] // This is for manual analysis
|
||||
pub has_been_merged: bool,
|
||||
pub idempotency_key: Option<String>,
|
||||
}
|
||||
|
||||
impl PartialEq<Self> for StoredDocumentVersion {
|
||||
|
|
|
|||
|
|
@ -1,35 +1,52 @@
|
|||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use anyhow::Context;
|
||||
use log::{debug, warn};
|
||||
use tokio::sync::{Mutex, broadcast};
|
||||
|
||||
use super::models::WebSocketServerMessageWithOrigin;
|
||||
use crate::{
|
||||
app_state::database::models::VaultId, config::server_config::ServerConfig, errors::server_error,
|
||||
};
|
||||
use crate::{app_state::database::models::VaultId, config::server_config::ServerConfig};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Broadcasts {
|
||||
max_clients_per_vault: usize,
|
||||
broadcast_channel_capacity: usize,
|
||||
tx: Arc<Mutex<HashMap<VaultId, broadcast::Sender<WebSocketServerMessageWithOrigin>>>>,
|
||||
}
|
||||
|
||||
type TxMap = HashMap<VaultId, broadcast::Sender<WebSocketServerMessageWithOrigin>>;
|
||||
|
||||
impl Broadcasts {
|
||||
pub fn new(server_config: &ServerConfig) -> Self {
|
||||
Self {
|
||||
max_clients_per_vault: server_config.max_clients_per_vault,
|
||||
broadcast_channel_capacity: server_config.broadcast_channel_capacity,
|
||||
tx: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove senders for vaults with no active receivers
|
||||
fn prune_inactive_vaults(tx_map: &mut TxMap) {
|
||||
tx_map.retain(|_, sender| sender.receiver_count() > 0);
|
||||
}
|
||||
|
||||
pub async fn get_receiver(
|
||||
&self,
|
||||
vault: VaultId,
|
||||
) -> broadcast::Receiver<WebSocketServerMessageWithOrigin> {
|
||||
let tx = self.get_or_create(vault).await;
|
||||
max_clients: usize,
|
||||
) -> Result<broadcast::Receiver<WebSocketServerMessageWithOrigin>, crate::errors::SyncServerError>
|
||||
{
|
||||
let mut tx_map = self.tx.lock().await;
|
||||
Self::prune_inactive_vaults(&mut tx_map);
|
||||
|
||||
tx.subscribe()
|
||||
let sender = tx_map
|
||||
.entry(vault)
|
||||
.or_insert_with(|| broadcast::channel(self.broadcast_channel_capacity).0);
|
||||
|
||||
if sender.receiver_count() >= max_clients {
|
||||
return Err(crate::errors::client_error(anyhow::anyhow!(
|
||||
"Vault has reached the maximum number of clients ({max_clients})"
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(sender.subscribe())
|
||||
}
|
||||
|
||||
/// Notify all clients (who are subscribed to the vault) about an update.
|
||||
|
|
@ -39,31 +56,20 @@ impl Broadcasts {
|
|||
vault: VaultId,
|
||||
document: WebSocketServerMessageWithOrigin,
|
||||
) {
|
||||
let tx = self.get_or_create(vault.clone()).await;
|
||||
let mut tx_map = self.tx.lock().await;
|
||||
Self::prune_inactive_vaults(&mut tx_map);
|
||||
|
||||
if tx.receiver_count() == 0 {
|
||||
let sender = tx_map
|
||||
.entry(vault.clone())
|
||||
.or_insert_with(|| broadcast::channel(self.broadcast_channel_capacity).0);
|
||||
|
||||
if sender.receiver_count() == 0 {
|
||||
debug!("Skipping broadcast, no clients connected for vault `{vault}`");
|
||||
return;
|
||||
}
|
||||
|
||||
let result = tx
|
||||
.send(document)
|
||||
.context("Cannot broadcast server message to websocket listeners")
|
||||
.map_err(server_error);
|
||||
|
||||
if result.is_err() {
|
||||
warn!("Failed to send message: {result:?}");
|
||||
if let Err(e) = sender.send(document) {
|
||||
warn!("Failed to broadcast to vault `{vault}`: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_or_create(
|
||||
&self,
|
||||
vault: VaultId,
|
||||
) -> broadcast::Sender<WebSocketServerMessageWithOrigin> {
|
||||
let mut tx = self.tx.lock().await;
|
||||
|
||||
tx.entry(vault)
|
||||
.or_insert_with(|| broadcast::channel(self.max_clients_per_vault).0.clone())
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue