diff --git a/backend/sync_server/src/app_state/cursors.rs b/backend/sync_server/src/app_state/cursors.rs index d5aa01e4..a48aceec 100644 --- a/backend/sync_server/src/app_state/cursors.rs +++ b/backend/sync_server/src/app_state/cursors.rs @@ -100,6 +100,14 @@ impl Cursors { .await; } } + + pub async fn remove_cursors_of_device(&self, vault_id: &str, device_id: &str) { + let mut vault_to_cursors = self.vault_to_cursors.lock().await; + + if let Some(cursors) = vault_to_cursors.get_mut(vault_id) { + cursors.retain(|c| c.client_cursors.device_id != device_id); + } + } } #[derive(Clone, Debug)] diff --git a/backend/sync_server/src/app_state/websocket/utils.rs b/backend/sync_server/src/app_state/websocket/utils.rs index 7c4e2c05..cf337e39 100644 --- a/backend/sync_server/src/app_state/websocket/utils.rs +++ b/backend/sync_server/src/app_state/websocket/utils.rs @@ -12,12 +12,12 @@ use crate::{ server::auth::auth, }; -pub fn get_handshake( +pub fn get_authenticated_handshake( state: &AppState, vault_id: &VaultId, - message: Message, + message: Option, ) -> Result { - if let Message::Text(message) = message { + if let Some(Message::Text(message)) = message { let message: WebSocketClientMessage = serde_json::from_str(&message) .context("Failed to parse message") .map_err(server_error)?; diff --git a/backend/sync_server/src/server/websocket.rs b/backend/sync_server/src/server/websocket.rs index 822c211c..4a7d7833 100644 --- a/backend/sync_server/src/server/websocket.rs +++ b/backend/sync_server/src/server/websocket.rs @@ -20,10 +20,12 @@ use crate::{ CursorPositionFromServer, WebSocketClientMessage, WebSocketServerMessage, WebSocketVaultUpdate, }, - utils::{get_handshake, get_unseen_documents, send_update_over_websocket}, + utils::{ + get_authenticated_handshake, get_unseen_documents, send_update_over_websocket, + }, }, }, - errors::{SyncServerError, client_error, server_error, unauthenticated_error}, + errors::{SyncServerError, client_error, server_error}, utils::normalize::normalize, }; @@ -61,13 +63,15 @@ async fn websocket( ) -> Result<(), SyncServerError> { let (mut sender, mut websocket_receiver) = stream.split(); - let handshake = if let Some(Ok(message)) = websocket_receiver.next().await { - get_handshake(&state, &vault_id, message)? - } else { - return Err(unauthenticated_error(anyhow::anyhow!( - "Failed to authenticate due to invalid message" - ))); - }; + let handshake = get_authenticated_handshake( + &state, + &vault_id, + websocket_receiver + .next() + .await + .transpose() + .unwrap_or_default(), + )?; let mut broadcast_receiver = state.broadcasts.get_receiver(vault_id.clone()).await; @@ -103,6 +107,8 @@ async fn websocket( }); let device_id = handshake.device_id.clone(); + let vault_id_clone = vault_id.clone(); + let cursor_manager = state.cursors.clone(); let mut receive_task = tokio::spawn(async move { while let Some(Ok(Message::Text(message))) = websocket_receiver.next().await { let message: WebSocketClientMessage = serde_json::from_str(&message) @@ -116,9 +122,12 @@ async fn websocket( ))); } WebSocketClientMessage::CursorPositions(cursors) => { - state - .cursors - .update_cursors(vault_id.clone(), &device_id, cursors.document_to_cursors) + cursor_manager + .update_cursors( + vault_id_clone.clone(), + &device_id, + cursors.document_to_cursors, + ) .await; } } @@ -132,15 +141,26 @@ async fn websocket( _ = &mut receive_task => send_task.abort(), }; - send_task - .await - .context("WebSocket send task failed") - .map_err(server_error)??; + let result = { + send_task + .await + .context("WebSocket send task failed") + .map_err(server_error) + .and_then(|x| x)?; - receive_task - .await - .context("WebSocket receive task failed") - .map_err(server_error)??; + receive_task + .await + .context("WebSocket receive task failed") + .map_err(server_error) + .and_then(|x| x)?; - Ok(()) + Ok(()) + }; + + state + .cursors + .remove_cursors_of_device(&vault_id, &handshake.device_id) + .await; + + result }