Add device id and use it to filter out updates coming from the same device
This commit is contained in:
parent
11e2d121b1
commit
648db73628
8 changed files with 101 additions and 25 deletions
|
|
@ -3,13 +3,19 @@ use std::{collections::HashMap, sync::Arc};
|
|||
use anyhow::Context;
|
||||
use tokio::sync::{Mutex, broadcast};
|
||||
|
||||
use super::database::models::{DocumentVersionWithoutContent, VaultId};
|
||||
use super::database::models::{DeviceId, DocumentVersionWithoutContent, VaultId};
|
||||
use crate::{config::server_config::ServerConfig, errors::server_error};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Broadcasts {
|
||||
max_clients_per_vault: usize,
|
||||
tx: Arc<Mutex<HashMap<VaultId, broadcast::Sender<DocumentVersionWithoutContent>>>>,
|
||||
tx: Arc<Mutex<HashMap<VaultId, broadcast::Sender<VaultUpdate>>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VaultUpdate {
|
||||
pub origin_device_id: Option<DeviceId>,
|
||||
pub document: DocumentVersionWithoutContent,
|
||||
}
|
||||
|
||||
impl Broadcasts {
|
||||
|
|
@ -20,10 +26,7 @@ impl Broadcasts {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn get_receiver(
|
||||
&self,
|
||||
vault: VaultId,
|
||||
) -> broadcast::Receiver<DocumentVersionWithoutContent> {
|
||||
pub async fn get_receiver(&self, vault: VaultId) -> broadcast::Receiver<VaultUpdate> {
|
||||
let tx = self.get_or_create(vault).await;
|
||||
|
||||
tx.subscribe()
|
||||
|
|
@ -31,7 +34,7 @@ impl Broadcasts {
|
|||
|
||||
/// Sent a document update to all clients subscribed to the vault.
|
||||
/// We ignore & log failures.
|
||||
pub async fn send(&self, vault: VaultId, document: DocumentVersionWithoutContent) {
|
||||
pub async fn send(&self, vault: VaultId, document: VaultUpdate) {
|
||||
let tx = self.get_or_create(vault).await;
|
||||
|
||||
let result = tx
|
||||
|
|
@ -44,10 +47,7 @@ impl Broadcasts {
|
|||
}
|
||||
}
|
||||
|
||||
async fn get_or_create(
|
||||
&self,
|
||||
vault: VaultId,
|
||||
) -> broadcast::Sender<DocumentVersionWithoutContent> {
|
||||
async fn get_or_create(&self, vault: VaultId) -> broadcast::Sender<VaultUpdate> {
|
||||
let mut tx = self.tx.lock().await;
|
||||
|
||||
tx.entry(vault)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ use sync_lib::bytes_to_base64;
|
|||
pub type VaultId = String;
|
||||
pub type VaultUpdateId = i64;
|
||||
pub type DocumentId = uuid::Uuid;
|
||||
pub type DeviceId = String;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StoredDocumentVersion {
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ pub async fn create_server(config_path: Option<OsString>) -> Result<()> {
|
|||
TraceLayer::new_for_http()
|
||||
.make_span_with(|request: &Request<_>| {
|
||||
info_span!(
|
||||
"http_request",
|
||||
"http",
|
||||
method = ?request.method(),
|
||||
uri = ?request.uri(),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,8 +10,9 @@ use super::requests::{CreateDocumentVersion, CreateDocumentVersionMultipart};
|
|||
use crate::{
|
||||
app_state::{
|
||||
AppState,
|
||||
broadcasts::VaultUpdate,
|
||||
database::models::{
|
||||
DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId,
|
||||
DeviceId, DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId,
|
||||
},
|
||||
},
|
||||
errors::{SyncServerError, client_error, server_error},
|
||||
|
|
@ -40,6 +41,7 @@ pub async fn create_document_multipart(
|
|||
vault_id,
|
||||
request.document_id,
|
||||
request.relative_path,
|
||||
request.device_id,
|
||||
request.content.contents.to_vec(),
|
||||
)
|
||||
.await
|
||||
|
|
@ -63,6 +65,7 @@ pub async fn create_document_json(
|
|||
vault_id,
|
||||
request.document_id,
|
||||
request.relative_path,
|
||||
request.device_id,
|
||||
content_bytes,
|
||||
)
|
||||
.await
|
||||
|
|
@ -73,6 +76,7 @@ async fn internal_create_document(
|
|||
vault_id: VaultId,
|
||||
document_id: Option<DocumentId>,
|
||||
relative_path: String,
|
||||
device_id: Option<DeviceId>,
|
||||
content: Vec<u8>,
|
||||
) -> Result<Json<DocumentVersionWithoutContent>, SyncServerError> {
|
||||
let mut transaction = state
|
||||
|
|
@ -131,7 +135,13 @@ async fn internal_create_document(
|
|||
|
||||
state
|
||||
.broadcasts
|
||||
.send(vault_id, new_version.clone().into())
|
||||
.send(
|
||||
vault_id,
|
||||
VaultUpdate {
|
||||
origin_device_id: device_id,
|
||||
document: new_version.clone().into(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(new_version.into()))
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ use super::requests::DeleteDocumentVersion;
|
|||
use crate::{
|
||||
app_state::{
|
||||
AppState,
|
||||
broadcasts::VaultUpdate,
|
||||
database::models::{
|
||||
DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId,
|
||||
},
|
||||
|
|
@ -67,7 +68,13 @@ pub async fn delete_document(
|
|||
|
||||
state
|
||||
.broadcasts
|
||||
.send(vault_id, new_version.clone().into())
|
||||
.send(
|
||||
vault_id,
|
||||
VaultUpdate {
|
||||
origin_device_id: request.device_id,
|
||||
document: new_version.clone().into(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(new_version.into()))
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use axum_typed_multipart::TryFromMultipart;
|
|||
use schemars::JsonSchema;
|
||||
use serde::{self, Deserialize};
|
||||
|
||||
use crate::app_state::database::models::{DocumentId, VaultUpdateId};
|
||||
use crate::app_state::database::models::{DeviceId, DocumentId, VaultUpdateId};
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
|
|
@ -16,6 +16,7 @@ pub struct CreateDocumentVersion {
|
|||
pub document_id: Option<DocumentId>,
|
||||
pub relative_path: String,
|
||||
pub content_base64: String,
|
||||
pub device_id: Option<DeviceId>,
|
||||
}
|
||||
|
||||
#[derive(Debug, TryFromMultipart, JsonSchema)]
|
||||
|
|
@ -24,6 +25,7 @@ pub struct CreateDocumentVersionMultipart {
|
|||
pub relative_path: String,
|
||||
#[form_data(limit = "unlimited")]
|
||||
pub content: FieldData<Bytes>,
|
||||
pub device_id: Option<DeviceId>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
|
|
@ -32,6 +34,7 @@ pub struct UpdateDocumentVersion {
|
|||
pub parent_version_id: VaultUpdateId,
|
||||
pub relative_path: String,
|
||||
pub content_base64: String,
|
||||
pub device_id: Option<DeviceId>,
|
||||
}
|
||||
|
||||
#[derive(Debug, TryFromMultipart, JsonSchema)]
|
||||
|
|
@ -41,10 +44,12 @@ pub struct UpdateDocumentVersionMultipart {
|
|||
pub relative_path: String,
|
||||
#[form_data(limit = "unlimited")]
|
||||
pub content: FieldData<Bytes>,
|
||||
pub device_id: Option<DeviceId>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct DeleteDocumentVersion {
|
||||
pub relative_path: String,
|
||||
pub device_id: Option<DeviceId>,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,7 +14,8 @@ use super::{
|
|||
use crate::{
|
||||
app_state::{
|
||||
AppState,
|
||||
database::models::{DocumentId, StoredDocumentVersion, VaultId, VaultUpdateId},
|
||||
broadcasts::VaultUpdate,
|
||||
database::models::{DeviceId, DocumentId, StoredDocumentVersion, VaultId, VaultUpdateId},
|
||||
},
|
||||
errors::{SyncServerError, client_error, not_found_error, server_error},
|
||||
utils::{deduped_file_paths, sanitize_path},
|
||||
|
|
@ -44,6 +45,7 @@ pub async fn update_document_multipart(
|
|||
document_id,
|
||||
request.parent_version_id,
|
||||
request.relative_path,
|
||||
request.device_id,
|
||||
request.content.contents.to_vec(),
|
||||
)
|
||||
.await
|
||||
|
|
@ -68,6 +70,7 @@ pub async fn update_document_json(
|
|||
document_id,
|
||||
request.parent_version_id,
|
||||
request.relative_path,
|
||||
request.device_id,
|
||||
content_bytes,
|
||||
)
|
||||
.await
|
||||
|
|
@ -80,6 +83,7 @@ async fn internal_update_document(
|
|||
document_id: DocumentId,
|
||||
parent_version_id: VaultUpdateId,
|
||||
relative_path: String,
|
||||
device_id: Option<DeviceId>,
|
||||
content: Vec<u8>,
|
||||
) -> Result<Json<DocumentUpdateResponse>, SyncServerError> {
|
||||
// No need for a transaction as document versions are immutable
|
||||
|
|
@ -98,6 +102,34 @@ async fn internal_update_document(
|
|||
Ok,
|
||||
)?;
|
||||
|
||||
let sanitized_relative_path = sanitize_path(&relative_path);
|
||||
|
||||
// Return the latest version if the update is a no-op from the client's
|
||||
// perspective
|
||||
if content == parent_document.content
|
||||
&& sanitized_relative_path == parent_document.relative_path
|
||||
{
|
||||
info!("Document content is the same as the parent version, skipping update");
|
||||
|
||||
let latest_version = state
|
||||
.database
|
||||
.get_latest_document(&vault_id, &document_id, None)
|
||||
.await
|
||||
.map_err(server_error)?
|
||||
.map_or_else(
|
||||
|| {
|
||||
Err(not_found_error(anyhow!(
|
||||
"Document with id `{document_id}` not found",
|
||||
)))
|
||||
},
|
||||
Ok,
|
||||
)?;
|
||||
|
||||
return Ok(Json(DocumentUpdateResponse::FastForwardUpdate(
|
||||
latest_version.into(),
|
||||
)));
|
||||
}
|
||||
|
||||
let mut transaction = state
|
||||
.database
|
||||
.create_write_transaction(&vault_id)
|
||||
|
|
@ -136,8 +168,6 @@ async fn internal_update_document(
|
|||
)));
|
||||
}
|
||||
|
||||
let sanitized_relative_path = sanitize_path(&relative_path);
|
||||
|
||||
// Return the latest version if the content and path are the same as the latest
|
||||
// version
|
||||
if content == latest_version.content && sanitized_relative_path == latest_version.relative_path
|
||||
|
|
@ -208,7 +238,13 @@ async fn internal_update_document(
|
|||
|
||||
state
|
||||
.broadcasts
|
||||
.send(vault_id, new_version.clone().into())
|
||||
.send(
|
||||
vault_id,
|
||||
VaultUpdate {
|
||||
origin_device_id: device_id,
|
||||
document: new_version.clone().into(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(if is_different_from_request_content {
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ use super::auth::auth;
|
|||
use crate::{
|
||||
app_state::{
|
||||
AppState,
|
||||
database::models::{DocumentVersionWithoutContent, VaultId, VaultUpdateId},
|
||||
database::models::{DeviceId, DocumentVersionWithoutContent, VaultId, VaultUpdateId},
|
||||
},
|
||||
errors::{SyncServerError, server_error, unauthenticated_error},
|
||||
};
|
||||
|
|
@ -61,6 +61,13 @@ async fn websocket_wrapped(
|
|||
warn!("Websocket connection closed on vault '{vault_id}'");
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct WebsocketHandshake {
|
||||
pub token: String,
|
||||
pub device_id: DeviceId,
|
||||
}
|
||||
|
||||
async fn websocket(
|
||||
state: AppState,
|
||||
stream: WebSocket,
|
||||
|
|
@ -69,13 +76,19 @@ async fn websocket(
|
|||
) -> Result<(), SyncServerError> {
|
||||
let (mut sender, mut receiver) = stream.split();
|
||||
|
||||
if let Some(Ok(Message::Text(token))) = receiver.next().await {
|
||||
auth(&state, &token, &vault_id)?;
|
||||
let handshake = if let Some(Ok(Message::Text(token))) = receiver.next().await {
|
||||
let handshake: WebsocketHandshake = serde_json::from_str(&token)
|
||||
.context("Failed to parse token")
|
||||
.map_err(server_error)?;
|
||||
|
||||
auth(&state, &handshake.token, &vault_id)?;
|
||||
|
||||
handshake
|
||||
} else {
|
||||
return Err(unauthenticated_error(anyhow::anyhow!(
|
||||
"Failed to authenticate"
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
let mut rx = state.broadcasts.get_receiver(vault_id.clone()).await;
|
||||
|
||||
|
|
@ -99,7 +112,11 @@ async fn websocket(
|
|||
|
||||
let mut send_task = tokio::spawn(async move {
|
||||
while let Ok(update) = rx.recv().await {
|
||||
send_document_over_websocket(update, &mut sender).await?;
|
||||
if Some(&handshake.device_id) == update.origin_device_id.as_ref() {
|
||||
continue;
|
||||
}
|
||||
|
||||
send_document_over_websocket(update.document, &mut sender).await?;
|
||||
}
|
||||
|
||||
Ok::<(), SyncServerError>(())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue