516 lines
17 KiB
Rust
516 lines
17 KiB
Rust
use core::time::Duration;
|
|
use std::{collections::HashMap, sync::Arc};
|
|
|
|
use anyhow::{Context as _, Result};
|
|
use log::info;
|
|
use models::{
|
|
DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId, VaultUpdateId,
|
|
};
|
|
use sqlx::{ConnectOptions, sqlite::SqliteConnectOptions, types::chrono::Utc};
|
|
|
|
pub mod models;
|
|
use sqlx::{Pool, Sqlite, sqlite::SqlitePoolOptions};
|
|
use tokio::sync::Mutex;
|
|
use tokio::time::Instant;
|
|
use uuid::fmt::Hyphenated;
|
|
|
|
use super::websocket::{
|
|
broadcasts::Broadcasts,
|
|
models::{WebSocketServerMessage, WebSocketServerMessageWithOrigin, WebSocketVaultUpdate},
|
|
};
|
|
use crate::config::database_config::DatabaseConfig;
|
|
|
|
#[derive(Clone)]
|
|
struct PoolWithTimestamp {
|
|
pool: Pool<Sqlite>,
|
|
last_accessed: Instant,
|
|
}
|
|
|
|
impl std::fmt::Debug for PoolWithTimestamp {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
f.debug_struct("PoolWithTimestamp")
|
|
.field("pool", &"Pool<Sqlite>")
|
|
.field("last_accessed", &self.last_accessed)
|
|
.finish()
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct Database {
|
|
config: DatabaseConfig,
|
|
broadcasts: Broadcasts,
|
|
connection_pools: Arc<Mutex<HashMap<VaultId, PoolWithTimestamp>>>,
|
|
}
|
|
|
|
pub type Transaction<'a> = sqlx::Transaction<'a, Sqlite>;
|
|
|
|
impl Database {
|
|
pub async fn try_new(config: &DatabaseConfig, broadcasts: &Broadcasts) -> Result<Self> {
|
|
tokio::fs::create_dir_all(&config.databases_directory_path)
|
|
.await
|
|
.with_context(|| {
|
|
format!(
|
|
"Failed to create databases directory at `{}`",
|
|
config.databases_directory_path.to_string_lossy()
|
|
)
|
|
})?;
|
|
|
|
let mut connection_pools = std::collections::HashMap::new();
|
|
|
|
info!("Applying pending database migrations");
|
|
let mut entries = tokio::fs::read_dir(&config.databases_directory_path).await?;
|
|
while let Some(entry) = entries.next_entry().await? {
|
|
if !entry.file_name().to_string_lossy().ends_with(".sqlite") {
|
|
continue;
|
|
}
|
|
|
|
let vault: VaultId = entry
|
|
.file_name()
|
|
.to_string_lossy()
|
|
.trim_end_matches(".sqlite")
|
|
.to_owned();
|
|
|
|
let pool = Self::create_vault_database(config, &vault).await?;
|
|
connection_pools.insert(
|
|
vault.clone(),
|
|
PoolWithTimestamp {
|
|
pool,
|
|
last_accessed: Instant::now(),
|
|
},
|
|
);
|
|
}
|
|
info!("Database migrations applied");
|
|
|
|
let database = Self {
|
|
config: config.clone(),
|
|
connection_pools: Arc::new(Mutex::new(connection_pools)),
|
|
broadcasts: broadcasts.clone(),
|
|
};
|
|
|
|
// Start background task to cleanup idle connection pools
|
|
database.start_idle_pool_cleanup();
|
|
|
|
Ok(database)
|
|
}
|
|
|
|
async fn create_vault_database(
|
|
config: &DatabaseConfig,
|
|
vault: &VaultId,
|
|
) -> Result<Pool<Sqlite>> {
|
|
let file_name = config
|
|
.databases_directory_path
|
|
.join(format!("{vault}.sqlite"));
|
|
|
|
let connection_options = SqliteConnectOptions::new()
|
|
.filename(file_name.clone())
|
|
.create_if_missing(true)
|
|
.auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Full)
|
|
.busy_timeout(Duration::from_secs(3600))
|
|
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
|
|
.log_slow_statements(log::LevelFilter::Warn, Duration::from_secs(30));
|
|
|
|
let pool = SqlitePoolOptions::new()
|
|
.max_connections(config.max_connections_per_vault)
|
|
.acquire_slow_threshold(Duration::from_secs(30))
|
|
.test_before_acquire(true)
|
|
.connect_with(connection_options)
|
|
.await
|
|
.with_context(|| format!("Cannot open database at `{}`", file_name.display()))?;
|
|
|
|
Self::run_migrations(&pool).await?;
|
|
|
|
Ok(pool)
|
|
}
|
|
|
|
async fn run_migrations(pool: &Pool<Sqlite>) -> Result<()> {
|
|
sqlx::migrate!("src/app_state/database/migrations")
|
|
.run(pool)
|
|
.await
|
|
.context("Cannot check for pending migrations")
|
|
}
|
|
|
|
async fn get_connection_pool(&self, vault: &VaultId) -> Result<Pool<Sqlite>> {
|
|
let mut pools = self.connection_pools.lock().await;
|
|
|
|
if !pools.contains_key(vault) {
|
|
let pool = Self::create_vault_database(&self.config, vault).await?;
|
|
pools.insert(
|
|
vault.clone(),
|
|
PoolWithTimestamp {
|
|
pool,
|
|
last_accessed: Instant::now(),
|
|
},
|
|
);
|
|
}
|
|
|
|
let pool_with_timestamp = pools
|
|
.get_mut(vault)
|
|
.expect("Pool was just inserted or already exists");
|
|
|
|
// Update last accessed time
|
|
pool_with_timestamp.last_accessed = Instant::now();
|
|
|
|
Ok(pool_with_timestamp.pool.clone())
|
|
}
|
|
|
|
/// Attempting to write from this transaction might result in a
|
|
/// database locked error. Use this transaction for read-only operations.
|
|
pub async fn create_readonly_transaction(
|
|
&self,
|
|
vault: &VaultId,
|
|
) -> Result<Transaction<'static>> {
|
|
self.get_connection_pool(vault)
|
|
.await?
|
|
.begin()
|
|
.await
|
|
.context("Cannot create transaction")
|
|
}
|
|
|
|
pub async fn create_write_transaction(&self, vault: &VaultId) -> Result<Transaction<'static>> {
|
|
let mut transaction = self.create_readonly_transaction(vault).await?;
|
|
|
|
// sqlx doesn't support immediate transactions for sqlite: https://github.com/launchbadge/sqlx/issues/481
|
|
sqlx::query!("END; BEGIN IMMEDIATE;")
|
|
.execute(&mut *transaction)
|
|
.await?;
|
|
|
|
Ok(transaction)
|
|
}
|
|
|
|
/// Return the latest state of all documents in the vault
|
|
pub async fn get_latest_documents(
|
|
&self,
|
|
vault: &VaultId,
|
|
transaction: Option<&mut Transaction<'_>>,
|
|
) -> Result<Vec<DocumentVersionWithoutContent>> {
|
|
let query = sqlx::query!(
|
|
r#"
|
|
select
|
|
vault_update_id,
|
|
document_id as "document_id: Hyphenated",
|
|
relative_path,
|
|
updated_date as "updated_date: chrono::DateTime<Utc>",
|
|
is_deleted,
|
|
user_id,
|
|
device_id,
|
|
length(content) as "content_size: u64"
|
|
from latest_document_versions
|
|
order by vault_update_id
|
|
"#,
|
|
);
|
|
|
|
if let Some(transaction) = transaction {
|
|
query.fetch_all(&mut **transaction).await
|
|
} else {
|
|
query
|
|
.fetch_all(&self.get_connection_pool(vault).await?)
|
|
.await
|
|
}
|
|
.context("Cannot fetch latest documents")
|
|
.map(|rows| {
|
|
rows.into_iter()
|
|
.map(|row| DocumentVersionWithoutContent {
|
|
vault_update_id: row.vault_update_id,
|
|
document_id: row.document_id.into(),
|
|
relative_path: row.relative_path,
|
|
updated_date: row.updated_date,
|
|
is_deleted: row.is_deleted,
|
|
user_id: row.user_id,
|
|
device_id: row.device_id,
|
|
content_size: row
|
|
.content_size
|
|
.expect("Content size can't be null but sqlx can't infer it"),
|
|
})
|
|
.collect()
|
|
})
|
|
}
|
|
|
|
/// Return the latest state of all documents (including deleted) in the
|
|
/// vault which have changed since the given update id
|
|
pub async fn get_latest_documents_since(
|
|
&self,
|
|
vault: &VaultId,
|
|
vault_update_id: VaultUpdateId,
|
|
transaction: Option<&mut Transaction<'_>>,
|
|
) -> Result<Vec<DocumentVersionWithoutContent>> {
|
|
let query = sqlx::query!(
|
|
r#"
|
|
select
|
|
vault_update_id,
|
|
document_id as "document_id: Hyphenated",
|
|
relative_path,
|
|
updated_date as "updated_date: chrono::DateTime<Utc>",
|
|
is_deleted,
|
|
user_id,
|
|
device_id,
|
|
length(content) as "content_size: u64"
|
|
from latest_document_versions
|
|
where vault_update_id > ?
|
|
order by vault_update_id
|
|
"#,
|
|
vault_update_id
|
|
);
|
|
|
|
if let Some(transaction) = transaction {
|
|
query.fetch_all(&mut **transaction).await
|
|
} else {
|
|
query
|
|
.fetch_all(&self.get_connection_pool(vault).await?)
|
|
.await
|
|
}
|
|
.with_context(|| {
|
|
format!("Cannot fetch latest documents since vault_update_id `{vault_update_id}`")
|
|
})
|
|
.map(|rows| {
|
|
rows.into_iter()
|
|
.map(|row| DocumentVersionWithoutContent {
|
|
vault_update_id: row.vault_update_id,
|
|
document_id: row.document_id.into(),
|
|
relative_path: row.relative_path,
|
|
updated_date: row.updated_date,
|
|
is_deleted: row.is_deleted,
|
|
user_id: row.user_id,
|
|
device_id: row.device_id,
|
|
content_size: row
|
|
.content_size
|
|
.expect("Content size can't be null but sqlx can't infer it"),
|
|
})
|
|
.collect()
|
|
})
|
|
}
|
|
|
|
pub async fn get_max_update_id_in_vault(
|
|
&self,
|
|
vault: &VaultId,
|
|
transaction: Option<&mut Transaction<'_>>,
|
|
) -> Result<i64> {
|
|
let query = sqlx::query!(
|
|
r#"
|
|
select coalesce(max(vault_update_id), 0) as max_vault_update_id
|
|
from documents
|
|
"#,
|
|
);
|
|
|
|
if let Some(transaction) = transaction {
|
|
query.fetch_one(&mut **transaction).await
|
|
} else {
|
|
query
|
|
.fetch_one(&self.get_connection_pool(vault).await?)
|
|
.await
|
|
}
|
|
.map(|row| row.max_vault_update_id)
|
|
.context("Cannot fetch max update id in vault")
|
|
}
|
|
|
|
pub async fn get_latest_non_deleted_document_by_path(
|
|
&self,
|
|
vault: &VaultId,
|
|
relative_path: &str,
|
|
transaction: Option<&mut Transaction<'_>>,
|
|
) -> Result<Option<StoredDocumentVersion>> {
|
|
let query = sqlx::query_as!(
|
|
StoredDocumentVersion,
|
|
r#"
|
|
select
|
|
vault_update_id,
|
|
document_id as "document_id: Hyphenated",
|
|
relative_path,
|
|
updated_date as "updated_date: chrono::DateTime<Utc>",
|
|
content,
|
|
is_deleted,
|
|
user_id,
|
|
device_id,
|
|
has_been_merged
|
|
from latest_document_versions
|
|
where relative_path = ? and is_deleted = false
|
|
order by vault_update_id desc -- `latest_document_versions` only contains a single latest version of each document, however,
|
|
-- multiple documents can have the same `relative_path`, if they have been deleted. That's
|
|
-- why we only care about the latest version of the document with the given relative path.
|
|
limit 1
|
|
"#,
|
|
relative_path
|
|
);
|
|
|
|
if let Some(transaction) = transaction {
|
|
query.fetch_optional(&mut **transaction).await
|
|
} else {
|
|
query
|
|
.fetch_optional(&self.get_connection_pool(vault).await?)
|
|
.await
|
|
}
|
|
.context("Cannot fetch latest document version")
|
|
}
|
|
|
|
pub async fn get_latest_document(
|
|
&self,
|
|
vault: &VaultId,
|
|
document_id: &DocumentId,
|
|
transaction: Option<&mut Transaction<'_>>,
|
|
) -> Result<Option<StoredDocumentVersion>> {
|
|
let document_id = document_id.as_hyphenated();
|
|
let query = sqlx::query_as!(
|
|
StoredDocumentVersion,
|
|
r#"
|
|
select
|
|
vault_update_id,
|
|
document_id as "document_id: Hyphenated",
|
|
relative_path,
|
|
updated_date as "updated_date: chrono::DateTime<Utc>",
|
|
content,
|
|
is_deleted,
|
|
user_id,
|
|
device_id,
|
|
has_been_merged
|
|
from latest_document_versions
|
|
where document_id = ?
|
|
"#,
|
|
document_id
|
|
);
|
|
|
|
if let Some(transaction) = transaction {
|
|
query.fetch_optional(&mut **transaction).await
|
|
} else {
|
|
query
|
|
.fetch_optional(&self.get_connection_pool(vault).await?)
|
|
.await
|
|
}
|
|
.context("Cannot fetch latest document version")
|
|
}
|
|
|
|
pub async fn get_document_version(
|
|
&self,
|
|
vault: &VaultId,
|
|
vault_update_id: VaultUpdateId,
|
|
transaction: Option<&mut Transaction<'_>>,
|
|
) -> Result<Option<StoredDocumentVersion>> {
|
|
let query = sqlx::query_as!(
|
|
StoredDocumentVersion,
|
|
r#"
|
|
select
|
|
vault_update_id,
|
|
document_id as "document_id: Hyphenated",
|
|
relative_path,
|
|
updated_date as "updated_date: chrono::DateTime<Utc>",
|
|
content,
|
|
is_deleted,
|
|
user_id,
|
|
device_id,
|
|
has_been_merged
|
|
from documents
|
|
where vault_update_id = ?"#,
|
|
vault_update_id
|
|
);
|
|
|
|
if let Some(transaction) = transaction {
|
|
query.fetch_optional(&mut **transaction).await
|
|
} else {
|
|
query
|
|
.fetch_optional(&self.get_connection_pool(vault).await?)
|
|
.await
|
|
}
|
|
.context("Cannot fetch document version")
|
|
}
|
|
|
|
// inserting the document must be the last step of the transaction if there's one
|
|
pub async fn insert_document_version(
|
|
&self,
|
|
vault_id: &VaultId,
|
|
version: &StoredDocumentVersion,
|
|
transaction: Option<Transaction<'_>>,
|
|
) -> Result<()> {
|
|
let document_id = version.document_id.as_hyphenated();
|
|
let query = sqlx::query!(
|
|
r#"
|
|
insert into documents (
|
|
vault_update_id,
|
|
document_id,
|
|
relative_path,
|
|
updated_date,
|
|
content,
|
|
is_deleted,
|
|
user_id,
|
|
device_id
|
|
)
|
|
values (?, ?, ?, ?, ?, ?, ?, ?)
|
|
"#,
|
|
version.vault_update_id,
|
|
document_id,
|
|
version.relative_path,
|
|
version.updated_date,
|
|
version.content,
|
|
version.is_deleted,
|
|
version.user_id,
|
|
version.device_id
|
|
);
|
|
|
|
if let Some(mut transaction) = transaction {
|
|
query
|
|
.execute(&mut *transaction)
|
|
.await
|
|
.context("Cannot insert document version")?;
|
|
|
|
transaction
|
|
.commit()
|
|
.await
|
|
.context("Failed to commit transaction")?;
|
|
} else {
|
|
query
|
|
.execute(&self.get_connection_pool(vault_id).await?)
|
|
.await
|
|
.context("Cannot insert document version")?;
|
|
}
|
|
|
|
self.broadcasts
|
|
.send_document_update(
|
|
vault_id.clone(),
|
|
WebSocketServerMessageWithOrigin::with_origin(
|
|
version.device_id.clone(),
|
|
WebSocketServerMessage::VaultUpdate(WebSocketVaultUpdate {
|
|
documents: vec![version.clone().into()],
|
|
is_initial_sync: false,
|
|
}),
|
|
),
|
|
)
|
|
.await;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Cleanup idle connection pools that haven't been accessed in more than 5 minutes
|
|
async fn cleanup_idle_pools(&self) {
|
|
let mut pools = self.connection_pools.lock().await;
|
|
let now = Instant::now();
|
|
let idle_timeout = Duration::from_secs(5 * 60); // 5 minutes
|
|
|
|
// Collect vaults to remove
|
|
let vaults_to_remove: Vec<VaultId> = pools
|
|
.iter()
|
|
.filter(|(_, pool_with_timestamp)| {
|
|
now.duration_since(pool_with_timestamp.last_accessed) > idle_timeout
|
|
})
|
|
.map(|(vault_id, _)| vault_id.clone())
|
|
.collect();
|
|
|
|
// Close and remove idle pools
|
|
for vault_id in &vaults_to_remove {
|
|
if let Some(pool_with_timestamp) = pools.remove(vault_id) {
|
|
info!("Closing idle database connection pool for vault `{vault_id}`");
|
|
pool_with_timestamp.pool.close().await;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Start a background task that periodically cleans up idle connection pools
|
|
fn start_idle_pool_cleanup(&self) {
|
|
let database = self.clone();
|
|
tokio::spawn(async move {
|
|
let mut interval = tokio::time::interval(Duration::from_secs(60)); // Check every minute
|
|
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
|
|
|
loop {
|
|
interval.tick().await;
|
|
database.cleanup_idle_pools().await;
|
|
}
|
|
});
|
|
}
|
|
}
|