Extract reconcile (#85)

This commit is contained in:
Andras Schmelczer 2025-07-13 11:06:42 +01:00 committed by GitHub
parent 75b020146a
commit bb0e44f06f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
141 changed files with 294 additions and 36720 deletions

View file

@ -0,0 +1,9 @@
<!DOCTYPE html>
<html>
<head>
<title>VaultLink</title>
</head>
<body>
<h1>VaultLink server</h1>
</body>
</html>

View file

@ -0,0 +1,70 @@
use std::collections::HashMap;
use axum::{
extract::{Path, Request, State},
middleware::Next,
response::Response,
};
use axum_extra::{
TypedHeader,
headers::{Authorization, authorization::Bearer},
};
use log::info;
use crate::{
app_state::{AppState, database::models::VaultId},
config::user_config::{AllowListedVaults, User, VaultAccess},
errors::{SyncServerError, permission_denied_error, unauthenticated_error},
utils::normalize::normalize_string,
};
pub async fn auth_middleware(
State(state): State<AppState>,
Path(path_params): Path<HashMap<String, String>>,
TypedHeader(auth_header): TypedHeader<Authorization<Bearer>>,
mut req: Request,
next: Next,
) -> Result<Response, SyncServerError> {
let token = auth_header.token().trim();
let vault_id = normalize_string(
path_params
.get("vault_id")
.ok_or_else(|| unauthenticated_error(anyhow::anyhow!("Missing vault_id")))?,
);
let user = auth(&state, token, &vault_id)?;
req.extensions_mut().insert(user);
Ok(next.run(req).await)
}
pub fn auth(state: &AppState, token: &str, vault_id: &VaultId) -> Result<User, SyncServerError> {
let user = state
.config
.users
.get_user(token)
.cloned()
.ok_or_else(|| unauthenticated_error(anyhow::anyhow!("Invalid token")))?;
if match user.vault_access {
VaultAccess::AllowAccessToAll => true,
VaultAccess::AllowList(AllowListedVaults { ref allowed }) => allowed.contains(vault_id),
} {
info!(
"User '{}' is authenticated and is authorised to access to vault '{vault_id}'",
user.name
);
Ok(user)
} else {
info!(
"User '{}' is authenticated but is not authorised to access vault '{vault_id}'",
user.name
);
Err(permission_denied_error(anyhow::anyhow!(
"Permission denied for vault `{vault_id}`"
)))
}
}

View file

@ -0,0 +1,95 @@
use anyhow::Context as _;
use axum::{
Extension, Json,
extract::{Path, State},
};
use axum_extra::TypedHeader;
use axum_typed_multipart::TypedMultipart;
use serde::Deserialize;
use super::{device_id_header::DeviceIdHeader, requests::CreateDocumentVersion};
use crate::{
app_state::{
AppState,
database::models::{DocumentVersionWithoutContent, StoredDocumentVersion, VaultId},
},
config::user_config::User,
errors::{SyncServerError, client_error, server_error},
utils::{normalize::normalize, sanitize_path::sanitize_path},
};
#[derive(Deserialize)]
pub struct CreateDocumentPathParams {
#[serde(deserialize_with = "normalize")]
vault_id: VaultId,
}
/// Create a new document in case a document with the same doesn't exist
/// already. If a document with the same path exists, a new version is created
/// with their content merged.
#[axum::debug_handler]
pub async fn create_document(
Path(CreateDocumentPathParams { vault_id }): Path<CreateDocumentPathParams>,
Extension(user): Extension<User>,
TypedHeader(device_id): TypedHeader<DeviceIdHeader>,
State(state): State<AppState>,
TypedMultipart(request): TypedMultipart<CreateDocumentVersion>,
) -> Result<Json<DocumentVersionWithoutContent>, SyncServerError> {
let mut transaction = state
.database
.create_write_transaction(&vault_id)
.await
.map_err(server_error)?;
let document_id = match request.document_id {
Some(document_id) => {
let existing_version = state
.database
.get_latest_document(&vault_id, &document_id, Some(&mut transaction))
.await
.map_err(server_error)?;
if existing_version.is_some() {
return Err(client_error(anyhow::anyhow!(
"Document with the same ID already exists"
)));
}
document_id
}
None => uuid::Uuid::new_v4(),
};
let last_update_id = state
.database
.get_max_update_id_in_vault(&vault_id, Some(&mut transaction))
.await
.map_err(server_error)?;
let sanitized_relative_path = sanitize_path(&request.relative_path);
let new_version = StoredDocumentVersion {
vault_update_id: last_update_id + 1,
document_id,
relative_path: sanitized_relative_path,
content: request.content.contents.to_vec(),
updated_date: chrono::Utc::now(),
is_deleted: false,
user_id: user.name,
device_id: device_id.0,
};
state
.database
.insert_document_version(&vault_id, &new_version, Some(&mut transaction))
.await
.map_err(server_error)?;
transaction
.commit()
.await
.context("Failed to commit successful transaction")
.map_err(server_error)?;
Ok(Json(new_version.into()))
}

View file

@ -0,0 +1,84 @@
use anyhow::Context as _;
use axum::{
Extension, Json,
extract::{Path, State},
};
use axum_extra::TypedHeader;
use serde::Deserialize;
use super::{device_id_header::DeviceIdHeader, requests::DeleteDocumentVersion};
use crate::{
app_state::{
AppState,
database::models::{
DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId,
},
},
config::user_config::User,
errors::{SyncServerError, server_error},
utils::{normalize::normalize, sanitize_path::sanitize_path},
};
#[derive(Deserialize)]
pub struct DeleteDocumentPathParams {
#[serde(deserialize_with = "normalize")]
vault_id: VaultId,
document_id: DocumentId,
}
#[axum::debug_handler]
pub async fn delete_document(
Path(DeleteDocumentPathParams {
vault_id,
document_id,
}): Path<DeleteDocumentPathParams>,
Extension(user): Extension<User>,
TypedHeader(device_id): TypedHeader<DeviceIdHeader>,
State(state): State<AppState>,
Json(request): Json<DeleteDocumentVersion>,
) -> Result<Json<DocumentVersionWithoutContent>, SyncServerError> {
let mut transaction = state
.database
.create_write_transaction(&vault_id)
.await
.map_err(server_error)?;
let last_update_id = state
.database
.get_max_update_id_in_vault(&vault_id, Some(&mut transaction))
.await
.map_err(server_error)?;
let latest_content = state
.database
.get_latest_document(&vault_id, &document_id, Some(&mut transaction))
.await
.map_err(server_error)?
.map_or_else(Vec::new, |version| version.content); // in case the document has never existed before deleting it
let new_version = StoredDocumentVersion {
vault_update_id: last_update_id + 1,
document_id,
relative_path: sanitize_path(&request.relative_path),
content: latest_content, // copy the content from the latest version
updated_date: chrono::Utc::now(),
is_deleted: true,
user_id: user.name,
device_id: device_id.0,
};
state
.database
.insert_document_version(&vault_id, &new_version, Some(&mut transaction))
.await
.map_err(server_error)?;
transaction
.commit()
.await
.context("Failed to commit successful transaction")
.map_err(server_error)?;
Ok(Json(new_version.into()))
}

View file

@ -0,0 +1,33 @@
use axum_extra::headers;
use headers::{Header, HeaderName, HeaderValue};
pub struct DeviceIdHeader(pub String);
pub static DEVICE_ID_HEADER_NAME: HeaderName = HeaderName::from_static("device-id");
impl Header for DeviceIdHeader {
fn name() -> &'static HeaderName { &DEVICE_ID_HEADER_NAME }
fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
where
I: Iterator<Item = &'i HeaderValue>,
{
let value = values.next().ok_or_else(headers::Error::invalid)?;
Ok(DeviceIdHeader(
value
.to_str()
.map_err(|_| headers::Error::invalid())?
.to_owned(),
))
}
fn encode<E>(&self, values: &mut E)
where
E: Extend<HeaderValue>,
{
let value = HeaderValue::from_static(Box::leak(self.0.to_string().into_boxed_str()));
values.extend(std::iter::once(value));
}
}

View file

@ -0,0 +1,57 @@
use anyhow::anyhow;
use axum::{
Json,
extract::{Path, State},
};
use serde::Deserialize;
use crate::{
app_state::{
AppState,
database::models::{DocumentId, DocumentVersion, VaultId, VaultUpdateId},
},
errors::{SyncServerError, not_found_error, server_error},
utils::normalize::normalize,
};
#[derive(Deserialize)]
pub struct FetchDocumentVersionPathParams {
#[serde(deserialize_with = "normalize")]
vault_id: VaultId,
document_id: DocumentId,
vault_update_id: VaultUpdateId,
}
#[axum::debug_handler]
pub async fn fetch_document_version(
Path(FetchDocumentVersionPathParams {
vault_id,
document_id,
vault_update_id,
}): Path<FetchDocumentVersionPathParams>,
State(state): State<AppState>,
) -> Result<Json<DocumentVersion>, SyncServerError> {
let result = state
.database
.get_document_version(&vault_id, vault_update_id, None)
.await
.map_err(server_error)?
.map_or_else(
|| {
Err(not_found_error(anyhow!(
"Document with vault update id `{vault_update_id}` not found",
)))
},
Ok,
)?;
if result.document_id != document_id {
return Err(not_found_error(anyhow!(
"Document with document id `{document_id}` does not have a version with id \
`{vault_update_id}`",
)));
}
Ok(Json(result.into()))
}

View file

@ -0,0 +1,57 @@
use anyhow::anyhow;
use axum::{
body::Bytes,
extract::{Path, State},
};
use serde::Deserialize;
use crate::{
app_state::{
AppState,
database::models::{DocumentId, VaultId, VaultUpdateId},
},
errors::{SyncServerError, not_found_error, server_error},
utils::normalize::normalize,
};
#[derive(Deserialize)]
pub struct FetchDocumentVersionContentPathParams {
#[serde(deserialize_with = "normalize")]
vault_id: VaultId,
document_id: DocumentId,
vault_update_id: VaultUpdateId,
}
#[axum::debug_handler]
pub async fn fetch_document_version_content(
Path(FetchDocumentVersionContentPathParams {
vault_id,
document_id,
vault_update_id,
}): Path<FetchDocumentVersionContentPathParams>,
State(state): State<AppState>,
) -> Result<Bytes, SyncServerError> {
let result = state
.database
.get_document_version(&vault_id, vault_update_id, None)
.await
.map_err(server_error)?
.map_or_else(
|| {
Err(not_found_error(anyhow!(
"Document with vault update id `{vault_update_id}` not found",
)))
},
Ok,
)?;
if result.document_id != document_id {
return Err(not_found_error(anyhow!(
"Document with document id `{document_id}` does not have a version with id \
`{vault_update_id}`",
)));
}
Ok(result.content.into())
}

View file

@ -0,0 +1,48 @@
use anyhow::anyhow;
use axum::{
Json,
extract::{Path, State},
};
use serde::Deserialize;
use crate::{
app_state::{
AppState,
database::models::{DocumentId, DocumentVersion, VaultId},
},
errors::{SyncServerError, not_found_error, server_error},
utils::normalize::normalize,
};
#[derive(Deserialize)]
pub struct FetchLatestDocumentVersionPathParams {
#[serde(deserialize_with = "normalize")]
vault_id: VaultId,
document_id: DocumentId,
}
#[axum::debug_handler]
pub async fn fetch_latest_document_version(
Path(FetchLatestDocumentVersionPathParams {
vault_id,
document_id,
}): Path<FetchLatestDocumentVersionPathParams>,
State(state): State<AppState>,
) -> Result<Json<DocumentVersion>, SyncServerError> {
let latest_version = state
.database
.get_latest_document(&vault_id, &document_id, None)
.await
.map_err(server_error)?
.map_or_else(
|| {
Err(not_found_error(anyhow!(
"Document with id `{document_id}` not found",
)))
},
Ok,
)?;
Ok(Json(latest_version.into()))
}

View file

@ -0,0 +1,56 @@
use axum::{
Json,
extract::{Path, Query, State},
};
use serde::Deserialize;
use super::responses::FetchLatestDocumentsResponse;
use crate::{
app_state::{
AppState,
database::models::{VaultId, VaultUpdateId},
},
errors::{SyncServerError, server_error},
utils::normalize::normalize,
};
#[derive(Deserialize)]
pub struct FetchLatestDocumentsPathParams {
#[serde(deserialize_with = "normalize")]
vault_id: VaultId,
}
#[derive(Deserialize)]
pub struct QueryParams {
since_update_id: Option<VaultUpdateId>,
}
#[axum::debug_handler]
pub async fn fetch_latest_documents(
Path(FetchLatestDocumentsPathParams { vault_id }): Path<FetchLatestDocumentsPathParams>,
Query(QueryParams { since_update_id }): Query<QueryParams>,
State(state): State<AppState>,
) -> Result<Json<FetchLatestDocumentsResponse>, SyncServerError> {
let documents = if let Some(since_update_id) = since_update_id {
state
.database
.get_latest_documents_since(&vault_id, since_update_id, None)
.await
.map_err(server_error)
} else {
state
.database
.get_latest_documents(&vault_id, None)
.await
.map_err(server_error)
}?;
Ok(Json(FetchLatestDocumentsResponse {
last_update_id: documents
.iter()
.map(|doc| doc.vault_update_id)
.max()
.unwrap_or(since_update_id.unwrap_or(0)),
latest_documents: documents,
}))
}

View file

@ -0,0 +1,7 @@
use axum::response::{Html, IntoResponse};
pub async fn index() -> impl IntoResponse {
const HTML_CONTENT: &str = include_str!("./assets/index.html");
let html_content = HTML_CONTENT;
Html(html_content)
}

View file

@ -0,0 +1,37 @@
use axum::{
Json,
extract::{Path, State},
};
use axum_extra::{
TypedHeader,
headers::{Authorization, authorization::Bearer},
};
use serde::Deserialize;
use super::{auth::auth, responses::PingResponse};
use crate::{
app_state::{AppState, database::models::VaultId},
errors::SyncServerError,
utils::normalize::normalize,
};
#[derive(Deserialize)]
pub struct PingPathParams {
#[serde(deserialize_with = "normalize")]
vault_id: VaultId,
}
#[axum::debug_handler]
pub async fn ping(
maybe_auth_header: Option<TypedHeader<Authorization<Bearer>>>,
Path(PingPathParams { vault_id }): Path<PingPathParams>,
State(state): State<AppState>,
) -> Result<Json<PingResponse>, SyncServerError> {
let is_authenticated = maybe_auth_header
.is_some_and(|auth_header| auth(&state, auth_header.token(), &vault_id).is_ok());
Ok(Json(PingResponse {
server_version: env!("CARGO_PKG_VERSION").to_owned(),
is_authenticated,
}))
}

View file

@ -0,0 +1,39 @@
use axum::body::Bytes;
use axum_typed_multipart::{FieldData, TryFromMultipart};
use serde::{self, Deserialize};
use ts_rs::TS;
use crate::app_state::database::models::{DocumentId, VaultUpdateId};
#[derive(TS, Debug, TryFromMultipart)]
#[ts(export)]
pub struct CreateDocumentVersion {
/// The client can decide the document id (if it wishes to) in order
/// to help with syncing. If the client does not provide a document id,
/// the server will generate one. If the client provides a document id
/// it must not already exist in the database.
pub document_id: Option<DocumentId>,
pub relative_path: String,
#[ts(as = "Vec<u8>")]
#[form_data(limit = "unlimited")]
pub content: FieldData<Bytes>,
}
#[derive(TS, Debug, TryFromMultipart)]
#[ts(export)]
pub struct UpdateDocumentVersion {
pub parent_version_id: VaultUpdateId,
pub relative_path: String,
#[ts(as = "Vec<u8>")]
#[form_data(limit = "unlimited")]
pub content: FieldData<Bytes>,
}
#[derive(TS, Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct DeleteDocumentVersion {
pub relative_path: String,
}

