From eb1cc6104259776a99f456ef147e0f5d9103154c Mon Sep 17 00:00:00 2001 From: Andras Schmelczer Date: Sun, 1 Jun 2025 09:50:52 +0100 Subject: [PATCH] Implement cursor broadcasting backend --- .gitignore | 35 +++-- backend/Cargo.lock | 35 +++++ backend/sync_server/Cargo.toml | 1 + backend/sync_server/src/app_state.rs | 13 +- backend/sync_server/src/app_state/cursors.rs | 124 +++++++++++++++ backend/sync_server/src/app_state/database.rs | 28 +++- .../src/app_state/database/models.rs | 4 +- .../sync_server/src/app_state/websocket.rs | 3 + .../app_state/{ => websocket}/broadcasts.rs | 36 +++-- .../src/app_state/websocket/models.rs | 76 +++++++++ .../src/app_state/websocket/utils.rs | 74 +++++++++ .../sync_server/src/config/database_config.rs | 14 +- backend/sync_server/src/consts.rs | 7 +- backend/sync_server/src/server.rs | 2 +- .../sync_server/src/server/create_document.rs | 29 +--- .../sync_server/src/server/delete_document.rs | 16 +- backend/sync_server/src/server/requests.rs | 7 +- .../sync_server/src/server/update_document.rs | 29 +--- backend/sync_server/src/server/websocket.rs | 146 ++++++++---------- 19 files changed, 488 insertions(+), 191 deletions(-) create mode 100644 backend/sync_server/src/app_state/cursors.rs create mode 100644 backend/sync_server/src/app_state/websocket.rs rename backend/sync_server/src/app_state/{ => websocket}/broadcasts.rs (53%) create mode 100644 backend/sync_server/src/app_state/websocket/models.rs create mode 100644 backend/sync_server/src/app_state/websocket/utils.rs diff --git a/.gitignore b/.gitignore index a91ed90b..384c91eb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,17 +1,18 @@ -# npm -node_modules - -# Exclude macOS Finder (System Explorer) View States -.DS_Store - -# Rust build folder -backend/target - -frontend/*/dist - -backend/db.sqlite3* -backend/databases - -*.log - -*.sqlx +# npm +node_modules + +# Exclude macOS Finder (System Explorer) View States +.DS_Store + +# Rust build folder +backend/target + +# Frontend build folders +frontend/*/dist + +backend/db.sqlite3* +backend/databases +backend/sync_server/bindings/*.ts + +*.log +*.sqlx diff --git a/backend/Cargo.lock b/backend/Cargo.lock index adbb5d20..2f009e1d 100644 --- a/backend/Cargo.lock +++ b/backend/Cargo.lock @@ -2589,6 +2589,7 @@ dependencies = [ "tower-http", "tracing", "tracing-subscriber", + "ts-rs", "uuid", ] @@ -2622,6 +2623,15 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "termcolor" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] + [[package]] name = "test-case" version = "3.3.1" @@ -2920,6 +2930,31 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "ts-rs" +version = "10.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e640d9b0964e9d39df633548591090ab92f7a4567bc31d3891af23471a3365c6" +dependencies = [ + "chrono", + "lazy_static", + "thiserror 2.0.12", + "ts-rs-macros", + "uuid", +] + +[[package]] +name = "ts-rs-macros" +version = "10.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e9d8656589772eeec2cf7a8264d9cda40fb28b9bc53118ceb9e8c07f8f38730" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.90", + "termcolor", +] + [[package]] name = "tungstenite" version = "0.24.0" diff --git a/backend/sync_server/Cargo.toml b/backend/sync_server/Cargo.toml index a483ed5c..e593dc3b 100644 --- a/backend/sync_server/Cargo.toml +++ b/backend/sync_server/Cargo.toml @@ -37,6 +37,7 @@ futures = "0.3.31" serde_json = "1.0.140" clap-verbosity-flag = "3.0.3" bimap = "0.6.3" +ts-rs = { version = "10.1", features = ["uuid-impl", "chrono-impl"] } [lints] workspace = true diff --git a/backend/sync_server/src/app_state.rs b/backend/sync_server/src/app_state.rs index 1cad9149..a61467d5 100644 --- a/backend/sync_server/src/app_state.rs +++ b/backend/sync_server/src/app_state.rs @@ -1,11 +1,13 @@ -pub mod broadcasts; +pub mod cursors; pub mod database; +pub mod websocket; use std::ffi::OsString; use anyhow::Result; -use broadcasts::Broadcasts; +use cursors::Cursors; use database::Database; +use websocket::broadcasts::Broadcasts; use crate::{config::Config, consts::DEFAULT_CONFIG_PATH}; @@ -13,6 +15,7 @@ use crate::{config::Config, consts::DEFAULT_CONFIG_PATH}; pub struct AppState { pub config: Config, pub database: Database, + pub cursors: Cursors, pub broadcasts: Broadcasts, } @@ -22,12 +25,16 @@ impl AppState { let path = std::path::PathBuf::from(config_path); let config = Config::read_or_create(&path).await?; - let database = Database::try_new(&config.database).await?; let broadcasts = Broadcasts::new(&config.server); + let database = Database::try_new(&config.database, &broadcasts).await?; + let cursors: Cursors = Cursors::new(&config.database, &broadcasts); + + Cursors::start_background_task(cursors.clone()); Ok(Self { config, database, + cursors, broadcasts, }) } diff --git a/backend/sync_server/src/app_state/cursors.rs b/backend/sync_server/src/app_state/cursors.rs new file mode 100644 index 00000000..851566a7 --- /dev/null +++ b/backend/sync_server/src/app_state/cursors.rs @@ -0,0 +1,124 @@ +use core::time::Duration; +use std::{collections::HashMap, sync::Arc}; + +use chrono::TimeDelta; +use sqlx::types::chrono::Utc; +use tokio::sync::Mutex; + +use super::{ + database::models::{DeviceId, VaultId}, + websocket::{ + broadcasts::Broadcasts, + models::{ + ClientCursors, CursorPositionFromServer, WebSocketServerMessage, + WebSocketServerMessageWithOrigin, + }, + }, +}; +use crate::config::database_config::DatabaseConfig; + +const BACKGROUND_TASK_INTERVAL: Duration = Duration::from_secs(1); + +#[derive(Clone, Debug)] +pub struct Cursors { + config: DatabaseConfig, + broadcasts: Broadcasts, + vault_to_cursors: Arc>>>, +} + +impl Cursors { + pub fn new(config: &DatabaseConfig, broadcasts: &Broadcasts) -> Self { + Self { + config: config.clone(), + broadcasts: broadcasts.clone(), + vault_to_cursors: Arc::new(Mutex::new(HashMap::new())), + } + } + + pub async fn update_cursors( + &self, + vault_id: VaultId, + device_id: &DeviceId, + document_to_cursors: HashMap>, + ) { + let mut vault_to_cursors = self.vault_to_cursors.lock().await; + + let all_device_cursors = vault_to_cursors.entry(vault_id).or_insert_with(Vec::new); + + all_device_cursors.retain(|c| &c.client_cursors.device_id != device_id); + all_device_cursors.push(ClientCursorsWithTimeToLive::new(ClientCursors { + device_id: device_id.to_string(), + cursors: document_to_cursors, + })); + } + + pub async fn get_cursors(&self, vault_id: &VaultId) -> Vec { + let vault_to_cursors = self.vault_to_cursors.lock().await; + vault_to_cursors + .get(vault_id) + .map(|cursors| { + cursors + .iter() + .cloned() + .map(|with_ttl| with_ttl.client_cursors) + .collect::>() + }) + .unwrap_or_default() + } + + pub fn start_background_task(self) { + tokio::spawn(async move { + self.run_backround_task().await; + }); + } + + async fn run_backround_task(&self) { + loop { + self.remove_expired_cursors().await; + self.broadcast_cursors().await; + tokio::time::sleep(BACKGROUND_TASK_INTERVAL).await; + } + } + + async fn remove_expired_cursors(&self) { + let mut vault_to_cursors = self.vault_to_cursors.lock().await; + + for (_vault_id, cursors) in vault_to_cursors.iter_mut() { + cursors.retain(|cursor| !cursor.is_expired(self.config.cursor_timeout)); + } + } + + async fn broadcast_cursors(&self) { + let vault_to_cursors = self.vault_to_cursors.lock().await; + + for (vault_id, cursors) in vault_to_cursors.iter() { + self.broadcasts + .send_document_update( + vault_id.clone(), + WebSocketServerMessageWithOrigin::new(WebSocketServerMessage::CursorPositions( + CursorPositionFromServer { + clients: cursors.iter().map(|c| c.client_cursors.clone()).collect(), + }, + )), + ) + .await; + } + } +} + +#[derive(Clone, Debug)] +struct ClientCursorsWithTimeToLive { + client_cursors: ClientCursors, + last_updated: chrono::DateTime, +} + +impl ClientCursorsWithTimeToLive { + fn new(client_cursors: ClientCursors) -> Self { + Self { + client_cursors, + last_updated: Utc::now(), + } + } + + pub fn is_expired(&self, ttl: TimeDelta) -> bool { Utc::now() - self.last_updated > ttl } +} diff --git a/backend/sync_server/src/app_state/database.rs b/backend/sync_server/src/app_state/database.rs index 2ef03ba1..f8940140 100644 --- a/backend/sync_server/src/app_state/database.rs +++ b/backend/sync_server/src/app_state/database.rs @@ -6,23 +6,29 @@ use models::{ DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId, VaultUpdateId, }; use sqlx::{sqlite::SqliteConnectOptions, types::chrono::Utc}; + pub mod models; use sqlx::{Pool, Sqlite, sqlite::SqlitePoolOptions}; use tokio::sync::Mutex; use uuid::fmt::Hyphenated; +use super::websocket::{ + broadcasts::Broadcasts, + models::{WebSocketServerMessage, WebSocketServerMessageWithOrigin, WebSocketVaultUpdate}, +}; use crate::config::database_config::DatabaseConfig; #[derive(Clone, Debug)] pub struct Database { config: DatabaseConfig, + broadcasts: Broadcasts, connection_pools: Arc>>>, } pub type Transaction<'a> = sqlx::Transaction<'a, Sqlite>; impl Database { - pub async fn try_new(config: &DatabaseConfig) -> Result { + pub async fn try_new(config: &DatabaseConfig, broadcasts: &Broadcasts) -> Result { tokio::fs::create_dir_all(&config.databases_directory_path) .await .with_context(|| { @@ -55,6 +61,7 @@ impl Database { Ok(Self { config: config.clone(), connection_pools: Arc::new(Mutex::new(connection_pools)), + broadcasts: broadcasts.clone(), }) } @@ -362,7 +369,7 @@ impl Database { pub async fn insert_document_version( &self, - vault: &VaultId, + vault_id: &VaultId, version: &StoredDocumentVersion, transaction: Option<&mut Transaction<'_>>, ) -> Result<()> { @@ -394,10 +401,25 @@ impl Database { if let Some(transaction) = transaction { query.execute(&mut **transaction).await } else { - query.execute(&self.get_connection_pool(vault).await?).await + query + .execute(&self.get_connection_pool(vault_id).await?) + .await } .context("Cannot insert document version")?; + self.broadcasts + .send_document_update( + vault_id.clone(), + WebSocketServerMessageWithOrigin::with_origin( + version.device_id.clone(), + WebSocketServerMessage::VaultUpdate(WebSocketVaultUpdate { + documents: vec![version.clone().into()], + is_initial_sync: false, + }), + ), + ) + .await; + Ok(()) } } diff --git a/backend/sync_server/src/app_state/database/models.rs b/backend/sync_server/src/app_state/database/models.rs index 62ba66b6..197d96d7 100644 --- a/backend/sync_server/src/app_state/database/models.rs +++ b/backend/sync_server/src/app_state/database/models.rs @@ -2,9 +2,11 @@ use chrono::{DateTime, Utc}; use schemars::JsonSchema; use serde::Serialize; use sync_lib::bytes_to_base64; +use ts_rs::TS; pub type VaultId = String; pub type VaultUpdateId = i64; + pub type DocumentId = uuid::Uuid; pub type UserId = String; pub type DeviceId = String; @@ -25,7 +27,7 @@ impl PartialEq for StoredDocumentVersion { fn eq(&self, other: &Self) -> bool { self.vault_update_id == other.vault_update_id } } -#[derive(Debug, Clone, Serialize, JsonSchema)] +#[derive(TS, Debug, Clone, Serialize, JsonSchema)] #[serde(rename_all = "camelCase")] pub struct DocumentVersionWithoutContent { pub vault_update_id: VaultUpdateId, diff --git a/backend/sync_server/src/app_state/websocket.rs b/backend/sync_server/src/app_state/websocket.rs new file mode 100644 index 00000000..b945606f --- /dev/null +++ b/backend/sync_server/src/app_state/websocket.rs @@ -0,0 +1,3 @@ +pub mod broadcasts; +pub mod models; +pub mod utils; diff --git a/backend/sync_server/src/app_state/broadcasts.rs b/backend/sync_server/src/app_state/websocket/broadcasts.rs similarity index 53% rename from backend/sync_server/src/app_state/broadcasts.rs rename to backend/sync_server/src/app_state/websocket/broadcasts.rs index f71886cf..cef6ee6a 100644 --- a/backend/sync_server/src/app_state/broadcasts.rs +++ b/backend/sync_server/src/app_state/websocket/broadcasts.rs @@ -3,19 +3,15 @@ use std::{collections::HashMap, sync::Arc}; use anyhow::Context; use tokio::sync::{Mutex, broadcast}; -use super::database::models::{DeviceId, DocumentVersionWithoutContent, VaultId}; -use crate::{config::server_config::ServerConfig, errors::server_error}; +use super::models::WebSocketServerMessageWithOrigin; +use crate::{ + app_state::database::models::VaultId, config::server_config::ServerConfig, errors::server_error, +}; #[derive(Debug, Clone)] pub struct Broadcasts { max_clients_per_vault: usize, - tx: Arc>>>, -} - -#[derive(Debug, Clone)] -pub struct VaultUpdate { - pub origin_device_id: Option, - pub document: DocumentVersionWithoutContent, + tx: Arc>>>, } impl Broadcasts { @@ -26,20 +22,27 @@ impl Broadcasts { } } - pub async fn get_receiver(&self, vault: VaultId) -> broadcast::Receiver { + pub async fn get_receiver( + &self, + vault: VaultId, + ) -> broadcast::Receiver { let tx = self.get_or_create(vault).await; tx.subscribe() } - /// Sent a document update to all clients subscribed to the vault. - /// We ignore & log failures. - pub async fn send(&self, vault: VaultId, document: VaultUpdate) { + /// Notify all clients (who are subscribed to the vault) about an update. + /// We only log failures. + pub async fn send_document_update( + &self, + vault: VaultId, + document: WebSocketServerMessageWithOrigin, + ) { let tx = self.get_or_create(vault).await; let result = tx .send(document) - .context("Cannot broadcast update message to websocket listeners") + .context("Cannot broadcast server message to websocket listeners") .map_err(server_error); if result.is_err() { @@ -47,7 +50,10 @@ impl Broadcasts { } } - async fn get_or_create(&self, vault: VaultId) -> broadcast::Sender { + async fn get_or_create( + &self, + vault: VaultId, + ) -> broadcast::Sender { let mut tx = self.tx.lock().await; tx.entry(vault) diff --git a/backend/sync_server/src/app_state/websocket/models.rs b/backend/sync_server/src/app_state/websocket/models.rs new file mode 100644 index 00000000..3205ff25 --- /dev/null +++ b/backend/sync_server/src/app_state/websocket/models.rs @@ -0,0 +1,76 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; +use ts_rs::TS; + +use crate::app_state::database::models::{DeviceId, DocumentVersionWithoutContent, VaultUpdateId}; + +#[derive(TS, Deserialize, Clone, Debug)] +#[serde(rename_all = "camelCase")] +pub struct WebSocketHandshake { + pub token: String, + pub device_id: DeviceId, + pub last_seen_vault_update_id: Option, +} + +#[derive(TS, Deserialize, Clone, Debug)] +#[serde(rename_all = "camelCase")] +pub struct CursorPositionFromClient { + pub document_to_cursors: HashMap>, +} + +#[derive(TS, Serialize, Clone, Debug)] +#[serde(rename_all = "camelCase")] +pub struct ClientCursors { + pub device_id: DeviceId, + pub cursors: HashMap>, +} + +#[derive(TS, Serialize, Clone, Debug)] +#[serde(rename_all = "camelCase")] +pub struct CursorPositionFromServer { + pub clients: Vec, +} + +#[derive(TS, Serialize, Clone, Debug)] +#[serde(rename_all = "camelCase")] +pub struct WebSocketVaultUpdate { + pub documents: Vec, + pub is_initial_sync: bool, +} + +#[derive(TS, Deserialize, Clone, Debug)] +#[ts(export)] +pub enum WebSocketClientMessage { + Handshake(WebSocketHandshake), + CursorPositions(CursorPositionFromClient), +} + +#[derive(TS, Serialize, Clone, Debug)] +#[ts(export)] +pub enum WebSocketServerMessage { + VaultUpdate(WebSocketVaultUpdate), + CursorPositions(CursorPositionFromServer), +} + +#[derive(Clone, Debug)] +pub struct WebSocketServerMessageWithOrigin { + pub origin_device_id: Option, + pub message: WebSocketServerMessage, +} + +impl WebSocketServerMessageWithOrigin { + pub fn new(message: WebSocketServerMessage) -> Self { + Self { + origin_device_id: None, + message, + } + } + + pub fn with_origin(origin_device_id: DeviceId, message: WebSocketServerMessage) -> Self { + Self { + origin_device_id: Some(origin_device_id), + message, + } + } +} diff --git a/backend/sync_server/src/app_state/websocket/utils.rs b/backend/sync_server/src/app_state/websocket/utils.rs new file mode 100644 index 00000000..7c4e2c05 --- /dev/null +++ b/backend/sync_server/src/app_state/websocket/utils.rs @@ -0,0 +1,74 @@ +use anyhow::Context; +use axum::extract::ws::{Message, WebSocket}; +use futures::{sink::SinkExt, stream::SplitSink}; + +use super::models::{WebSocketClientMessage, WebSocketHandshake, WebSocketServerMessage}; +use crate::{ + app_state::{ + AppState, + database::models::{DocumentVersionWithoutContent, VaultId, VaultUpdateId}, + }, + errors::{SyncServerError, server_error, unauthenticated_error}, + server::auth::auth, +}; + +pub fn get_handshake( + state: &AppState, + vault_id: &VaultId, + message: Message, +) -> Result { + if let Message::Text(message) = message { + let message: WebSocketClientMessage = serde_json::from_str(&message) + .context("Failed to parse message") + .map_err(server_error)?; + + match message { + WebSocketClientMessage::Handshake(handshake) => { + auth(state, handshake.token.trim(), vault_id)?; + Ok(handshake) + } + WebSocketClientMessage::CursorPositions(_) => Err(unauthenticated_error( + anyhow::anyhow!("Expected a handshake message"), + )), + } + } else { + Err(unauthenticated_error(anyhow::anyhow!( + "Failed to authenticate due to invalid message" + ))) + } +} + +pub async fn get_unseen_documents( + state: &AppState, + vault_id: &VaultId, + last_seen_vault_update_id: Option, +) -> Result, SyncServerError> { + if let Some(update_id) = last_seen_vault_update_id { + state + .database + .get_latest_documents_since(vault_id, update_id, None) + .await + .map_err(server_error) + } else { + state + .database + .get_latest_documents(vault_id, None) + .await + .map_err(server_error) + } +} + +pub async fn send_update_over_websocket( + update: &WebSocketServerMessage, + sender: &mut SplitSink, +) -> Result<(), SyncServerError> { + let serialized_update = serde_json::to_string(update) + .context("Failed to serialize update") + .map_err(server_error)?; + + sender + .send(Message::Text(serialized_update)) + .await + .context("Failed to send message over websocket") + .map_err(server_error) +} diff --git a/backend/sync_server/src/config/database_config.rs b/backend/sync_server/src/config/database_config.rs index ef26a09d..6f91e19c 100644 --- a/backend/sync_server/src/config/database_config.rs +++ b/backend/sync_server/src/config/database_config.rs @@ -1,9 +1,12 @@ use std::path::PathBuf; +use chrono::TimeDelta; use log::debug; use serde::{Deserialize, Serialize}; -use crate::consts::{DEFAULT_DATABASES_DIRECTORY_PATH, DEFAULT_MAX_CONNECTIONS_PER_VAULT}; +use crate::consts::{ + DEFAULT_CURSOR_TIMEOUT, DEFAULT_DATABASES_DIRECTORY_PATH, DEFAULT_MAX_CONNECTIONS_PER_VAULT, +}; #[derive(Debug, Deserialize, Serialize, Clone)] pub struct DatabaseConfig { @@ -12,6 +15,9 @@ pub struct DatabaseConfig { #[serde(default = "default_max_connections_per_vault")] pub max_connections_per_vault: u32, + + #[serde(default = "default_cursor_timeout")] + pub cursor_timeout: TimeDelta, } fn default_databases_directory_path() -> PathBuf { @@ -24,11 +30,17 @@ fn default_max_connections_per_vault() -> u32 { DEFAULT_MAX_CONNECTIONS_PER_VAULT } +fn default_cursor_timeout() -> TimeDelta { + debug!("Using default cursor timeout: {DEFAULT_CURSOR_TIMEOUT}"); + DEFAULT_CURSOR_TIMEOUT +} + impl Default for DatabaseConfig { fn default() -> Self { Self { databases_directory_path: default_databases_directory_path(), max_connections_per_vault: default_max_connections_per_vault(), + cursor_timeout: default_cursor_timeout(), } } } diff --git a/backend/sync_server/src/consts.rs b/backend/sync_server/src/consts.rs index 57fb2559..03d5f4c2 100644 --- a/backend/sync_server/src/consts.rs +++ b/backend/sync_server/src/consts.rs @@ -1,8 +1,13 @@ +use chrono::TimeDelta; + pub const DEFAULT_CONFIG_PATH: &str = "config.yml"; + pub const DEFAULT_DATABASES_DIRECTORY_PATH: &str = "databases"; +pub const DEFAULT_MAX_CONNECTIONS_PER_VAULT: u32 = 12; +pub const DEFAULT_CURSOR_TIMEOUT: TimeDelta = TimeDelta::seconds(60); + pub const DEFAULT_HOST: &str = "127.0.0.1"; pub const DEFAULT_PORT: u16 = 3000; -pub const DEFAULT_MAX_CONNECTIONS_PER_VAULT: u32 = 12; pub const DEFAULT_MAX_BODY_SIZE_MB: usize = 4096; pub const DEFAULT_RESPONSE_TIMEOUT_SECONDS: u64 = 60; pub const DEFAULT_MAX_CLIENTS_PER_VAULT: usize = 256; diff --git a/backend/sync_server/src/server.rs b/backend/sync_server/src/server.rs index 0fd5fa03..3b1f7201 100644 --- a/backend/sync_server/src/server.rs +++ b/backend/sync_server/src/server.rs @@ -1,4 +1,4 @@ -mod auth; +pub mod auth; mod create_document; mod delete_document; mod device_id_header; diff --git a/backend/sync_server/src/server/create_document.rs b/backend/sync_server/src/server/create_document.rs index b9459df5..84f16d6a 100644 --- a/backend/sync_server/src/server/create_document.rs +++ b/backend/sync_server/src/server/create_document.rs @@ -17,9 +17,8 @@ use super::{ use crate::{ app_state::{ AppState, - broadcasts::VaultUpdate, database::models::{ - DeviceId, DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId, + DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId, }, }, config::user_config::User, @@ -41,7 +40,7 @@ pub struct CreateDocumentPathParams { pub async fn create_document_multipart( Path(CreateDocumentPathParams { vault_id }): Path, Extension(user): Extension, - TypedHeader(user_agent): TypedHeader, + TypedHeader(device_id): TypedHeader, State(state): State, TypedMultipart(axum_typed_multipart::TypedMultipart(request)): TypedMultipart< CreateDocumentVersionMultipart, @@ -49,12 +48,11 @@ pub async fn create_document_multipart( ) -> Result, SyncServerError> { internal_create_document( user, - user_agent, + device_id, state, vault_id, request.document_id, request.relative_path, - request.device_id, request.content.contents.to_vec(), ) .await @@ -67,7 +65,7 @@ pub async fn create_document_multipart( pub async fn create_document_json( Path(CreateDocumentPathParams { vault_id }): Path, Extension(user): Extension, - TypedHeader(user_agent): TypedHeader, + TypedHeader(device_id): TypedHeader, State(state): State, Json(request): Json, ) -> Result, SyncServerError> { @@ -77,12 +75,11 @@ pub async fn create_document_json( internal_create_document( user, - user_agent, + device_id, state, vault_id, request.document_id, request.relative_path, - request.device_id, content_bytes, ) .await @@ -91,12 +88,11 @@ pub async fn create_document_json( #[allow(clippy::too_many_arguments)] async fn internal_create_document( user: User, - user_agent: DeviceIdHeader, + device_id: DeviceIdHeader, state: AppState, vault_id: VaultId, document_id: Option, relative_path: String, - device_id: Option, content: Vec, ) -> Result, SyncServerError> { let mut transaction = state @@ -140,7 +136,7 @@ async fn internal_create_document( updated_date: chrono::Utc::now(), is_deleted: false, user_id: user.name, - device_id: user_agent.0, + device_id: device_id.0, }; state @@ -155,16 +151,5 @@ async fn internal_create_document( .context("Failed to commit successful transaction") .map_err(server_error)?; - state - .broadcasts - .send( - vault_id, - VaultUpdate { - origin_device_id: device_id, - document: new_version.clone().into(), - }, - ) - .await; - Ok(Json(new_version.into())) } diff --git a/backend/sync_server/src/server/delete_document.rs b/backend/sync_server/src/server/delete_document.rs index dbb9a0df..d27e97cd 100644 --- a/backend/sync_server/src/server/delete_document.rs +++ b/backend/sync_server/src/server/delete_document.rs @@ -12,7 +12,6 @@ use super::{device_id_header::DeviceIdHeader, requests::DeleteDocumentVersion}; use crate::{ app_state::{ AppState, - broadcasts::VaultUpdate, database::models::{ DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId, }, @@ -38,7 +37,7 @@ pub async fn delete_document( document_id, }): Path, Extension(user): Extension, - TypedHeader(user_agent): TypedHeader, + TypedHeader(device_id): TypedHeader, State(state): State, Json(request): Json, ) -> Result, SyncServerError> { @@ -69,7 +68,7 @@ pub async fn delete_document( updated_date: chrono::Utc::now(), is_deleted: true, user_id: user.name, - device_id: user_agent.0, + device_id: device_id.0, }; state @@ -84,16 +83,5 @@ pub async fn delete_document( .context("Failed to commit successful transaction") .map_err(server_error)?; - state - .broadcasts - .send( - vault_id, - VaultUpdate { - origin_device_id: request.device_id, - document: new_version.clone().into(), - }, - ) - .await; - Ok(Json(new_version.into())) } diff --git a/backend/sync_server/src/server/requests.rs b/backend/sync_server/src/server/requests.rs index 26e6a398..89820dbe 100644 --- a/backend/sync_server/src/server/requests.rs +++ b/backend/sync_server/src/server/requests.rs @@ -4,7 +4,7 @@ use axum_typed_multipart::TryFromMultipart; use schemars::JsonSchema; use serde::{self, Deserialize}; -use crate::app_state::database::models::{DeviceId, DocumentId, VaultUpdateId}; +use crate::app_state::database::models::{DocumentId, VaultUpdateId}; #[derive(Debug, Deserialize, JsonSchema)] #[serde(rename_all = "camelCase")] @@ -16,7 +16,6 @@ pub struct CreateDocumentVersion { pub document_id: Option, pub relative_path: String, pub content_base64: String, - pub device_id: Option, } #[derive(Debug, TryFromMultipart, JsonSchema)] @@ -25,7 +24,6 @@ pub struct CreateDocumentVersionMultipart { pub relative_path: String, #[form_data(limit = "unlimited")] pub content: FieldData, - pub device_id: Option, } #[derive(Debug, Deserialize, JsonSchema)] @@ -34,7 +32,6 @@ pub struct UpdateDocumentVersion { pub parent_version_id: VaultUpdateId, pub relative_path: String, pub content_base64: String, - pub device_id: Option, } #[derive(Debug, TryFromMultipart, JsonSchema)] @@ -44,12 +41,10 @@ pub struct UpdateDocumentVersionMultipart { pub relative_path: String, #[form_data(limit = "unlimited")] pub content: FieldData, - pub device_id: Option, } #[derive(Debug, Deserialize, JsonSchema)] #[serde(rename_all = "camelCase")] pub struct DeleteDocumentVersion { pub relative_path: String, - pub device_id: Option, } diff --git a/backend/sync_server/src/server/update_document.rs b/backend/sync_server/src/server/update_document.rs index 22eb38b0..a784dad4 100644 --- a/backend/sync_server/src/server/update_document.rs +++ b/backend/sync_server/src/server/update_document.rs @@ -19,8 +19,7 @@ use super::{ use crate::{ app_state::{ AppState, - broadcasts::VaultUpdate, - database::models::{DeviceId, DocumentId, StoredDocumentVersion, VaultId, VaultUpdateId}, + database::models::{DocumentId, StoredDocumentVersion, VaultId, VaultUpdateId}, }, config::user_config::User, errors::{SyncServerError, client_error, not_found_error, server_error}, @@ -43,7 +42,7 @@ pub async fn update_document_multipart( document_id, }): Path, Extension(user): Extension, - TypedHeader(user_agent): TypedHeader, + TypedHeader(device_id): TypedHeader, State(state): State, TypedMultipart(axum_typed_multipart::TypedMultipart(request)): TypedMultipart< UpdateDocumentVersionMultipart, @@ -51,13 +50,12 @@ pub async fn update_document_multipart( ) -> Result, SyncServerError> { internal_update_document( user, - user_agent, + device_id, state, vault_id, document_id, request.parent_version_id, request.relative_path, - request.device_id, request.content.contents.to_vec(), ) .await @@ -70,7 +68,7 @@ pub async fn update_document_json( document_id, }): Path, Extension(user): Extension, - TypedHeader(user_agent): TypedHeader, + TypedHeader(device_id): TypedHeader, State(state): State, Json(request): Json, ) -> Result, SyncServerError> { @@ -80,13 +78,12 @@ pub async fn update_document_json( internal_update_document( user, - user_agent, + device_id, state, vault_id, document_id, request.parent_version_id, request.relative_path, - request.device_id, content_bytes, ) .await @@ -95,13 +92,12 @@ pub async fn update_document_json( #[allow(clippy::too_many_arguments, clippy::too_many_lines)] async fn internal_update_document( user: User, - user_agent: DeviceIdHeader, + device_id: DeviceIdHeader, state: AppState, vault_id: VaultId, document_id: DocumentId, parent_version_id: VaultUpdateId, relative_path: String, - device_id: Option, content: Vec, ) -> Result, SyncServerError> { // No need for a transaction as document versions are immutable @@ -215,7 +211,7 @@ async fn internal_update_document( updated_date: chrono::Utc::now(), is_deleted: false, user_id: user.name, - device_id: user_agent.0, + device_id: device_id.0, }; state @@ -230,17 +226,6 @@ async fn internal_update_document( .context("Failed to commit successful transaction") .map_err(server_error)?; - state - .broadcasts - .send( - vault_id, - VaultUpdate { - origin_device_id: device_id, - document: new_version.clone().into(), - }, - ) - .await; - Ok(Json(if is_different_from_request_content { DocumentUpdateResponse::MergingUpdate(new_version.into()) } else { diff --git a/backend/sync_server/src/server/websocket.rs b/backend/sync_server/src/server/websocket.rs index 2517fe88..ea0e7fad 100644 --- a/backend/sync_server/src/server/websocket.rs +++ b/backend/sync_server/src/server/websocket.rs @@ -6,64 +6,52 @@ use axum::{ }, response::Response, }; -use futures::{ - sink::SinkExt, - stream::{SplitSink, StreamExt}, -}; +use futures::stream::StreamExt; use log::{error, info, warn}; use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; -use super::auth::auth; use crate::{ app_state::{ AppState, - database::models::{DeviceId, DocumentVersionWithoutContent, VaultId, VaultUpdateId}, + database::models::VaultId, + websocket::{ + models::{ + CursorPositionFromServer, WebSocketClientMessage, WebSocketServerMessage, + WebSocketVaultUpdate, + }, + utils::{get_handshake, get_unseen_documents, send_update_over_websocket}, + }, }, - errors::{SyncServerError, server_error, unauthenticated_error}, + errors::{SyncServerError, client_error, server_error, unauthenticated_error}, utils::normalize::normalize, }; // This is required for aide to infer the path parameter types and names #[derive(Deserialize, JsonSchema)] -pub struct WebsocketPathParams { +pub struct WebSocketPathParams { #[serde(deserialize_with = "normalize")] vault_id: VaultId, } pub async fn websocket_handler( ws: WebSocketUpgrade, - Path(WebsocketPathParams { vault_id }): Path, + Path(WebSocketPathParams { vault_id }): Path, State(state): State, ) -> Result { Ok(ws.on_upgrade(move |socket| websocket_wrapped(state, socket, vault_id))) } async fn websocket_wrapped(state: AppState, stream: WebSocket, vault_id: VaultId) { - info!("Websocket connection opened on vault '{vault_id}'"); + info!("WebSocket connection opened on vault '{vault_id}'"); let result = websocket(state, stream, vault_id.clone()).await; if let Err(err) = result { - error!("Websocket connection error on vault '{vault_id}': {err}"); + error!("WebSocket connection error on vault '{vault_id}': {err}"); } - warn!("Websocket connection closed on vault '{vault_id}'"); -} - -#[derive(Deserialize)] -#[serde(rename_all = "camelCase")] -struct WebsocketHandshake { - pub token: String, - pub device_id: DeviceId, - pub last_seen_vault_update_id: Option, -} - -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -struct WebsocketVaultUpdate { - pub documents: Vec, - pub is_initial_sync: bool, + warn!("WebSocket connection closed on vault '{vault_id}'"); } async fn websocket( @@ -73,68 +61,71 @@ async fn websocket( ) -> Result<(), SyncServerError> { let (mut sender, mut receiver) = stream.split(); - let handshake = if let Some(Ok(Message::Text(token))) = receiver.next().await { - let handshake: WebsocketHandshake = serde_json::from_str(&token) - .context("Failed to parse token") - .map_err(server_error)?; - - auth(&state, handshake.token.trim(), &vault_id)?; - - handshake + let handshake = if let Some(Ok(message)) = receiver.next().await { + get_handshake(&state, &vault_id, message)? } else { return Err(unauthenticated_error(anyhow::anyhow!( - "Failed to authenticate" + "Failed to authenticate due to invalid message" ))); }; let mut rx = state.broadcasts.get_receiver(vault_id.clone()).await; - let documents = if let Some(update_id) = handshake.last_seen_vault_update_id { - state - .database - .get_latest_documents_since(&vault_id, update_id, None) - .await - .map_err(server_error) - } else { - state - .database - .get_latest_documents(&vault_id, None) - .await - .map_err(server_error) - }?; - send_update_over_websocket( - &WebsocketVaultUpdate { - documents, + &WebSocketServerMessage::VaultUpdate(WebSocketVaultUpdate { + documents: get_unseen_documents(&state, &vault_id, handshake.last_seen_vault_update_id) + .await?, is_initial_sync: true, - }, + }), &mut sender, ) .await?; + send_update_over_websocket( + &WebSocketServerMessage::CursorPositions(CursorPositionFromServer { + clients: state.cursors.get_cursors(&vault_id).await, + }), + &mut sender, + ) + .await?; + + let device_id = handshake.device_id.clone(); let mut send_task = tokio::spawn(async move { while let Ok(update) = rx.recv().await { - if Some(&handshake.device_id) == update.origin_device_id.as_ref() { + if Some(&device_id) == update.origin_device_id.as_ref() { continue; } - send_update_over_websocket( - &WebsocketVaultUpdate { - documents: vec![update.document], - is_initial_sync: false, - }, - &mut sender, - ) - .await?; + send_update_over_websocket(&update.message, &mut sender).await?; } Ok::<(), SyncServerError>(()) }); - let mut recv_task = - tokio::spawn( - async move { while let Some(Ok(Message::Text(_text))) = receiver.next().await {} }, - ); + let device_id = handshake.device_id.clone(); + let mut recv_task = tokio::spawn(async move { + while let Some(Ok(Message::Text(message))) = receiver.next().await { + let message: WebSocketClientMessage = serde_json::from_str(&message) + .context("Failed to parse message") + .map_err(server_error)?; + + match message { + WebSocketClientMessage::Handshake(_) => { + return Err(client_error(anyhow::anyhow!( + "Unexpected handshake message" + ))); + } + WebSocketClientMessage::CursorPositions(cursors) => { + state + .cursors + .update_cursors(vault_id.clone(), &device_id, cursors.document_to_cursors) + .await; + } + } + } + + Ok::<(), SyncServerError>(()) + }); tokio::select! { _ = &mut send_task => recv_task.abort(), @@ -143,28 +134,13 @@ async fn websocket( send_task .await - .context("Websocket send task failed") + .context("WebSocket send task failed") .map_err(server_error)??; recv_task .await - .context("Websocket receive task failed") - .map_err(server_error)?; + .context("WebSocket receive task failed") + .map_err(server_error)??; Ok(()) } - -async fn send_update_over_websocket( - update: &WebsocketVaultUpdate, - sender: &mut SplitSink, -) -> Result<(), SyncServerError> { - let serialized_update = serde_json::to_string(update) - .context("Failed to serialize update") - .map_err(server_error)?; - - sender - .send(Message::Text(serialized_update)) - .await - .context("Failed to send message over websocket") - .map_err(server_error) -}