Add device id and use it to filter out updates coming from the same device

This commit is contained in:
Andras Schmelczer 2025-04-04 23:13:50 +01:00
parent 11e2d121b1
commit 648db73628
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
8 changed files with 101 additions and 25 deletions

View file

@ -3,13 +3,19 @@ use std::{collections::HashMap, sync::Arc};
use anyhow::Context;
use tokio::sync::{Mutex, broadcast};
use super::database::models::{DocumentVersionWithoutContent, VaultId};
use super::database::models::{DeviceId, DocumentVersionWithoutContent, VaultId};
use crate::{config::server_config::ServerConfig, errors::server_error};
#[derive(Debug, Clone)]
pub struct Broadcasts {
max_clients_per_vault: usize,
tx: Arc<Mutex<HashMap<VaultId, broadcast::Sender<DocumentVersionWithoutContent>>>>,
tx: Arc<Mutex<HashMap<VaultId, broadcast::Sender<VaultUpdate>>>>,
}
#[derive(Debug, Clone)]
pub struct VaultUpdate {
pub origin_device_id: Option<DeviceId>,
pub document: DocumentVersionWithoutContent,
}
impl Broadcasts {
@ -20,10 +26,7 @@ impl Broadcasts {
}
}
pub async fn get_receiver(
&self,
vault: VaultId,
) -> broadcast::Receiver<DocumentVersionWithoutContent> {
pub async fn get_receiver(&self, vault: VaultId) -> broadcast::Receiver<VaultUpdate> {
let tx = self.get_or_create(vault).await;
tx.subscribe()
@ -31,7 +34,7 @@ impl Broadcasts {
/// Sent a document update to all clients subscribed to the vault.
/// We ignore & log failures.
pub async fn send(&self, vault: VaultId, document: DocumentVersionWithoutContent) {
pub async fn send(&self, vault: VaultId, document: VaultUpdate) {
let tx = self.get_or_create(vault).await;
let result = tx
@ -44,10 +47,7 @@ impl Broadcasts {
}
}
async fn get_or_create(
&self,
vault: VaultId,
) -> broadcast::Sender<DocumentVersionWithoutContent> {
async fn get_or_create(&self, vault: VaultId) -> broadcast::Sender<VaultUpdate> {
let mut tx = self.tx.lock().await;
tx.entry(vault)

View file

@ -6,6 +6,7 @@ use sync_lib::bytes_to_base64;
pub type VaultId = String;
pub type VaultUpdateId = i64;
pub type DocumentId = uuid::Uuid;
pub type DeviceId = String;
#[derive(Debug, Clone)]
pub struct StoredDocumentVersion {

View file

@ -86,7 +86,7 @@ pub async fn create_server(config_path: Option<OsString>) -> Result<()> {
TraceLayer::new_for_http()
.make_span_with(|request: &Request<_>| {
info_span!(
"http_request",
"http",
method = ?request.method(),
uri = ?request.uri(),
)

View file

@ -10,8 +10,9 @@ use super::requests::{CreateDocumentVersion, CreateDocumentVersionMultipart};
use crate::{
app_state::{
AppState,
broadcasts::VaultUpdate,
database::models::{
DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId,
DeviceId, DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId,
},
},
errors::{SyncServerError, client_error, server_error},
@ -40,6 +41,7 @@ pub async fn create_document_multipart(
vault_id,
request.document_id,
request.relative_path,
request.device_id,
request.content.contents.to_vec(),
)
.await
@ -63,6 +65,7 @@ pub async fn create_document_json(
vault_id,
request.document_id,
request.relative_path,
request.device_id,
content_bytes,
)
.await
@ -73,6 +76,7 @@ async fn internal_create_document(
vault_id: VaultId,
document_id: Option<DocumentId>,
relative_path: String,
device_id: Option<DeviceId>,
content: Vec<u8>,
) -> Result<Json<DocumentVersionWithoutContent>, SyncServerError> {
let mut transaction = state
@ -131,7 +135,13 @@ async fn internal_create_document(
state
.broadcasts
.send(vault_id, new_version.clone().into())
.send(
vault_id,
VaultUpdate {
origin_device_id: device_id,
document: new_version.clone().into(),
},
)
.await;
Ok(Json(new_version.into()))

View file

@ -8,6 +8,7 @@ use super::requests::DeleteDocumentVersion;
use crate::{
app_state::{
AppState,
broadcasts::VaultUpdate,
database::models::{
DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId,
},
@ -67,7 +68,13 @@ pub async fn delete_document(
state
.broadcasts
.send(vault_id, new_version.clone().into())
.send(
vault_id,
VaultUpdate {
origin_device_id: request.device_id,
document: new_version.clone().into(),
},
)
.await;
Ok(Json(new_version.into()))

View file

@ -4,7 +4,7 @@ use axum_typed_multipart::TryFromMultipart;
use schemars::JsonSchema;
use serde::{self, Deserialize};
use crate::app_state::database::models::{DocumentId, VaultUpdateId};
use crate::app_state::database::models::{DeviceId, DocumentId, VaultUpdateId};
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
@ -16,6 +16,7 @@ pub struct CreateDocumentVersion {
pub document_id: Option<DocumentId>,
pub relative_path: String,
pub content_base64: String,
pub device_id: Option<DeviceId>,
}
#[derive(Debug, TryFromMultipart, JsonSchema)]
@ -24,6 +25,7 @@ pub struct CreateDocumentVersionMultipart {
pub relative_path: String,
#[form_data(limit = "unlimited")]
pub content: FieldData<Bytes>,
pub device_id: Option<DeviceId>,
}
#[derive(Debug, Deserialize, JsonSchema)]
@ -32,6 +34,7 @@ pub struct UpdateDocumentVersion {
pub parent_version_id: VaultUpdateId,
pub relative_path: String,
pub content_base64: String,
pub device_id: Option<DeviceId>,
}
#[derive(Debug, TryFromMultipart, JsonSchema)]
@ -41,10 +44,12 @@ pub struct UpdateDocumentVersionMultipart {
pub relative_path: String,
#[form_data(limit = "unlimited")]
pub content: FieldData<Bytes>,
pub device_id: Option<DeviceId>,
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct DeleteDocumentVersion {
pub relative_path: String,
pub device_id: Option<DeviceId>,
}

View file

@ -14,7 +14,8 @@ use super::{
use crate::{
app_state::{
AppState,
database::models::{DocumentId, StoredDocumentVersion, VaultId, VaultUpdateId},
broadcasts::VaultUpdate,
database::models::{DeviceId, DocumentId, StoredDocumentVersion, VaultId, VaultUpdateId},
},
errors::{SyncServerError, client_error, not_found_error, server_error},
utils::{deduped_file_paths, sanitize_path},
@ -44,6 +45,7 @@ pub async fn update_document_multipart(
document_id,
request.parent_version_id,
request.relative_path,
request.device_id,
request.content.contents.to_vec(),
)
.await
@ -68,6 +70,7 @@ pub async fn update_document_json(
document_id,
request.parent_version_id,
request.relative_path,
request.device_id,
content_bytes,
)
.await
@ -80,6 +83,7 @@ async fn internal_update_document(
document_id: DocumentId,
parent_version_id: VaultUpdateId,
relative_path: String,
device_id: Option<DeviceId>,
content: Vec<u8>,
) -> Result<Json<DocumentUpdateResponse>, SyncServerError> {
// No need for a transaction as document versions are immutable
@ -98,6 +102,34 @@ async fn internal_update_document(
Ok,
)?;
let sanitized_relative_path = sanitize_path(&relative_path);
// Return the latest version if the update is a no-op from the client's
// perspective
if content == parent_document.content
&& sanitized_relative_path == parent_document.relative_path
{
info!("Document content is the same as the parent version, skipping update");
let latest_version = state
.database
.get_latest_document(&vault_id, &document_id, None)
.await
.map_err(server_error)?
.map_or_else(
|| {
Err(not_found_error(anyhow!(
"Document with id `{document_id}` not found",
)))
},
Ok,
)?;
return Ok(Json(DocumentUpdateResponse::FastForwardUpdate(
latest_version.into(),
)));
}
let mut transaction = state
.database
.create_write_transaction(&vault_id)
@ -136,8 +168,6 @@ async fn internal_update_document(
)));
}
let sanitized_relative_path = sanitize_path(&relative_path);
// Return the latest version if the content and path are the same as the latest
// version
if content == latest_version.content && sanitized_relative_path == latest_version.relative_path
@ -208,7 +238,13 @@ async fn internal_update_document(
state
.broadcasts
.send(vault_id, new_version.clone().into())
.send(
vault_id,
VaultUpdate {
origin_device_id: device_id,
document: new_version.clone().into(),
},
)
.await;
Ok(Json(if is_different_from_request_content {

View file

@ -18,7 +18,7 @@ use super::auth::auth;
use crate::{
app_state::{
AppState,
database::models::{DocumentVersionWithoutContent, VaultId, VaultUpdateId},
database::models::{DeviceId, DocumentVersionWithoutContent, VaultId, VaultUpdateId},
},
errors::{SyncServerError, server_error, unauthenticated_error},
};
@ -61,6 +61,13 @@ async fn websocket_wrapped(
warn!("Websocket connection closed on vault '{vault_id}'");
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct WebsocketHandshake {
pub token: String,
pub device_id: DeviceId,
}
async fn websocket(
state: AppState,
stream: WebSocket,
@ -69,13 +76,19 @@ async fn websocket(
) -> Result<(), SyncServerError> {
let (mut sender, mut receiver) = stream.split();
if let Some(Ok(Message::Text(token))) = receiver.next().await {
auth(&state, &token, &vault_id)?;
let handshake = if let Some(Ok(Message::Text(token))) = receiver.next().await {
let handshake: WebsocketHandshake = serde_json::from_str(&token)
.context("Failed to parse token")
.map_err(server_error)?;
auth(&state, &handshake.token, &vault_id)?;
handshake
} else {
return Err(unauthenticated_error(anyhow::anyhow!(
"Failed to authenticate"
)));
}
};
let mut rx = state.broadcasts.get_receiver(vault_id.clone()).await;
@ -99,7 +112,11 @@ async fn websocket(
let mut send_task = tokio::spawn(async move {
while let Ok(update) = rx.recv().await {
send_document_over_websocket(update, &mut sender).await?;
if Some(&handshake.device_id) == update.origin_device_id.as_ref() {
continue;
}
send_document_over_websocket(update.document, &mut sender).await?;
}
Ok::<(), SyncServerError>(())