View file

@ -0,0 +1,45 @@
use serde::{self, Serialize};
use ts_rs::TS;
use crate::app_state::database::models::{
DocumentVersion, DocumentVersionWithoutContent, VaultUpdateId,
};
/// Response to a ping request.
#[derive(TS, Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct PingResponse {
/// Semantic version of the server.
pub server_version: String,
/// Whether the client is authenticated based on the sent Authorization
/// header.
pub is_authenticated: bool,
}
/// Response to a fetch latest documents request.
#[derive(TS, Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct FetchLatestDocumentsResponse {
pub latest_documents: Vec<DocumentVersionWithoutContent>,
/// The update ID of the latest document in the response.
pub last_update_id: VaultUpdateId,
}
/// Response to an update document request.
#[derive(TS, Debug, Clone, Serialize)]
#[serde(tag = "type")]
#[ts(export)]
pub enum DocumentUpdateResponse {
/// Returned when the created/updated document's content is the same as was
/// sent in the create/update request and thus the response doesn't contain
/// the content because the client must already have it.
FastForwardUpdate(DocumentVersionWithoutContent),
/// Returned when the created/updated document's content is different from
/// what was sent in the create/update request.
MergingUpdate(DocumentVersion),
}

View file

@ -0,0 +1,199 @@
use anyhow::{Context as _, anyhow};
use axum::{
Extension, Json,
extract::{Path, State},
};
use axum_extra::TypedHeader;
use axum_typed_multipart::TypedMultipart;
use log::info;
use reconcile_text::{BuiltinTokenizer, is_binary, reconcile};
use serde::Deserialize;
use super::{
device_id_header::DeviceIdHeader, requests::UpdateDocumentVersion,
responses::DocumentUpdateResponse,
};
use crate::{
app_state::{
AppState,
database::models::{DocumentId, StoredDocumentVersion, VaultId},
},
config::user_config::User,
errors::{SyncServerError, not_found_error, server_error},
utils::{
dedup_paths::dedup_paths, is_file_type_mergable::is_file_type_mergable,
normalize::normalize, sanitize_path::sanitize_path,
},
};
#[derive(Deserialize)]
pub struct UpdateDocumentPathParams {
#[serde(deserialize_with = "normalize")]
vault_id: VaultId,
document_id: DocumentId,
}
#[axum::debug_handler]
#[allow(clippy::too_many_lines)]
pub async fn update_document(
Path(UpdateDocumentPathParams {
vault_id,
document_id,
}): Path<UpdateDocumentPathParams>,
Extension(user): Extension<User>,
TypedHeader(device_id): TypedHeader<DeviceIdHeader>,
State(state): State<AppState>,
TypedMultipart(request): TypedMultipart<UpdateDocumentVersion>,
) -> Result<Json<DocumentUpdateResponse>, SyncServerError> {
// No need for a transaction as document versions are immutable
let parent_document = state
.database
.get_document_version(&vault_id, request.parent_version_id, None)
.await
.map_err(server_error)?
.map_or_else(
|| {
Err(not_found_error(anyhow!(
"Parent version with id `{}` not found",
request.parent_version_id
)))
},
Ok,
)?;
let sanitized_relative_path = sanitize_path(&request.relative_path);
let mut transaction = state
.database
.create_write_transaction(&vault_id)
.await
.map_err(server_error)?;
let last_update_id = state
.database
.get_max_update_id_in_vault(&vault_id, Some(&mut transaction))
.await
.map_err(server_error)?;
let latest_version = state
.database
.get_latest_document(&vault_id, &document_id, Some(&mut transaction))
.await
.map_err(server_error)?
.map_or_else(
|| {
Err(not_found_error(anyhow!(
"Document with id `{document_id}` not found",
)))
},
Ok,
)?;
if latest_version.is_deleted {
transaction
.rollback()
.await
.context("Failed to roll back transaction")
.map_err(server_error)?;
return Ok(Json(DocumentUpdateResponse::FastForwardUpdate(
latest_version.into(),
)));
}
let content = request.content.contents.to_vec();
// Return the latest version if the content and path are the same as the latest
// version
if content == latest_version.content && sanitized_relative_path == latest_version.relative_path
{
info!("Document content is the same as the latest version, skipping update");
transaction
.rollback()
.await
.context("Failed to roll back transaction")
.map_err(server_error)?;
return Ok(Json(DocumentUpdateResponse::FastForwardUpdate(
latest_version.into(),
)));
}
let merged_content = if is_file_type_mergable(&sanitized_relative_path)
&& !is_binary(&parent_document.content)
&& !is_binary(&latest_version.content)
&& !is_binary(&content)
{
reconcile(
str::from_utf8(&parent_document.content)
.expect("parent must be valid UTF-8 because it's not binary"),
&str::from_utf8(&latest_version.content)
.expect("latest_version must be valid UTF-8 because it's not binary")
.into(),
&str::from_utf8(&content)
.expect("content must be valid UTF-8 because it's not binary")
.into(),
&*BuiltinTokenizer::Word,
)
.apply()
.text()
.into_bytes()
} else {
content.clone()
};
let is_different_from_request_content = merged_content != content;
// We can only update the relative path if we're the first one to do so
let new_relative_path = if parent_document.relative_path == latest_version.relative_path
&& latest_version.relative_path != sanitized_relative_path
{
let mut new_relative_path = String::default();
for candidate in dedup_paths(&sanitized_relative_path) {
if state
.database
.get_latest_document_by_path(&vault_id, &candidate, Some(&mut transaction))
.await
.map_err(server_error)?
.is_none()
{
new_relative_path = candidate;
break;
}
}
new_relative_path
} else {
latest_version.relative_path.clone()
};
let new_version = StoredDocumentVersion {
document_id,
vault_update_id: last_update_id + 1,
relative_path: new_relative_path,
content: merged_content,
updated_date: chrono::Utc::now(),
is_deleted: false,
user_id: user.name,
device_id: device_id.0,
};
state
.database
.insert_document_version(&vault_id, &new_version, Some(&mut transaction))
.await
.map_err(server_error)?;
transaction
.commit()
.await
.context("Failed to commit successful transaction")
.map_err(server_error)?;
Ok(Json(if is_different_from_request_content {
DocumentUpdateResponse::MergingUpdate(new_version.into())
} else {
DocumentUpdateResponse::FastForwardUpdate(new_version.into())
}))
}

