diff --git a/backend/sync_server/src/app_state/broadcasts.rs b/backend/sync_server/src/app_state/broadcasts.rs index 9d2d219..f71886c 100644 --- a/backend/sync_server/src/app_state/broadcasts.rs +++ b/backend/sync_server/src/app_state/broadcasts.rs @@ -3,13 +3,19 @@ use std::{collections::HashMap, sync::Arc}; use anyhow::Context; use tokio::sync::{Mutex, broadcast}; -use super::database::models::{DocumentVersionWithoutContent, VaultId}; +use super::database::models::{DeviceId, DocumentVersionWithoutContent, VaultId}; use crate::{config::server_config::ServerConfig, errors::server_error}; #[derive(Debug, Clone)] pub struct Broadcasts { max_clients_per_vault: usize, - tx: Arc>>>, + tx: Arc>>>, +} + +#[derive(Debug, Clone)] +pub struct VaultUpdate { + pub origin_device_id: Option, + pub document: DocumentVersionWithoutContent, } impl Broadcasts { @@ -20,10 +26,7 @@ impl Broadcasts { } } - pub async fn get_receiver( - &self, - vault: VaultId, - ) -> broadcast::Receiver { + pub async fn get_receiver(&self, vault: VaultId) -> broadcast::Receiver { let tx = self.get_or_create(vault).await; tx.subscribe() @@ -31,7 +34,7 @@ impl Broadcasts { /// Sent a document update to all clients subscribed to the vault. /// We ignore & log failures. - pub async fn send(&self, vault: VaultId, document: DocumentVersionWithoutContent) { + pub async fn send(&self, vault: VaultId, document: VaultUpdate) { let tx = self.get_or_create(vault).await; let result = tx @@ -44,10 +47,7 @@ impl Broadcasts { } } - async fn get_or_create( - &self, - vault: VaultId, - ) -> broadcast::Sender { + async fn get_or_create(&self, vault: VaultId) -> broadcast::Sender { let mut tx = self.tx.lock().await; tx.entry(vault) diff --git a/backend/sync_server/src/app_state/database/models.rs b/backend/sync_server/src/app_state/database/models.rs index a837e93..55079c8 100644 --- a/backend/sync_server/src/app_state/database/models.rs +++ b/backend/sync_server/src/app_state/database/models.rs @@ -6,6 +6,7 @@ use sync_lib::bytes_to_base64; pub type VaultId = String; pub type VaultUpdateId = i64; pub type DocumentId = uuid::Uuid; +pub type DeviceId = String; #[derive(Debug, Clone)] pub struct StoredDocumentVersion { diff --git a/backend/sync_server/src/server.rs b/backend/sync_server/src/server.rs index 45d43d8..e993ed1 100644 --- a/backend/sync_server/src/server.rs +++ b/backend/sync_server/src/server.rs @@ -86,7 +86,7 @@ pub async fn create_server(config_path: Option) -> Result<()> { TraceLayer::new_for_http() .make_span_with(|request: &Request<_>| { info_span!( - "http_request", + "http", method = ?request.method(), uri = ?request.uri(), ) diff --git a/backend/sync_server/src/server/create_document.rs b/backend/sync_server/src/server/create_document.rs index bc54264..1c2e612 100644 --- a/backend/sync_server/src/server/create_document.rs +++ b/backend/sync_server/src/server/create_document.rs @@ -10,8 +10,9 @@ use super::requests::{CreateDocumentVersion, CreateDocumentVersionMultipart}; use crate::{ app_state::{ AppState, + broadcasts::VaultUpdate, database::models::{ - DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId, + DeviceId, DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId, }, }, errors::{SyncServerError, client_error, server_error}, @@ -40,6 +41,7 @@ pub async fn create_document_multipart( vault_id, request.document_id, request.relative_path, + request.device_id, request.content.contents.to_vec(), ) .await @@ -63,6 +65,7 @@ pub async fn create_document_json( vault_id, request.document_id, request.relative_path, + request.device_id, content_bytes, ) .await @@ -73,6 +76,7 @@ async fn internal_create_document( vault_id: VaultId, document_id: Option, relative_path: String, + device_id: Option, content: Vec, ) -> Result, SyncServerError> { let mut transaction = state @@ -131,7 +135,13 @@ async fn internal_create_document( state .broadcasts - .send(vault_id, new_version.clone().into()) + .send( + vault_id, + VaultUpdate { + origin_device_id: device_id, + document: new_version.clone().into(), + }, + ) .await; Ok(Json(new_version.into())) diff --git a/backend/sync_server/src/server/delete_document.rs b/backend/sync_server/src/server/delete_document.rs index f278f77..2d02dec 100644 --- a/backend/sync_server/src/server/delete_document.rs +++ b/backend/sync_server/src/server/delete_document.rs @@ -8,6 +8,7 @@ use super::requests::DeleteDocumentVersion; use crate::{ app_state::{ AppState, + broadcasts::VaultUpdate, database::models::{ DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId, }, @@ -67,7 +68,13 @@ pub async fn delete_document( state .broadcasts - .send(vault_id, new_version.clone().into()) + .send( + vault_id, + VaultUpdate { + origin_device_id: request.device_id, + document: new_version.clone().into(), + }, + ) .await; Ok(Json(new_version.into())) diff --git a/backend/sync_server/src/server/requests.rs b/backend/sync_server/src/server/requests.rs index 89820db..26e6a39 100644 --- a/backend/sync_server/src/server/requests.rs +++ b/backend/sync_server/src/server/requests.rs @@ -4,7 +4,7 @@ use axum_typed_multipart::TryFromMultipart; use schemars::JsonSchema; use serde::{self, Deserialize}; -use crate::app_state::database::models::{DocumentId, VaultUpdateId}; +use crate::app_state::database::models::{DeviceId, DocumentId, VaultUpdateId}; #[derive(Debug, Deserialize, JsonSchema)] #[serde(rename_all = "camelCase")] @@ -16,6 +16,7 @@ pub struct CreateDocumentVersion { pub document_id: Option, pub relative_path: String, pub content_base64: String, + pub device_id: Option, } #[derive(Debug, TryFromMultipart, JsonSchema)] @@ -24,6 +25,7 @@ pub struct CreateDocumentVersionMultipart { pub relative_path: String, #[form_data(limit = "unlimited")] pub content: FieldData, + pub device_id: Option, } #[derive(Debug, Deserialize, JsonSchema)] @@ -32,6 +34,7 @@ pub struct UpdateDocumentVersion { pub parent_version_id: VaultUpdateId, pub relative_path: String, pub content_base64: String, + pub device_id: Option, } #[derive(Debug, TryFromMultipart, JsonSchema)] @@ -41,10 +44,12 @@ pub struct UpdateDocumentVersionMultipart { pub relative_path: String, #[form_data(limit = "unlimited")] pub content: FieldData, + pub device_id: Option, } #[derive(Debug, Deserialize, JsonSchema)] #[serde(rename_all = "camelCase")] pub struct DeleteDocumentVersion { pub relative_path: String, + pub device_id: Option, } diff --git a/backend/sync_server/src/server/update_document.rs b/backend/sync_server/src/server/update_document.rs index c953b68..a572394 100644 --- a/backend/sync_server/src/server/update_document.rs +++ b/backend/sync_server/src/server/update_document.rs @@ -14,7 +14,8 @@ use super::{ use crate::{ app_state::{ AppState, - database::models::{DocumentId, StoredDocumentVersion, VaultId, VaultUpdateId}, + broadcasts::VaultUpdate, + database::models::{DeviceId, DocumentId, StoredDocumentVersion, VaultId, VaultUpdateId}, }, errors::{SyncServerError, client_error, not_found_error, server_error}, utils::{deduped_file_paths, sanitize_path}, @@ -44,6 +45,7 @@ pub async fn update_document_multipart( document_id, request.parent_version_id, request.relative_path, + request.device_id, request.content.contents.to_vec(), ) .await @@ -68,6 +70,7 @@ pub async fn update_document_json( document_id, request.parent_version_id, request.relative_path, + request.device_id, content_bytes, ) .await @@ -80,6 +83,7 @@ async fn internal_update_document( document_id: DocumentId, parent_version_id: VaultUpdateId, relative_path: String, + device_id: Option, content: Vec, ) -> Result, SyncServerError> { // No need for a transaction as document versions are immutable @@ -98,6 +102,34 @@ async fn internal_update_document( Ok, )?; + let sanitized_relative_path = sanitize_path(&relative_path); + + // Return the latest version if the update is a no-op from the client's + // perspective + if content == parent_document.content + && sanitized_relative_path == parent_document.relative_path + { + info!("Document content is the same as the parent version, skipping update"); + + let latest_version = state + .database + .get_latest_document(&vault_id, &document_id, None) + .await + .map_err(server_error)? + .map_or_else( + || { + Err(not_found_error(anyhow!( + "Document with id `{document_id}` not found", + ))) + }, + Ok, + )?; + + return Ok(Json(DocumentUpdateResponse::FastForwardUpdate( + latest_version.into(), + ))); + } + let mut transaction = state .database .create_write_transaction(&vault_id) @@ -136,8 +168,6 @@ async fn internal_update_document( ))); } - let sanitized_relative_path = sanitize_path(&relative_path); - // Return the latest version if the content and path are the same as the latest // version if content == latest_version.content && sanitized_relative_path == latest_version.relative_path @@ -208,7 +238,13 @@ async fn internal_update_document( state .broadcasts - .send(vault_id, new_version.clone().into()) + .send( + vault_id, + VaultUpdate { + origin_device_id: device_id, + document: new_version.clone().into(), + }, + ) .await; Ok(Json(if is_different_from_request_content { diff --git a/backend/sync_server/src/server/websocket.rs b/backend/sync_server/src/server/websocket.rs index aa5bc88..7241b12 100644 --- a/backend/sync_server/src/server/websocket.rs +++ b/backend/sync_server/src/server/websocket.rs @@ -18,7 +18,7 @@ use super::auth::auth; use crate::{ app_state::{ AppState, - database::models::{DocumentVersionWithoutContent, VaultId, VaultUpdateId}, + database::models::{DeviceId, DocumentVersionWithoutContent, VaultId, VaultUpdateId}, }, errors::{SyncServerError, server_error, unauthenticated_error}, }; @@ -61,6 +61,13 @@ async fn websocket_wrapped( warn!("Websocket connection closed on vault '{vault_id}'"); } +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct WebsocketHandshake { + pub token: String, + pub device_id: DeviceId, +} + async fn websocket( state: AppState, stream: WebSocket, @@ -69,13 +76,19 @@ async fn websocket( ) -> Result<(), SyncServerError> { let (mut sender, mut receiver) = stream.split(); - if let Some(Ok(Message::Text(token))) = receiver.next().await { - auth(&state, &token, &vault_id)?; + let handshake = if let Some(Ok(Message::Text(token))) = receiver.next().await { + let handshake: WebsocketHandshake = serde_json::from_str(&token) + .context("Failed to parse token") + .map_err(server_error)?; + + auth(&state, &handshake.token, &vault_id)?; + + handshake } else { return Err(unauthenticated_error(anyhow::anyhow!( "Failed to authenticate" ))); - } + }; let mut rx = state.broadcasts.get_receiver(vault_id.clone()).await; @@ -99,7 +112,11 @@ async fn websocket( let mut send_task = tokio::spawn(async move { while let Ok(update) = rx.recv().await { - send_document_over_websocket(update, &mut sender).await?; + if Some(&handshake.device_id) == update.origin_device_id.as_ref() { + continue; + } + + send_document_over_websocket(update.document, &mut sender).await?; } Ok::<(), SyncServerError>(())