Remove rust workspace
This commit is contained in:
parent
da60f8c005
commit
fc19e650ca
52 changed files with 101 additions and 50 deletions
5
sync-server/.dockerignore
Normal file
5
sync-server/.dockerignore
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
target
|
||||
Dockerfile
|
||||
.dockerignore
|
||||
databases
|
||||
*.yml
|
||||
1
sync-server/.env
Normal file
1
sync-server/.env
Normal file
|
|
@ -0,0 +1 @@
|
|||
DATABASE_URL=sqlite://db.sqlite3
|
||||
3099
sync-server/Cargo.lock
generated
Normal file
3099
sync-server/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
89
sync-server/Cargo.toml
Normal file
89
sync-server/Cargo.toml
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
[package]
|
||||
name = "sync_server"
|
||||
rust-version = "1.83"
|
||||
authors = ["Andras Schmelczer <andras@schmelczer.dev>"]
|
||||
edition = "2024"
|
||||
license = "MIT"
|
||||
repository = "https://github.com/schmelczer/vault-link"
|
||||
version = "0.4.0"
|
||||
|
||||
[dependencies]
|
||||
serde = { version = "1.0.219", default-features = false, features = ["derive"] }
|
||||
thiserror = { version = "2.0.12", default-features = false }
|
||||
tokio = { version = "1.44.2", features = ["full"]}
|
||||
uuid = { version = "1.16.0", features = ["v4", "serde"] }
|
||||
log = { version = "0.4.27" }
|
||||
anyhow = { version = "1.0.98", features = ["backtrace"] }
|
||||
axum = { version = "0.7.4", features = ["ws", "macros", "tracing", "multipart"]}
|
||||
axum-extra = { version = "0.9.6", features = ["typed-header"] }
|
||||
axum_typed_multipart = "0.11.0"
|
||||
tower-http = { version = "0.6.1", features = ["cors", "trace", "limit", "timeout"] }
|
||||
tracing = "0.1.41"
|
||||
tracing-subscriber = { version = "0.3.19", features = ["fmt", "env-filter"]}
|
||||
sqlx = { version = "0.8.6", features = ["sqlite", "runtime-tokio", "uuid", "chrono"] }
|
||||
chrono = { version = "0.4.41", features = ["serde"] }
|
||||
rand = "0.9.0"
|
||||
sanitize-filename = "0.6.0"
|
||||
regex = "1.11.1"
|
||||
clap = { version = "4.5.38", features = ["derive"] }
|
||||
futures = "0.3.31"
|
||||
serde_yaml = "0.9.34"
|
||||
serde_json = "1.0.140"
|
||||
clap-verbosity-flag = "3.0.3"
|
||||
bimap = "0.6.3"
|
||||
ts-rs = { version = "10.1", features = ["uuid-impl", "chrono-impl"] }
|
||||
serde_with = "3.12.0"
|
||||
base64 = "0.22.1"
|
||||
reconcile-text = "0.4.10"
|
||||
|
||||
[profile.release]
|
||||
codegen-units = 1
|
||||
lto = true
|
||||
opt-level = 3
|
||||
strip="debuginfo" # Keep some info for better panics
|
||||
|
||||
[lints.rust]
|
||||
unsafe_code = "forbid"
|
||||
rust_2018_idioms = { level = "warn", priority = -1 }
|
||||
missing_debug_implementations = "warn"
|
||||
|
||||
[lints.clippy]
|
||||
await_holding_lock = "warn"
|
||||
dbg_macro = "warn"
|
||||
empty_enum = "warn"
|
||||
enum_glob_use = "warn"
|
||||
exit = "warn"
|
||||
filter_map_next = "warn"
|
||||
fn_params_excessive_bools = "warn"
|
||||
if_let_mutex = "warn"
|
||||
imprecise_flops = "warn"
|
||||
inefficient_to_string = "warn"
|
||||
linkedlist = "warn"
|
||||
lossy_float_literal = "warn"
|
||||
macro_use_imports = "warn"
|
||||
match_wildcard_for_single_variants = "warn"
|
||||
mem_forget = "warn"
|
||||
needless_borrow = "warn"
|
||||
needless_continue = "warn"
|
||||
option_option = "warn"
|
||||
rest_pat_in_fully_bound_structs = "warn"
|
||||
str_to_string = "warn"
|
||||
suboptimal_flops = "warn"
|
||||
todo = "warn"
|
||||
uninlined_format_args = "warn"
|
||||
unnested_or_patterns = "warn"
|
||||
unused_self = "warn"
|
||||
verbose_file_reads = "warn"
|
||||
|
||||
large_stack_arrays = { level = "allow", priority = 1 } # https://github.com/rust-lang/rust-clippy/issues/13774
|
||||
|
||||
# Silly lints
|
||||
implicit_return = { level = "allow", priority = 1 }
|
||||
question_mark_used = { level = "allow", priority = 1 }
|
||||
struct_field_names = { level = "allow", priority = 1 }
|
||||
single_char_lifetime_names = { level = "allow", priority = 1 }
|
||||
single_call_fn = { level = "allow", priority = 1 }
|
||||
similar_names = { level = "allow", priority = 1 }
|
||||
missing_docs_in_private_items = { level = "allow", priority = 1 }
|
||||
|
||||
pedantic = { level = "warn", priority = 0 }
|
||||
33
sync-server/Dockerfile
Normal file
33
sync-server/Dockerfile
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
FROM rust:1.87 AS builder
|
||||
|
||||
WORKDIR /usr/src/backend
|
||||
|
||||
RUN apt update && apt install -y musl-tools
|
||||
RUN cargo install sqlx-cli
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN sqlx database create --database-url sqlite://db.sqlite3
|
||||
RUN sqlx migrate run --source sync_server/src/app_state/database/migrations --database-url sqlite://db.sqlite3
|
||||
|
||||
RUN cargo build --package sync_server --release --target x86_64-unknown-linux-musl
|
||||
|
||||
# Runtime image
|
||||
FROM alpine:3.22.0
|
||||
|
||||
LABEL org.opencontainers.image.authors="andras@schmelczer.dev"
|
||||
|
||||
RUN apk add --no-cache curl
|
||||
|
||||
COPY --from=builder /usr/src/backend/target/x86_64-unknown-linux-musl/release/sync_server /app/sync_server
|
||||
|
||||
VOLUME /data
|
||||
EXPOSE 3000/tcp
|
||||
WORKDIR /data
|
||||
|
||||
HEALTHCHECK \
|
||||
--interval=30s \
|
||||
--timeout=5s \
|
||||
CMD curl -f http://localhost:3000/vaults/fake/ping || exit 1
|
||||
|
||||
ENTRYPOINT ["/app/sync_server"]
|
||||
9
sync-server/README.md
Normal file
9
sync-server/README.md
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
# Sync server
|
||||
|
||||
## Creating/resetting the Database for development
|
||||
|
||||
```sh
|
||||
sqlx database create --database-url sqlite://db.sqlite3
|
||||
sqlx migrate run --source sync_server/src/app_state/database/migrations --database-url sqlite://db.sqlite3
|
||||
cargo sqlx prepare --workspace
|
||||
```
|
||||
5
sync-server/build.rs
Normal file
5
sync-server/build.rs
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
// generated by `sqlx migrate build-script`
|
||||
fn main() {
|
||||
// trigger recompilation when a new migration is added
|
||||
println!("cargo:rerun-if-changed=migrations");
|
||||
}
|
||||
26
sync-server/config-e2e.yml
Normal file
26
sync-server/config-e2e.yml
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
database:
|
||||
databases_directory_path: databases
|
||||
max_connections_per_vault: 12
|
||||
cursor_timeout_seconds: 60
|
||||
server:
|
||||
host: 0.0.0.0
|
||||
port: 3000
|
||||
max_body_size_mb: 512
|
||||
max_clients_per_vault: 256
|
||||
response_timeout_seconds: 60
|
||||
users:
|
||||
user_configs:
|
||||
- name: admin
|
||||
token: test-token-change-me
|
||||
vault_access:
|
||||
type: allow_access_to_all
|
||||
- name: other-admin
|
||||
token: test-token-change-me2
|
||||
vault_access:
|
||||
type: allow_access_to_all
|
||||
- name: test
|
||||
token: other-test-token
|
||||
vault_access:
|
||||
type: allow_list
|
||||
allowed:
|
||||
- default
|
||||
4
sync-server/rust-toolchain.toml
Normal file
4
sync-server/rust-toolchain.toml
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
[toolchain]
|
||||
channel = "nightly-2025-06-06"
|
||||
targets = [ "x86_64-unknown-linux-gnu", "x86_64-unknown-linux-musl" ]
|
||||
profile = "default"
|
||||
8
sync-server/rustfmt.toml
Normal file
8
sync-server/rustfmt.toml
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
imports_granularity = "crate"
|
||||
condense_wildcard_suffixes = true
|
||||
fn_single_line = true
|
||||
format_strings = true
|
||||
reorder_impl_items = true
|
||||
group_imports = "StdExternalCrate"
|
||||
use_field_init_shorthand = true
|
||||
wrap_comments=true
|
||||
41
sync-server/src/app_state.rs
Normal file
41
sync-server/src/app_state.rs
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
pub mod cursors;
|
||||
pub mod database;
|
||||
pub mod websocket;
|
||||
|
||||
use std::ffi::OsString;
|
||||
|
||||
use anyhow::Result;
|
||||
use cursors::Cursors;
|
||||
use database::Database;
|
||||
use websocket::broadcasts::Broadcasts;
|
||||
|
||||
use crate::{config::Config, consts::DEFAULT_CONFIG_PATH};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AppState {
|
||||
pub config: Config,
|
||||
pub database: Database,
|
||||
pub cursors: Cursors,
|
||||
pub broadcasts: Broadcasts,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub async fn try_new(config_path: Option<OsString>) -> Result<Self> {
|
||||
let config_path = config_path.unwrap_or_else(|| OsString::from(DEFAULT_CONFIG_PATH));
|
||||
let path = std::path::PathBuf::from(config_path);
|
||||
|
||||
let config = Config::read_or_create(&path).await?;
|
||||
let broadcasts = Broadcasts::new(&config.server);
|
||||
let database = Database::try_new(&config.database, &broadcasts).await?;
|
||||
let cursors: Cursors = Cursors::new(&config.database, &broadcasts);
|
||||
|
||||
Cursors::start_background_task(cursors.clone());
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
database,
|
||||
cursors,
|
||||
broadcasts,
|
||||
})
|
||||
}
|
||||
}
|
||||
128
sync-server/src/app_state/cursors.rs
Normal file
128
sync-server/src/app_state/cursors.rs
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
use core::time::Duration;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use super::{
|
||||
database::models::{DeviceId, VaultId},
|
||||
websocket::{
|
||||
broadcasts::Broadcasts,
|
||||
models::{
|
||||
ClientCursors, CursorPositionFromServer, CursorSpan, WebSocketServerMessage,
|
||||
WebSocketServerMessageWithOrigin,
|
||||
},
|
||||
},
|
||||
};
|
||||
use crate::config::database_config::DatabaseConfig;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Cursors {
|
||||
config: DatabaseConfig,
|
||||
broadcasts: Broadcasts,
|
||||
vault_to_cursors: Arc<Mutex<HashMap<VaultId, Vec<ClientCursorsWithTimeToLive>>>>,
|
||||
}
|
||||
|
||||
impl Cursors {
|
||||
pub fn new(config: &DatabaseConfig, broadcasts: &Broadcasts) -> Self {
|
||||
Self {
|
||||
config: config.clone(),
|
||||
broadcasts: broadcasts.clone(),
|
||||
vault_to_cursors: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn update_cursors(
|
||||
&self,
|
||||
vault_id: VaultId,
|
||||
user_name: String,
|
||||
device_id: &DeviceId,
|
||||
document_to_cursors: HashMap<String, Vec<CursorSpan>>,
|
||||
) {
|
||||
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||
|
||||
let all_device_cursors = vault_to_cursors.entry(vault_id).or_insert_with(Vec::new);
|
||||
|
||||
all_device_cursors.retain(|c| &c.client_cursors.device_id != device_id);
|
||||
all_device_cursors.push(ClientCursorsWithTimeToLive::new(ClientCursors {
|
||||
user_name,
|
||||
device_id: device_id.to_string(),
|
||||
cursors: document_to_cursors,
|
||||
}));
|
||||
|
||||
drop(vault_to_cursors); // Explicitly drop the lock before broadcasting to avoid deadlock
|
||||
self.broadcast_cursors().await;
|
||||
}
|
||||
|
||||
pub async fn get_cursors(&self, vault_id: &VaultId) -> Vec<ClientCursors> {
|
||||
let vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||
vault_to_cursors
|
||||
.get(vault_id)
|
||||
.map(|cursors| {
|
||||
cursors
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|with_ttl| with_ttl.client_cursors)
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn start_background_task(self) {
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
self.remove_expired_cursors().await;
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn remove_expired_cursors(&self) {
|
||||
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||
|
||||
for (_vault_id, cursors) in vault_to_cursors.iter_mut() {
|
||||
cursors.retain(|cursor| !cursor.is_expired(self.config.cursor_timeout));
|
||||
}
|
||||
}
|
||||
|
||||
async fn broadcast_cursors(&self) {
|
||||
let vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||
|
||||
for (vault_id, cursors) in vault_to_cursors.iter() {
|
||||
self.broadcasts
|
||||
.send_document_update(
|
||||
vault_id.clone(),
|
||||
WebSocketServerMessageWithOrigin::new(WebSocketServerMessage::CursorPositions(
|
||||
CursorPositionFromServer {
|
||||
clients: cursors.iter().map(|c| c.client_cursors.clone()).collect(),
|
||||
},
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn remove_cursors_of_device(&self, vault_id: &str, device_id: &str) {
|
||||
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||
|
||||
if let Some(cursors) = vault_to_cursors.get_mut(vault_id) {
|
||||
cursors.retain(|c| c.client_cursors.device_id != device_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct ClientCursorsWithTimeToLive {
|
||||
client_cursors: ClientCursors,
|
||||
last_updated: std::time::Instant,
|
||||
}
|
||||
|
||||
impl ClientCursorsWithTimeToLive {
|
||||
fn new(client_cursors: ClientCursors) -> Self {
|
||||
Self {
|
||||
client_cursors,
|
||||
last_updated: std::time::Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_expired(&self, ttl: Duration) -> bool { self.last_updated.elapsed() > ttl }
|
||||
}
|
||||
425
sync-server/src/app_state/database.rs
Normal file
425
sync-server/src/app_state/database.rs
Normal file
|
|
@ -0,0 +1,425 @@
|
|||
use core::time::Duration;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use models::{
|
||||
DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId, VaultUpdateId,
|
||||
};
|
||||
use sqlx::{sqlite::SqliteConnectOptions, types::chrono::Utc};
|
||||
|
||||
pub mod models;
|
||||
use sqlx::{Pool, Sqlite, sqlite::SqlitePoolOptions};
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::fmt::Hyphenated;
|
||||
|
||||
use super::websocket::{
|
||||
broadcasts::Broadcasts,
|
||||
models::{WebSocketServerMessage, WebSocketServerMessageWithOrigin, WebSocketVaultUpdate},
|
||||
};
|
||||
use crate::config::database_config::DatabaseConfig;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Database {
|
||||
config: DatabaseConfig,
|
||||
broadcasts: Broadcasts,
|
||||
connection_pools: Arc<Mutex<HashMap<VaultId, Pool<Sqlite>>>>,
|
||||
}
|
||||
|
||||
pub type Transaction<'a> = sqlx::Transaction<'a, Sqlite>;
|
||||
|
||||
impl Database {
|
||||
pub async fn try_new(config: &DatabaseConfig, broadcasts: &Broadcasts) -> Result<Self> {
|
||||
tokio::fs::create_dir_all(&config.databases_directory_path)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Failed to create databases directory: {}",
|
||||
config.databases_directory_path.to_string_lossy()
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut connection_pools = std::collections::HashMap::new();
|
||||
|
||||
let mut entries = tokio::fs::read_dir(&config.databases_directory_path).await?;
|
||||
while let Some(entry) = entries.next_entry().await? {
|
||||
if !entry.file_name().to_string_lossy().ends_with(".sqlite") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let vault: VaultId = entry
|
||||
.file_name()
|
||||
.to_string_lossy()
|
||||
.trim_end_matches(".sqlite")
|
||||
.to_owned();
|
||||
|
||||
connection_pools.insert(
|
||||
vault.clone(),
|
||||
Self::create_vault_database(config, &vault).await?,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
config: config.clone(),
|
||||
connection_pools: Arc::new(Mutex::new(connection_pools)),
|
||||
broadcasts: broadcasts.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn create_vault_database(
|
||||
config: &DatabaseConfig,
|
||||
vault: &VaultId,
|
||||
) -> Result<Pool<Sqlite>> {
|
||||
let file_name = config
|
||||
.databases_directory_path
|
||||
.join(format!("{vault}.sqlite"));
|
||||
|
||||
let connection_options = SqliteConnectOptions::new()
|
||||
.filename(file_name.clone())
|
||||
.create_if_missing(true)
|
||||
.busy_timeout(Duration::from_secs(3600))
|
||||
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal);
|
||||
|
||||
let pool = SqlitePoolOptions::new()
|
||||
.max_connections(config.max_connections_per_vault)
|
||||
.test_before_acquire(true)
|
||||
.connect_with(connection_options)
|
||||
.await
|
||||
.with_context(|| format!("Cannot open database at {}", file_name.display()))?;
|
||||
|
||||
Self::run_migrations(&pool).await?;
|
||||
|
||||
Ok(pool)
|
||||
}
|
||||
|
||||
async fn run_migrations(pool: &Pool<Sqlite>) -> Result<()> {
|
||||
sqlx::migrate!("src/app_state/database/migrations")
|
||||
.run(pool)
|
||||
.await
|
||||
.context("Cannot check for pending migrations")
|
||||
}
|
||||
|
||||
async fn get_connection_pool(&self, vault: &VaultId) -> Result<Pool<Sqlite>> {
|
||||
let mut pools = self.connection_pools.lock().await;
|
||||
if !pools.contains_key(vault) {
|
||||
let pool = Self::create_vault_database(&self.config, vault).await?;
|
||||
pools.insert(vault.clone(), pool);
|
||||
}
|
||||
|
||||
let pool = pools
|
||||
.get(vault)
|
||||
.expect("Pool was just inserted or already exists");
|
||||
|
||||
Ok(pool.clone())
|
||||
}
|
||||
|
||||
/// Attempting to write from this transaction might result in a
|
||||
/// database locked error. Use this transaction for read-only operations.
|
||||
pub async fn create_readonly_transaction(
|
||||
&self,
|
||||
vault: &VaultId,
|
||||
) -> Result<Transaction<'static>> {
|
||||
self.get_connection_pool(vault)
|
||||
.await?
|
||||
.begin()
|
||||
.await
|
||||
.context("Cannot create transaction")
|
||||
}
|
||||
|
||||
pub async fn create_write_transaction(&self, vault: &VaultId) -> Result<Transaction<'static>> {
|
||||
let mut transaction = self.create_readonly_transaction(vault).await?;
|
||||
|
||||
// sqlx doesn't support immediate transactions for sqlite: https://github.com/launchbadge/sqlx/issues/481
|
||||
sqlx::query!("END; BEGIN IMMEDIATE;")
|
||||
.execute(&mut *transaction)
|
||||
.await?;
|
||||
|
||||
Ok(transaction)
|
||||
}
|
||||
|
||||
/// Return the latest state of all documents in the vault
|
||||
pub async fn get_latest_documents(
|
||||
&self,
|
||||
vault: &VaultId,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
) -> Result<Vec<DocumentVersionWithoutContent>> {
|
||||
let query = sqlx::query!(
|
||||
r#"
|
||||
select
|
||||
vault_update_id,
|
||||
document_id as "document_id: Hyphenated",
|
||||
relative_path,
|
||||
updated_date as "updated_date: chrono::DateTime<Utc>",
|
||||
is_deleted,
|
||||
user_id,
|
||||
device_id,
|
||||
length(content) as "content_size: u64"
|
||||
from latest_document_versions
|
||||
order by vault_update_id
|
||||
"#,
|
||||
);
|
||||
|
||||
if let Some(transaction) = transaction {
|
||||
query.fetch_all(&mut **transaction).await
|
||||
} else {
|
||||
query
|
||||
.fetch_all(&self.get_connection_pool(vault).await?)
|
||||
.await
|
||||
}
|
||||
.context("Cannot fetch latest documents")
|
||||
.map(|rows| {
|
||||
rows.into_iter()
|
||||
.map(|row| DocumentVersionWithoutContent {
|
||||
vault_update_id: row.vault_update_id,
|
||||
document_id: row.document_id.into(),
|
||||
relative_path: row.relative_path,
|
||||
updated_date: row.updated_date,
|
||||
is_deleted: row.is_deleted,
|
||||
user_id: row.user_id,
|
||||
device_id: row.device_id,
|
||||
content_size: row
|
||||
.content_size
|
||||
.expect("Content size can't be null but sqlx can't infer it"),
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
/// Return the latest state of all documents (including deleted) in the
|
||||
/// vault which have changed since the given update id
|
||||
pub async fn get_latest_documents_since(
|
||||
&self,
|
||||
vault: &VaultId,
|
||||
vault_update_id: VaultUpdateId,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
) -> Result<Vec<DocumentVersionWithoutContent>> {
|
||||
let query = sqlx::query!(
|
||||
r#"
|
||||
select
|
||||
vault_update_id,
|
||||
document_id as "document_id: Hyphenated",
|
||||
relative_path,
|
||||
updated_date as "updated_date: chrono::DateTime<Utc>",
|
||||
is_deleted,
|
||||
user_id,
|
||||
device_id,
|
||||
length(content) as "content_size: u64"
|
||||
from latest_document_versions
|
||||
where vault_update_id > ?
|
||||
order by vault_update_id
|
||||
"#,
|
||||
vault_update_id
|
||||
);
|
||||
|
||||
if let Some(transaction) = transaction {
|
||||
query.fetch_all(&mut **transaction).await
|
||||
} else {
|
||||
query
|
||||
.fetch_all(&self.get_connection_pool(vault).await?)
|
||||
.await
|
||||
}
|
||||
.with_context(|| {
|
||||
format!("Cannot fetch latest documents since vault_update_id {vault_update_id}")
|
||||
})
|
||||
.map(|rows| {
|
||||
rows.into_iter()
|
||||
.map(|row| DocumentVersionWithoutContent {
|
||||
vault_update_id: row.vault_update_id,
|
||||
document_id: row.document_id.into(),
|
||||
relative_path: row.relative_path,
|
||||
updated_date: row.updated_date,
|
||||
is_deleted: row.is_deleted,
|
||||
user_id: row.user_id,
|
||||
device_id: row.device_id,
|
||||
content_size: row
|
||||
.content_size
|
||||
.expect("Content size can't be null but sqlx can't infer it"),
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_max_update_id_in_vault(
|
||||
&self,
|
||||
vault: &VaultId,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
) -> Result<i64> {
|
||||
let query = sqlx::query!(
|
||||
r#"
|
||||
select coalesce(max(vault_update_id), 0) as max_vault_update_id
|
||||
from documents
|
||||
"#,
|
||||
);
|
||||
|
||||
if let Some(transaction) = transaction {
|
||||
query.fetch_one(&mut **transaction).await
|
||||
} else {
|
||||
query
|
||||
.fetch_one(&self.get_connection_pool(vault).await?)
|
||||
.await
|
||||
}
|
||||
.map(|row| row.max_vault_update_id)
|
||||
.context("Cannot fetch max update id in vault")
|
||||
}
|
||||
|
||||
pub async fn get_latest_document_by_path(
|
||||
&self,
|
||||
vault: &VaultId,
|
||||
relative_path: &str,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
) -> Result<Option<StoredDocumentVersion>> {
|
||||
let query = sqlx::query_as!(
|
||||
StoredDocumentVersion,
|
||||
r#"
|
||||
select
|
||||
vault_update_id,
|
||||
document_id as "document_id: Hyphenated",
|
||||
relative_path,
|
||||
updated_date as "updated_date: chrono::DateTime<Utc>",
|
||||
content,
|
||||
is_deleted,
|
||||
user_id,
|
||||
device_id
|
||||
from latest_document_versions
|
||||
where relative_path = ?
|
||||
order by vault_update_id desc -- `latest_document_versions` only contains a single latest version of each document, however,
|
||||
-- multiple documents can have the same `relative_path`, if they have been deleted. That's
|
||||
-- why we only care about the latest version of the document with the given relative path.
|
||||
limit 1
|
||||
"#,
|
||||
relative_path
|
||||
);
|
||||
|
||||
if let Some(transaction) = transaction {
|
||||
query.fetch_optional(&mut **transaction).await
|
||||
} else {
|
||||
query
|
||||
.fetch_optional(&self.get_connection_pool(vault).await?)
|
||||
.await
|
||||
}
|
||||
.context("Cannot fetch latest document version")
|
||||
}
|
||||
|
||||
pub async fn get_latest_document(
|
||||
&self,
|
||||
vault: &VaultId,
|
||||
document_id: &DocumentId,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
) -> Result<Option<StoredDocumentVersion>> {
|
||||
let document_id = document_id.as_hyphenated();
|
||||
let query = sqlx::query_as!(
|
||||
StoredDocumentVersion,
|
||||
r#"
|
||||
select
|
||||
vault_update_id,
|
||||
document_id as "document_id: Hyphenated",
|
||||
relative_path,
|
||||
updated_date as "updated_date: chrono::DateTime<Utc>",
|
||||
content,
|
||||
is_deleted,
|
||||
user_id,
|
||||
device_id
|
||||
from latest_document_versions
|
||||
where document_id = ?
|
||||
"#,
|
||||
document_id
|
||||
);
|
||||
|
||||
if let Some(transaction) = transaction {
|
||||
query.fetch_optional(&mut **transaction).await
|
||||
} else {
|
||||
query
|
||||
.fetch_optional(&self.get_connection_pool(vault).await?)
|
||||
.await
|
||||
}
|
||||
.context("Cannot fetch latest document version")
|
||||
}
|
||||
|
||||
pub async fn get_document_version(
|
||||
&self,
|
||||
vault: &VaultId,
|
||||
vault_update_id: VaultUpdateId,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
) -> Result<Option<StoredDocumentVersion>> {
|
||||
let query = sqlx::query_as!(
|
||||
StoredDocumentVersion,
|
||||
r#"
|
||||
select
|
||||
vault_update_id,
|
||||
document_id as "document_id: Hyphenated",
|
||||
relative_path,
|
||||
updated_date as "updated_date: chrono::DateTime<Utc>",
|
||||
content,
|
||||
is_deleted,
|
||||
user_id,
|
||||
device_id
|
||||
from documents
|
||||
where vault_update_id = ?"#,
|
||||
vault_update_id
|
||||
);
|
||||
|
||||
if let Some(transaction) = transaction {
|
||||
query.fetch_optional(&mut **transaction).await
|
||||
} else {
|
||||
query
|
||||
.fetch_optional(&self.get_connection_pool(vault).await?)
|
||||
.await
|
||||
}
|
||||
.context("Cannot fetch document version")
|
||||
}
|
||||
|
||||
pub async fn insert_document_version(
|
||||
&self,
|
||||
vault_id: &VaultId,
|
||||
version: &StoredDocumentVersion,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
) -> Result<()> {
|
||||
let document_id = version.document_id.as_hyphenated();
|
||||
let query = sqlx::query!(
|
||||
r#"
|
||||
insert into documents (
|
||||
vault_update_id,
|
||||
document_id,
|
||||
relative_path,
|
||||
updated_date,
|
||||
content,
|
||||
is_deleted,
|
||||
user_id,
|
||||
device_id
|
||||
)
|
||||
values (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"#,
|
||||
version.vault_update_id,
|
||||
document_id,
|
||||
version.relative_path,
|
||||
version.updated_date,
|
||||
version.content,
|
||||
version.is_deleted,
|
||||
version.user_id,
|
||||
version.device_id
|
||||
);
|
||||
|
||||
if let Some(transaction) = transaction {
|
||||
query.execute(&mut **transaction).await
|
||||
} else {
|
||||
query
|
||||
.execute(&self.get_connection_pool(vault_id).await?)
|
||||
.await
|
||||
}
|
||||
.context("Cannot insert document version")?;
|
||||
|
||||
self.broadcasts
|
||||
.send_document_update(
|
||||
vault_id.clone(),
|
||||
WebSocketServerMessageWithOrigin::with_origin(
|
||||
version.device_id.clone(),
|
||||
WebSocketServerMessage::VaultUpdate(WebSocketVaultUpdate {
|
||||
documents: vec![version.clone().into()],
|
||||
is_initial_sync: false,
|
||||
}),
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
CREATE TABLE IF NOT EXISTS documents (
|
||||
vault_update_id INTEGER NOT NULL PRIMARY KEY,
|
||||
document_id TEXT NOT NULL,
|
||||
relative_path TEXT NOT NULL,
|
||||
updated_date TIMESTAMP NOT NULL,
|
||||
content BLOB NOT NULL,
|
||||
is_deleted BOOLEAN NOT NULL
|
||||
);
|
||||
|
||||
CREATE VIEW IF NOT EXISTS latest_document_versions AS
|
||||
SELECT d.*
|
||||
FROM documents d
|
||||
INNER JOIN (
|
||||
SELECT MAX(vault_update_id) AS max_version_id
|
||||
FROM documents
|
||||
GROUP BY document_id
|
||||
) max_versions
|
||||
ON d.vault_update_id = max_versions.max_version_id;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_documents_vault_id_relative_path
|
||||
ON documents (relative_path);
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
ALTER TABLE documents ADD COLUMN user_id TEXT NOT NULL DEFAULT "";
|
||||
ALTER TABLE documents ADD COLUMN device_id TEXT NOT NULL DEFAULT "";
|
||||
89
sync-server/src/app_state/database/models.rs
Normal file
89
sync-server/src/app_state/database/models.rs
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
use base64::{Engine as _, engine::general_purpose::STANDARD};
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::Serialize;
|
||||
use ts_rs::TS;
|
||||
|
||||
pub type VaultId = String;
|
||||
pub type VaultUpdateId = i64;
|
||||
|
||||
pub type DocumentId = uuid::Uuid;
|
||||
pub type UserId = String;
|
||||
pub type DeviceId = String;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StoredDocumentVersion {
|
||||
pub vault_update_id: VaultUpdateId,
|
||||
pub document_id: DocumentId,
|
||||
pub relative_path: String,
|
||||
pub updated_date: DateTime<Utc>,
|
||||
pub content: Vec<u8>,
|
||||
pub is_deleted: bool,
|
||||
pub user_id: UserId,
|
||||
pub device_id: DeviceId,
|
||||
}
|
||||
|
||||
impl PartialEq<Self> for StoredDocumentVersion {
|
||||
fn eq(&self, other: &Self) -> bool { self.vault_update_id == other.vault_update_id }
|
||||
}
|
||||
|
||||
#[derive(TS, Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct DocumentVersionWithoutContent {
|
||||
#[ts(as = "i32")]
|
||||
pub vault_update_id: VaultUpdateId,
|
||||
|
||||
pub document_id: DocumentId,
|
||||
pub relative_path: String,
|
||||
pub updated_date: DateTime<Utc>,
|
||||
pub is_deleted: bool,
|
||||
pub user_id: UserId,
|
||||
pub device_id: DeviceId,
|
||||
|
||||
#[ts(as = "i32")]
|
||||
pub content_size: u64,
|
||||
}
|
||||
|
||||
impl From<StoredDocumentVersion> for DocumentVersionWithoutContent {
|
||||
fn from(value: StoredDocumentVersion) -> Self {
|
||||
Self {
|
||||
vault_update_id: value.vault_update_id,
|
||||
document_id: value.document_id,
|
||||
relative_path: value.relative_path,
|
||||
updated_date: value.updated_date,
|
||||
is_deleted: value.is_deleted,
|
||||
user_id: value.user_id,
|
||||
device_id: value.device_id,
|
||||
content_size: value.content.len() as u64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(TS, Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct DocumentVersion {
|
||||
#[ts(as = "i32")]
|
||||
pub vault_update_id: VaultUpdateId,
|
||||
|
||||
pub document_id: DocumentId,
|
||||
pub relative_path: String,
|
||||
pub updated_date: DateTime<Utc>,
|
||||
pub content_base64: String,
|
||||
pub is_deleted: bool,
|
||||
pub user_id: UserId,
|
||||
pub device_id: DeviceId,
|
||||
}
|
||||
|
||||
impl From<StoredDocumentVersion> for DocumentVersion {
|
||||
fn from(value: StoredDocumentVersion) -> Self {
|
||||
Self {
|
||||
vault_update_id: value.vault_update_id,
|
||||
document_id: value.document_id,
|
||||
relative_path: value.relative_path,
|
||||
updated_date: value.updated_date,
|
||||
content_base64: STANDARD.encode(&value.content),
|
||||
is_deleted: value.is_deleted,
|
||||
user_id: value.user_id,
|
||||
device_id: value.device_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
3
sync-server/src/app_state/websocket.rs
Normal file
3
sync-server/src/app_state/websocket.rs
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
pub mod broadcasts;
|
||||
pub mod models;
|
||||
pub mod utils;
|
||||
63
sync-server/src/app_state/websocket/broadcasts.rs
Normal file
63
sync-server/src/app_state/websocket/broadcasts.rs
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use anyhow::Context;
|
||||
use tokio::sync::{Mutex, broadcast};
|
||||
|
||||
use super::models::WebSocketServerMessageWithOrigin;
|
||||
use crate::{
|
||||
app_state::database::models::VaultId, config::server_config::ServerConfig, errors::server_error,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Broadcasts {
|
||||
max_clients_per_vault: usize,
|
||||
tx: Arc<Mutex<HashMap<VaultId, broadcast::Sender<WebSocketServerMessageWithOrigin>>>>,
|
||||
}
|
||||
|
||||
impl Broadcasts {
|
||||
pub fn new(server_config: &ServerConfig) -> Self {
|
||||
Self {
|
||||
max_clients_per_vault: server_config.max_clients_per_vault,
|
||||
tx: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_receiver(
|
||||
&self,
|
||||
vault: VaultId,
|
||||
) -> broadcast::Receiver<WebSocketServerMessageWithOrigin> {
|
||||
let tx = self.get_or_create(vault).await;
|
||||
|
||||
tx.subscribe()
|
||||
}
|
||||
|
||||
/// Notify all clients (who are subscribed to the vault) about an update.
|
||||
/// We only log failures.
|
||||
pub async fn send_document_update(
|
||||
&self,
|
||||
vault: VaultId,
|
||||
document: WebSocketServerMessageWithOrigin,
|
||||
) {
|
||||
let tx = self.get_or_create(vault).await;
|
||||
|
||||
let result = tx
|
||||
.send(document)
|
||||
.context("Cannot broadcast server message to websocket listeners")
|
||||
.map_err(server_error);
|
||||
|
||||
if result.is_err() {
|
||||
log::debug!("Failed to send message: {result:?}");
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_or_create(
|
||||
&self,
|
||||
vault: VaultId,
|
||||
) -> broadcast::Sender<WebSocketServerMessageWithOrigin> {
|
||||
let mut tx = self.tx.lock().await;
|
||||
|
||||
tx.entry(vault)
|
||||
.or_insert_with(|| broadcast::channel(self.max_clients_per_vault).0.clone())
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
88
sync-server/src/app_state/websocket/models.rs
Normal file
88
sync-server/src/app_state/websocket/models.rs
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use ts_rs::TS;
|
||||
|
||||
use crate::app_state::database::models::{DeviceId, DocumentVersionWithoutContent, VaultUpdateId};
|
||||
|
||||
#[derive(TS, Deserialize, Clone, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct WebSocketHandshake {
|
||||
pub token: String,
|
||||
pub device_id: DeviceId,
|
||||
|
||||
#[ts(as = "Option<i32>")]
|
||||
pub last_seen_vault_update_id: Option<VaultUpdateId>,
|
||||
}
|
||||
|
||||
#[derive(TS, Serialize, Deserialize, Clone, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CursorSpan {
|
||||
pub start: usize,
|
||||
pub end: usize,
|
||||
}
|
||||
|
||||
#[derive(TS, Deserialize, Clone, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CursorPositionFromClient {
|
||||
pub document_to_cursors: HashMap<String, Vec<CursorSpan>>,
|
||||
}
|
||||
|
||||
#[derive(TS, Serialize, Clone, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClientCursors {
|
||||
pub user_name: String,
|
||||
pub device_id: DeviceId,
|
||||
pub cursors: HashMap<String, Vec<CursorSpan>>,
|
||||
}
|
||||
|
||||
#[derive(TS, Serialize, Clone, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CursorPositionFromServer {
|
||||
pub clients: Vec<ClientCursors>,
|
||||
}
|
||||
|
||||
#[derive(TS, Serialize, Clone, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct WebSocketVaultUpdate {
|
||||
pub documents: Vec<DocumentVersionWithoutContent>,
|
||||
pub is_initial_sync: bool,
|
||||
}
|
||||
|
||||
#[derive(TS, Deserialize, Clone, Debug)]
|
||||
#[serde(rename_all = "camelCase", tag = "type")]
|
||||
#[ts(export)]
|
||||
pub enum WebSocketClientMessage {
|
||||
Handshake(WebSocketHandshake),
|
||||
CursorPositions(CursorPositionFromClient),
|
||||
}
|
||||
|
||||
#[derive(TS, Serialize, Clone, Debug)]
|
||||
#[serde(rename_all = "camelCase", tag = "type")]
|
||||
#[ts(export)]
|
||||
pub enum WebSocketServerMessage {
|
||||
VaultUpdate(WebSocketVaultUpdate),
|
||||
CursorPositions(CursorPositionFromServer),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct WebSocketServerMessageWithOrigin {
|
||||
pub origin_device_id: Option<DeviceId>,
|
||||
pub message: WebSocketServerMessage,
|
||||
}
|
||||
|
||||
impl WebSocketServerMessageWithOrigin {
|
||||
pub fn new(message: WebSocketServerMessage) -> Self {
|
||||
Self {
|
||||
origin_device_id: None,
|
||||
message,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_origin(origin_device_id: DeviceId, message: WebSocketServerMessage) -> Self {
|
||||
Self {
|
||||
origin_device_id: Some(origin_device_id),
|
||||
message,
|
||||
}
|
||||
}
|
||||
}
|
||||
80
sync-server/src/app_state/websocket/utils.rs
Normal file
80
sync-server/src/app_state/websocket/utils.rs
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
use anyhow::Context;
|
||||
use axum::extract::ws::{Message, WebSocket};
|
||||
use futures::{sink::SinkExt, stream::SplitSink};
|
||||
|
||||
use super::models::{WebSocketClientMessage, WebSocketHandshake, WebSocketServerMessage};
|
||||
use crate::{
|
||||
app_state::{
|
||||
AppState,
|
||||
database::models::{DocumentVersionWithoutContent, VaultId, VaultUpdateId},
|
||||
},
|
||||
config::user_config::User,
|
||||
errors::{SyncServerError, server_error, unauthenticated_error},
|
||||
server::auth::auth,
|
||||
};
|
||||
|
||||
pub struct AuthenticatedWebSocketHandshake {
|
||||
pub handshake: WebSocketHandshake,
|
||||
pub user: User,
|
||||
}
|
||||
|
||||
pub fn get_authenticated_handshake(
|
||||
state: &AppState,
|
||||
vault_id: &VaultId,
|
||||
message: Option<Message>,
|
||||
) -> Result<AuthenticatedWebSocketHandshake, SyncServerError> {
|
||||
if let Some(Message::Text(message)) = message {
|
||||
let message: WebSocketClientMessage = serde_json::from_str(&message)
|
||||
.context("Failed to parse message")
|
||||
.map_err(server_error)?;
|
||||
|
||||
match message {
|
||||
WebSocketClientMessage::Handshake(handshake) => {
|
||||
let user = auth(state, handshake.token.trim(), vault_id)?;
|
||||
Ok(AuthenticatedWebSocketHandshake { handshake, user })
|
||||
}
|
||||
WebSocketClientMessage::CursorPositions(_) => Err(unauthenticated_error(
|
||||
anyhow::anyhow!("Expected a handshake message"),
|
||||
)),
|
||||
}
|
||||
} else {
|
||||
Err(unauthenticated_error(anyhow::anyhow!(
|
||||
"Failed to authenticate due to invalid message"
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_unseen_documents(
|
||||
state: &AppState,
|
||||
vault_id: &VaultId,
|
||||
last_seen_vault_update_id: Option<VaultUpdateId>,
|
||||
) -> Result<Vec<DocumentVersionWithoutContent>, SyncServerError> {
|
||||
if let Some(update_id) = last_seen_vault_update_id {
|
||||
state
|
||||
.database
|
||||
.get_latest_documents_since(vault_id, update_id, None)
|
||||
.await
|
||||
.map_err(server_error)
|
||||
} else {
|
||||
state
|
||||
.database
|
||||
.get_latest_documents(vault_id, None)
|
||||
.await
|
||||
.map_err(server_error)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_update_over_websocket(
|
||||
update: &WebSocketServerMessage,
|
||||
sender: &mut SplitSink<WebSocket, Message>,
|
||||
) -> Result<(), SyncServerError> {
|
||||
let serialized_update = serde_json::to_string(update)
|
||||
.context("Failed to serialize update")
|
||||
.map_err(server_error)?;
|
||||
|
||||
sender
|
||||
.send(Message::Text(serialized_update))
|
||||
.await
|
||||
.context("Failed to send message over websocket")
|
||||
.map_err(server_error)
|
||||
}
|
||||
2
sync-server/src/cli.rs
Normal file
2
sync-server/src/cli.rs
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
pub mod args;
|
||||
pub mod color_when;
|
||||
26
sync-server/src/cli/args.rs
Normal file
26
sync-server/src/cli/args.rs
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
use std::ffi::OsString;
|
||||
|
||||
use clap::Parser;
|
||||
use clap_verbosity_flag::{InfoLevel, Verbosity};
|
||||
|
||||
use crate::cli::color_when::ColorWhen;
|
||||
|
||||
/// Server for backing the `VaultLink` plugin
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(version, about, long_about = None)]
|
||||
pub struct Args {
|
||||
#[arg(index = 1)]
|
||||
pub config_path: Option<OsString>,
|
||||
|
||||
#[command(flatten)]
|
||||
pub verbose: Verbosity<InfoLevel>,
|
||||
|
||||
#[arg(
|
||||
long,
|
||||
value_name = "WHEN",
|
||||
default_value_t = ColorWhen::Auto,
|
||||
default_missing_value = "always",
|
||||
value_enum
|
||||
)]
|
||||
pub color: ColorWhen,
|
||||
}
|
||||
31
sync-server/src/cli/color_when.rs
Normal file
31
sync-server/src/cli/color_when.rs
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
use std::io::IsTerminal;
|
||||
|
||||
use clap::ValueEnum;
|
||||
|
||||
#[derive(ValueEnum, Copy, Clone, Debug, PartialEq, Eq)]
|
||||
pub enum ColorWhen {
|
||||
Always,
|
||||
Auto,
|
||||
Never,
|
||||
}
|
||||
|
||||
impl ColorWhen {
|
||||
pub fn use_colors(self) -> bool {
|
||||
match self {
|
||||
ColorWhen::Always => true,
|
||||
ColorWhen::Auto => {
|
||||
std::env::var_os("NO_COLOR").is_none() && std::io::stderr().is_terminal()
|
||||
}
|
||||
ColorWhen::Never => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ColorWhen {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.to_possible_value()
|
||||
.expect("no values are skipped")
|
||||
.get_name()
|
||||
.fmt(f)
|
||||
}
|
||||
}
|
||||
66
sync-server/src/config.rs
Normal file
66
sync-server/src/config.rs
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
use std::path::Path;
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use database_config::DatabaseConfig;
|
||||
use log::info;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use server_config::ServerConfig;
|
||||
use tokio::fs;
|
||||
use user_config::UserConfig;
|
||||
|
||||
pub mod database_config;
|
||||
pub mod server_config;
|
||||
pub mod user_config;
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
|
||||
pub struct Config {
|
||||
#[serde(default)]
|
||||
pub database: DatabaseConfig,
|
||||
#[serde(default)]
|
||||
pub server: ServerConfig,
|
||||
#[serde(default)]
|
||||
pub users: UserConfig,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub async fn read_or_create(path: &Path) -> Result<Self> {
|
||||
let config = if path.exists() {
|
||||
info!(
|
||||
"Loading configuration from '{}'",
|
||||
path.canonicalize().unwrap().display()
|
||||
);
|
||||
Self::load_from_file(path).await?
|
||||
} else {
|
||||
Self::default()
|
||||
};
|
||||
|
||||
config.write(path).await?;
|
||||
info!(
|
||||
"Updated configuration at '{}'",
|
||||
path.canonicalize().unwrap().display()
|
||||
);
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
pub async fn load_from_file(path: &Path) -> Result<Self> {
|
||||
let contents = fs::read_to_string(path).await.with_context(|| {
|
||||
format!(
|
||||
"Cannot load configuration from disk from {}",
|
||||
path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
let config = serde_yaml::from_str(&contents).context("Failed to parse configuration")?;
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
pub async fn write(&self, path: &Path) -> Result<()> {
|
||||
let contents = serde_yaml::to_string(&self).context("Failed to serialize configuration")?;
|
||||
|
||||
fs::write(path, contents)
|
||||
.await
|
||||
.context("Failed to write configuration to disk")
|
||||
}
|
||||
}
|
||||
48
sync-server/src/config/database_config.rs
Normal file
48
sync-server/src/config/database_config.rs
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
use std::{path::PathBuf, time::Duration};
|
||||
|
||||
use log::debug;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::serde_as;
|
||||
|
||||
use crate::consts::{
|
||||
DEFAULT_CURSOR_TIMEOUT, DEFAULT_DATABASES_DIRECTORY_PATH, DEFAULT_MAX_CONNECTIONS_PER_VAULT,
|
||||
};
|
||||
|
||||
#[serde_with::serde_as]
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct DatabaseConfig {
|
||||
#[serde(default = "default_databases_directory_path")]
|
||||
pub databases_directory_path: PathBuf,
|
||||
|
||||
#[serde(default = "default_max_connections_per_vault")]
|
||||
pub max_connections_per_vault: u32,
|
||||
|
||||
#[serde(default = "default_cursor_timeout", rename = "cursor_timeout_seconds")]
|
||||
#[serde_as(as = "serde_with::DurationSeconds<u64>")]
|
||||
pub cursor_timeout: Duration,
|
||||
}
|
||||
|
||||
fn default_databases_directory_path() -> PathBuf {
|
||||
debug!("Using default databases directory path: {DEFAULT_DATABASES_DIRECTORY_PATH:?}");
|
||||
PathBuf::from(DEFAULT_DATABASES_DIRECTORY_PATH)
|
||||
}
|
||||
|
||||
fn default_max_connections_per_vault() -> u32 {
|
||||
debug!("Using default max connections: {DEFAULT_MAX_CONNECTIONS_PER_VAULT}");
|
||||
DEFAULT_MAX_CONNECTIONS_PER_VAULT
|
||||
}
|
||||
|
||||
fn default_cursor_timeout() -> Duration {
|
||||
debug!("Using default cursor timeout: {DEFAULT_CURSOR_TIMEOUT:?}");
|
||||
DEFAULT_CURSOR_TIMEOUT
|
||||
}
|
||||
|
||||
impl Default for DatabaseConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
databases_directory_path: default_databases_directory_path(),
|
||||
max_connections_per_vault: default_max_connections_per_vault(),
|
||||
cursor_timeout: default_cursor_timeout(),
|
||||
}
|
||||
}
|
||||
}
|
||||
50
sync-server/src/config/server_config.rs
Normal file
50
sync-server/src/config/server_config.rs
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
use log::debug;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::consts::{
|
||||
DEFAULT_HOST, DEFAULT_MAX_BODY_SIZE_MB, DEFAULT_MAX_CLIENTS_PER_VAULT, DEFAULT_PORT,
|
||||
DEFAULT_RESPONSE_TIMEOUT_SECONDS,
|
||||
};
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
|
||||
pub struct ServerConfig {
|
||||
#[serde(default = "default_host")]
|
||||
pub host: String,
|
||||
|
||||
#[serde(default = "default_port")]
|
||||
pub port: u16,
|
||||
|
||||
#[serde(default = "default_max_body_size_mb")]
|
||||
pub max_body_size_mb: usize,
|
||||
|
||||
#[serde(default = "default_max_clients_per_vault")]
|
||||
pub max_clients_per_vault: usize,
|
||||
|
||||
#[serde(default = "default_response_timeout_seconds")]
|
||||
pub response_timeout_seconds: u64,
|
||||
}
|
||||
|
||||
fn default_host() -> String {
|
||||
debug!("Using default server host: {DEFAULT_HOST}");
|
||||
DEFAULT_HOST.to_owned()
|
||||
}
|
||||
|
||||
fn default_port() -> u16 {
|
||||
debug!("Using default server port: {DEFAULT_PORT}");
|
||||
DEFAULT_PORT
|
||||
}
|
||||
|
||||
fn default_max_body_size_mb() -> usize {
|
||||
debug!("Using default max body size (MB): {DEFAULT_MAX_BODY_SIZE_MB}");
|
||||
DEFAULT_MAX_BODY_SIZE_MB
|
||||
}
|
||||
|
||||
fn default_max_clients_per_vault() -> usize {
|
||||
debug!("Using default max clients per vault: {DEFAULT_MAX_CLIENTS_PER_VAULT}");
|
||||
DEFAULT_MAX_CLIENTS_PER_VAULT
|
||||
}
|
||||
|
||||
fn default_response_timeout_seconds() -> u64 {
|
||||
debug!("Using default response timeout (seconds): {DEFAULT_RESPONSE_TIMEOUT_SECONDS}");
|
||||
DEFAULT_RESPONSE_TIMEOUT_SECONDS
|
||||
}
|
||||
164
sync-server/src/config/user_config.rs
Normal file
164
sync-server/src/config/user_config.rs
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
use bimap::BiHashMap;
|
||||
use rand::{Rng, distr::Alphanumeric, rng};
|
||||
use serde::{Deserialize, Deserializer, Serialize, de::Error};
|
||||
|
||||
use crate::app_state::database::models::VaultId;
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct UserConfig {
|
||||
#[serde(default = "default_users", deserialize_with = "validate_users")]
|
||||
pub user_configs: Vec<User>,
|
||||
}
|
||||
|
||||
fn validate_users<'de, D>(deserializer: D) -> Result<Vec<User>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let users = Vec::<User>::deserialize(deserializer)?;
|
||||
|
||||
let mut user_token_map = BiHashMap::new();
|
||||
for user in &users {
|
||||
if let Some(existing_name) = user_token_map.get_by_right(&user.token) {
|
||||
return Err(D::Error::custom(format!(
|
||||
"Duplicate user token found: '{}' for users '{}' and '{}'. User tokens must be \
|
||||
unique.",
|
||||
user.token, existing_name, user.name
|
||||
)));
|
||||
}
|
||||
|
||||
if user_token_map.contains_left(&user.name) {
|
||||
return Err(D::Error::custom(format!(
|
||||
"Duplicate user name found: '{}'. User names must be unique.",
|
||||
user.name
|
||||
)));
|
||||
}
|
||||
|
||||
user_token_map.insert(user.name.clone(), user.token.clone());
|
||||
}
|
||||
|
||||
Ok(users)
|
||||
}
|
||||
|
||||
impl UserConfig {
|
||||
pub fn get_user(&self, token: &str) -> Option<&User> {
|
||||
self.user_configs.iter().find(|u| u.token == token)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct User {
|
||||
pub name: String,
|
||||
pub token: String,
|
||||
pub vault_access: VaultAccess,
|
||||
}
|
||||
|
||||
impl Default for UserConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
user_configs: default_users(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
|
||||
#[serde(rename_all = "snake_case", tag = "type")]
|
||||
pub enum VaultAccess {
|
||||
#[default]
|
||||
AllowAccessToAll,
|
||||
|
||||
AllowList(AllowListedVaults),
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
|
||||
pub struct AllowListedVaults {
|
||||
pub allowed: Vec<VaultId>,
|
||||
}
|
||||
|
||||
fn default_users() -> Vec<User> {
|
||||
vec![User {
|
||||
name: "admin".to_owned(),
|
||||
token: get_random_token(),
|
||||
vault_access: VaultAccess::default(),
|
||||
}]
|
||||
}
|
||||
|
||||
pub fn get_random_token() -> String {
|
||||
rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(64)
|
||||
.map(char::from)
|
||||
.collect()
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use serde_json::json;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_validate_users_unique_names_and_tokens() {
|
||||
let config_json = json!({
|
||||
"user_configs": [
|
||||
{
|
||||
"name": "alice",
|
||||
"token": "token1",
|
||||
"vault_access": { "type": "allow_access_to_all" }
|
||||
},
|
||||
{
|
||||
"name": "bob",
|
||||
"token": "token2",
|
||||
"vault_access": { "type": "allow_access_to_all" }
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let config: Result<UserConfig, _> = serde_json::from_value(config_json);
|
||||
assert!(config.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_users_duplicate_names() {
|
||||
let config_json = json!({
|
||||
"user_configs": [
|
||||
{
|
||||
"name": "alice",
|
||||
"token": "token1",
|
||||
"vault_access": { "type": "allow_access_to_all" }
|
||||
},
|
||||
{
|
||||
"name": "alice",
|
||||
"token": "token2",
|
||||
"vault_access": { "type": "allow_access_to_all" }
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let config: Result<UserConfig, _> = serde_json::from_value(config_json);
|
||||
assert!(config.is_err());
|
||||
let err = config.unwrap_err().to_string();
|
||||
assert!(err.contains("Duplicate user name found"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_users_duplicate_tokens() {
|
||||
let config_json = json!({
|
||||
"user_configs": [
|
||||
{
|
||||
"name": "alice",
|
||||
"token": "token1",
|
||||
"vault_access": { "type": "allow_access_to_all" }
|
||||
},
|
||||
{
|
||||
"name": "bob",
|
||||
"token": "token1",
|
||||
"vault_access": { "type": "allow_access_to_all" }
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let config: Result<UserConfig, _> = serde_json::from_value(config_json);
|
||||
assert!(config.is_err());
|
||||
let err = config.unwrap_err().to_string();
|
||||
assert!(err.contains("Duplicate user token found"));
|
||||
}
|
||||
}
|
||||
13
sync-server/src/consts.rs
Normal file
13
sync-server/src/consts.rs
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
use std::time::Duration;
|
||||
|
||||
pub const DEFAULT_CONFIG_PATH: &str = "config.yml";
|
||||
|
||||
pub const DEFAULT_DATABASES_DIRECTORY_PATH: &str = "databases";
|
||||
pub const DEFAULT_MAX_CONNECTIONS_PER_VAULT: u32 = 12;
|
||||
pub const DEFAULT_CURSOR_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
|
||||
pub const DEFAULT_HOST: &str = "127.0.0.1";
|
||||
pub const DEFAULT_PORT: u16 = 3000;
|
||||
pub const DEFAULT_MAX_BODY_SIZE_MB: usize = 4096;
|
||||
pub const DEFAULT_RESPONSE_TIMEOUT_SECONDS: u64 = 60;
|
||||
pub const DEFAULT_MAX_CLIENTS_PER_VAULT: usize = 256;
|
||||
140
sync-server/src/errors.rs
Normal file
140
sync-server/src/errors.rs
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
use std::fmt::Display;
|
||||
|
||||
use axum::{
|
||||
Json,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use log::{debug, error};
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
use ts_rs::TS;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum SyncServerError {
|
||||
#[error("Initialisation error: {0}")]
|
||||
InitError(#[source] anyhow::Error),
|
||||
|
||||
#[error("Client error: {0:?}")]
|
||||
ClientError(#[source] anyhow::Error),
|
||||
|
||||
#[error("Server error: {0:?}")]
|
||||
ServerError(#[source] anyhow::Error),
|
||||
|
||||
#[error("Not found: {0}")]
|
||||
NotFound(#[source] anyhow::Error),
|
||||
|
||||
#[error("Unauthorized: {0}")]
|
||||
Unauthenticated(#[source] anyhow::Error),
|
||||
|
||||
#[error("Permission denied error: {0}")]
|
||||
PermissionDeniedError(#[source] anyhow::Error),
|
||||
}
|
||||
|
||||
impl SyncServerError {
|
||||
pub fn serialize(&self) -> SerializedError {
|
||||
match self {
|
||||
Self::InitError(error)
|
||||
| Self::ClientError(error)
|
||||
| Self::ServerError(error)
|
||||
| Self::NotFound(error)
|
||||
| Self::Unauthenticated(error)
|
||||
| Self::PermissionDeniedError(error) => error.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(TS, Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export)]
|
||||
pub struct SerializedError {
|
||||
pub error_type: &'static str,
|
||||
pub message: String,
|
||||
pub causes: Vec<String>,
|
||||
}
|
||||
|
||||
impl Display for SerializedError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
if !self.causes.is_empty() {
|
||||
write!(f, "\nCauses:\n")?;
|
||||
for cause in &self.causes {
|
||||
write!(f, "{}", &format!("- {cause}\n"))?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for SyncServerError {
|
||||
fn into_response(self) -> Response {
|
||||
let body = Json(self.serialize());
|
||||
|
||||
match self {
|
||||
Self::InitError(_) | Self::ServerError(_) => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, body).into_response()
|
||||
}
|
||||
Self::ClientError(_) => (StatusCode::BAD_REQUEST, body).into_response(),
|
||||
Self::NotFound(_) => (StatusCode::NOT_FOUND, body).into_response(),
|
||||
Self::Unauthenticated(_) => (StatusCode::UNAUTHORIZED, body).into_response(),
|
||||
Self::PermissionDeniedError(_) => (StatusCode::FORBIDDEN, body).into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&anyhow::Error> for SerializedError {
|
||||
fn from(error: &anyhow::Error) -> SerializedError {
|
||||
let mut causes = vec![];
|
||||
let mut current_error = error.source();
|
||||
while let Some(error) = current_error {
|
||||
causes.push(error.to_string());
|
||||
current_error = error.source();
|
||||
}
|
||||
|
||||
SerializedError {
|
||||
error_type: error.downcast_ref::<SyncServerError>().map_or(
|
||||
"UnknownError",
|
||||
|e| match e {
|
||||
SyncServerError::InitError(_) => "InitError",
|
||||
SyncServerError::ClientError(_) => "ClientError",
|
||||
SyncServerError::ServerError(_) => "ServerError",
|
||||
SyncServerError::NotFound(_) => "NotFound",
|
||||
SyncServerError::Unauthenticated(_) => "Unauthenticated",
|
||||
SyncServerError::PermissionDeniedError(_) => "PermissionDeniedError",
|
||||
},
|
||||
),
|
||||
message: error.to_string(),
|
||||
causes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init_error(error: anyhow::Error) -> SyncServerError {
|
||||
debug!("Initialization error: {error:?}");
|
||||
SyncServerError::InitError(error)
|
||||
}
|
||||
|
||||
pub fn server_error(error: anyhow::Error) -> SyncServerError {
|
||||
debug!("Server error: {error:?}");
|
||||
SyncServerError::ServerError(error)
|
||||
}
|
||||
|
||||
pub fn client_error(error: anyhow::Error) -> SyncServerError {
|
||||
debug!("Client error: {error:?}");
|
||||
SyncServerError::ClientError(error)
|
||||
}
|
||||
|
||||
pub fn not_found_error(error: anyhow::Error) -> SyncServerError {
|
||||
debug!("Not found: {error:?}");
|
||||
SyncServerError::NotFound(error)
|
||||
}
|
||||
|
||||
pub fn unauthenticated_error(error: anyhow::Error) -> SyncServerError {
|
||||
debug!("Unauthenticated user: {error:?}");
|
||||
SyncServerError::Unauthenticated(error)
|
||||
}
|
||||
|
||||
pub fn permission_denied_error(error: anyhow::Error) -> SyncServerError {
|
||||
debug!("Permission denied: {error:?}");
|
||||
SyncServerError::PermissionDeniedError(error)
|
||||
}
|
||||
86
sync-server/src/main.rs
Normal file
86
sync-server/src/main.rs
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
mod app_state;
|
||||
mod cli;
|
||||
mod config;
|
||||
mod consts;
|
||||
mod errors;
|
||||
mod server;
|
||||
mod utils;
|
||||
|
||||
use std::process::ExitCode;
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use clap::Parser;
|
||||
use cli::args::Args;
|
||||
use errors::{SyncServerError, init_error};
|
||||
use log::info;
|
||||
use server::create_server;
|
||||
use tracing_subscriber::{EnvFilter, fmt::format, util::SubscriberInitExt};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> ExitCode {
|
||||
let args = Args::parse();
|
||||
|
||||
let mut result = set_up_logging(&args);
|
||||
|
||||
if result.is_ok() {
|
||||
result = start_server(args).await;
|
||||
}
|
||||
|
||||
match result {
|
||||
Ok(()) => ExitCode::SUCCESS,
|
||||
Err(e) => {
|
||||
eprintln!("{}", e.serialize());
|
||||
ExitCode::FAILURE
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn set_up_logging(args: &Args) -> Result<(), SyncServerError> {
|
||||
let level_filter = match args.verbose.log_level_filter() {
|
||||
// We don't want to allow disabling all logging
|
||||
log::LevelFilter::Off | log::LevelFilter::Error => tracing::Level::ERROR,
|
||||
log::LevelFilter::Warn => tracing::Level::WARN,
|
||||
log::LevelFilter::Info => tracing::Level::INFO,
|
||||
log::LevelFilter::Debug => tracing::Level::DEBUG,
|
||||
log::LevelFilter::Trace => tracing::Level::TRACE,
|
||||
};
|
||||
|
||||
let env_filter = EnvFilter::builder()
|
||||
.with_default_directive(level_filter.into())
|
||||
.from_env()
|
||||
.context("Failed to create logging env filter")
|
||||
.map_err(init_error)?;
|
||||
|
||||
let use_colors = args.color.use_colors();
|
||||
|
||||
let is_debug_mode = args.verbose.log_level_filter() >= log::LevelFilter::Debug;
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.with_ansi(use_colors)
|
||||
.with_env_filter(env_filter)
|
||||
.event_format(
|
||||
format()
|
||||
.without_time()
|
||||
.with_target(is_debug_mode)
|
||||
.with_line_number(is_debug_mode)
|
||||
.compact(),
|
||||
)
|
||||
.finish()
|
||||
.try_init()
|
||||
.context("Failed to initialise tracing")
|
||||
.map_err(init_error)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_server(args: Args) -> Result<(), SyncServerError> {
|
||||
info!(
|
||||
"Starting VaultLink server version {}",
|
||||
env!("CARGO_PKG_VERSION")
|
||||
);
|
||||
|
||||
create_server(args.config_path)
|
||||
.await
|
||||
.context("Failed to start server")
|
||||
.map_err(init_error)
|
||||
}
|
||||
184
sync-server/src/server.rs
Normal file
184
sync-server/src/server.rs
Normal file
|
|
@ -0,0 +1,184 @@
|
|||
pub mod auth;
|
||||
mod create_document;
|
||||
mod delete_document;
|
||||
mod device_id_header;
|
||||
mod fetch_document_version;
|
||||
mod fetch_document_version_content;
|
||||
mod fetch_latest_document_version;
|
||||
mod fetch_latest_documents;
|
||||
mod index;
|
||||
mod ping;
|
||||
mod requests;
|
||||
mod responses;
|
||||
mod update_document;
|
||||
mod websocket;
|
||||
|
||||
use std::{ffi::OsString, time::Duration};
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use auth::auth_middleware;
|
||||
use axum::{
|
||||
Router,
|
||||
extract::{DefaultBodyLimit, Request},
|
||||
http::{self, HeaderValue, Method},
|
||||
middleware,
|
||||
response::IntoResponse,
|
||||
routing::{IntoMakeService, delete, get, post, put},
|
||||
};
|
||||
use device_id_header::DEVICE_ID_HEADER_NAME;
|
||||
use log::info;
|
||||
use tokio::signal;
|
||||
use tower_http::{
|
||||
LatencyUnit,
|
||||
cors::CorsLayer,
|
||||
limit::RequestBodyLimitLayer,
|
||||
timeout::TimeoutLayer,
|
||||
trace::{
|
||||
DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, DefaultOnRequest, DefaultOnResponse,
|
||||
TraceLayer,
|
||||
},
|
||||
};
|
||||
use tracing::{Level, info_span};
|
||||
|
||||
use crate::{
|
||||
app_state::AppState,
|
||||
config::server_config::ServerConfig,
|
||||
errors::{client_error, not_found_error},
|
||||
};
|
||||
|
||||
pub async fn create_server(config_path: Option<OsString>) -> Result<()> {
|
||||
let app_state = AppState::try_new(config_path)
|
||||
.await
|
||||
.context("Failed to initialise app state")?;
|
||||
|
||||
let server_config = app_state.config.server.clone();
|
||||
|
||||
let app = Router::new()
|
||||
.nest("/", get_authed_routes(app_state.clone()))
|
||||
.route("/", get(index::index))
|
||||
.route("/vaults/:vault_id/ping", get(ping::ping))
|
||||
.route("/vaults/:vault_id/ws", get(websocket::websocket_handler))
|
||||
.layer(DefaultBodyLimit::disable())
|
||||
.layer(RequestBodyLimitLayer::new(
|
||||
app_state.config.server.max_body_size_mb * 1024 * 1024,
|
||||
))
|
||||
.layer(TimeoutLayer::new(Duration::from_secs(
|
||||
server_config.response_timeout_seconds,
|
||||
)))
|
||||
.layer(
|
||||
CorsLayer::new()
|
||||
.allow_origin("*".parse::<HeaderValue>().expect("Failed to parse origin"))
|
||||
.allow_headers([
|
||||
http::header::CONTENT_TYPE,
|
||||
http::header::AUTHORIZATION,
|
||||
DEVICE_ID_HEADER_NAME.clone(),
|
||||
])
|
||||
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE]),
|
||||
)
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(|request: &Request<_>| {
|
||||
info_span!(
|
||||
"http",
|
||||
method = ?request.method(),
|
||||
uri = ?request.uri(),
|
||||
)
|
||||
})
|
||||
.on_request(DefaultOnRequest::new().level(Level::INFO))
|
||||
.on_response(
|
||||
DefaultOnResponse::new()
|
||||
.level(Level::INFO)
|
||||
.latency_unit(LatencyUnit::Millis),
|
||||
)
|
||||
.on_body_chunk(DefaultOnBodyChunk::new())
|
||||
.on_eos(DefaultOnEos::new())
|
||||
.on_failure(DefaultOnFailure::new().level(Level::ERROR)),
|
||||
)
|
||||
.with_state(app_state)
|
||||
.fallback(handle_404)
|
||||
.fallback(handle_405)
|
||||
.into_make_service();
|
||||
|
||||
start_server(app, &server_config).await
|
||||
}
|
||||
|
||||
fn get_authed_routes(app_state: AppState) -> Router<AppState> {
|
||||
Router::new()
|
||||
.route(
|
||||
"/vaults/:vault_id/documents",
|
||||
get(fetch_latest_documents::fetch_latest_documents),
|
||||
)
|
||||
.route(
|
||||
"/vaults/:vault_id/documents",
|
||||
post(create_document::create_document),
|
||||
)
|
||||
.route(
|
||||
"/vaults/:vault_id/documents/:document_id",
|
||||
get(fetch_latest_document_version::fetch_latest_document_version),
|
||||
)
|
||||
.route(
|
||||
"/vaults/:vault_id/documents/:document_id",
|
||||
put(update_document::update_document),
|
||||
)
|
||||
.route(
|
||||
"/vaults/:vault_id/documents/:document_id/versions/:version_id",
|
||||
put(fetch_document_version::fetch_document_version),
|
||||
)
|
||||
.route(
|
||||
"/vaults/:vault_id/documents/:document_id/versions/:version_id/content",
|
||||
put(fetch_document_version_content::fetch_document_version_content),
|
||||
)
|
||||
.route(
|
||||
"/vaults/:vault_id/documents/:document_id",
|
||||
delete(delete_document::delete_document),
|
||||
)
|
||||
.layer(middleware::from_fn_with_state(app_state, auth_middleware))
|
||||
}
|
||||
|
||||
async fn start_server(app: IntoMakeService<axum::Router>, config: &ServerConfig) -> Result<()> {
|
||||
let address = format!("{}:{}", config.host, config.port);
|
||||
let listener = tokio::net::TcpListener::bind(address.clone())
|
||||
.await
|
||||
.with_context(|| format!("Failed to bind to address: {address}"))?;
|
||||
|
||||
info!(
|
||||
"Listening on http://{}",
|
||||
listener
|
||||
.local_addr()
|
||||
.context("Failed to get local address")?
|
||||
);
|
||||
|
||||
axum::serve(listener, app)
|
||||
.with_graceful_shutdown(shutdown_signal())
|
||||
.tcp_nodelay(true)
|
||||
.await
|
||||
.context("Failed to start server")
|
||||
}
|
||||
|
||||
async fn shutdown_signal() {
|
||||
let ctrl_c = async {
|
||||
signal::ctrl_c()
|
||||
.await
|
||||
.expect("failed to install Ctrl+C handler");
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
let terminate = async {
|
||||
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||
.expect("failed to install signal handler")
|
||||
.recv()
|
||||
.await;
|
||||
};
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let terminate = std::future::pending::<()>();
|
||||
|
||||
tokio::select! {
|
||||
() = ctrl_c => {},
|
||||
() = terminate => {},
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_404() -> impl IntoResponse { not_found_error(anyhow!("Page not found")) }
|
||||
|
||||
async fn handle_405() -> impl IntoResponse { client_error(anyhow!("Method not allowed")) }
|
||||
9
sync-server/src/server/assets/index.html
Normal file
9
sync-server/src/server/assets/index.html
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>VaultLink</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>VaultLink server</h1>
|
||||
</body>
|
||||
</html>
|
||||
70
sync-server/src/server/auth.rs
Normal file
70
sync-server/src/server/auth.rs
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use axum::{
|
||||
extract::{Path, Request, State},
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use axum_extra::{
|
||||
TypedHeader,
|
||||
headers::{Authorization, authorization::Bearer},
|
||||
};
|
||||
use log::info;
|
||||
|
||||
use crate::{
|
||||
app_state::{AppState, database::models::VaultId},
|
||||
config::user_config::{AllowListedVaults, User, VaultAccess},
|
||||
errors::{SyncServerError, permission_denied_error, unauthenticated_error},
|
||||
utils::normalize::normalize_string,
|
||||
};
|
||||
|
||||
pub async fn auth_middleware(
|
||||
State(state): State<AppState>,
|
||||
Path(path_params): Path<HashMap<String, String>>,
|
||||
TypedHeader(auth_header): TypedHeader<Authorization<Bearer>>,
|
||||
mut req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, SyncServerError> {
|
||||
let token = auth_header.token().trim();
|
||||
let vault_id = normalize_string(
|
||||
path_params
|
||||
.get("vault_id")
|
||||
.ok_or_else(|| unauthenticated_error(anyhow::anyhow!("Missing vault_id")))?,
|
||||
);
|
||||
|
||||
let user = auth(&state, token, &vault_id)?;
|
||||
|
||||
req.extensions_mut().insert(user);
|
||||
|
||||
Ok(next.run(req).await)
|
||||
}
|
||||
|
||||
pub fn auth(state: &AppState, token: &str, vault_id: &VaultId) -> Result<User, SyncServerError> {
|
||||
let user = state
|
||||
.config
|
||||
.users
|
||||
.get_user(token)
|
||||
.cloned()
|
||||
.ok_or_else(|| unauthenticated_error(anyhow::anyhow!("Invalid token")))?;
|
||||
|
||||
if match user.vault_access {
|
||||
VaultAccess::AllowAccessToAll => true,
|
||||
VaultAccess::AllowList(AllowListedVaults { ref allowed }) => allowed.contains(vault_id),
|
||||
} {
|
||||
info!(
|
||||
"User '{}' is authenticated and is authorised to access to vault '{vault_id}'",
|
||||
user.name
|
||||
);
|
||||
|
||||
Ok(user)
|
||||
} else {
|
||||
info!(
|
||||
"User '{}' is authenticated but is not authorised to access vault '{vault_id}'",
|
||||
user.name
|
||||
);
|
||||
|
||||
Err(permission_denied_error(anyhow::anyhow!(
|
||||
"Permission denied for vault `{vault_id}`"
|
||||
)))
|
||||
}
|
||||
}
|
||||
95
sync-server/src/server/create_document.rs
Normal file
95
sync-server/src/server/create_document.rs
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
use anyhow::Context as _;
|
||||
use axum::{
|
||||
Extension, Json,
|
||||
extract::{Path, State},
|
||||
};
|
||||
use axum_extra::TypedHeader;
|
||||
use axum_typed_multipart::TypedMultipart;
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::{device_id_header::DeviceIdHeader, requests::CreateDocumentVersion};
|
||||
use crate::{
|
||||
app_state::{
|
||||
AppState,
|
||||
database::models::{DocumentVersionWithoutContent, StoredDocumentVersion, VaultId},
|
||||
},
|
||||
config::user_config::User,
|
||||
errors::{SyncServerError, client_error, server_error},
|
||||
utils::{normalize::normalize, sanitize_path::sanitize_path},
|
||||
};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct CreateDocumentPathParams {
|
||||
#[serde(deserialize_with = "normalize")]
|
||||
vault_id: VaultId,
|
||||
}
|
||||
|
||||
/// Create a new document in case a document with the same doesn't exist
|
||||
/// already. If a document with the same path exists, a new version is created
|
||||
/// with their content merged.
|
||||
#[axum::debug_handler]
|
||||
pub async fn create_document(
|
||||
Path(CreateDocumentPathParams { vault_id }): Path<CreateDocumentPathParams>,
|
||||
Extension(user): Extension<User>,
|
||||
TypedHeader(device_id): TypedHeader<DeviceIdHeader>,
|
||||
State(state): State<AppState>,
|
||||
TypedMultipart(request): TypedMultipart<CreateDocumentVersion>,
|
||||
) -> Result<Json<DocumentVersionWithoutContent>, SyncServerError> {
|
||||
let mut transaction = state
|
||||
.database
|
||||
.create_write_transaction(&vault_id)
|
||||
.await
|
||||
.map_err(server_error)?;
|
||||
|
||||
let document_id = match request.document_id {
|
||||
Some(document_id) => {
|
||||
let existing_version = state
|
||||
.database
|
||||
.get_latest_document(&vault_id, &document_id, Some(&mut transaction))
|
||||
.await
|
||||
.map_err(server_error)?;
|
||||
|
||||
if existing_version.is_some() {
|
||||
return Err(client_error(anyhow::anyhow!(
|
||||
"Document with the same ID already exists"
|
||||
)));
|
||||
}
|
||||
|
||||
document_id
|
||||
}
|
||||
None => uuid::Uuid::new_v4(),
|
||||
};
|
||||
|
||||
let last_update_id = state
|
||||
.database
|
||||
.get_max_update_id_in_vault(&vault_id, Some(&mut transaction))
|
||||
.await
|
||||
.map_err(server_error)?;
|
||||
|
||||
let sanitized_relative_path = sanitize_path(&request.relative_path);
|
||||
|
||||
let new_version = StoredDocumentVersion {
|
||||
vault_update_id: last_update_id + 1,
|
||||
document_id,
|
||||
relative_path: sanitized_relative_path,
|
||||
content: request.content.contents.to_vec(),
|
||||
updated_date: chrono::Utc::now(),
|
||||
is_deleted: false,
|
||||
user_id: user.name,
|
||||
device_id: device_id.0,
|
||||
};
|
||||
|
||||
state
|
||||
.database
|
||||
.insert_document_version(&vault_id, &new_version, Some(&mut transaction))
|
||||
.await
|
||||
.map_err(server_error)?;
|
||||
|
||||
transaction
|
||||
.commit()
|
||||
.await
|
||||
.context("Failed to commit successful transaction")
|
||||
.map_err(server_error)?;
|
||||
|
||||
Ok(Json(new_version.into()))
|
||||
}
|
||||
84
sync-server/src/server/delete_document.rs
Normal file
84
sync-server/src/server/delete_document.rs
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
use anyhow::Context as _;
|
||||
use axum::{
|
||||
Extension, Json,
|
||||
extract::{Path, State},
|
||||
};
|
||||
use axum_extra::TypedHeader;
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::{device_id_header::DeviceIdHeader, requests::DeleteDocumentVersion};
|
||||
use crate::{
|
||||
app_state::{
|
||||
AppState,
|
||||
database::models::{
|
||||
DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId,
|
||||
},
|
||||
},
|
||||
config::user_config::User,
|
||||
errors::{SyncServerError, server_error},
|
||||
utils::{normalize::normalize, sanitize_path::sanitize_path},
|
||||
};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct DeleteDocumentPathParams {
|
||||
#[serde(deserialize_with = "normalize")]
|
||||
vault_id: VaultId,
|
||||
|
||||
document_id: DocumentId,
|
||||
}
|
||||
|
||||
#[axum::debug_handler]
|
||||
pub async fn delete_document(
|
||||
Path(DeleteDocumentPathParams {
|
||||
vault_id,
|
||||
document_id,
|
||||
}): Path<DeleteDocumentPathParams>,
|
||||
Extension(user): Extension<User>,
|
||||
TypedHeader(device_id): TypedHeader<DeviceIdHeader>,
|
||||
State(state): State<AppState>,
|
||||
Json(request): Json<DeleteDocumentVersion>,
|
||||
) -> Result<Json<DocumentVersionWithoutContent>, SyncServerError> {
|
||||
let mut transaction = state
|
||||
.database
|
||||
.create_write_transaction(&vault_id)
|
||||
.await
|
||||
.map_err(server_error)?;
|
||||
|
||||
let last_update_id = state
|
||||
.database
|
||||
.get_max_update_id_in_vault(&vault_id, Some(&mut transaction))
|
||||
.await
|
||||
.map_err(server_error)?;
|
||||
|
||||
let latest_content = state
|
||||
.database
|
||||
.get_latest_document(&vault_id, &document_id, Some(&mut transaction))
|
||||
.await
|
||||
.map_err(server_error)?
|
||||
.map_or_else(Vec::new, |version| version.content); // in case the document has never existed before deleting it
|
||||
|
||||
let new_version = StoredDocumentVersion {
|
||||
vault_update_id: last_update_id + 1,
|
||||
document_id,
|
||||
relative_path: sanitize_path(&request.relative_path),
|
||||
content: latest_content, // copy the content from the latest version
|
||||
updated_date: chrono::Utc::now(),
|
||||
is_deleted: true,
|
||||
user_id: user.name,
|
||||
device_id: device_id.0,
|
||||
};
|
||||
|
||||
state
|
||||
.database
|
||||
.insert_document_version(&vault_id, &new_version, Some(&mut transaction))
|
||||
.await
|
||||
.map_err(server_error)?;
|
||||
|
||||
transaction
|
||||
.commit()
|
||||
.await
|
||||
.context("Failed to commit successful transaction")
|
||||
.map_err(server_error)?;
|
||||
|
||||
Ok(Json(new_version.into()))
|
||||
}
|
||||
33
sync-server/src/server/device_id_header.rs
Normal file
33
sync-server/src/server/device_id_header.rs
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
use axum_extra::headers;
|
||||
use headers::{Header, HeaderName, HeaderValue};
|
||||
|
||||
pub struct DeviceIdHeader(pub String);
|
||||
|
||||
pub static DEVICE_ID_HEADER_NAME: HeaderName = HeaderName::from_static("device-id");
|
||||
|
||||
impl Header for DeviceIdHeader {
|
||||
fn name() -> &'static HeaderName { &DEVICE_ID_HEADER_NAME }
|
||||
|
||||
fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
|
||||
where
|
||||
I: Iterator<Item = &'i HeaderValue>,
|
||||
{
|
||||
let value = values.next().ok_or_else(headers::Error::invalid)?;
|
||||
|
||||
Ok(DeviceIdHeader(
|
||||
value
|
||||
.to_str()
|
||||
.map_err(|_| headers::Error::invalid())?
|
||||
.to_owned(),
|
||||
))
|
||||
}
|
||||
|
||||
fn encode<E>(&self, values: &mut E)
|
||||
where
|
||||
E: Extend<HeaderValue>,
|
||||
{
|
||||
let value = HeaderValue::from_static(Box::leak(self.0.to_string().into_boxed_str()));
|
||||
|
||||
values.extend(std::iter::once(value));
|
||||
}
|
||||
}
|
||||
57
sync-server/src/server/fetch_document_version.rs
Normal file
57
sync-server/src/server/fetch_document_version.rs
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
use anyhow::anyhow;
|
||||
use axum::{
|
||||
Json,
|
||||
extract::{Path, State},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::{
|
||||
app_state::{
|
||||
AppState,
|
||||
database::models::{DocumentId, DocumentVersion, VaultId, VaultUpdateId},
|
||||
},
|
||||
errors::{SyncServerError, not_found_error, server_error},
|
||||
utils::normalize::normalize,
|
||||
};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct FetchDocumentVersionPathParams {
|
||||
#[serde(deserialize_with = "normalize")]
|
||||
vault_id: VaultId,
|
||||
|
||||
document_id: DocumentId,
|
||||
vault_update_id: VaultUpdateId,
|
||||
}
|
||||
|
||||
#[axum::debug_handler]
|
||||
pub async fn fetch_document_version(
|
||||
Path(FetchDocumentVersionPathParams {
|
||||
vault_id,
|
||||
document_id,
|
||||
vault_update_id,
|
||||
}): Path<FetchDocumentVersionPathParams>,
|
||||
State(state): State<AppState>,
|
||||
) -> Result<Json<DocumentVersion>, SyncServerError> {
|
||||
let result = state
|
||||
.database
|
||||
.get_document_version(&vault_id, vault_update_id, None)
|
||||
.await
|
||||
.map_err(server_error)?
|
||||
.map_or_else(
|
||||
|| {
|
||||
Err(not_found_error(anyhow!(
|
||||
"Document with vault update id `{vault_update_id}` not found",
|
||||
)))
|
||||
},
|
||||
Ok,
|
||||
)?;
|
||||
|
||||
if result.document_id != document_id {
|
||||
return Err(not_found_error(anyhow!(
|
||||
"Document with document id `{document_id}` does not have a version with id \
|
||||
`{vault_update_id}`",
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(Json(result.into()))
|
||||
}
|
||||
57
sync-server/src/server/fetch_document_version_content.rs
Normal file
57
sync-server/src/server/fetch_document_version_content.rs
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
use anyhow::anyhow;
|
||||
use axum::{
|
||||
body::Bytes,
|
||||
extract::{Path, State},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::{
|
||||
app_state::{
|
||||
AppState,
|
||||
database::models::{DocumentId, VaultId, VaultUpdateId},
|
||||
},
|
||||
errors::{SyncServerError, not_found_error, server_error},
|
||||
utils::normalize::normalize,
|
||||
};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct FetchDocumentVersionContentPathParams {
|
||||
#[serde(deserialize_with = "normalize")]
|
||||
vault_id: VaultId,
|
||||
|
||||
document_id: DocumentId,
|
||||
vault_update_id: VaultUpdateId,
|
||||
}
|
||||
|
||||
#[axum::debug_handler]
|
||||
pub async fn fetch_document_version_content(
|
||||
Path(FetchDocumentVersionContentPathParams {
|
||||
vault_id,
|
||||
document_id,
|
||||
vault_update_id,
|
||||
}): Path<FetchDocumentVersionContentPathParams>,
|
||||
State(state): State<AppState>,
|
||||
) -> Result<Bytes, SyncServerError> {
|
||||
let result = state
|
||||
.database
|
||||
.get_document_version(&vault_id, vault_update_id, None)
|
||||
.await
|
||||
.map_err(server_error)?
|
||||
.map_or_else(
|
||||
|| {
|
||||
Err(not_found_error(anyhow!(
|
||||
"Document with vault update id `{vault_update_id}` not found",
|
||||
)))
|
||||
},
|
||||
Ok,
|
||||
)?;
|
||||
|
||||
if result.document_id != document_id {
|
||||
return Err(not_found_error(anyhow!(
|
||||
"Document with document id `{document_id}` does not have a version with id \
|
||||
`{vault_update_id}`",
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(result.content.into())
|
||||
}
|
||||
48
sync-server/src/server/fetch_latest_document_version.rs
Normal file
48
sync-server/src/server/fetch_latest_document_version.rs
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
use anyhow::anyhow;
|
||||
use axum::{
|
||||
Json,
|
||||
extract::{Path, State},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::{
|
||||
app_state::{
|
||||
AppState,
|
||||
database::models::{DocumentId, DocumentVersion, VaultId},
|
||||
},
|
||||
errors::{SyncServerError, not_found_error, server_error},
|
||||
utils::normalize::normalize,
|
||||
};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct FetchLatestDocumentVersionPathParams {
|
||||
#[serde(deserialize_with = "normalize")]
|
||||
vault_id: VaultId,
|
||||
|
||||
document_id: DocumentId,
|
||||
}
|
||||
|
||||
#[axum::debug_handler]
|
||||
pub async fn fetch_latest_document_version(
|
||||
Path(FetchLatestDocumentVersionPathParams {
|
||||
vault_id,
|
||||
document_id,
|
||||
}): Path<FetchLatestDocumentVersionPathParams>,
|
||||
State(state): State<AppState>,
|
||||
) -> Result<Json<DocumentVersion>, SyncServerError> {
|
||||
let latest_version = state
|
||||
.database
|
||||
.get_latest_document(&vault_id, &document_id, None)
|
||||
.await
|
||||
.map_err(server_error)?
|
||||
.map_or_else(
|
||||
|| {
|
||||
Err(not_found_error(anyhow!(
|
||||
"Document with id `{document_id}` not found",
|
||||
)))
|
||||
},
|
||||
Ok,
|
||||
)?;
|
||||
|
||||
Ok(Json(latest_version.into()))
|
||||
}
|
||||
56
sync-server/src/server/fetch_latest_documents.rs
Normal file
56
sync-server/src/server/fetch_latest_documents.rs
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
use axum::{
|
||||
Json,
|
||||
extract::{Path, Query, State},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::responses::FetchLatestDocumentsResponse;
|
||||
use crate::{
|
||||
app_state::{
|
||||
AppState,
|
||||
database::models::{VaultId, VaultUpdateId},
|
||||
},
|
||||
errors::{SyncServerError, server_error},
|
||||
utils::normalize::normalize,
|
||||
};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct FetchLatestDocumentsPathParams {
|
||||
#[serde(deserialize_with = "normalize")]
|
||||
vault_id: VaultId,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct QueryParams {
|
||||
since_update_id: Option<VaultUpdateId>,
|
||||
}
|
||||
|
||||
#[axum::debug_handler]
|
||||
pub async fn fetch_latest_documents(
|
||||
Path(FetchLatestDocumentsPathParams { vault_id }): Path<FetchLatestDocumentsPathParams>,
|
||||
Query(QueryParams { since_update_id }): Query<QueryParams>,
|
||||
State(state): State<AppState>,
|
||||
) -> Result<Json<FetchLatestDocumentsResponse>, SyncServerError> {
|
||||
let documents = if let Some(since_update_id) = since_update_id {
|
||||
state
|
||||
.database
|
||||
.get_latest_documents_since(&vault_id, since_update_id, None)
|
||||
.await
|
||||
.map_err(server_error)
|
||||
} else {
|
||||
state
|
||||
.database
|
||||
.get_latest_documents(&vault_id, None)
|
||||
.await
|
||||
.map_err(server_error)
|
||||
}?;
|
||||
|
||||
Ok(Json(FetchLatestDocumentsResponse {
|
||||
last_update_id: documents
|
||||
.iter()
|
||||
.map(|doc| doc.vault_update_id)
|
||||
.max()
|
||||
.unwrap_or(since_update_id.unwrap_or(0)),
|
||||
latest_documents: documents,
|
||||
}))
|
||||
}
|
||||
7
sync-server/src/server/index.rs
Normal file
7
sync-server/src/server/index.rs
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
use axum::response::{Html, IntoResponse};
|
||||
|
||||
pub async fn index() -> impl IntoResponse {
|
||||
const HTML_CONTENT: &str = include_str!("./assets/index.html");
|
||||
let html_content = HTML_CONTENT;
|
||||
Html(html_content)
|
||||
}
|
||||
37
sync-server/src/server/ping.rs
Normal file
37
sync-server/src/server/ping.rs
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
use axum::{
|
||||
Json,
|
||||
extract::{Path, State},
|
||||
};
|
||||
use axum_extra::{
|
||||
TypedHeader,
|
||||
headers::{Authorization, authorization::Bearer},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::{auth::auth, responses::PingResponse};
|
||||
use crate::{
|
||||
app_state::{AppState, database::models::VaultId},
|
||||
errors::SyncServerError,
|
||||
utils::normalize::normalize,
|
||||
};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct PingPathParams {
|
||||
#[serde(deserialize_with = "normalize")]
|
||||
vault_id: VaultId,
|
||||
}
|
||||
|
||||
#[axum::debug_handler]
|
||||
pub async fn ping(
|
||||
maybe_auth_header: Option<TypedHeader<Authorization<Bearer>>>,
|
||||
Path(PingPathParams { vault_id }): Path<PingPathParams>,
|
||||
State(state): State<AppState>,
|
||||
) -> Result<Json<PingResponse>, SyncServerError> {
|
||||
let is_authenticated = maybe_auth_header
|
||||
.is_some_and(|auth_header| auth(&state, auth_header.token(), &vault_id).is_ok());
|
||||
|
||||
Ok(Json(PingResponse {
|
||||
server_version: env!("CARGO_PKG_VERSION").to_owned(),
|
||||
is_authenticated,
|
||||
}))
|
||||
}
|
||||
39
sync-server/src/server/requests.rs
Normal file
39
sync-server/src/server/requests.rs
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
use axum::body::Bytes;
|
||||
use axum_typed_multipart::{FieldData, TryFromMultipart};
|
||||
use serde::{self, Deserialize};
|
||||
use ts_rs::TS;
|
||||
|
||||
use crate::app_state::database::models::{DocumentId, VaultUpdateId};
|
||||
|
||||
#[derive(TS, Debug, TryFromMultipart)]
|
||||
#[ts(export)]
|
||||
pub struct CreateDocumentVersion {
|
||||
/// The client can decide the document id (if it wishes to) in order
|
||||
/// to help with syncing. If the client does not provide a document id,
|
||||
/// the server will generate one. If the client provides a document id
|
||||
/// it must not already exist in the database.
|
||||
pub document_id: Option<DocumentId>,
|
||||
pub relative_path: String,
|
||||
|
||||
#[ts(as = "Vec<u8>")]
|
||||
#[form_data(limit = "unlimited")]
|
||||
pub content: FieldData<Bytes>,
|
||||
}
|
||||
|
||||
#[derive(TS, Debug, TryFromMultipart)]
|
||||
#[ts(export)]
|
||||
pub struct UpdateDocumentVersion {
|
||||
pub parent_version_id: VaultUpdateId,
|
||||
pub relative_path: String,
|
||||
|
||||
#[ts(as = "Vec<u8>")]
|
||||
#[form_data(limit = "unlimited")]
|
||||
pub content: FieldData<Bytes>,
|
||||
}
|
||||
|
||||
#[derive(TS, Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export)]
|
||||
pub struct DeleteDocumentVersion {
|
||||
pub relative_path: String,
|
||||
}
|
||||
45
sync-server/src/server/responses.rs
Normal file
45
sync-server/src/server/responses.rs
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
use serde::{self, Serialize};
|
||||
use ts_rs::TS;
|
||||
|
||||
use crate::app_state::database::models::{
|
||||
DocumentVersion, DocumentVersionWithoutContent, VaultUpdateId,
|
||||
};
|
||||
|
||||
/// Response to a ping request.
|
||||
#[derive(TS, Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export)]
|
||||
pub struct PingResponse {
|
||||
/// Semantic version of the server.
|
||||
pub server_version: String,
|
||||
|
||||
/// Whether the client is authenticated based on the sent Authorization
|
||||
/// header.
|
||||
pub is_authenticated: bool,
|
||||
}
|
||||
|
||||
/// Response to a fetch latest documents request.
|
||||
#[derive(TS, Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export)]
|
||||
pub struct FetchLatestDocumentsResponse {
|
||||
pub latest_documents: Vec<DocumentVersionWithoutContent>,
|
||||
|
||||
/// The update ID of the latest document in the response.
|
||||
pub last_update_id: VaultUpdateId,
|
||||
}
|
||||
|
||||
/// Response to an update document request.
|
||||
#[derive(TS, Debug, Clone, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
#[ts(export)]
|
||||
pub enum DocumentUpdateResponse {
|
||||
/// Returned when the created/updated document's content is the same as was
|
||||
/// sent in the create/update request and thus the response doesn't contain
|
||||
/// the content because the client must already have it.
|
||||
FastForwardUpdate(DocumentVersionWithoutContent),
|
||||
|
||||
/// Returned when the created/updated document's content is different from
|
||||
/// what was sent in the create/update request.
|
||||
MergingUpdate(DocumentVersion),
|
||||
}
|
||||
199
sync-server/src/server/update_document.rs
Normal file
199
sync-server/src/server/update_document.rs
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
use anyhow::{Context as _, anyhow};
|
||||
use axum::{
|
||||
Extension, Json,
|
||||
extract::{Path, State},
|
||||
};
|
||||
use axum_extra::TypedHeader;
|
||||
use axum_typed_multipart::TypedMultipart;
|
||||
use log::info;
|
||||
use reconcile_text::{BuiltinTokenizer, is_binary, reconcile};
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::{
|
||||
device_id_header::DeviceIdHeader, requests::UpdateDocumentVersion,
|
||||
responses::DocumentUpdateResponse,
|
||||
};
|
||||
use crate::{
|
||||
app_state::{
|
||||
AppState,
|
||||
database::models::{DocumentId, StoredDocumentVersion, VaultId},
|
||||
},
|
||||
config::user_config::User,
|
||||
errors::{SyncServerError, not_found_error, server_error},
|
||||
utils::{
|
||||
dedup_paths::dedup_paths, is_filetype_mergable::is_file_type_mergable,
|
||||
normalize::normalize, sanitize_path::sanitize_path,
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct UpdateDocumentPathParams {
|
||||
#[serde(deserialize_with = "normalize")]
|
||||
vault_id: VaultId,
|
||||
|
||||
document_id: DocumentId,
|
||||
}
|
||||
|
||||
#[axum::debug_handler]
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn update_document(
|
||||
Path(UpdateDocumentPathParams {
|
||||
vault_id,
|
||||
document_id,
|
||||
}): Path<UpdateDocumentPathParams>,
|
||||
Extension(user): Extension<User>,
|
||||
TypedHeader(device_id): TypedHeader<DeviceIdHeader>,
|
||||
State(state): State<AppState>,
|
||||
TypedMultipart(request): TypedMultipart<UpdateDocumentVersion>,
|
||||
) -> Result<Json<DocumentUpdateResponse>, SyncServerError> {
|
||||
// No need for a transaction as document versions are immutable
|
||||
let parent_document = state
|
||||
.database
|
||||
.get_document_version(&vault_id, request.parent_version_id, None)
|
||||
.await
|
||||
.map_err(server_error)?
|
||||
.map_or_else(
|
||||
|| {
|
||||
Err(not_found_error(anyhow!(
|
||||
"Parent version with id `{}` not found",
|
||||
request.parent_version_id
|
||||
)))
|
||||
},
|
||||
Ok,
|
||||
)?;
|
||||
|
||||
let sanitized_relative_path = sanitize_path(&request.relative_path);
|
||||
|
||||
let mut transaction = state
|
||||
.database
|
||||
.create_write_transaction(&vault_id)
|
||||
.await
|
||||
.map_err(server_error)?;
|
||||
|
||||
let last_update_id = state
|
||||
.database
|
||||
.get_max_update_id_in_vault(&vault_id, Some(&mut transaction))
|
||||
.await
|
||||
.map_err(server_error)?;
|
||||
|
||||
let latest_version = state
|
||||
.database
|
||||
.get_latest_document(&vault_id, &document_id, Some(&mut transaction))
|
||||
.await
|
||||
.map_err(server_error)?
|
||||
.map_or_else(
|
||||
|| {
|
||||
Err(not_found_error(anyhow!(
|
||||
"Document with id `{document_id}` not found",
|
||||
)))
|
||||
},
|
||||
Ok,
|
||||
)?;
|
||||
|
||||
if latest_version.is_deleted {
|
||||
transaction
|
||||
.rollback()
|
||||
.await
|
||||
.context("Failed to roll back transaction")
|
||||
.map_err(server_error)?;
|
||||
|
||||
return Ok(Json(DocumentUpdateResponse::FastForwardUpdate(
|
||||
latest_version.into(),
|
||||
)));
|
||||
}
|
||||
|
||||
let content = request.content.contents.to_vec();
|
||||
|
||||
// Return the latest version if the content and path are the same as the latest
|
||||
// version
|
||||
if content == latest_version.content && sanitized_relative_path == latest_version.relative_path
|
||||
{
|
||||
info!("Document content is the same as the latest version, skipping update");
|
||||
transaction
|
||||
.rollback()
|
||||
.await
|
||||
.context("Failed to roll back transaction")
|
||||
.map_err(server_error)?;
|
||||
|
||||
return Ok(Json(DocumentUpdateResponse::FastForwardUpdate(
|
||||
latest_version.into(),
|
||||
)));
|
||||
}
|
||||
|
||||
let merged_content = if is_file_type_mergable(&sanitized_relative_path)
|
||||
&& is_binary(&parent_document.content)
|
||||
&& is_binary(&latest_version.content)
|
||||
&& is_binary(&content)
|
||||
{
|
||||
reconcile(
|
||||
&str::from_utf8(&parent_document.content)
|
||||
.expect("parent must be valid UTF-8 because it's not binary"),
|
||||
&str::from_utf8(&latest_version.content)
|
||||
.expect("latest_version must be valid UTF-8 because it's not binary")
|
||||
.into(),
|
||||
&str::from_utf8(&content)
|
||||
.expect("content must be valid UTF-8 because it's not binary")
|
||||
.into(),
|
||||
&*BuiltinTokenizer::Word,
|
||||
)
|
||||
.apply()
|
||||
.text()
|
||||
.into_bytes()
|
||||
} else {
|
||||
content.clone()
|
||||
};
|
||||
|
||||
let is_different_from_request_content = merged_content != content;
|
||||
|
||||
// We can only update the relative path if we're the first one to do so
|
||||
let new_relative_path = if parent_document.relative_path == latest_version.relative_path
|
||||
&& latest_version.relative_path != sanitized_relative_path
|
||||
{
|
||||
let mut new_relative_path = String::default();
|
||||
for candidate in dedup_paths(&sanitized_relative_path) {
|
||||
if state
|
||||
.database
|
||||
.get_latest_document_by_path(&vault_id, &candidate, Some(&mut transaction))
|
||||
.await
|
||||
.map_err(server_error)?
|
||||
.is_none()
|
||||
{
|
||||
new_relative_path = candidate;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
new_relative_path
|
||||
} else {
|
||||
latest_version.relative_path.clone()
|
||||
};
|
||||
|
||||
let new_version = StoredDocumentVersion {
|
||||
document_id,
|
||||
vault_update_id: last_update_id + 1,
|
||||
relative_path: new_relative_path,
|
||||
content: merged_content,
|
||||
updated_date: chrono::Utc::now(),
|
||||
is_deleted: false,
|
||||
user_id: user.name,
|
||||
device_id: device_id.0,
|
||||
};
|
||||
|
||||
state
|
||||
.database
|
||||
.insert_document_version(&vault_id, &new_version, Some(&mut transaction))
|
||||
.await
|
||||
.map_err(server_error)?;
|
||||
|
||||
transaction
|
||||
.commit()
|
||||
.await
|
||||
.context("Failed to commit successful transaction")
|
||||
.map_err(server_error)?;
|
||||
|
||||
Ok(Json(if is_different_from_request_content {
|
||||
DocumentUpdateResponse::MergingUpdate(new_version.into())
|
||||
} else {
|
||||
DocumentUpdateResponse::FastForwardUpdate(new_version.into())
|
||||
}))
|
||||
}
|
||||
181
sync-server/src/server/websocket.rs
Normal file
181
sync-server/src/server/websocket.rs
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
use anyhow::Context;
|
||||
use axum::{
|
||||
extract::{
|
||||
Path, State,
|
||||
ws::{Message, WebSocket, WebSocketUpgrade},
|
||||
},
|
||||
response::Response,
|
||||
};
|
||||
use futures::stream::StreamExt;
|
||||
use log::{debug, info};
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::{
|
||||
app_state::{
|
||||
AppState,
|
||||
database::models::VaultId,
|
||||
websocket::{
|
||||
models::{
|
||||
CursorPositionFromServer, WebSocketClientMessage, WebSocketServerMessage,
|
||||
WebSocketVaultUpdate,
|
||||
},
|
||||
utils::{
|
||||
get_authenticated_handshake, get_unseen_documents, send_update_over_websocket,
|
||||
},
|
||||
},
|
||||
},
|
||||
errors::{SyncServerError, client_error, server_error},
|
||||
utils::normalize::normalize,
|
||||
};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct WebSocketPathParams {
|
||||
#[serde(deserialize_with = "normalize")]
|
||||
vault_id: VaultId,
|
||||
}
|
||||
|
||||
pub async fn websocket_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
Path(WebSocketPathParams { vault_id }): Path<WebSocketPathParams>,
|
||||
State(state): State<AppState>,
|
||||
) -> Result<Response, SyncServerError> {
|
||||
Ok(ws.on_upgrade(move |socket| websocket_wrapped(state, socket, vault_id)))
|
||||
}
|
||||
|
||||
async fn websocket_wrapped(state: AppState, stream: WebSocket, vault_id: VaultId) {
|
||||
info!("WebSocket connection opened on vault '{vault_id}'");
|
||||
|
||||
let result = websocket(state, stream, vault_id.clone()).await;
|
||||
|
||||
if let Err(err) = result {
|
||||
debug!("WebSocket connection error on vault '{vault_id}': {err}");
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn websocket(
|
||||
state: AppState,
|
||||
stream: WebSocket,
|
||||
vault_id: VaultId,
|
||||
) -> Result<(), SyncServerError> {
|
||||
let (mut sender, mut websocket_receiver) = stream.split();
|
||||
|
||||
let authed_handshake = get_authenticated_handshake(
|
||||
&state,
|
||||
&vault_id,
|
||||
websocket_receiver
|
||||
.next()
|
||||
.await
|
||||
.transpose()
|
||||
.unwrap_or_default(),
|
||||
)?;
|
||||
|
||||
info!(
|
||||
"WebSocket handshake successful for vault '{vault_id}' for '{}'",
|
||||
authed_handshake.handshake.device_id
|
||||
);
|
||||
|
||||
let mut broadcast_receiver = state.broadcasts.get_receiver(vault_id.clone()).await;
|
||||
|
||||
send_update_over_websocket(
|
||||
&WebSocketServerMessage::VaultUpdate(WebSocketVaultUpdate {
|
||||
documents: get_unseen_documents(
|
||||
&state,
|
||||
&vault_id,
|
||||
authed_handshake.handshake.last_seen_vault_update_id,
|
||||
)
|
||||
.await?,
|
||||
is_initial_sync: true,
|
||||
}),
|
||||
&mut sender,
|
||||
)
|
||||
.await?;
|
||||
|
||||
send_update_over_websocket(
|
||||
&WebSocketServerMessage::CursorPositions(CursorPositionFromServer {
|
||||
clients: state.cursors.get_cursors(&vault_id).await,
|
||||
}),
|
||||
&mut sender,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let device_id = authed_handshake.handshake.device_id.clone();
|
||||
let mut send_task = tokio::spawn(async move {
|
||||
while let Ok(update) = broadcast_receiver.recv().await {
|
||||
if Some(&device_id) == update.origin_device_id.as_ref() {
|
||||
continue;
|
||||
}
|
||||
|
||||
send_update_over_websocket(&update.message, &mut sender).await?;
|
||||
}
|
||||
|
||||
Ok::<(), SyncServerError>(())
|
||||
});
|
||||
|
||||
let device_id = authed_handshake.handshake.device_id.clone();
|
||||
let vault_id_clone = vault_id.clone();
|
||||
let cursor_manager = state.cursors.clone();
|
||||
let mut receive_task = tokio::spawn(async move {
|
||||
while let Some(Ok(Message::Text(message))) = websocket_receiver.next().await {
|
||||
let message: WebSocketClientMessage = serde_json::from_str(&message)
|
||||
.context("Failed to parse WebSocket message from client")
|
||||
.map_err(server_error)?;
|
||||
|
||||
match message {
|
||||
WebSocketClientMessage::Handshake(_) => {
|
||||
return Err(client_error(anyhow::anyhow!(
|
||||
"Unexpected handshake message"
|
||||
)));
|
||||
}
|
||||
WebSocketClientMessage::CursorPositions(cursors) => {
|
||||
cursor_manager
|
||||
.update_cursors(
|
||||
vault_id_clone.clone(),
|
||||
authed_handshake.user.name.clone(),
|
||||
&device_id,
|
||||
cursors.document_to_cursors,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok::<(), SyncServerError>(())
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
_ = &mut send_task => receive_task.abort(),
|
||||
_ = &mut receive_task => send_task.abort(),
|
||||
};
|
||||
|
||||
let result: Result<(), SyncServerError> = (async {
|
||||
send_task
|
||||
.await
|
||||
.context("WebSocket send task failed")
|
||||
.map_err(client_error)
|
||||
.and_then(|err| err)?;
|
||||
|
||||
receive_task
|
||||
.await
|
||||
.context("WebSocket receive task failed")
|
||||
.map_err(client_error)
|
||||
.and_then(|err| err)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await;
|
||||
|
||||
state
|
||||
.cursors
|
||||
.remove_cursors_of_device(&vault_id, &authed_handshake.handshake.device_id)
|
||||
.await;
|
||||
|
||||
if result.is_err() {
|
||||
info!(
|
||||
"WebSocket disconnected on vault '{vault_id}' for '{}'",
|
||||
authed_handshake.handshake.device_id
|
||||
);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
4
sync-server/src/utils.rs
Normal file
4
sync-server/src/utils.rs
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
pub mod dedup_paths;
|
||||
pub mod is_filetype_mergable;
|
||||
pub mod normalize;
|
||||
pub mod sanitize_path;
|
||||
88
sync-server/src/utils/dedup_paths.rs
Normal file
88
sync-server/src/utils/dedup_paths.rs
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
use regex::Regex;
|
||||
|
||||
pub fn dedup_paths(path: &str) -> impl Iterator<Item = String> {
|
||||
let mut path_parts = path.split('/').collect::<Vec<_>>();
|
||||
let file_name = path_parts.pop().unwrap().to_owned();
|
||||
|
||||
let mut directory = path_parts.join("/");
|
||||
if !directory.is_empty() {
|
||||
directory.push('/');
|
||||
}
|
||||
|
||||
let name_parts = file_name.rsplitn(2, '.').collect::<Vec<_>>();
|
||||
let mut reverse_parts = name_parts.into_iter().rev();
|
||||
let (stem, extension) = match (reverse_parts.next(), reverse_parts.next()) {
|
||||
(Some(stem), maybe_extension) => (
|
||||
stem.to_owned(),
|
||||
maybe_extension
|
||||
.map(|ext| format!(".{ext}"))
|
||||
.unwrap_or_default(),
|
||||
),
|
||||
_ => unreachable!("Path must have at least one part"),
|
||||
};
|
||||
|
||||
let regex = Regex::new(r" \((\d+)\)$").unwrap();
|
||||
let start_number = regex
|
||||
.captures(&stem)
|
||||
.and_then(|caps| caps.get(1))
|
||||
.and_then(|m| m.as_str().parse::<u32>().ok())
|
||||
.unwrap_or(0);
|
||||
|
||||
let clean_stem = regex.replace(&stem, "").to_string();
|
||||
|
||||
(start_number..).map(move |dedup_number| {
|
||||
if dedup_number == 0 {
|
||||
format!("{directory}{clean_stem}{extension}")
|
||||
} else {
|
||||
format!("{directory}{clean_stem} ({dedup_number}){extension}")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_dedup_paths() {
|
||||
let mut deduped = dedup_paths("file.txt");
|
||||
assert_eq!(deduped.next(), Some("file.txt".to_owned()));
|
||||
assert_eq!(deduped.next(), Some("file (1).txt".to_owned()));
|
||||
assert_eq!(deduped.next(), Some("file (2).txt".to_owned()));
|
||||
|
||||
let mut deduped = dedup_paths("file");
|
||||
assert_eq!(deduped.next(), Some("file".to_owned()));
|
||||
assert_eq!(deduped.next(), Some("file (1)".to_owned()));
|
||||
assert_eq!(deduped.next(), Some("file (2)".to_owned()));
|
||||
|
||||
let mut deduped = dedup_paths("file (51).md");
|
||||
assert_eq!(deduped.next(), Some("file (51).md".to_owned()));
|
||||
assert_eq!(deduped.next(), Some("file (52).md".to_owned()));
|
||||
assert_eq!(deduped.next(), Some("file (53).md".to_owned()));
|
||||
|
||||
let mut deduped = dedup_paths("file (5)");
|
||||
assert_eq!(deduped.next(), Some("file (5)".to_owned()));
|
||||
assert_eq!(deduped.next(), Some("file (6)".to_owned()));
|
||||
assert_eq!(deduped.next(), Some("file (7)".to_owned()));
|
||||
|
||||
let mut deduped = dedup_paths("my/path.with.dots/file (5).md");
|
||||
assert_eq!(
|
||||
deduped.next(),
|
||||
Some("my/path.with.dots/file (5).md".to_owned())
|
||||
);
|
||||
assert_eq!(
|
||||
deduped.next(),
|
||||
Some("my/path.with.dots/file (6).md".to_owned())
|
||||
);
|
||||
|
||||
let mut deduped = dedup_paths("my/path.with.dots/file (5)");
|
||||
assert_eq!(
|
||||
deduped.next(),
|
||||
Some("my/path.with.dots/file (5)".to_owned())
|
||||
);
|
||||
assert_eq!(
|
||||
deduped.next(),
|
||||
Some("my/path.with.dots/file (6)".to_owned())
|
||||
);
|
||||
}
|
||||
}
|
||||
23
sync-server/src/utils/is_filetype_mergable.rs
Normal file
23
sync-server/src/utils/is_filetype_mergable.rs
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
pub fn is_file_type_mergable(path_or_file_name: &str) -> bool {
|
||||
let file_extension = path_or_file_name.split('.').next_back().unwrap_or_default();
|
||||
|
||||
matches!(file_extension.to_lowercase().as_str(), "md" | "txt")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_is_file_type_mergable() {
|
||||
assert!(is_file_type_mergable(".md"));
|
||||
assert!(is_file_type_mergable("hi.md"));
|
||||
assert!(is_file_type_mergable("my/path/to/my/document.md"));
|
||||
assert!(is_file_type_mergable("hi.MD"));
|
||||
assert!(is_file_type_mergable("my/path/to/my/DOCUMENT.MD"));
|
||||
|
||||
assert!(!is_file_type_mergable(".json"));
|
||||
assert!(!is_file_type_mergable("HELLO.JSON"));
|
||||
assert!(!is_file_type_mergable("my/config.yml"));
|
||||
}
|
||||
}
|
||||
11
sync-server/src/utils/normalize.rs
Normal file
11
sync-server/src/utils/normalize.rs
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
use serde::{Deserialize, Deserializer};
|
||||
|
||||
pub fn normalize<'de, D>(deserializer: D) -> Result<String, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
Ok(normalize_string(&s))
|
||||
}
|
||||
|
||||
pub fn normalize_string(s: &str) -> String { s.trim().to_lowercase() }
|
||||
34
sync-server/src/utils/sanitize_path.rs
Normal file
34
sync-server/src/utils/sanitize_path.rs
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
/// Sanitize the document's path to allow all clients to create the same path in
|
||||
/// their filesystem. If we didn't do this server-side, client's would need to
|
||||
/// deal with mapping invalid names to valid ones and then back.
|
||||
pub fn sanitize_path(path: &str) -> String {
|
||||
let options = sanitize_filename::Options {
|
||||
truncate: true,
|
||||
windows: true, // Windows is the lowest common denominator
|
||||
replacement: "",
|
||||
};
|
||||
|
||||
path.split('/')
|
||||
.map(|part| {
|
||||
let proposal = sanitize_filename::sanitize_with_options(part, options.clone());
|
||||
if !part.is_empty() && proposal.is_empty() {
|
||||
"_".to_owned()
|
||||
} else {
|
||||
proposal
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("/")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_path() {
|
||||
assert_eq!(sanitize_path("/my/path/what?"), "/my/path/what");
|
||||
assert_eq!(sanitize_path("file (1).md"), "file (1).md");
|
||||
assert_eq!(sanitize_path("/my/path/\\\\:?"), "/my/path/_");
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue