Add WebSocket support (#12)

This commit is contained in:
Andras Schmelczer 2025-03-29 10:17:46 +00:00 committed by GitHub
parent 3d27b7f313
commit 1aad0fce31
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
68 changed files with 2578 additions and 993 deletions

View file

@ -1,13 +1,19 @@
pub mod broadcasts;
pub mod database;
use std::ffi::OsString;
use anyhow::Result;
use broadcasts::Broadcasts;
use database::Database;
use crate::{config::Config, consts::DEFAULT_CONFIG_PATH, database::Database};
use crate::{config::Config, consts::DEFAULT_CONFIG_PATH};
#[derive(Clone, Debug)]
pub struct AppState {
pub config: Config,
pub database: Database,
pub broadcasts: Broadcasts,
}
impl AppState {
@ -17,7 +23,12 @@ impl AppState {
let config = Config::read_or_create(&path).await?;
let database = Database::try_new(&config.database).await?;
let broadcasts = Broadcasts::new(&config.server);
Ok(Self { config, database })
Ok(Self {
config,
database,
broadcasts,
})
}
}

View file

@ -0,0 +1,57 @@
use std::{collections::HashMap, sync::Arc};
use anyhow::Context;
use tokio::sync::{Mutex, broadcast};
use super::database::models::{DocumentVersionWithoutContent, VaultId};
use crate::{config::server_config::ServerConfig, errors::server_error};
#[derive(Debug, Clone)]
pub struct Broadcasts {
max_clients_per_vault: usize,
tx: Arc<Mutex<HashMap<VaultId, broadcast::Sender<DocumentVersionWithoutContent>>>>,
}
impl Broadcasts {
pub fn new(server_config: &ServerConfig) -> Self {
Self {
max_clients_per_vault: server_config.max_clients_per_vault,
tx: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn get_receiver(
&self,
vault: VaultId,
) -> broadcast::Receiver<DocumentVersionWithoutContent> {
let tx = self.get_or_create(vault).await;
tx.subscribe()
}
/// Sent a document update to all clients subscribed to the vault.
/// We ignore & log failures.
pub async fn send(&self, vault: VaultId, document: DocumentVersionWithoutContent) {
let tx = self.get_or_create(vault).await;
let result = tx
.send(document)
.context("Cannot broadcast update message to websocket listeners")
.map_err(server_error);
if result.is_err() {
log::debug!("Failed to send message: {result:?}");
}
}
async fn get_or_create(
&self,
vault: VaultId,
) -> broadcast::Sender<DocumentVersionWithoutContent> {
let mut tx = self.tx.lock().await;
tx.entry(vault)
.or_insert_with(|| broadcast::channel(self.max_clients_per_vault).0.clone())
.clone()
}
}

View file

@ -85,13 +85,13 @@ impl Database {
}
async fn run_migrations(pool: &Pool<Sqlite>) -> Result<()> {
sqlx::migrate!("src/database/migrations")
sqlx::migrate!("src/app_state/database/migrations")
.run(pool)
.await
.context("Cannot check for pending migrations")
}
async fn get_connection_pool(&mut self, vault: &VaultId) -> Result<Pool<Sqlite>> {
async fn get_connection_pool(&self, vault: &VaultId) -> Result<Pool<Sqlite>> {
let mut pools = self.connection_pools.lock().await;
if !pools.contains_key(vault) {
let pool = Self::create_vault_database(&self.config, vault).await?;
@ -108,7 +108,7 @@ impl Database {
/// Attempting to write from this transaction might result in a
/// database locked error. Use this transaction for read-only operations.
pub async fn create_readonly_transaction(
&mut self,
&self,
vault: &VaultId,
) -> Result<Transaction<'static>> {
self.get_connection_pool(vault)
@ -118,10 +118,7 @@ impl Database {
.context("Cannot create transaction")
}
pub async fn create_write_transaction(
&mut self,
vault: &VaultId,
) -> Result<Transaction<'static>> {
pub async fn create_write_transaction(&self, vault: &VaultId) -> Result<Transaction<'static>> {
let mut transaction = self.create_readonly_transaction(vault).await?;
// sqlx doesn't support immediate transactions for sqlite: https://github.com/launchbadge/sqlx/issues/481
@ -134,7 +131,7 @@ impl Database {
/// Return the latest state of all documents in the vault
pub async fn get_latest_documents(
&mut self,
&self,
vault: &VaultId,
transaction: Option<&mut Transaction<'_>>,
) -> Result<Vec<DocumentVersionWithoutContent>> {
@ -165,7 +162,7 @@ impl Database {
/// Return the latest state of all documents (including deleted) in the
/// vault which have changed since the given update id
pub async fn get_latest_documents_since(
&mut self,
&self,
vault: &VaultId,
vault_update_id: VaultUpdateId,
transaction: Option<&mut Transaction<'_>>,
@ -199,7 +196,7 @@ impl Database {
}
pub async fn get_max_update_id_in_vault(
&mut self,
&self,
vault: &VaultId,
transaction: Option<&mut Transaction<'_>>,
) -> Result<i64> {
@ -222,7 +219,7 @@ impl Database {
}
pub async fn get_latest_document_by_path(
&mut self,
&self,
vault: &VaultId,
relative_path: &str,
transaction: Option<&mut Transaction<'_>>,
@ -258,7 +255,7 @@ impl Database {
}
pub async fn get_latest_document(
&mut self,
&self,
vault: &VaultId,
document_id: &DocumentId,
transaction: Option<&mut Transaction<'_>>,
@ -291,7 +288,7 @@ impl Database {
}
pub async fn get_document_version(
&mut self,
&self,
vault: &VaultId,
vault_update_id: VaultUpdateId,
transaction: Option<&mut Transaction<'_>>,
@ -322,7 +319,7 @@ impl Database {
}
pub async fn insert_document_version(
&mut self,
&self,
vault: &VaultId,
version: &StoredDocumentVersion,
transaction: Option<&mut Transaction<'_>>,

View file

@ -1 +1,2 @@
pub mod args;
pub mod color_when;

View file

@ -1,38 +1,26 @@
use std::ffi::OsString;
use clap::{Parser, ValueEnum};
use clap::Parser;
use clap_verbosity_flag::{InfoLevel, Verbosity};
/// Server for backing the VaultLink plugin
use crate::cli::color_when::ColorWhen;
/// Server for backing the `VaultLink` plugin
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
pub struct Args {
#[arg(index = 1)]
pub config_path: Option<OsString>,
#[command(flatten)]
pub verbose: Verbosity<InfoLevel>,
#[arg(
long,
require_equals = true,
value_name = "WHEN",
num_args = 0..=1,
default_value_t = ColorWhen::Auto,
default_missing_value = "always",
value_enum
)]
pub color: ColorWhen,
}
#[derive(ValueEnum, Copy, Clone, Debug, PartialEq, Eq)]
pub enum ColorWhen {
Always,
Auto,
Never,
}
impl std::fmt::Display for ColorWhen {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.to_possible_value()
.expect("no values are skipped")
.get_name()
.fmt(f)
}
}

View file

@ -0,0 +1,31 @@
use std::io::IsTerminal;
use clap::ValueEnum;
#[derive(ValueEnum, Copy, Clone, Debug, PartialEq, Eq)]
pub enum ColorWhen {
Always,
Auto,
Never,
}
impl ColorWhen {
pub fn use_colors(self) -> bool {
match self {
ColorWhen::Always => true,
ColorWhen::Auto => {
std::env::var_os("NO_COLOR").is_none() && std::io::stderr().is_terminal()
}
ColorWhen::Never => false,
}
}
}
impl std::fmt::Display for ColorWhen {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.to_possible_value()
.expect("no values are skipped")
.get_name()
.fmt(f)
}
}

View file

@ -1,7 +1,10 @@
use log::debug;
use serde::{Deserialize, Serialize};
use crate::consts::{DEFAULT_HOST, DEFAULT_MAX_BODY_SIZE_MB, DEFAULT_PORT};
use crate::consts::{
DEFAULT_HOST, DEFAULT_MAX_BODY_SIZE_MB, DEFAULT_MAX_CLIENTS_PER_VAULT, DEFAULT_PORT,
};
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ServerConfig {
#[serde(default = "default_host")]
@ -12,6 +15,9 @@ pub struct ServerConfig {
#[serde(default = "default_max_body_size_mb")]
pub max_body_size_mb: usize,
#[serde(default = "default_max_clients_per_vault")]
pub max_clients_per_vault: usize,
}
fn default_host() -> String {
@ -29,12 +35,18 @@ fn default_max_body_size_mb() -> usize {
DEFAULT_MAX_BODY_SIZE_MB
}
fn default_max_clients_per_vault() -> usize {
debug!("Using default max clients per vault: {DEFAULT_MAX_CLIENTS_PER_VAULT}");
DEFAULT_MAX_CLIENTS_PER_VAULT
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
host: default_host(),
port: default_port(),
max_body_size_mb: default_max_body_size_mb(),
max_clients_per_vault: default_max_clients_per_vault(),
}
}
}

View file

@ -4,3 +4,4 @@ pub const DEFAULT_HOST: &str = "127.0.0.1";
pub const DEFAULT_PORT: u16 = 3000;
pub const DEFAULT_MAX_CONNECTIONS: u32 = 12;
pub const DEFAULT_MAX_BODY_SIZE_MB: usize = 4096;
pub const DEFAULT_MAX_CLIENTS_PER_VAULT: usize = 256;

View file

@ -1,38 +1,79 @@
mod app_state;
mod cli;
mod config;
mod consts;
mod database;
mod errors;
mod server;
mod utils;
use std::process::ExitCode;
use anyhow::{Context as _, Result};
use clap::Parser;
use cli::args::Args;
use errors::{SyncServerError, init_error};
use log::info;
use server::create_server;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use tracing_subscriber::{EnvFilter, fmt::format, util::SubscriberInitExt};
#[tokio::main]
async fn main() -> Result<(), SyncServerError> {
async fn main() -> ExitCode {
let args = Args::parse();
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
format!(
"{}=debug,tower_http=debug,axum::rejection=trace",
env!("CARGO_CRATE_NAME")
)
.into()
}),
let mut result = set_up_logging(&args);
if result.is_ok() {
result = start_server(args).await;
}
match result {
Ok(()) => ExitCode::SUCCESS,
Err(e) => {
eprintln!("Failed to set up logging: {e}");
ExitCode::FAILURE
}
}
}
fn set_up_logging(args: &Args) -> Result<(), SyncServerError> {
let level_filter = match args.verbose.log_level_filter() {
// We don't want to allow disabling all logging
log::LevelFilter::Off | log::LevelFilter::Error => tracing::Level::ERROR,
log::LevelFilter::Warn => tracing::Level::WARN,
log::LevelFilter::Info => tracing::Level::INFO,
log::LevelFilter::Debug => tracing::Level::DEBUG,
log::LevelFilter::Trace => tracing::Level::TRACE,
};
let env_filter = EnvFilter::builder()
.with_default_directive(level_filter.into())
.from_env()
.context("Failed to create logging env filter")
.map_err(init_error)?;
let use_colors = args.color.use_colors();
let is_debug_mode = args.verbose.log_level_filter() >= log::LevelFilter::Debug;
tracing_subscriber::fmt()
.with_ansi(use_colors)
.with_env_filter(env_filter)
.event_format(
format()
.without_time()
.with_target(is_debug_mode)
.with_line_number(is_debug_mode)
.compact(),
)
.with(tracing_subscriber::fmt::layer())
.finish()
.try_init()
.context("Failed to initialise tracing")
.map_err(init_error)?;
Ok(())
}
async fn start_server(args: Args) -> Result<(), SyncServerError> {
info!(
"Starting VaultLink server version {}",
env!("CARGO_PKG_VERSION")

View file

@ -1,3 +1,16 @@
mod auth;
mod create_document;
mod delete_document;
mod fetch_document_version;
mod fetch_document_version_content;
mod fetch_latest_document_version;
mod fetch_latest_documents;
mod ping;
mod requests;
mod responses;
mod update_document;
mod websocket;
use std::{ffi::OsString, sync::Arc};
use aide::{
@ -10,7 +23,6 @@ use aide::{
transform::TransformOpenApi,
};
use anyhow::{Context as _, Result, anyhow};
use app_state::AppState;
use axum::{
Extension, Json,
extract::{DefaultBodyLimit, Request},
@ -32,21 +44,10 @@ use tower_http::{
use tracing::{Level, info_span};
use crate::{
app_state::AppState,
config::server_config::ServerConfig,
errors::{SerializedError, not_found_error},
};
mod app_state;
mod auth;
mod create_document;
mod delete_document;
mod fetch_document_version;
mod fetch_document_version_content;
mod fetch_latest_document_version;
mod fetch_latest_documents;
mod ping;
mod requests;
mod responses;
mod update_document;
pub async fn create_server(config_path: Option<OsString>) -> Result<()> {
aide::r#gen::on_error(|err| error!("{err}"));
@ -65,6 +66,7 @@ pub async fn create_server(config_path: Option<OsString>) -> Result<()> {
"/vaults/:vault_id/documents",
get(fetch_latest_documents::fetch_latest_documents),
)
.route("/vaults/:vault_id/ws", get(websocket::websocket_handler))
.api_route(
"/vaults/:vault_id/documents",
post(create_document::create_document_multipart),

View file

@ -1,9 +1,10 @@
use super::app_state::AppState;
use crate::{
app_state::AppState,
config::user_config::User,
errors::{SyncServerError, unauthorized_error},
};
// TODO: turn this into a middleware
pub fn auth(app_state: &AppState, token: &str) -> Result<User, SyncServerError> {
app_state
.config

View file

@ -11,12 +11,16 @@ use serde::Deserialize;
use sync_lib::base64_to_bytes;
use super::{
app_state::AppState,
auth::auth,
requests::{CreateDocumentVersion, CreateDocumentVersionMultipart},
};
use crate::{
database::models::{DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId},
app_state::{
AppState,
database::models::{
DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId,
},
},
errors::{SyncServerError, client_error, server_error},
utils::sanitize_path,
};
@ -77,7 +81,7 @@ pub async fn create_document_json(
async fn internal_create_document(
auth_header: Authorization<Bearer>,
mut state: AppState,
state: AppState,
vault_id: VaultId,
document_id: Option<DocumentId>,
relative_path: String,
@ -139,5 +143,10 @@ async fn internal_create_document(
.context("Failed to commit successful transaction")
.map_err(server_error)?;
state
.broadcasts
.send(vault_id, new_version.clone().into())
.await;
Ok(Json(new_version.into()))
}

View file

@ -8,9 +8,14 @@ use axum_jsonschema::Json;
use schemars::JsonSchema;
use serde::Deserialize;
use super::{app_state::AppState, auth::auth, requests::DeleteDocumentVersion};
use super::{auth::auth, requests::DeleteDocumentVersion};
use crate::{
database::models::{DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId},
app_state::{
AppState,
database::models::{
DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId,
},
},
errors::{SyncServerError, server_error},
utils::sanitize_path,
};
@ -29,7 +34,7 @@ pub async fn delete_document(
vault_id,
document_id,
}): Path<DeleteDocumentPathParams>,
State(mut state): State<AppState>,
State(state): State<AppState>,
Json(request): Json<DeleteDocumentVersion>,
) -> Result<Json<DocumentVersionWithoutContent>, SyncServerError> {
auth(&state, auth_header.token())?;
@ -67,5 +72,10 @@ pub async fn delete_document(
.context("Failed to commit successful transaction")
.map_err(server_error)?;
state
.broadcasts
.send(vault_id, new_version.clone().into())
.await;
Ok(Json(new_version.into()))
}

View file

@ -8,9 +8,12 @@ use axum_jsonschema::Json;
use schemars::JsonSchema;
use serde::Deserialize;
use super::{app_state::AppState, auth::auth};
use super::auth::auth;
use crate::{
database::models::{DocumentId, DocumentVersion, VaultId, VaultUpdateId},
app_state::{
AppState,
database::models::{DocumentId, DocumentVersion, VaultId, VaultUpdateId},
},
errors::{SyncServerError, not_found_error, server_error},
};
@ -30,7 +33,7 @@ pub async fn fetch_document_version(
document_id,
vault_update_id,
}): Path<FetchDocumentVersionPathParams>,
State(mut state): State<AppState>,
State(state): State<AppState>,
) -> Result<Json<DocumentVersion>, SyncServerError> {
auth(&state, auth_header.token())?;

View file

@ -10,9 +10,12 @@ use axum_extra::{
use schemars::JsonSchema;
use serde::Deserialize;
use super::{app_state::AppState, auth::auth};
use super::auth::auth;
use crate::{
database::models::{DocumentId, VaultId, VaultUpdateId},
app_state::{
AppState,
database::models::{DocumentId, VaultId, VaultUpdateId},
},
errors::{SyncServerError, not_found_error, server_error},
};
@ -32,7 +35,7 @@ pub async fn fetch_document_version_content(
document_id,
vault_update_id,
}): Path<FetchDocumentVersionContentPathParams>,
State(mut state): State<AppState>,
State(state): State<AppState>,
) -> Result<Bytes, SyncServerError> {
auth(&state, auth_header.token())?;

View file

@ -8,9 +8,12 @@ use axum_jsonschema::Json;
use schemars::JsonSchema;
use serde::Deserialize;
use super::{app_state::AppState, auth::auth};
use super::auth::auth;
use crate::{
database::models::{DocumentId, DocumentVersion, VaultId},
app_state::{
AppState,
database::models::{DocumentId, DocumentVersion, VaultId},
},
errors::{SyncServerError, not_found_error, server_error},
};
@ -28,7 +31,7 @@ pub async fn fetch_latest_document_version(
vault_id,
document_id,
}): Path<FetchLatestDocumentVersionPathParams>,
State(mut state): State<AppState>,
State(state): State<AppState>,
) -> Result<Json<DocumentVersion>, SyncServerError> {
auth(&state, auth_header.token())?;

View file

@ -7,9 +7,12 @@ use axum_jsonschema::Json;
use schemars::JsonSchema;
use serde::Deserialize;
use super::{app_state::AppState, auth::auth, responses::FetchLatestDocumentsResponse};
use super::{auth::auth, responses::FetchLatestDocumentsResponse};
use crate::{
database::models::{VaultId, VaultUpdateId},
app_state::{
AppState,
database::models::{VaultId, VaultUpdateId},
},
errors::{SyncServerError, server_error},
};
@ -30,7 +33,7 @@ pub async fn fetch_latest_documents(
TypedHeader(auth_header): TypedHeader<Authorization<Bearer>>,
Path(FetchLatestDocumentsPathParams { vault_id }): Path<FetchLatestDocumentsPathParams>,
Query(QueryParams { since_update_id }): Query<QueryParams>,
State(mut state): State<AppState>,
State(state): State<AppState>,
) -> Result<Json<FetchLatestDocumentsResponse>, SyncServerError> {
auth(&state, auth_header.token())?;

View file

@ -4,8 +4,8 @@ use axum_extra::{
headers::{Authorization, authorization::Bearer},
};
use super::{app_state::AppState, auth::auth, responses::PingResponse};
use crate::errors::SyncServerError;
use super::{auth::auth, responses::PingResponse};
use crate::{app_state::AppState, errors::SyncServerError};
#[axum::debug_handler]
pub async fn ping(

View file

@ -4,7 +4,7 @@ use axum_typed_multipart::TryFromMultipart;
use schemars::JsonSchema;
use serde::{self, Deserialize};
use crate::database::models::{DocumentId, VaultUpdateId};
use crate::app_state::database::models::{DocumentId, VaultUpdateId};
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]

View file

@ -1,7 +1,9 @@
use schemars::JsonSchema;
use serde::{self, Serialize};
use crate::database::models::{DocumentVersion, DocumentVersionWithoutContent, VaultUpdateId};
use crate::app_state::database::models::{
DocumentVersion, DocumentVersionWithoutContent, VaultUpdateId,
};
/// Response to a ping request.
#[derive(Debug, Clone, Serialize, JsonSchema)]

View file

@ -12,13 +12,15 @@ use serde::Deserialize;
use sync_lib::{base64_to_bytes, is_file_type_mergable, merge};
use super::{
app_state::AppState,
auth::auth,
requests::{UpdateDocumentVersion, UpdateDocumentVersionMultipart},
responses::DocumentUpdateResponse,
};
use crate::{
database::models::{DocumentId, StoredDocumentVersion, VaultId, VaultUpdateId},
app_state::{
AppState,
database::models::{DocumentId, StoredDocumentVersion, VaultId, VaultUpdateId},
},
errors::{SyncServerError, client_error, not_found_error, server_error},
utils::{deduped_file_paths, sanitize_path},
};
@ -83,7 +85,7 @@ pub async fn update_document_json(
#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
async fn internal_update_document(
auth_header: Authorization<Bearer>,
mut state: AppState,
state: AppState,
vault_id: VaultId,
document_id: DocumentId,
parent_version_id: VaultUpdateId,
@ -216,6 +218,11 @@ async fn internal_update_document(
.context("Failed to commit successful transaction")
.map_err(server_error)?;
state
.broadcasts
.send(vault_id, new_version.clone().into())
.await;
Ok(Json(if is_different_from_request_content {
DocumentUpdateResponse::MergingUpdate(new_version.into())
} else {

View file

@ -0,0 +1,147 @@
use anyhow::Context;
use axum::{
extract::{
Path, Query, State,
ws::{Message, WebSocket, WebSocketUpgrade},
},
response::Response,
};
use futures::{
sink::SinkExt,
stream::{SplitSink, StreamExt},
};
use log::{error, info, warn};
use schemars::JsonSchema;
use serde::Deserialize;
use super::auth::auth;
use crate::{
app_state::{
AppState,
database::models::{DocumentVersionWithoutContent, VaultId, VaultUpdateId},
},
errors::{SyncServerError, server_error, unauthorized_error},
};
// This is required for aide to infer the path parameter types and names
#[derive(Deserialize, JsonSchema)]
pub struct WebsocketPathParams {
vault_id: VaultId,
}
// This is required for aide to infer the path parameter types and names
#[derive(Deserialize, JsonSchema)]
pub struct QueryParams {
since_update_id: Option<VaultUpdateId>,
}
pub async fn websocket_handler(
ws: WebSocketUpgrade,
Path(WebsocketPathParams { vault_id }): Path<WebsocketPathParams>,
Query(QueryParams { since_update_id }): Query<QueryParams>,
State(state): State<AppState>,
) -> Result<Response, SyncServerError> {
Ok(ws.on_upgrade(move |socket| websocket_wrapped(state, socket, vault_id, since_update_id)))
}
async fn websocket_wrapped(
state: AppState,
stream: WebSocket,
vault_id: VaultId,
since_update_id: Option<VaultUpdateId>,
) {
info!("Websocket connection opened on vault '{}'", vault_id);
let result = websocket(state, stream, vault_id.clone(), since_update_id).await;
if let Err(err) = result {
error!(
"Websocket connection error on vault '{}': {}",
vault_id, err
);
}
warn!("Websocket connection closed on vault '{}'", vault_id);
}
async fn websocket(
state: AppState,
stream: WebSocket,
vault_id: VaultId,
since_update_id: Option<VaultUpdateId>,
) -> Result<(), SyncServerError> {
let (mut sender, mut receiver) = stream.split();
if let Some(Ok(Message::Text(token))) = receiver.next().await {
auth(&state, &token)?;
} else {
return Err(unauthorized_error(anyhow::anyhow!(
"Failed to authenticate"
)));
}
let mut rx = state.broadcasts.get_receiver(vault_id.clone()).await;
let documents = if let Some(since_update_id) = since_update_id {
state
.database
.get_latest_documents_since(&vault_id, since_update_id, None)
.await
.map_err(server_error)
} else {
state
.database
.get_latest_documents(&vault_id, None)
.await
.map_err(server_error)
}?;
for document in documents {
send_document_over_websocket(document, &mut sender).await?;
}
let mut send_task = tokio::spawn(async move {
while let Ok(update) = rx.recv().await {
send_document_over_websocket(update, &mut sender).await?;
}
Ok::<(), SyncServerError>(())
});
let mut recv_task =
tokio::spawn(
async move { while let Some(Ok(Message::Text(_text))) = receiver.next().await {} },
);
tokio::select! {
_ = &mut send_task => recv_task.abort(),
_ = &mut recv_task => send_task.abort(),
};
send_task
.await
.context("Websocket send task failed")
.map_err(server_error)??;
recv_task
.await
.context("Websocket receive task failed")
.map_err(server_error)?;
Ok(())
}
async fn send_document_over_websocket(
document: DocumentVersionWithoutContent,
sender: &mut SplitSink<WebSocket, Message>,
) -> Result<(), SyncServerError> {
let serialized_update = serde_json::to_string(&document)
.context("Failed to serialize update")
.map_err(server_error)?;
sender
.send(Message::Text(serialized_update))
.await
.context("Failed to send message over websocket")
.map_err(server_error)
}