From 042233c4d783e872060004b0478f2d849e3052fa Mon Sep 17 00:00:00 2001 From: Andras Schmelczer Date: Fri, 8 May 2026 21:35:52 +0100 Subject: [PATCH] split: server websocket + cursors src/server/websocket.rs handshake/catch-up rewrite, app_state/cursors.rs, app_state/websocket/{broadcasts,models,utils}.rs. --- sync-server/src/app_state/cursors.rs | 90 +++-- .../src/app_state/websocket/broadcasts.rs | 160 ++++++-- sync-server/src/app_state/websocket/models.rs | 21 +- sync-server/src/app_state/websocket/utils.rs | 16 +- sync-server/src/server/websocket.rs | 370 +++++++++++++----- 5 files changed, 487 insertions(+), 170 deletions(-) diff --git a/sync-server/src/app_state/cursors.rs b/sync-server/src/app_state/cursors.rs index d083e1ac..e17fb4f7 100644 --- a/sync-server/src/app_state/cursors.rs +++ b/sync-server/src/app_state/cursors.rs @@ -42,7 +42,9 @@ impl Cursors { ) { let mut vault_to_cursors = self.vault_to_cursors.lock().await; - let all_device_cursors = vault_to_cursors.entry(vault_id).or_insert_with(Vec::new); + let all_device_cursors = vault_to_cursors + .entry(vault_id.clone()) + .or_insert_with(Vec::new); all_device_cursors.retain(|c| &c.client_cursors.device_id != device_id); all_device_cursors.push(ClientCursorsWithTimeToLive::new(ClientCursors { @@ -52,7 +54,7 @@ impl Cursors { })); drop(vault_to_cursors); // Explicitly drop the lock before broadcasting to avoid deadlock - self.broadcast_cursors().await; + self.broadcast_cursors_for_vault(&vault_id).await; } pub async fn get_cursors(&self, vault_id: &VaultId) -> Vec { @@ -69,45 +71,81 @@ impl Cursors { .unwrap_or_default() } - pub fn start_background_task(self) { + pub fn start_background_task(self, mut shutdown: tokio::sync::watch::Receiver<()>) { tokio::spawn(async move { loop { - self.remove_expired_cursors().await; - tokio::time::sleep(Duration::from_secs(1)).await; + tokio::select! { + () = tokio::time::sleep(Duration::from_secs(1)) => { + self.remove_expired_cursors().await; + } + Ok(()) = shutdown.changed() => break, + } } }); } async fn remove_expired_cursors(&self) { - let mut vault_to_cursors = self.vault_to_cursors.lock().await; + let changed_vaults: Vec = { + let mut vault_to_cursors = self.vault_to_cursors.lock().await; - for (_vault_id, cursors) in vault_to_cursors.iter_mut() { - cursors.retain(|cursor| !cursor.is_expired(self.config.cursor_timeout)); + let mut changed = Vec::new(); + for (vault_id, cursors) in vault_to_cursors.iter_mut() { + let before = cursors.len(); + cursors.retain(|cursor| !cursor.is_expired(self.config.cursor_timeout)); + if cursors.len() != before { + changed.push(vault_id.clone()); + } + } + + // Remove empty vault entries to prevent unbounded growth + vault_to_cursors.retain(|_, cursors| !cursors.is_empty()); + + changed + }; + + for vault_id in &changed_vaults { + self.broadcast_cursors_for_vault(vault_id).await; } } - async fn broadcast_cursors(&self) { - let vault_to_cursors = self.vault_to_cursors.lock().await; + async fn broadcast_cursors_for_vault(&self, vault_id: &VaultId) { + let client_cursors: Vec = { + let vault_to_cursors = self.vault_to_cursors.lock().await; + vault_to_cursors + .get(vault_id) + .map(|cursors| cursors.iter().map(|c| c.client_cursors.clone()).collect()) + .unwrap_or_default() + }; - for (vault_id, cursors) in vault_to_cursors.iter() { - self.broadcasts - .send_document_update( - vault_id.clone(), - WebSocketServerMessageWithOrigin::new(WebSocketServerMessage::CursorPositions( - CursorPositionFromServer { - clients: cursors.iter().map(|c| c.client_cursors.clone()).collect(), - }, - )), - ) - .await; - } + self.broadcasts.send_document_update( + vault_id.clone(), + WebSocketServerMessageWithOrigin::new(WebSocketServerMessage::CursorPositions( + CursorPositionFromServer { + clients: client_cursors, + }, + )), + ); } - pub async fn remove_cursors_of_device(&self, vault_id: &str, device_id: &str) { - let mut vault_to_cursors = self.vault_to_cursors.lock().await; + pub async fn remove_cursors_of_device(&self, vault_id: &VaultId, device_id: &DeviceId) { + let changed = { + let mut vault_to_cursors = self.vault_to_cursors.lock().await; - if let Some(cursors) = vault_to_cursors.get_mut(vault_id) { - cursors.retain(|c| c.client_cursors.device_id != device_id); + if let Some(cursors) = vault_to_cursors.get_mut(vault_id) { + let before = cursors.len(); + cursors.retain(|c| c.client_cursors.device_id != *device_id); + let changed = cursors.len() != before; + if cursors.is_empty() { + vault_to_cursors.remove(vault_id); + } + changed + } else { + false + } + }; + + if changed { + self.broadcast_cursors_for_vault(vault_id).await; } } } diff --git a/sync-server/src/app_state/websocket/broadcasts.rs b/sync-server/src/app_state/websocket/broadcasts.rs index 60ae0219..b9e2ea39 100644 --- a/sync-server/src/app_state/websocket/broadcasts.rs +++ b/sync-server/src/app_state/websocket/broadcasts.rs @@ -1,69 +1,147 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::HashMap, + sync::{Arc, Mutex as StdMutex}, +}; -use anyhow::Context; -use log::{debug, warn}; +use log::{debug, info, warn}; use tokio::sync::{Mutex, broadcast}; -use super::models::WebSocketServerMessageWithOrigin; -use crate::{ - app_state::database::models::VaultId, config::server_config::ServerConfig, errors::server_error, -}; +use super::models::{WebSocketServerMessage, WebSocketServerMessageWithOrigin}; +use crate::{app_state::database::models::VaultId, config::server_config::ServerConfig}; #[derive(Debug, Clone)] pub struct Broadcasts { - max_clients_per_vault: usize, - tx: Arc>>>, + broadcast_channel_capacity: usize, + // `tx` uses a blocking std::sync::Mutex because the critical section is + // a HashMap lookup plus a synchronous `broadcast::Sender::send`. Making + // this non-async lets `send_document_update` run without an `.await`, + // so an axum handler that is cancelled between `transaction.commit()` + // and the broadcast can never drop the notification mid-flight. + tx: Arc>>>, + send_locks: Arc>>>>, } +type TxMap = HashMap>; + impl Broadcasts { pub fn new(server_config: &ServerConfig) -> Self { Self { - max_clients_per_vault: server_config.max_clients_per_vault, - tx: Arc::new(Mutex::new(HashMap::new())), + broadcast_channel_capacity: server_config.broadcast_channel_capacity, + tx: Arc::new(StdMutex::new(HashMap::new())), + send_locks: Arc::new(Mutex::new(HashMap::new())), } } - pub async fn get_receiver( + /// Acquire a per-vault lock that serializes broadcasts in commit order. + /// Must be acquired before the insert, held through commit and broadcast. + pub async fn acquire_send_lock(&self, vault: &VaultId) -> tokio::sync::OwnedMutexGuard<()> { + let lock = { + let mut locks = self.send_locks.lock().await; + locks + .entry(vault.clone()) + .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(()))) + .clone() + }; + lock.lock_owned().await + } + + /// Remove senders for vaults with no active receivers + fn prune_inactive_vaults(tx_map: &mut TxMap) -> Vec { + let mut pruned = Vec::new(); + tx_map.retain(|vault, sender| { + let alive = sender.receiver_count() > 0; + if !alive { + pruned.push(vault.clone()); + } + alive + }); + pruned + } + + pub fn get_receiver( &self, vault: VaultId, - ) -> broadcast::Receiver { - let tx = self.get_or_create(vault).await; + max_clients: usize, + ) -> Result, crate::errors::SyncServerError> + { + let mut tx_map = self + .tx + .lock() + .expect("broadcasts.tx mutex poisoned — a previous holder panicked"); - tx.subscribe() + let count_before_prune = tx_map + .get(&vault) + .map_or(0, tokio::sync::broadcast::Sender::receiver_count); + let pruned = Self::prune_inactive_vaults(&mut tx_map); + let pruned_self = pruned.contains(&vault); + + let sender = tx_map + .entry(vault.clone()) + .or_insert_with(|| broadcast::channel(self.broadcast_channel_capacity).0); + + // Hold the lock across the count check *and* the subscribe so the + // `max_clients` cap is atomic: two concurrent callers can't both + // observe `receiver_count() < max_clients` and both subscribe. + if sender.receiver_count() >= max_clients { + return Err(crate::errors::client_error(anyhow::anyhow!( + "Vault has reached the maximum number of clients ({max_clients})" + ))); + } + + let receiver = sender.subscribe(); + let count_after = sender.receiver_count(); + info!( + "[BCAST] get_receiver vault={vault} count_before_prune={count_before_prune} pruned_self={pruned_self} pruned_total={} count_after_subscribe={count_after}", + pruned.len() + ); + Ok(receiver) } /// Notify all clients (who are subscribed to the vault) about an update. - /// We only log failures and don't propagate them. - pub async fn send_document_update( - &self, - vault: VaultId, - document: WebSocketServerMessageWithOrigin, - ) { - let tx = self.get_or_create(vault.clone()).await; + /// Synchronous: safe to invoke from a handler between `commit()` and + /// function return without worrying about task cancellation dropping + /// the broadcast mid-flight. Failures are logged, never propagated. + pub fn send_document_update(&self, vault: VaultId, document: WebSocketServerMessageWithOrigin) { + let vault_update_id = match &document.message { + WebSocketServerMessage::VaultUpdate(u) => Some(u.document.vault_update_id), + WebSocketServerMessage::CursorPositions(_) => None, + }; + let is_deleted = match &document.message { + WebSocketServerMessage::VaultUpdate(u) => Some(u.document.is_deleted), + WebSocketServerMessage::CursorPositions(_) => None, + }; + let mut tx_map = self + .tx + .lock() + .expect("broadcasts.tx mutex poisoned — a previous holder panicked"); + let count_before_prune = tx_map + .get(&vault) + .map_or(0, tokio::sync::broadcast::Sender::receiver_count); + let pruned = Self::prune_inactive_vaults(&mut tx_map); + let pruned_self = pruned.contains(&vault); - if tx.receiver_count() == 0 { + let sender = tx_map + .entry(vault.clone()) + .or_insert_with(|| broadcast::channel(self.broadcast_channel_capacity).0); + + let count_before_send = sender.receiver_count(); + + if count_before_send == 0 { + info!( + "[BCAST] send_document_update vault={vault} vuid={vault_update_id:?} is_deleted={is_deleted:?} count_before_prune={count_before_prune} pruned_self={pruned_self} count_before_send=0 SKIPPED" + ); debug!("Skipping broadcast, no clients connected for vault `{vault}`"); return; } - let result = tx - .send(document) - .context("Cannot broadcast server message to websocket listeners") - .map_err(server_error); - - if result.is_err() { - warn!("Failed to send message: {result:?}"); + let send_result = sender.send(document); + match &send_result { + Ok(n) => info!( + "[BCAST] send_document_update vault={vault} vuid={vault_update_id:?} is_deleted={is_deleted:?} count_before_prune={count_before_prune} pruned_self={pruned_self} count_before_send={count_before_send} SENT delivered_to={n}" + ), + Err(e) => warn!( + "[BCAST] send_document_update vault={vault} vuid={vault_update_id:?} is_deleted={is_deleted:?} count_before_prune={count_before_prune} pruned_self={pruned_self} count_before_send={count_before_send} FAILED err={e}" + ), } } - - async fn get_or_create( - &self, - vault: VaultId, - ) -> broadcast::Sender { - let mut tx = self.tx.lock().await; - - tx.entry(vault) - .or_insert_with(|| broadcast::channel(self.max_clients_per_vault).0.clone()) - .clone() - } } diff --git a/sync-server/src/app_state/websocket/models.rs b/sync-server/src/app_state/websocket/models.rs index e037fb7e..eb6c956a 100644 --- a/sync-server/src/app_state/websocket/models.rs +++ b/sync-server/src/app_state/websocket/models.rs @@ -11,7 +11,7 @@ pub struct WebSocketHandshake { pub token: String, pub device_id: DeviceId, - #[ts(as = "Option")] + #[ts(type = "number | null")] pub last_seen_vault_update_id: Option, } @@ -22,13 +22,14 @@ pub struct CursorPositionFromClient { } #[derive(TS, Serialize, Deserialize, Clone, Debug)] +#[serde(rename_all = "camelCase")] pub struct DocumentWithCursors { // It's None in case the document is dirty. // We still want to sync the cursor to mark // that it exists and can be client-side // interpolated. However, the actual // position is meaningless. - #[ts(as = "Option")] + #[ts(type = "number | null")] pub vault_update_id: Option, pub document_id: DocumentId, @@ -57,11 +58,19 @@ pub struct CursorPositionFromServer { pub clients: Vec, } +// One committed version. Non-delete updates are broadcast to every +// connected client *except* the device that authored them — that +// device already has the new state via its HTTP response. Deletes are +// broadcast to every client including the author: the author keeps +// the document in its sync queue until this receipt arrives so a late +// remote update can't sneak in between the HTTP response and the +// queue cleanup. The server also emits these one-at-a-time to catch +// up a freshly-connected client on versions committed while it was +// offline, in ascending `vault_update_id` order. #[derive(TS, Serialize, Clone, Debug)] #[serde(rename_all = "camelCase")] pub struct WebSocketVaultUpdate { - pub documents: Vec, - pub is_initial_sync: bool, + pub document: DocumentVersionWithoutContent, } #[derive(TS, Deserialize, Clone, Debug)] @@ -80,6 +89,10 @@ pub enum WebSocketServerMessage { CursorPositions(CursorPositionFromServer), } +/// Broadcast envelope carrying the message plus the device that produced +/// it. The per-recipient send task compares `origin_device_id` against +/// its own device id to fill in `originates_from_self` before the message +/// is serialized on the wire. #[derive(Clone, Debug)] pub struct WebSocketServerMessageWithOrigin { pub origin_device_id: Option, diff --git a/sync-server/src/app_state/websocket/utils.rs b/sync-server/src/app_state/websocket/utils.rs index 1e0dd243..d78360de 100644 --- a/sync-server/src/app_state/websocket/utils.rs +++ b/sync-server/src/app_state/websocket/utils.rs @@ -9,7 +9,7 @@ use crate::{ database::models::{DocumentVersionWithoutContent, VaultId, VaultUpdateId}, }, config::user_config::User, - errors::{SyncServerError, server_error, unauthenticated_error}, + errors::{SyncServerError, client_error, server_error, unauthenticated_error}, server::auth::auth, }; @@ -26,7 +26,7 @@ pub fn get_authenticated_handshake( if let Some(Message::Text(message)) = message { let message: WebSocketClientMessage = serde_json::from_str(&message) .context("Failed to parse message") - .map_err(server_error)?; + .map_err(client_error)?; match message { WebSocketClientMessage::Handshake(handshake) => { @@ -44,21 +44,29 @@ pub fn get_authenticated_handshake( } } +/// Stream the documents the client missed while offline, bounded above +/// by `up_to_vault_update_id` so the catch-up is a stable snapshot at +/// exactly that cursor. The WebSocket handshake atomically subscribes +/// to the broadcast channel and snapshots this cursor under the per- +/// vault send lock; commits past the cursor are then delivered solely +/// through the broadcast channel (filtered by the same cursor on the +/// receive side), so every committed update is delivered exactly once. pub async fn get_unseen_documents( state: &AppState, vault_id: &VaultId, last_seen_vault_update_id: Option, + up_to_vault_update_id: VaultUpdateId, ) -> Result, SyncServerError> { if let Some(update_id) = last_seen_vault_update_id { state .database - .get_latest_documents_since(vault_id, update_id, None) + .get_latest_documents_since(vault_id, update_id, Some(up_to_vault_update_id), None) .await .map_err(server_error) } else { state .database - .get_latest_documents(vault_id, None) + .get_latest_documents(vault_id, Some(up_to_vault_update_id), None) .await .map_err(server_error) } diff --git a/sync-server/src/server/websocket.rs b/sync-server/src/server/websocket.rs index bb10b49f..6e1af0ba 100644 --- a/sync-server/src/server/websocket.rs +++ b/sync-server/src/server/websocket.rs @@ -1,15 +1,3 @@ -use anyhow::Context; -use axum::{ - extract::{ - Path, State, - ws::{Message, WebSocket, WebSocketUpgrade}, - }, - response::Response, -}; -use futures::stream::StreamExt; -use log::{debug, info}; -use serde::Deserialize; - use crate::{ app_state::{ AppState, @@ -24,9 +12,35 @@ use crate::{ }, }, }, + consts::{ + HANDSHAKE_TIMEOUT, MAX_CURSOR_DOCUMENTS, MAX_CURSORS_PER_DOCUMENT, MAX_RELATIVE_PATH_LEN, + }, errors::{SyncServerError, client_error, server_error}, utils::normalize::normalize, }; +use anyhow::Context; +use axum::{ + extract::{ + Path, State, + ws::{Message, WebSocket, WebSocketUpgrade}, + }, + response::Response, +}; +use futures::sink::SinkExt; +use futures::stream::StreamExt; +use log::{debug, info, warn}; +use serde::Deserialize; + +/// Tracks a pending (not yet authenticated) WebSocket connection. +/// Decrements the counter when dropped, ensuring cleanup even if +/// the upgrade never completes or auth fails. +struct PendingWsGuard(std::sync::Arc); + +impl Drop for PendingWsGuard { + fn drop(&mut self) { + self.0.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + } +} #[derive(Deserialize)] pub struct WebSocketPathParams { @@ -39,13 +53,31 @@ pub async fn websocket_handler( Path(WebSocketPathParams { vault_id }): Path, State(state): State, ) -> Result { - Ok(ws.on_upgrade(move |socket| websocket_wrapped(state, socket, vault_id))) + let current = state + .pending_ws_connections + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + if current >= state.config.server.max_pending_websocket_connections { + state + .pending_ws_connections + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + return Err(client_error(anyhow::anyhow!( + "Too many pending WebSocket connections" + ))); + } + + let guard = PendingWsGuard(state.pending_ws_connections.clone()); + Ok(ws.on_upgrade(move |socket| websocket_wrapped(state, socket, vault_id, guard))) } -async fn websocket_wrapped(state: AppState, stream: WebSocket, vault_id: VaultId) { +async fn websocket_wrapped( + state: AppState, + stream: WebSocket, + vault_id: VaultId, + pending_guard: PendingWsGuard, +) { info!("WebSocket connection opened on vault `{vault_id}`"); - let result = websocket(state, stream, vault_id.clone()).await; + let result = websocket(state, stream, vault_id.clone(), pending_guard).await; if let Err(err) = result { debug!("WebSocket connection error on vault `{vault_id}`: {err}"); @@ -57,39 +89,112 @@ async fn websocket( state: AppState, stream: WebSocket, vault_id: VaultId, + pending_guard: PendingWsGuard, ) -> Result<(), SyncServerError> { let (mut sender, mut websocket_receiver) = stream.split(); - let authed_handshake = get_authenticated_handshake( - &state, - &vault_id, - websocket_receiver - .next() - .await - .transpose() - .unwrap_or_default(), - )?; + let handshake_msg = tokio::time::timeout(HANDSHAKE_TIMEOUT, websocket_receiver.next()) + .await + .map_err(|_| client_error(anyhow::anyhow!("WebSocket handshake timed out")))? + .transpose() + .map_err(|e| client_error(anyhow::anyhow!("WebSocket error during handshake: {e}")))?; + + let authed_handshake = get_authenticated_handshake(&state, &vault_id, handshake_msg)?; info!( "WebSocket handshake successful for vault `{vault_id}` for `{}`", authed_handshake.handshake.device_id ); - let mut broadcast_receiver = state.broadcasts.get_receiver(vault_id.clone()).await; + // Auth complete — no longer a pending connection. + drop(pending_guard); - send_update_over_websocket( - &WebSocketServerMessage::VaultUpdate(WebSocketVaultUpdate { - documents: get_unseen_documents( - &state, - &vault_id, - authed_handshake.handshake.last_seen_vault_update_id, - ) - .await?, - is_initial_sync: true, - }), - &mut sender, + let max_clients = state.config.server.max_clients_per_vault; + + // Atomic subscribe + cursor snapshot, serialized against in-flight + // broadcasts: + // + // 1. Acquire the per-vault broadcast send lock. While we hold it, + // no `send_document_update` can run, so no broadcast can fire + // between our subscribe and our cursor snapshot. + // 2. Subscribe to the broadcast channel (now we'll see every + // broadcast that fires after we drop the send guard). + // 3. Snapshot `cursor = max committed vault_update_id`. Because + // `insert_document_version` holds the same send lock from + // *before* the commit through *after* the broadcast, every doc + // visible at this cursor has either (a) already had its + // broadcast delivered to all then-existing subscribers — and we + // weren't one of them, so we'll catch it via the snapshot — or + // (b) had its broadcast contend on the lock we're holding, and + // will be delivered to us as soon as we drop the guard, with + // `vault_update_id > cursor`. + // 4. Drop the send guard so writers can resume broadcasting. + // 5. Stream the catch-up bounded by the cursor — i.e. only docs + // with `vault_update_id <= cursor` — exactly once. + // 6. The send task forwards broadcasts but filters to + // `vault_update_id > cursor`, so a doc that's both in the + // catch-up and in a contended-then-released broadcast is + // delivered exactly once (via the catch-up). + let send_guard = state.broadcasts.acquire_send_lock(&vault_id).await; + let mut broadcast_receiver = match state.broadcasts.get_receiver(vault_id.clone(), max_clients) + { + Ok(receiver) => receiver, + Err(err) => { + drop(send_guard); + warn!( + "Vault `{vault_id}` has reached the maximum number of clients ({max_clients}), rejecting connection from `{}`", + authed_handshake.handshake.device_id + ); + if let Err(e) = sender + .send(Message::Close(Some(axum::extract::ws::CloseFrame { + code: 4000, + reason: format!( + "Vault has reached the maximum number of clients ({max_clients})" + ) + .into(), + }))) + .await + { + warn!("Failed to send WebSocket close frame: {e}"); + } + return Err(err); + } + }; + let cursor = state + .database + .get_max_update_id_in_vault(&vault_id, None) + .await + .map_err(server_error)?; + drop(send_guard); + + // Catch-up on versions committed while this client was offline, + // streamed one-at-a-time in ascending `vault_update_id` order, up + // to the snapshot cursor. + let unseen_documents = get_unseen_documents( + &state, + &vault_id, + authed_handshake.handshake.last_seen_vault_update_id, + cursor, ) .await?; + let unseen_summary: Vec<(i64, bool, String)> = unseen_documents + .iter() + .map(|d| (d.vault_update_id, d.is_deleted, d.relative_path.clone())) + .collect(); + info!( + "[CATCHUP] vault={vault_id} device={} last_seen={:?} cursor={cursor} unseen_count={} unseen={:?}", + authed_handshake.handshake.device_id, + authed_handshake.handshake.last_seen_vault_update_id, + unseen_summary.len(), + unseen_summary + ); + for document in unseen_documents { + send_update_over_websocket( + &WebSocketServerMessage::VaultUpdate(WebSocketVaultUpdate { document }), + &mut sender, + ) + .await?; + } send_update_over_websocket( &WebSocketServerMessage::CursorPositions(CursorPositionFromServer { @@ -101,24 +206,57 @@ async fn websocket( let device_id = authed_handshake.handshake.device_id.clone(); let mut send_task = tokio::spawn(async move { - while let Ok(update) = broadcast_receiver.recv().await { - if Some(&device_id) == update.origin_device_id.as_ref() { - continue; - } + loop { + match broadcast_receiver.recv().await { + Ok(update) => { + // Drop messages this device authored because the HTTP + // response already carried authoritative state back. + // Delete broadcasts are sent without an origin so the + // author also receives them — that's the receipt the + // client needs to drop the doc from its sync queue. + if Some(&device_id) == update.origin_device_id.as_ref() { + continue; + } - let message = match update.message { - WebSocketServerMessage::CursorPositions(CursorPositionFromServer { clients }) => { - WebSocketServerMessage::CursorPositions(CursorPositionFromServer { - clients: clients - .into_iter() - .filter(|client| client.device_id != device_id) - .collect(), - }) + // Filter out vault updates already covered by the + // catch-up snapshot. The handshake atomically + // subscribed and snapshotted `cursor` under the + // broadcast send lock, so any broadcast with + // `vault_update_id <= cursor` is one that contended + // on the lock during our subscribe — its row is + // already in the catch-up stream and re-delivering + // it via this channel would duplicate the message. + // Cursor messages aren't versioned and are always + // forwarded. + if let WebSocketServerMessage::VaultUpdate(WebSocketVaultUpdate { document }) = + &update.message + && document.vault_update_id <= cursor + { + continue; + } + + let message = match update.message { + WebSocketServerMessage::CursorPositions(CursorPositionFromServer { + clients, + }) => WebSocketServerMessage::CursorPositions(CursorPositionFromServer { + clients: clients + .into_iter() + .filter(|client| client.device_id != device_id) + .collect(), + }), + WebSocketServerMessage::VaultUpdate(_) => update.message, + }; + + send_update_over_websocket(&message, &mut sender).await?; } - WebSocketServerMessage::VaultUpdate(_) => update.message, - }; - - send_update_over_websocket(&message, &mut sender).await?; + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + warn!( + "WebSocket receiver lagged, dropped {n} messages — disconnecting client to force full resync" + ); + break; + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => break, + } } Ok::<(), SyncServerError>(()) @@ -128,26 +266,59 @@ async fn websocket( let vault_id_clone = vault_id.clone(); let cursor_manager = state.cursors.clone(); let mut receive_task = tokio::spawn(async move { - while let Some(Ok(Message::Text(message))) = websocket_receiver.next().await { - let message: WebSocketClientMessage = serde_json::from_str(&message) - .context("Failed to parse WebSocket message from client") - .map_err(server_error)?; + while let Some(msg) = websocket_receiver.next().await { + match msg { + Ok(Message::Text(message)) => { + let message: WebSocketClientMessage = serde_json::from_str(&message) + .context("Failed to parse WebSocket message from client") + .map_err(client_error)?; - match message { - WebSocketClientMessage::Handshake(_) => { - return Err(client_error(anyhow::anyhow!( - "Unexpected handshake message" - ))); + match message { + WebSocketClientMessage::Handshake(_) => { + return Err(client_error(anyhow::anyhow!( + "Unexpected handshake message" + ))); + } + WebSocketClientMessage::CursorPositions(cursors) => { + let docs = cursors.documents_with_cursors; + if docs.len() > MAX_CURSOR_DOCUMENTS { + warn!( + "Cursor update rejected: {} documents exceeds limit of {MAX_CURSOR_DOCUMENTS}", + docs.len() + ); + continue; + } + + let valid = docs.iter().all(|doc| { + doc.cursors.len() <= MAX_CURSORS_PER_DOCUMENT + && doc.relative_path.len() <= MAX_RELATIVE_PATH_LEN + }); + if !valid { + warn!( + "Cursor update rejected: a document exceeds cursor or path length limits" + ); + continue; + } + + cursor_manager + .update_cursors( + vault_id_clone.clone(), + authed_handshake.user.name.clone(), + &device_id, + docs, + ) + .await; + } + } } - WebSocketClientMessage::CursorPositions(cursors) => { - cursor_manager - .update_cursors( - vault_id_clone.clone(), - authed_handshake.user.name.clone(), - &device_id, - cursors.documents_with_cursors, - ) - .await; + Ok(Message::Close(_)) => break, + Ok(Message::Binary(_)) => { + warn!("Received unexpected binary WebSocket message, ignoring"); + } + Ok(_) => {} // Ping/Pong frames handled by axum + Err(e) => { + debug!("WebSocket receive error: {e}"); + break; } } } @@ -155,38 +326,47 @@ async fn websocket( Ok::<(), SyncServerError>(()) }); - tokio::select! { - _ = &mut send_task => receive_task.abort(), - _ = &mut receive_task => send_task.abort(), + let result: Result<(), SyncServerError> = tokio::select! { + send_result = &mut send_task => { + receive_task.abort(); + let _ = receive_task.await; + match send_result { + Err(e) => Err(server_error( + anyhow::Error::from(e).context("WebSocket send task failed"), + )), + Ok(inner) => inner, + } + }, + receive_result = &mut receive_task => { + send_task.abort(); + let _ = send_task.await; + match receive_result { + Err(e) => Err(server_error( + anyhow::Error::from(e).context("WebSocket receive task failed"), + )), + Ok(inner) => inner, + } + }, }; - let result: Result<(), SyncServerError> = (async { - send_task - .await - .context("WebSocket send task failed") - .map_err(client_error) - .and_then(|err| err)?; - - receive_task - .await - .context("WebSocket receive task failed") - .map_err(client_error) - .and_then(|err| err)?; - - Ok(()) - }) - .await; - state .cursors .remove_cursors_of_device(&vault_id, &authed_handshake.handshake.device_id) .await; - if result.is_err() { - info!( - "WebSocket disconnected on vault `{vault_id}` for `{}`", - authed_handshake.handshake.device_id - ); + match &result { + Ok(()) => { + info!( + "WebSocket disconnected on vault `{vault_id}` for `{}`", + authed_handshake.handshake.device_id + ); + } + Err(err) => { + warn!( + "WebSocket error on vault `{vault_id}` for `{}`: {err}", + authed_handshake.handshake.device_id + ); + } } result