use core::time::Duration; use std::{collections::HashMap, sync::Arc}; use anyhow::{Context as _, Result}; use log::info; use models::{ DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId, VaultUpdateId, }; use sqlx::{ConnectOptions, sqlite::SqliteConnectOptions, types::chrono::Utc}; pub mod models; use sqlx::{Pool, Sqlite, sqlite::SqlitePoolOptions}; use tokio::sync::Mutex; use tokio::time::Instant; use uuid::fmt::Hyphenated; use super::websocket::{ broadcasts::Broadcasts, models::{WebSocketServerMessage, WebSocketServerMessageWithOrigin, WebSocketVaultUpdate}, }; use crate::config::database_config::DatabaseConfig; #[derive(Clone)] struct PoolWithTimestamp { pool: Pool, last_accessed: Instant, } impl std::fmt::Debug for PoolWithTimestamp { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("PoolWithTimestamp") .field("pool", &"Pool") .field("last_accessed", &self.last_accessed) .finish() } } #[derive(Clone, Debug)] pub struct Database { config: DatabaseConfig, broadcasts: Broadcasts, connection_pools: Arc>>, } pub type Transaction<'a> = sqlx::Transaction<'a, Sqlite>; impl Database { pub async fn try_new(config: &DatabaseConfig, broadcasts: &Broadcasts) -> Result { tokio::fs::create_dir_all(&config.databases_directory_path) .await .with_context(|| { format!( "Failed to create databases directory at `{}`", config.databases_directory_path.to_string_lossy() ) })?; let mut connection_pools = std::collections::HashMap::new(); info!("Applying pending database migrations"); let mut entries = tokio::fs::read_dir(&config.databases_directory_path).await?; while let Some(entry) = entries.next_entry().await? { if !entry.file_name().to_string_lossy().ends_with(".sqlite") { continue; } let vault: VaultId = entry .file_name() .to_string_lossy() .trim_end_matches(".sqlite") .to_owned(); let pool = Self::create_vault_database(config, &vault).await?; connection_pools.insert( vault.clone(), PoolWithTimestamp { pool, last_accessed: Instant::now(), }, ); } info!("Database migrations applied"); let database = Self { config: config.clone(), connection_pools: Arc::new(Mutex::new(connection_pools)), broadcasts: broadcasts.clone(), }; // Start background task to cleanup idle connection pools database.start_idle_pool_cleanup(); Ok(database) } async fn create_vault_database( config: &DatabaseConfig, vault: &VaultId, ) -> Result> { let file_name = config .databases_directory_path .join(format!("{vault}.sqlite")); let connection_options = SqliteConnectOptions::new() .filename(file_name.clone()) .create_if_missing(true) .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Full) .busy_timeout(Duration::from_secs(3600)) .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal) .log_slow_statements(log::LevelFilter::Warn, Duration::from_secs(30)); let pool = SqlitePoolOptions::new() .max_connections(config.max_connections_per_vault) .acquire_slow_threshold(Duration::from_secs(30)) .test_before_acquire(true) .connect_with(connection_options) .await .with_context(|| format!("Cannot open database at `{}`", file_name.display()))?; Self::run_migrations(&pool).await?; Ok(pool) } async fn run_migrations(pool: &Pool) -> Result<()> { sqlx::migrate!("src/app_state/database/migrations") .run(pool) .await .context("Cannot check for pending migrations") } async fn get_connection_pool(&self, vault: &VaultId) -> Result> { let mut pools = self.connection_pools.lock().await; if !pools.contains_key(vault) { let pool = Self::create_vault_database(&self.config, vault).await?; pools.insert( vault.clone(), PoolWithTimestamp { pool, last_accessed: Instant::now(), }, ); } let pool_with_timestamp = pools .get_mut(vault) .expect("Pool was just inserted or already exists"); // Update last accessed time pool_with_timestamp.last_accessed = Instant::now(); Ok(pool_with_timestamp.pool.clone()) } /// Attempting to write from this transaction might result in a /// database locked error. Use this transaction for read-only operations. pub async fn create_readonly_transaction( &self, vault: &VaultId, ) -> Result> { self.get_connection_pool(vault) .await? .begin() .await .context("Cannot create transaction") } pub async fn create_write_transaction(&self, vault: &VaultId) -> Result> { let mut transaction = self.create_readonly_transaction(vault).await?; // sqlx doesn't support immediate transactions for sqlite: https://github.com/launchbadge/sqlx/issues/481 sqlx::query!("END; BEGIN IMMEDIATE;") .execute(&mut *transaction) .await?; Ok(transaction) } /// Return the latest state of all documents in the vault pub async fn get_latest_documents( &self, vault: &VaultId, transaction: Option<&mut Transaction<'_>>, ) -> Result> { let query = sqlx::query!( r#" select vault_update_id, document_id as "document_id: Hyphenated", relative_path, updated_date as "updated_date: chrono::DateTime", is_deleted, user_id, device_id, length(content) as "content_size: u64" from latest_document_versions order by vault_update_id "#, ); if let Some(transaction) = transaction { query.fetch_all(&mut **transaction).await } else { query .fetch_all(&self.get_connection_pool(vault).await?) .await } .context("Cannot fetch latest documents") .map(|rows| { rows.into_iter() .map(|row| DocumentVersionWithoutContent { vault_update_id: row.vault_update_id, document_id: row.document_id.into(), relative_path: row.relative_path, updated_date: row.updated_date, is_deleted: row.is_deleted, user_id: row.user_id, device_id: row.device_id, content_size: row .content_size .expect("Content size can't be null but sqlx can't infer it"), }) .collect() }) } /// Return the latest state of all documents (including deleted) in the /// vault which have changed since the given update id pub async fn get_latest_documents_since( &self, vault: &VaultId, vault_update_id: VaultUpdateId, transaction: Option<&mut Transaction<'_>>, ) -> Result> { let query = sqlx::query!( r#" select vault_update_id, document_id as "document_id: Hyphenated", relative_path, updated_date as "updated_date: chrono::DateTime", is_deleted, user_id, device_id, length(content) as "content_size: u64" from latest_document_versions where vault_update_id > ? order by vault_update_id "#, vault_update_id ); if let Some(transaction) = transaction { query.fetch_all(&mut **transaction).await } else { query .fetch_all(&self.get_connection_pool(vault).await?) .await } .with_context(|| { format!("Cannot fetch latest documents since vault_update_id `{vault_update_id}`") }) .map(|rows| { rows.into_iter() .map(|row| DocumentVersionWithoutContent { vault_update_id: row.vault_update_id, document_id: row.document_id.into(), relative_path: row.relative_path, updated_date: row.updated_date, is_deleted: row.is_deleted, user_id: row.user_id, device_id: row.device_id, content_size: row .content_size .expect("Content size can't be null but sqlx can't infer it"), }) .collect() }) } pub async fn get_max_update_id_in_vault( &self, vault: &VaultId, transaction: Option<&mut Transaction<'_>>, ) -> Result { let query = sqlx::query!( r#" select coalesce(max(vault_update_id), 0) as max_vault_update_id from documents "#, ); if let Some(transaction) = transaction { query.fetch_one(&mut **transaction).await } else { query .fetch_one(&self.get_connection_pool(vault).await?) .await } .map(|row| row.max_vault_update_id) .context("Cannot fetch max update id in vault") } pub async fn get_latest_non_deleted_document_by_path( &self, vault: &VaultId, relative_path: &str, transaction: Option<&mut Transaction<'_>>, ) -> Result> { let query = sqlx::query_as!( StoredDocumentVersion, r#" select vault_update_id, document_id as "document_id: Hyphenated", relative_path, updated_date as "updated_date: chrono::DateTime", content, is_deleted, user_id, device_id, has_been_merged from latest_document_versions where relative_path = ? and is_deleted = false order by vault_update_id desc -- `latest_document_versions` only contains a single latest version of each document, however, -- multiple documents can have the same `relative_path`, if they have been deleted. That's -- why we only care about the latest version of the document with the given relative path. limit 1 "#, relative_path ); if let Some(transaction) = transaction { query.fetch_optional(&mut **transaction).await } else { query .fetch_optional(&self.get_connection_pool(vault).await?) .await } .context("Cannot fetch latest document version") } pub async fn get_latest_document( &self, vault: &VaultId, document_id: &DocumentId, transaction: Option<&mut Transaction<'_>>, ) -> Result> { let document_id = document_id.as_hyphenated(); let query = sqlx::query_as!( StoredDocumentVersion, r#" select vault_update_id, document_id as "document_id: Hyphenated", relative_path, updated_date as "updated_date: chrono::DateTime", content, is_deleted, user_id, device_id, has_been_merged from latest_document_versions where document_id = ? "#, document_id ); if let Some(transaction) = transaction { query.fetch_optional(&mut **transaction).await } else { query .fetch_optional(&self.get_connection_pool(vault).await?) .await } .context("Cannot fetch latest document version") } pub async fn get_document_version( &self, vault: &VaultId, vault_update_id: VaultUpdateId, transaction: Option<&mut Transaction<'_>>, ) -> Result> { let query = sqlx::query_as!( StoredDocumentVersion, r#" select vault_update_id, document_id as "document_id: Hyphenated", relative_path, updated_date as "updated_date: chrono::DateTime", content, is_deleted, user_id, device_id, has_been_merged from documents where vault_update_id = ?"#, vault_update_id ); if let Some(transaction) = transaction { query.fetch_optional(&mut **transaction).await } else { query .fetch_optional(&self.get_connection_pool(vault).await?) .await } .context("Cannot fetch document version") } // inserting the document must be the last step of the transaction if there's one pub async fn insert_document_version( &self, vault_id: &VaultId, version: &StoredDocumentVersion, transaction: Option>, ) -> Result<()> { let document_id = version.document_id.as_hyphenated(); let query = sqlx::query!( r#" insert into documents ( vault_update_id, document_id, relative_path, updated_date, content, is_deleted, user_id, device_id ) values (?, ?, ?, ?, ?, ?, ?, ?) "#, version.vault_update_id, document_id, version.relative_path, version.updated_date, version.content, version.is_deleted, version.user_id, version.device_id ); if let Some(mut transaction) = transaction { query .execute(&mut *transaction) .await .context("Cannot insert document version")?; transaction .commit() .await .context("Failed to commit transaction")?; } else { query .execute(&self.get_connection_pool(vault_id).await?) .await .context("Cannot insert document version")?; } self.broadcasts .send_document_update( vault_id.clone(), WebSocketServerMessageWithOrigin::with_origin( version.device_id.clone(), WebSocketServerMessage::VaultUpdate(WebSocketVaultUpdate { documents: vec![version.clone().into()], is_initial_sync: false, }), ), ) .await; Ok(()) } /// Cleanup idle connection pools that haven't been accessed in more than 5 minutes async fn cleanup_idle_pools(&self) { let mut pools = self.connection_pools.lock().await; let now = Instant::now(); let idle_timeout = Duration::from_secs(5 * 60); // 5 minutes // Collect vaults to remove let vaults_to_remove: Vec = pools .iter() .filter(|(_, pool_with_timestamp)| { now.duration_since(pool_with_timestamp.last_accessed) > idle_timeout }) .map(|(vault_id, _)| vault_id.clone()) .collect(); // Close and remove idle pools for vault_id in &vaults_to_remove { if let Some(pool_with_timestamp) = pools.remove(vault_id) { info!("Closing idle database connection pool for vault `{vault_id}`"); pool_with_timestamp.pool.close().await; } } } /// Start a background task that periodically cleans up idle connection pools fn start_idle_pool_cleanup(&self) { let database = self.clone(); tokio::spawn(async move { let mut interval = tokio::time::interval(Duration::from_secs(60)); // Check every minute interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); loop { interval.tick().await; database.cleanup_idle_pools().await; } }); } }