From 63a4948b8778ab909300e81edd259bdf4469a147 Mon Sep 17 00:00:00 2001 From: Andras Schmelczer Date: Tue, 25 Mar 2025 22:26:32 +0000 Subject: [PATCH] Add WS endpoint --- backend/sync_server/src/server.rs | 28 ++-- backend/sync_server/src/server/websocket.rs | 145 ++++++++++++++++++++ 2 files changed, 160 insertions(+), 13 deletions(-) create mode 100644 backend/sync_server/src/server/websocket.rs diff --git a/backend/sync_server/src/server.rs b/backend/sync_server/src/server.rs index 083be5ba..90bd8ff3 100644 --- a/backend/sync_server/src/server.rs +++ b/backend/sync_server/src/server.rs @@ -1,3 +1,16 @@ +mod auth; +mod create_document; +mod delete_document; +mod fetch_document_version; +mod fetch_document_version_content; +mod fetch_latest_document_version; +mod fetch_latest_documents; +mod ping; +mod requests; +mod responses; +mod update_document; +mod websocket; + use std::{ffi::OsString, sync::Arc}; use aide::{ @@ -10,7 +23,6 @@ use aide::{ transform::TransformOpenApi, }; use anyhow::{Context as _, Result, anyhow}; -use app_state::AppState; use axum::{ Extension, Json, extract::{DefaultBodyLimit, Request}, @@ -32,21 +44,10 @@ use tower_http::{ use tracing::{Level, info_span}; use crate::{ + app_state::AppState, config::server_config::ServerConfig, errors::{SerializedError, not_found_error}, }; -mod app_state; -mod auth; -mod create_document; -mod delete_document; -mod fetch_document_version; -mod fetch_document_version_content; -mod fetch_latest_document_version; -mod fetch_latest_documents; -mod ping; -mod requests; -mod responses; -mod update_document; pub async fn create_server(config_path: Option) -> Result<()> { aide::r#gen::on_error(|err| error!("{err}")); @@ -65,6 +66,7 @@ pub async fn create_server(config_path: Option) -> Result<()> { "/vaults/:vault_id/documents", get(fetch_latest_documents::fetch_latest_documents), ) + .route("/vaults/:vault_id/ws", get(websocket::websocket_handler)) .api_route( "/vaults/:vault_id/documents", post(create_document::create_document_multipart), diff --git a/backend/sync_server/src/server/websocket.rs b/backend/sync_server/src/server/websocket.rs new file mode 100644 index 00000000..8c9ed31e --- /dev/null +++ b/backend/sync_server/src/server/websocket.rs @@ -0,0 +1,145 @@ +use anyhow::Context; +use axum::{ + extract::{ + Path, Query, State, + ws::{Message, WebSocket, WebSocketUpgrade}, + }, + response::Response, +}; +use futures::{ + sink::SinkExt, + stream::{SplitSink, StreamExt}, +}; +use log::{error, info, warn}; +use schemars::JsonSchema; +use serde::Deserialize; + +use super::auth::auth; +use crate::{ + app_state::{ + AppState, + database::models::{DocumentVersionWithoutContent, VaultId, VaultUpdateId}, + }, + errors::{SyncServerError, server_error, unauthorized_error}, +}; + +// This is required for aide to infer the path parameter types and names +#[derive(Deserialize, JsonSchema)] +pub struct WebsocketPathParams { + vault_id: VaultId, +} + +// This is required for aide to infer the path parameter types and names +#[derive(Deserialize, JsonSchema)] +pub struct QueryParams { + since_update_id: Option, +} + +pub async fn websocket_handler( + ws: WebSocketUpgrade, + Path(WebsocketPathParams { vault_id }): Path, + Query(QueryParams { since_update_id }): Query, + State(state): State, +) -> Result { + Ok(ws.on_upgrade(move |socket| websocket_wrapped(state, socket, vault_id, since_update_id))) +} + +async fn websocket_wrapped( + state: AppState, + stream: WebSocket, + vault_id: VaultId, + since_update_id: Option, +) { + info!("Websocket connection opened on vault '{}'", vault_id); + + let result = websocket(state, stream, vault_id.clone(), since_update_id).await; + + if let Err(err) = result { + error!( + "Websocket connection error on vault '{}': {}", + vault_id, err + ); + } + + warn!("Websocket connection closed on vault '{}'", vault_id); +} + +async fn websocket( + state: AppState, + stream: WebSocket, + vault_id: VaultId, + since_update_id: Option, +) -> Result<(), SyncServerError> { + let (mut sender, mut receiver) = stream.split(); + + if let Some(Ok(Message::Text(token))) = receiver.next().await { + auth(&state, &token)?; + } else { + return Err(unauthorized_error(anyhow::anyhow!( + "Failed to authenticate" + ))); + } + + let mut rx = state.broadcasts.get_receiver(vault_id.clone()).await; + + let documents = if let Some(since_update_id) = since_update_id { + state + .database + .get_latest_documents_since(&vault_id, since_update_id, None) + .await + .map_err(server_error) + } else { + state + .database + .get_latest_documents(&vault_id, None) + .await + .map_err(server_error) + }?; + + for document in documents { + send_document_over_websocket(document, &mut sender).await?; + } + + let mut send_task = tokio::spawn(async move { + while let Ok(update) = rx.recv().await { + send_document_over_websocket(update, &mut sender).await?; + } + + Ok::<(), SyncServerError>(()) + }); + + let mut recv_task = tokio::spawn(async move { + while let Some(Ok(Message::Text(text))) = receiver.next().await { + info!("Received message: {}", text); + // Add username before message. + // let _ = tx.send(format!("{name}: {text}")); + } + }); + + tokio::select! { + _ = &mut send_task => recv_task.abort(), + _ = &mut recv_task => send_task.abort(), + }; + + send_task + .await + .context("Websocket send task failed") + .map_err(server_error)??; + + Ok(()) +} + +async fn send_document_over_websocket( + document: DocumentVersionWithoutContent, + sender: &mut SplitSink, +) -> Result<(), SyncServerError> { + let serialized_update = serde_json::to_string(&document) + .context("Failed to serialize update") + .map_err(server_error)?; + + sender + .send(Message::Text(serialized_update)) + .await + .context("Failed to send message over websocket") + .map_err(server_error) +}