View file

@ -0,0 +1,181 @@
use anyhow::Context;
use axum::{
extract::{
Path, State,
ws::{Message, WebSocket, WebSocketUpgrade},
},
response::Response,
};
use futures::stream::StreamExt;
use log::{debug, info};
use serde::Deserialize;
use crate::{
app_state::{
AppState,
database::models::VaultId,
websocket::{
models::{
CursorPositionFromServer, WebSocketClientMessage, WebSocketServerMessage,
WebSocketVaultUpdate,
},
utils::{
get_authenticated_handshake, get_unseen_documents, send_update_over_websocket,
},
},
},
errors::{SyncServerError, client_error, server_error},
utils::normalize::normalize,
};
#[derive(Deserialize)]
pub struct WebSocketPathParams {
#[serde(deserialize_with = "normalize")]
vault_id: VaultId,
}
pub async fn websocket_handler(
ws: WebSocketUpgrade,
Path(WebSocketPathParams { vault_id }): Path<WebSocketPathParams>,
State(state): State<AppState>,
) -> Result<Response, SyncServerError> {
Ok(ws.on_upgrade(move |socket| websocket_wrapped(state, socket, vault_id)))
}
async fn websocket_wrapped(state: AppState, stream: WebSocket, vault_id: VaultId) {
info!("WebSocket connection opened on vault '{vault_id}'");
let result = websocket(state, stream, vault_id.clone()).await;
if let Err(err) = result {
debug!("WebSocket connection error on vault '{vault_id}': {err}");
}
}
#[allow(clippy::too_many_lines)]
async fn websocket(
state: AppState,
stream: WebSocket,
vault_id: VaultId,
) -> Result<(), SyncServerError> {
let (mut sender, mut websocket_receiver) = stream.split();
let authed_handshake = get_authenticated_handshake(
&state,
&vault_id,
websocket_receiver
.next()
.await
.transpose()
.unwrap_or_default(),
)?;
info!(
"WebSocket handshake successful for vault '{vault_id}' for '{}'",
authed_handshake.handshake.device_id
);
let mut broadcast_receiver = state.broadcasts.get_receiver(vault_id.clone()).await;
send_update_over_websocket(
&WebSocketServerMessage::VaultUpdate(WebSocketVaultUpdate {
documents: get_unseen_documents(
&state,
&vault_id,
authed_handshake.handshake.last_seen_vault_update_id,
)
.await?,
is_initial_sync: true,
}),
&mut sender,
)
.await?;
send_update_over_websocket(
&WebSocketServerMessage::CursorPositions(CursorPositionFromServer {
clients: state.cursors.get_cursors(&vault_id).await,
}),
&mut sender,
)
.await?;
let device_id = authed_handshake.handshake.device_id.clone();
let mut send_task = tokio::spawn(async move {
while let Ok(update) = broadcast_receiver.recv().await {
if Some(&device_id) == update.origin_device_id.as_ref() {
continue;
}
send_update_over_websocket(&update.message, &mut sender).await?;
}
Ok::<(), SyncServerError>(())
});
let device_id = authed_handshake.handshake.device_id.clone();
let vault_id_clone = vault_id.clone();
let cursor_manager = state.cursors.clone();
let mut receive_task = tokio::spawn(async move {
while let Some(Ok(Message::Text(message))) = websocket_receiver.next().await {
let message: WebSocketClientMessage = serde_json::from_str(&message)
.context("Failed to parse WebSocket message from client")
.map_err(server_error)?;
match message {
WebSocketClientMessage::Handshake(_) => {
return Err(client_error(anyhow::anyhow!(
"Unexpected handshake message"
)));
}
WebSocketClientMessage::CursorPositions(cursors) => {
cursor_manager
.update_cursors(
vault_id_clone.clone(),
authed_handshake.user.name.clone(),
&device_id,
cursors.document_to_cursors,
)
.await;
}
}
}
Ok::<(), SyncServerError>(())
});
tokio::select! {
_ = &mut send_task => receive_task.abort(),
_ = &mut receive_task => send_task.abort(),
};
let result: Result<(), SyncServerError> = (async {
send_task
.await
.context("WebSocket send task failed")
.map_err(client_error)
.and_then(|err| err)?;
receive_task
.await
.context("WebSocket receive task failed")
.map_err(client_error)
.and_then(|err| err)?;
Ok(())
})
.await;
state
.cursors
.remove_cursors_of_device(&vault_id, &authed_handshake.handshake.device_id)
.await;
if result.is_err() {
info!(
"WebSocket disconnected on vault '{vault_id}' for '{}'",
authed_handshake.handshake.device_id
);
}
result
}