Add WebSocket support (#12)
This commit is contained in:
parent
3d27b7f313
commit
1aad0fce31
68 changed files with 2578 additions and 993 deletions
|
|
@ -1,13 +1,19 @@
|
|||
pub mod broadcasts;
|
||||
pub mod database;
|
||||
|
||||
use std::ffi::OsString;
|
||||
|
||||
use anyhow::Result;
|
||||
use broadcasts::Broadcasts;
|
||||
use database::Database;
|
||||
|
||||
use crate::{config::Config, consts::DEFAULT_CONFIG_PATH, database::Database};
|
||||
use crate::{config::Config, consts::DEFAULT_CONFIG_PATH};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AppState {
|
||||
pub config: Config,
|
||||
pub database: Database,
|
||||
pub broadcasts: Broadcasts,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
|
|
@ -17,7 +23,12 @@ impl AppState {
|
|||
|
||||
let config = Config::read_or_create(&path).await?;
|
||||
let database = Database::try_new(&config.database).await?;
|
||||
let broadcasts = Broadcasts::new(&config.server);
|
||||
|
||||
Ok(Self { config, database })
|
||||
Ok(Self {
|
||||
config,
|
||||
database,
|
||||
broadcasts,
|
||||
})
|
||||
}
|
||||
}
|
||||
57
backend/sync_server/src/app_state/broadcasts.rs
Normal file
57
backend/sync_server/src/app_state/broadcasts.rs
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use anyhow::Context;
|
||||
use tokio::sync::{Mutex, broadcast};
|
||||
|
||||
use super::database::models::{DocumentVersionWithoutContent, VaultId};
|
||||
use crate::{config::server_config::ServerConfig, errors::server_error};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Broadcasts {
|
||||
max_clients_per_vault: usize,
|
||||
tx: Arc<Mutex<HashMap<VaultId, broadcast::Sender<DocumentVersionWithoutContent>>>>,
|
||||
}
|
||||
|
||||
impl Broadcasts {
|
||||
pub fn new(server_config: &ServerConfig) -> Self {
|
||||
Self {
|
||||
max_clients_per_vault: server_config.max_clients_per_vault,
|
||||
tx: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_receiver(
|
||||
&self,
|
||||
vault: VaultId,
|
||||
) -> broadcast::Receiver<DocumentVersionWithoutContent> {
|
||||
let tx = self.get_or_create(vault).await;
|
||||
|
||||
tx.subscribe()
|
||||
}
|
||||
|
||||
/// Sent a document update to all clients subscribed to the vault.
|
||||
/// We ignore & log failures.
|
||||
pub async fn send(&self, vault: VaultId, document: DocumentVersionWithoutContent) {
|
||||
let tx = self.get_or_create(vault).await;
|
||||
|
||||
let result = tx
|
||||
.send(document)
|
||||
.context("Cannot broadcast update message to websocket listeners")
|
||||
.map_err(server_error);
|
||||
|
||||
if result.is_err() {
|
||||
log::debug!("Failed to send message: {result:?}");
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_or_create(
|
||||
&self,
|
||||
vault: VaultId,
|
||||
) -> broadcast::Sender<DocumentVersionWithoutContent> {
|
||||
let mut tx = self.tx.lock().await;
|
||||
|
||||
tx.entry(vault)
|
||||
.or_insert_with(|| broadcast::channel(self.max_clients_per_vault).0.clone())
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
|
|
@ -85,13 +85,13 @@ impl Database {
|
|||
}
|
||||
|
||||
async fn run_migrations(pool: &Pool<Sqlite>) -> Result<()> {
|
||||
sqlx::migrate!("src/database/migrations")
|
||||
sqlx::migrate!("src/app_state/database/migrations")
|
||||
.run(pool)
|
||||
.await
|
||||
.context("Cannot check for pending migrations")
|
||||
}
|
||||
|
||||
async fn get_connection_pool(&mut self, vault: &VaultId) -> Result<Pool<Sqlite>> {
|
||||
async fn get_connection_pool(&self, vault: &VaultId) -> Result<Pool<Sqlite>> {
|
||||
let mut pools = self.connection_pools.lock().await;
|
||||
if !pools.contains_key(vault) {
|
||||
let pool = Self::create_vault_database(&self.config, vault).await?;
|
||||
|
|
@ -108,7 +108,7 @@ impl Database {
|
|||
/// Attempting to write from this transaction might result in a
|
||||
/// database locked error. Use this transaction for read-only operations.
|
||||
pub async fn create_readonly_transaction(
|
||||
&mut self,
|
||||
&self,
|
||||
vault: &VaultId,
|
||||
) -> Result<Transaction<'static>> {
|
||||
self.get_connection_pool(vault)
|
||||
|
|
@ -118,10 +118,7 @@ impl Database {
|
|||
.context("Cannot create transaction")
|
||||
}
|
||||
|
||||
pub async fn create_write_transaction(
|
||||
&mut self,
|
||||
vault: &VaultId,
|
||||
) -> Result<Transaction<'static>> {
|
||||
pub async fn create_write_transaction(&self, vault: &VaultId) -> Result<Transaction<'static>> {
|
||||
let mut transaction = self.create_readonly_transaction(vault).await?;
|
||||
|
||||
// sqlx doesn't support immediate transactions for sqlite: https://github.com/launchbadge/sqlx/issues/481
|
||||
|
|
@ -134,7 +131,7 @@ impl Database {
|
|||
|
||||
/// Return the latest state of all documents in the vault
|
||||
pub async fn get_latest_documents(
|
||||
&mut self,
|
||||
&self,
|
||||
vault: &VaultId,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
) -> Result<Vec<DocumentVersionWithoutContent>> {
|
||||
|
|
@ -165,7 +162,7 @@ impl Database {
|
|||
/// Return the latest state of all documents (including deleted) in the
|
||||
/// vault which have changed since the given update id
|
||||
pub async fn get_latest_documents_since(
|
||||
&mut self,
|
||||
&self,
|
||||
vault: &VaultId,
|
||||
vault_update_id: VaultUpdateId,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
|
|
@ -199,7 +196,7 @@ impl Database {
|
|||
}
|
||||
|
||||
pub async fn get_max_update_id_in_vault(
|
||||
&mut self,
|
||||
&self,
|
||||
vault: &VaultId,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
) -> Result<i64> {
|
||||
|
|
@ -222,7 +219,7 @@ impl Database {
|
|||
}
|
||||
|
||||
pub async fn get_latest_document_by_path(
|
||||
&mut self,
|
||||
&self,
|
||||
vault: &VaultId,
|
||||
relative_path: &str,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
|
|
@ -258,7 +255,7 @@ impl Database {
|
|||
}
|
||||
|
||||
pub async fn get_latest_document(
|
||||
&mut self,
|
||||
&self,
|
||||
vault: &VaultId,
|
||||
document_id: &DocumentId,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
|
|
@ -291,7 +288,7 @@ impl Database {
|
|||
}
|
||||
|
||||
pub async fn get_document_version(
|
||||
&mut self,
|
||||
&self,
|
||||
vault: &VaultId,
|
||||
vault_update_id: VaultUpdateId,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
|
|
@ -322,7 +319,7 @@ impl Database {
|
|||
}
|
||||
|
||||
pub async fn insert_document_version(
|
||||
&mut self,
|
||||
&self,
|
||||
vault: &VaultId,
|
||||
version: &StoredDocumentVersion,
|
||||
transaction: Option<&mut Transaction<'_>>,
|
||||
|
|
@ -1 +1,2 @@
|
|||
pub mod args;
|
||||
pub mod color_when;
|
||||
|
|
|
|||
|
|
@ -1,38 +1,26 @@
|
|||
use std::ffi::OsString;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
use clap::Parser;
|
||||
use clap_verbosity_flag::{InfoLevel, Verbosity};
|
||||
|
||||
/// Server for backing the VaultLink plugin
|
||||
use crate::cli::color_when::ColorWhen;
|
||||
|
||||
/// Server for backing the `VaultLink` plugin
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(version, about, long_about = None)]
|
||||
pub struct Args {
|
||||
#[arg(index = 1)]
|
||||
pub config_path: Option<OsString>,
|
||||
|
||||
#[command(flatten)]
|
||||
pub verbose: Verbosity<InfoLevel>,
|
||||
|
||||
#[arg(
|
||||
long,
|
||||
require_equals = true,
|
||||
value_name = "WHEN",
|
||||
num_args = 0..=1,
|
||||
default_value_t = ColorWhen::Auto,
|
||||
default_missing_value = "always",
|
||||
value_enum
|
||||
)]
|
||||
pub color: ColorWhen,
|
||||
}
|
||||
|
||||
#[derive(ValueEnum, Copy, Clone, Debug, PartialEq, Eq)]
|
||||
pub enum ColorWhen {
|
||||
Always,
|
||||
Auto,
|
||||
Never,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ColorWhen {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.to_possible_value()
|
||||
.expect("no values are skipped")
|
||||
.get_name()
|
||||
.fmt(f)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
31
backend/sync_server/src/cli/color_when.rs
Normal file
31
backend/sync_server/src/cli/color_when.rs
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
use std::io::IsTerminal;
|
||||
|
||||
use clap::ValueEnum;
|
||||
|
||||
#[derive(ValueEnum, Copy, Clone, Debug, PartialEq, Eq)]
|
||||
pub enum ColorWhen {
|
||||
Always,
|
||||
Auto,
|
||||
Never,
|
||||
}
|
||||
|
||||
impl ColorWhen {
|
||||
pub fn use_colors(self) -> bool {
|
||||
match self {
|
||||
ColorWhen::Always => true,
|
||||
ColorWhen::Auto => {
|
||||
std::env::var_os("NO_COLOR").is_none() && std::io::stderr().is_terminal()
|
||||
}
|
||||
ColorWhen::Never => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ColorWhen {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.to_possible_value()
|
||||
.expect("no values are skipped")
|
||||
.get_name()
|
||||
.fmt(f)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,10 @@
|
|||
use log::debug;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::consts::{DEFAULT_HOST, DEFAULT_MAX_BODY_SIZE_MB, DEFAULT_PORT};
|
||||
use crate::consts::{
|
||||
DEFAULT_HOST, DEFAULT_MAX_BODY_SIZE_MB, DEFAULT_MAX_CLIENTS_PER_VAULT, DEFAULT_PORT,
|
||||
};
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct ServerConfig {
|
||||
#[serde(default = "default_host")]
|
||||
|
|
@ -12,6 +15,9 @@ pub struct ServerConfig {
|
|||
|
||||
#[serde(default = "default_max_body_size_mb")]
|
||||
pub max_body_size_mb: usize,
|
||||
|
||||
#[serde(default = "default_max_clients_per_vault")]
|
||||
pub max_clients_per_vault: usize,
|
||||
}
|
||||
|
||||
fn default_host() -> String {
|
||||
|
|
@ -29,12 +35,18 @@ fn default_max_body_size_mb() -> usize {
|
|||
DEFAULT_MAX_BODY_SIZE_MB
|
||||
}
|
||||
|
||||
fn default_max_clients_per_vault() -> usize {
|
||||
debug!("Using default max clients per vault: {DEFAULT_MAX_CLIENTS_PER_VAULT}");
|
||||
DEFAULT_MAX_CLIENTS_PER_VAULT
|
||||
}
|
||||
|
||||
impl Default for ServerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
host: default_host(),
|
||||
port: default_port(),
|
||||
max_body_size_mb: default_max_body_size_mb(),
|
||||
max_clients_per_vault: default_max_clients_per_vault(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,3 +4,4 @@ pub const DEFAULT_HOST: &str = "127.0.0.1";
|
|||
pub const DEFAULT_PORT: u16 = 3000;
|
||||
pub const DEFAULT_MAX_CONNECTIONS: u32 = 12;
|
||||
pub const DEFAULT_MAX_BODY_SIZE_MB: usize = 4096;
|
||||
pub const DEFAULT_MAX_CLIENTS_PER_VAULT: usize = 256;
|
||||
|
|
|
|||
|
|
@ -1,38 +1,79 @@
|
|||
mod app_state;
|
||||
mod cli;
|
||||
mod config;
|
||||
mod consts;
|
||||
mod database;
|
||||
mod errors;
|
||||
mod server;
|
||||
mod utils;
|
||||
|
||||
use std::process::ExitCode;
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use clap::Parser;
|
||||
use cli::args::Args;
|
||||
use errors::{SyncServerError, init_error};
|
||||
use log::info;
|
||||
use server::create_server;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
use tracing_subscriber::{EnvFilter, fmt::format, util::SubscriberInitExt};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), SyncServerError> {
|
||||
async fn main() -> ExitCode {
|
||||
let args = Args::parse();
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
|
||||
format!(
|
||||
"{}=debug,tower_http=debug,axum::rejection=trace",
|
||||
env!("CARGO_CRATE_NAME")
|
||||
)
|
||||
.into()
|
||||
}),
|
||||
let mut result = set_up_logging(&args);
|
||||
|
||||
if result.is_ok() {
|
||||
result = start_server(args).await;
|
||||
}
|
||||
|
||||
match result {
|
||||
Ok(()) => ExitCode::SUCCESS,
|
||||
Err(e) => {
|
||||
eprintln!("Failed to set up logging: {e}");
|
||||
ExitCode::FAILURE
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn set_up_logging(args: &Args) -> Result<(), SyncServerError> {
|
||||
let level_filter = match args.verbose.log_level_filter() {
|
||||
// We don't want to allow disabling all logging
|
||||
log::LevelFilter::Off | log::LevelFilter::Error => tracing::Level::ERROR,
|
||||
log::LevelFilter::Warn => tracing::Level::WARN,
|
||||
log::LevelFilter::Info => tracing::Level::INFO,
|
||||
log::LevelFilter::Debug => tracing::Level::DEBUG,
|
||||
log::LevelFilter::Trace => tracing::Level::TRACE,
|
||||
};
|
||||
|
||||
let env_filter = EnvFilter::builder()
|
||||
.with_default_directive(level_filter.into())
|
||||
.from_env()
|
||||
.context("Failed to create logging env filter")
|
||||
.map_err(init_error)?;
|
||||
|
||||
let use_colors = args.color.use_colors();
|
||||
|
||||
let is_debug_mode = args.verbose.log_level_filter() >= log::LevelFilter::Debug;
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.with_ansi(use_colors)
|
||||
.with_env_filter(env_filter)
|
||||
.event_format(
|
||||
format()
|
||||
.without_time()
|
||||
.with_target(is_debug_mode)
|
||||
.with_line_number(is_debug_mode)
|
||||
.compact(),
|
||||
)
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.finish()
|
||||
.try_init()
|
||||
.context("Failed to initialise tracing")
|
||||
.map_err(init_error)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_server(args: Args) -> Result<(), SyncServerError> {
|
||||
info!(
|
||||
"Starting VaultLink server version {}",
|
||||
env!("CARGO_PKG_VERSION")
|
||||
|
|
|
|||
|
|
@ -1,3 +1,16 @@
|
|||
mod auth;
|
||||
mod create_document;
|
||||
mod delete_document;
|
||||
mod fetch_document_version;
|
||||
mod fetch_document_version_content;
|
||||
mod fetch_latest_document_version;
|
||||
mod fetch_latest_documents;
|
||||
mod ping;
|
||||
mod requests;
|
||||
mod responses;
|
||||
mod update_document;
|
||||
mod websocket;
|
||||
|
||||
use std::{ffi::OsString, sync::Arc};
|
||||
|
||||
use aide::{
|
||||
|
|
@ -10,7 +23,6 @@ use aide::{
|
|||
transform::TransformOpenApi,
|
||||
};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use app_state::AppState;
|
||||
use axum::{
|
||||
Extension, Json,
|
||||
extract::{DefaultBodyLimit, Request},
|
||||
|
|
@ -32,21 +44,10 @@ use tower_http::{
|
|||
use tracing::{Level, info_span};
|
||||
|
||||
use crate::{
|
||||
app_state::AppState,
|
||||
config::server_config::ServerConfig,
|
||||
errors::{SerializedError, not_found_error},
|
||||
};
|
||||
mod app_state;
|
||||
mod auth;
|
||||
mod create_document;
|
||||
mod delete_document;
|
||||
mod fetch_document_version;
|
||||
mod fetch_document_version_content;
|
||||
mod fetch_latest_document_version;
|
||||
mod fetch_latest_documents;
|
||||
mod ping;
|
||||
mod requests;
|
||||
mod responses;
|
||||
mod update_document;
|
||||
|
||||
pub async fn create_server(config_path: Option<OsString>) -> Result<()> {
|
||||
aide::r#gen::on_error(|err| error!("{err}"));
|
||||
|
|
@ -65,6 +66,7 @@ pub async fn create_server(config_path: Option<OsString>) -> Result<()> {
|
|||
"/vaults/:vault_id/documents",
|
||||
get(fetch_latest_documents::fetch_latest_documents),
|
||||
)
|
||||
.route("/vaults/:vault_id/ws", get(websocket::websocket_handler))
|
||||
.api_route(
|
||||
"/vaults/:vault_id/documents",
|
||||
post(create_document::create_document_multipart),
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
use super::app_state::AppState;
|
||||
use crate::{
|
||||
app_state::AppState,
|
||||
config::user_config::User,
|
||||
errors::{SyncServerError, unauthorized_error},
|
||||
};
|
||||
|
||||
// TODO: turn this into a middleware
|
||||
pub fn auth(app_state: &AppState, token: &str) -> Result<User, SyncServerError> {
|
||||
app_state
|
||||
.config
|
||||
|
|
|
|||
|
|
@ -11,12 +11,16 @@ use serde::Deserialize;
|
|||
use sync_lib::base64_to_bytes;
|
||||
|
||||
use super::{
|
||||
app_state::AppState,
|
||||
auth::auth,
|
||||
requests::{CreateDocumentVersion, CreateDocumentVersionMultipart},
|
||||
};
|
||||
use crate::{
|
||||
database::models::{DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId},
|
||||
app_state::{
|
||||
AppState,
|
||||
database::models::{
|
||||
DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId,
|
||||
},
|
||||
},
|
||||
errors::{SyncServerError, client_error, server_error},
|
||||
utils::sanitize_path,
|
||||
};
|
||||
|
|
@ -77,7 +81,7 @@ pub async fn create_document_json(
|
|||
|
||||
async fn internal_create_document(
|
||||
auth_header: Authorization<Bearer>,
|
||||
mut state: AppState,
|
||||
state: AppState,
|
||||
vault_id: VaultId,
|
||||
document_id: Option<DocumentId>,
|
||||
relative_path: String,
|
||||
|
|
@ -139,5 +143,10 @@ async fn internal_create_document(
|
|||
.context("Failed to commit successful transaction")
|
||||
.map_err(server_error)?;
|
||||
|
||||
state
|
||||
.broadcasts
|
||||
.send(vault_id, new_version.clone().into())
|
||||
.await;
|
||||
|
||||
Ok(Json(new_version.into()))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,9 +8,14 @@ use axum_jsonschema::Json;
|
|||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::{app_state::AppState, auth::auth, requests::DeleteDocumentVersion};
|
||||
use super::{auth::auth, requests::DeleteDocumentVersion};
|
||||
use crate::{
|
||||
database::models::{DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId},
|
||||
app_state::{
|
||||
AppState,
|
||||
database::models::{
|
||||
DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId,
|
||||
},
|
||||
},
|
||||
errors::{SyncServerError, server_error},
|
||||
utils::sanitize_path,
|
||||
};
|
||||
|
|
@ -29,7 +34,7 @@ pub async fn delete_document(
|
|||
vault_id,
|
||||
document_id,
|
||||
}): Path<DeleteDocumentPathParams>,
|
||||
State(mut state): State<AppState>,
|
||||
State(state): State<AppState>,
|
||||
Json(request): Json<DeleteDocumentVersion>,
|
||||
) -> Result<Json<DocumentVersionWithoutContent>, SyncServerError> {
|
||||
auth(&state, auth_header.token())?;
|
||||
|
|
@ -67,5 +72,10 @@ pub async fn delete_document(
|
|||
.context("Failed to commit successful transaction")
|
||||
.map_err(server_error)?;
|
||||
|
||||
state
|
||||
.broadcasts
|
||||
.send(vault_id, new_version.clone().into())
|
||||
.await;
|
||||
|
||||
Ok(Json(new_version.into()))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,9 +8,12 @@ use axum_jsonschema::Json;
|
|||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::{app_state::AppState, auth::auth};
|
||||
use super::auth::auth;
|
||||
use crate::{
|
||||
database::models::{DocumentId, DocumentVersion, VaultId, VaultUpdateId},
|
||||
app_state::{
|
||||
AppState,
|
||||
database::models::{DocumentId, DocumentVersion, VaultId, VaultUpdateId},
|
||||
},
|
||||
errors::{SyncServerError, not_found_error, server_error},
|
||||
};
|
||||
|
||||
|
|
@ -30,7 +33,7 @@ pub async fn fetch_document_version(
|
|||
document_id,
|
||||
vault_update_id,
|
||||
}): Path<FetchDocumentVersionPathParams>,
|
||||
State(mut state): State<AppState>,
|
||||
State(state): State<AppState>,
|
||||
) -> Result<Json<DocumentVersion>, SyncServerError> {
|
||||
auth(&state, auth_header.token())?;
|
||||
|
||||
|
|
|
|||
|
|
@ -10,9 +10,12 @@ use axum_extra::{
|
|||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::{app_state::AppState, auth::auth};
|
||||
use super::auth::auth;
|
||||
use crate::{
|
||||
database::models::{DocumentId, VaultId, VaultUpdateId},
|
||||
app_state::{
|
||||
AppState,
|
||||
database::models::{DocumentId, VaultId, VaultUpdateId},
|
||||
},
|
||||
errors::{SyncServerError, not_found_error, server_error},
|
||||
};
|
||||
|
||||
|
|
@ -32,7 +35,7 @@ pub async fn fetch_document_version_content(
|
|||
document_id,
|
||||
vault_update_id,
|
||||
}): Path<FetchDocumentVersionContentPathParams>,
|
||||
State(mut state): State<AppState>,
|
||||
State(state): State<AppState>,
|
||||
) -> Result<Bytes, SyncServerError> {
|
||||
auth(&state, auth_header.token())?;
|
||||
|
||||
|
|
|
|||
|
|
@ -8,9 +8,12 @@ use axum_jsonschema::Json;
|
|||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::{app_state::AppState, auth::auth};
|
||||
use super::auth::auth;
|
||||
use crate::{
|
||||
database::models::{DocumentId, DocumentVersion, VaultId},
|
||||
app_state::{
|
||||
AppState,
|
||||
database::models::{DocumentId, DocumentVersion, VaultId},
|
||||
},
|
||||
errors::{SyncServerError, not_found_error, server_error},
|
||||
};
|
||||
|
||||
|
|
@ -28,7 +31,7 @@ pub async fn fetch_latest_document_version(
|
|||
vault_id,
|
||||
document_id,
|
||||
}): Path<FetchLatestDocumentVersionPathParams>,
|
||||
State(mut state): State<AppState>,
|
||||
State(state): State<AppState>,
|
||||
) -> Result<Json<DocumentVersion>, SyncServerError> {
|
||||
auth(&state, auth_header.token())?;
|
||||
|
||||
|
|
|
|||
|
|
@ -7,9 +7,12 @@ use axum_jsonschema::Json;
|
|||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::{app_state::AppState, auth::auth, responses::FetchLatestDocumentsResponse};
|
||||
use super::{auth::auth, responses::FetchLatestDocumentsResponse};
|
||||
use crate::{
|
||||
database::models::{VaultId, VaultUpdateId},
|
||||
app_state::{
|
||||
AppState,
|
||||
database::models::{VaultId, VaultUpdateId},
|
||||
},
|
||||
errors::{SyncServerError, server_error},
|
||||
};
|
||||
|
||||
|
|
@ -30,7 +33,7 @@ pub async fn fetch_latest_documents(
|
|||
TypedHeader(auth_header): TypedHeader<Authorization<Bearer>>,
|
||||
Path(FetchLatestDocumentsPathParams { vault_id }): Path<FetchLatestDocumentsPathParams>,
|
||||
Query(QueryParams { since_update_id }): Query<QueryParams>,
|
||||
State(mut state): State<AppState>,
|
||||
State(state): State<AppState>,
|
||||
) -> Result<Json<FetchLatestDocumentsResponse>, SyncServerError> {
|
||||
auth(&state, auth_header.token())?;
|
||||
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ use axum_extra::{
|
|||
headers::{Authorization, authorization::Bearer},
|
||||
};
|
||||
|
||||
use super::{app_state::AppState, auth::auth, responses::PingResponse};
|
||||
use crate::errors::SyncServerError;
|
||||
use super::{auth::auth, responses::PingResponse};
|
||||
use crate::{app_state::AppState, errors::SyncServerError};
|
||||
|
||||
#[axum::debug_handler]
|
||||
pub async fn ping(
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use axum_typed_multipart::TryFromMultipart;
|
|||
use schemars::JsonSchema;
|
||||
use serde::{self, Deserialize};
|
||||
|
||||
use crate::database::models::{DocumentId, VaultUpdateId};
|
||||
use crate::app_state::database::models::{DocumentId, VaultUpdateId};
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
use schemars::JsonSchema;
|
||||
use serde::{self, Serialize};
|
||||
|
||||
use crate::database::models::{DocumentVersion, DocumentVersionWithoutContent, VaultUpdateId};
|
||||
use crate::app_state::database::models::{
|
||||
DocumentVersion, DocumentVersionWithoutContent, VaultUpdateId,
|
||||
};
|
||||
|
||||
/// Response to a ping request.
|
||||
#[derive(Debug, Clone, Serialize, JsonSchema)]
|
||||
|
|
|
|||
|
|
@ -12,13 +12,15 @@ use serde::Deserialize;
|
|||
use sync_lib::{base64_to_bytes, is_file_type_mergable, merge};
|
||||
|
||||
use super::{
|
||||
app_state::AppState,
|
||||
auth::auth,
|
||||
requests::{UpdateDocumentVersion, UpdateDocumentVersionMultipart},
|
||||
responses::DocumentUpdateResponse,
|
||||
};
|
||||
use crate::{
|
||||
database::models::{DocumentId, StoredDocumentVersion, VaultId, VaultUpdateId},
|
||||
app_state::{
|
||||
AppState,
|
||||
database::models::{DocumentId, StoredDocumentVersion, VaultId, VaultUpdateId},
|
||||
},
|
||||
errors::{SyncServerError, client_error, not_found_error, server_error},
|
||||
utils::{deduped_file_paths, sanitize_path},
|
||||
};
|
||||
|
|
@ -83,7 +85,7 @@ pub async fn update_document_json(
|
|||
#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
|
||||
async fn internal_update_document(
|
||||
auth_header: Authorization<Bearer>,
|
||||
mut state: AppState,
|
||||
state: AppState,
|
||||
vault_id: VaultId,
|
||||
document_id: DocumentId,
|
||||
parent_version_id: VaultUpdateId,
|
||||
|
|
@ -216,6 +218,11 @@ async fn internal_update_document(
|
|||
.context("Failed to commit successful transaction")
|
||||
.map_err(server_error)?;
|
||||
|
||||
state
|
||||
.broadcasts
|
||||
.send(vault_id, new_version.clone().into())
|
||||
.await;
|
||||
|
||||
Ok(Json(if is_different_from_request_content {
|
||||
DocumentUpdateResponse::MergingUpdate(new_version.into())
|
||||
} else {
|
||||
|
|
|
|||
147
backend/sync_server/src/server/websocket.rs
Normal file
147
backend/sync_server/src/server/websocket.rs
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
use anyhow::Context;
|
||||
use axum::{
|
||||
extract::{
|
||||
Path, Query, State,
|
||||
ws::{Message, WebSocket, WebSocketUpgrade},
|
||||
},
|
||||
response::Response,
|
||||
};
|
||||
use futures::{
|
||||
sink::SinkExt,
|
||||
stream::{SplitSink, StreamExt},
|
||||
};
|
||||
use log::{error, info, warn};
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::auth::auth;
|
||||
use crate::{
|
||||
app_state::{
|
||||
AppState,
|
||||
database::models::{DocumentVersionWithoutContent, VaultId, VaultUpdateId},
|
||||
},
|
||||
errors::{SyncServerError, server_error, unauthorized_error},
|
||||
};
|
||||
|
||||
// This is required for aide to infer the path parameter types and names
|
||||
#[derive(Deserialize, JsonSchema)]
|
||||
pub struct WebsocketPathParams {
|
||||
vault_id: VaultId,
|
||||
}
|
||||
|
||||
// This is required for aide to infer the path parameter types and names
|
||||
#[derive(Deserialize, JsonSchema)]
|
||||
pub struct QueryParams {
|
||||
since_update_id: Option<VaultUpdateId>,
|
||||
}
|
||||
|
||||
pub async fn websocket_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
Path(WebsocketPathParams { vault_id }): Path<WebsocketPathParams>,
|
||||
Query(QueryParams { since_update_id }): Query<QueryParams>,
|
||||
State(state): State<AppState>,
|
||||
) -> Result<Response, SyncServerError> {
|
||||
Ok(ws.on_upgrade(move |socket| websocket_wrapped(state, socket, vault_id, since_update_id)))
|
||||
}
|
||||
|
||||
async fn websocket_wrapped(
|
||||
state: AppState,
|
||||
stream: WebSocket,
|
||||
vault_id: VaultId,
|
||||
since_update_id: Option<VaultUpdateId>,
|
||||
) {
|
||||
info!("Websocket connection opened on vault '{}'", vault_id);
|
||||
|
||||
let result = websocket(state, stream, vault_id.clone(), since_update_id).await;
|
||||
|
||||
if let Err(err) = result {
|
||||
error!(
|
||||
"Websocket connection error on vault '{}': {}",
|
||||
vault_id, err
|
||||
);
|
||||
}
|
||||
|
||||
warn!("Websocket connection closed on vault '{}'", vault_id);
|
||||
}
|
||||
|
||||
async fn websocket(
|
||||
state: AppState,
|
||||
stream: WebSocket,
|
||||
vault_id: VaultId,
|
||||
since_update_id: Option<VaultUpdateId>,
|
||||
) -> Result<(), SyncServerError> {
|
||||
let (mut sender, mut receiver) = stream.split();
|
||||
|
||||
if let Some(Ok(Message::Text(token))) = receiver.next().await {
|
||||
auth(&state, &token)?;
|
||||
} else {
|
||||
return Err(unauthorized_error(anyhow::anyhow!(
|
||||
"Failed to authenticate"
|
||||
)));
|
||||
}
|
||||
|
||||
let mut rx = state.broadcasts.get_receiver(vault_id.clone()).await;
|
||||
|
||||
let documents = if let Some(since_update_id) = since_update_id {
|
||||
state
|
||||
.database
|
||||
.get_latest_documents_since(&vault_id, since_update_id, None)
|
||||
.await
|
||||
.map_err(server_error)
|
||||
} else {
|
||||
state
|
||||
.database
|
||||
.get_latest_documents(&vault_id, None)
|
||||
.await
|
||||
.map_err(server_error)
|
||||
}?;
|
||||
|
||||
for document in documents {
|
||||
send_document_over_websocket(document, &mut sender).await?;
|
||||
}
|
||||
|
||||
let mut send_task = tokio::spawn(async move {
|
||||
while let Ok(update) = rx.recv().await {
|
||||
send_document_over_websocket(update, &mut sender).await?;
|
||||
}
|
||||
|
||||
Ok::<(), SyncServerError>(())
|
||||
});
|
||||
|
||||
let mut recv_task =
|
||||
tokio::spawn(
|
||||
async move { while let Some(Ok(Message::Text(_text))) = receiver.next().await {} },
|
||||
);
|
||||
|
||||
tokio::select! {
|
||||
_ = &mut send_task => recv_task.abort(),
|
||||
_ = &mut recv_task => send_task.abort(),
|
||||
};
|
||||
|
||||
send_task
|
||||
.await
|
||||
.context("Websocket send task failed")
|
||||
.map_err(server_error)??;
|
||||
|
||||
recv_task
|
||||
.await
|
||||
.context("Websocket receive task failed")
|
||||
.map_err(server_error)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_document_over_websocket(
|
||||
document: DocumentVersionWithoutContent,
|
||||
sender: &mut SplitSink<WebSocket, Message>,
|
||||
) -> Result<(), SyncServerError> {
|
||||
let serialized_update = serde_json::to_string(&document)
|
||||
.context("Failed to serialize update")
|
||||
.map_err(server_error)?;
|
||||
|
||||
sender
|
||||
.send(Message::Text(serialized_update))
|
||||
.await
|
||||
.context("Failed to send message over websocket")
|
||||
.map_err(server_error)
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue