split: server websocket + cursors
src/server/websocket.rs handshake/catch-up rewrite, app_state/cursors.rs,
app_state/websocket/{broadcasts,models,utils}.rs.
This commit is contained in:
parent
4ba439b874
commit
042233c4d7
5 changed files with 487 additions and 170 deletions
|
|
@ -42,7 +42,9 @@ impl Cursors {
|
|||
) {
|
||||
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||
|
||||
let all_device_cursors = vault_to_cursors.entry(vault_id).or_insert_with(Vec::new);
|
||||
let all_device_cursors = vault_to_cursors
|
||||
.entry(vault_id.clone())
|
||||
.or_insert_with(Vec::new);
|
||||
|
||||
all_device_cursors.retain(|c| &c.client_cursors.device_id != device_id);
|
||||
all_device_cursors.push(ClientCursorsWithTimeToLive::new(ClientCursors {
|
||||
|
|
@ -52,7 +54,7 @@ impl Cursors {
|
|||
}));
|
||||
|
||||
drop(vault_to_cursors); // Explicitly drop the lock before broadcasting to avoid deadlock
|
||||
self.broadcast_cursors().await;
|
||||
self.broadcast_cursors_for_vault(&vault_id).await;
|
||||
}
|
||||
|
||||
pub async fn get_cursors(&self, vault_id: &VaultId) -> Vec<ClientCursors> {
|
||||
|
|
@ -69,45 +71,81 @@ impl Cursors {
|
|||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn start_background_task(self) {
|
||||
pub fn start_background_task(self, mut shutdown: tokio::sync::watch::Receiver<()>) {
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
() = tokio::time::sleep(Duration::from_secs(1)) => {
|
||||
self.remove_expired_cursors().await;
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
}
|
||||
Ok(()) = shutdown.changed() => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn remove_expired_cursors(&self) {
|
||||
let changed_vaults: Vec<VaultId> = {
|
||||
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||
|
||||
for (_vault_id, cursors) in vault_to_cursors.iter_mut() {
|
||||
let mut changed = Vec::new();
|
||||
for (vault_id, cursors) in vault_to_cursors.iter_mut() {
|
||||
let before = cursors.len();
|
||||
cursors.retain(|cursor| !cursor.is_expired(self.config.cursor_timeout));
|
||||
if cursors.len() != before {
|
||||
changed.push(vault_id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
async fn broadcast_cursors(&self) {
|
||||
// Remove empty vault entries to prevent unbounded growth
|
||||
vault_to_cursors.retain(|_, cursors| !cursors.is_empty());
|
||||
|
||||
changed
|
||||
};
|
||||
|
||||
for vault_id in &changed_vaults {
|
||||
self.broadcast_cursors_for_vault(vault_id).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn broadcast_cursors_for_vault(&self, vault_id: &VaultId) {
|
||||
let client_cursors: Vec<ClientCursors> = {
|
||||
let vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||
vault_to_cursors
|
||||
.get(vault_id)
|
||||
.map(|cursors| cursors.iter().map(|c| c.client_cursors.clone()).collect())
|
||||
.unwrap_or_default()
|
||||
};
|
||||
|
||||
for (vault_id, cursors) in vault_to_cursors.iter() {
|
||||
self.broadcasts
|
||||
.send_document_update(
|
||||
self.broadcasts.send_document_update(
|
||||
vault_id.clone(),
|
||||
WebSocketServerMessageWithOrigin::new(WebSocketServerMessage::CursorPositions(
|
||||
CursorPositionFromServer {
|
||||
clients: cursors.iter().map(|c| c.client_cursors.clone()).collect(),
|
||||
clients: client_cursors,
|
||||
},
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
pub async fn remove_cursors_of_device(&self, vault_id: &str, device_id: &str) {
|
||||
pub async fn remove_cursors_of_device(&self, vault_id: &VaultId, device_id: &DeviceId) {
|
||||
let changed = {
|
||||
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
|
||||
|
||||
if let Some(cursors) = vault_to_cursors.get_mut(vault_id) {
|
||||
cursors.retain(|c| c.client_cursors.device_id != device_id);
|
||||
let before = cursors.len();
|
||||
cursors.retain(|c| c.client_cursors.device_id != *device_id);
|
||||
let changed = cursors.len() != before;
|
||||
if cursors.is_empty() {
|
||||
vault_to_cursors.remove(vault_id);
|
||||
}
|
||||
changed
|
||||
} else {
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
if changed {
|
||||
self.broadcast_cursors_for_vault(vault_id).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,69 +1,147 @@
|
|||
use std::{collections::HashMap, sync::Arc};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{Arc, Mutex as StdMutex},
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use log::{debug, warn};
|
||||
use log::{debug, info, warn};
|
||||
use tokio::sync::{Mutex, broadcast};
|
||||
|
||||
use super::models::WebSocketServerMessageWithOrigin;
|
||||
use crate::{
|
||||
app_state::database::models::VaultId, config::server_config::ServerConfig, errors::server_error,
|
||||
};
|
||||
use super::models::{WebSocketServerMessage, WebSocketServerMessageWithOrigin};
|
||||
use crate::{app_state::database::models::VaultId, config::server_config::ServerConfig};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Broadcasts {
|
||||
max_clients_per_vault: usize,
|
||||
tx: Arc<Mutex<HashMap<VaultId, broadcast::Sender<WebSocketServerMessageWithOrigin>>>>,
|
||||
broadcast_channel_capacity: usize,
|
||||
// `tx` uses a blocking std::sync::Mutex because the critical section is
|
||||
// a HashMap lookup plus a synchronous `broadcast::Sender::send`. Making
|
||||
// this non-async lets `send_document_update` run without an `.await`,
|
||||
// so an axum handler that is cancelled between `transaction.commit()`
|
||||
// and the broadcast can never drop the notification mid-flight.
|
||||
tx: Arc<StdMutex<HashMap<VaultId, broadcast::Sender<WebSocketServerMessageWithOrigin>>>>,
|
||||
send_locks: Arc<Mutex<HashMap<VaultId, Arc<tokio::sync::Mutex<()>>>>>,
|
||||
}
|
||||
|
||||
type TxMap = HashMap<VaultId, broadcast::Sender<WebSocketServerMessageWithOrigin>>;
|
||||
|
||||
impl Broadcasts {
|
||||
pub fn new(server_config: &ServerConfig) -> Self {
|
||||
Self {
|
||||
max_clients_per_vault: server_config.max_clients_per_vault,
|
||||
tx: Arc::new(Mutex::new(HashMap::new())),
|
||||
broadcast_channel_capacity: server_config.broadcast_channel_capacity,
|
||||
tx: Arc::new(StdMutex::new(HashMap::new())),
|
||||
send_locks: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_receiver(
|
||||
/// Acquire a per-vault lock that serializes broadcasts in commit order.
|
||||
/// Must be acquired before the insert, held through commit and broadcast.
|
||||
pub async fn acquire_send_lock(&self, vault: &VaultId) -> tokio::sync::OwnedMutexGuard<()> {
|
||||
let lock = {
|
||||
let mut locks = self.send_locks.lock().await;
|
||||
locks
|
||||
.entry(vault.clone())
|
||||
.or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
|
||||
.clone()
|
||||
};
|
||||
lock.lock_owned().await
|
||||
}
|
||||
|
||||
/// Remove senders for vaults with no active receivers
|
||||
fn prune_inactive_vaults(tx_map: &mut TxMap) -> Vec<VaultId> {
|
||||
let mut pruned = Vec::new();
|
||||
tx_map.retain(|vault, sender| {
|
||||
let alive = sender.receiver_count() > 0;
|
||||
if !alive {
|
||||
pruned.push(vault.clone());
|
||||
}
|
||||
alive
|
||||
});
|
||||
pruned
|
||||
}
|
||||
|
||||
pub fn get_receiver(
|
||||
&self,
|
||||
vault: VaultId,
|
||||
) -> broadcast::Receiver<WebSocketServerMessageWithOrigin> {
|
||||
let tx = self.get_or_create(vault).await;
|
||||
max_clients: usize,
|
||||
) -> Result<broadcast::Receiver<WebSocketServerMessageWithOrigin>, crate::errors::SyncServerError>
|
||||
{
|
||||
let mut tx_map = self
|
||||
.tx
|
||||
.lock()
|
||||
.expect("broadcasts.tx mutex poisoned — a previous holder panicked");
|
||||
|
||||
tx.subscribe()
|
||||
let count_before_prune = tx_map
|
||||
.get(&vault)
|
||||
.map_or(0, tokio::sync::broadcast::Sender::receiver_count);
|
||||
let pruned = Self::prune_inactive_vaults(&mut tx_map);
|
||||
let pruned_self = pruned.contains(&vault);
|
||||
|
||||
let sender = tx_map
|
||||
.entry(vault.clone())
|
||||
.or_insert_with(|| broadcast::channel(self.broadcast_channel_capacity).0);
|
||||
|
||||
// Hold the lock across the count check *and* the subscribe so the
|
||||
// `max_clients` cap is atomic: two concurrent callers can't both
|
||||
// observe `receiver_count() < max_clients` and both subscribe.
|
||||
if sender.receiver_count() >= max_clients {
|
||||
return Err(crate::errors::client_error(anyhow::anyhow!(
|
||||
"Vault has reached the maximum number of clients ({max_clients})"
|
||||
)));
|
||||
}
|
||||
|
||||
let receiver = sender.subscribe();
|
||||
let count_after = sender.receiver_count();
|
||||
info!(
|
||||
"[BCAST] get_receiver vault={vault} count_before_prune={count_before_prune} pruned_self={pruned_self} pruned_total={} count_after_subscribe={count_after}",
|
||||
pruned.len()
|
||||
);
|
||||
Ok(receiver)
|
||||
}
|
||||
|
||||
/// Notify all clients (who are subscribed to the vault) about an update.
|
||||
/// We only log failures and don't propagate them.
|
||||
pub async fn send_document_update(
|
||||
&self,
|
||||
vault: VaultId,
|
||||
document: WebSocketServerMessageWithOrigin,
|
||||
) {
|
||||
let tx = self.get_or_create(vault.clone()).await;
|
||||
/// Synchronous: safe to invoke from a handler between `commit()` and
|
||||
/// function return without worrying about task cancellation dropping
|
||||
/// the broadcast mid-flight. Failures are logged, never propagated.
|
||||
pub fn send_document_update(&self, vault: VaultId, document: WebSocketServerMessageWithOrigin) {
|
||||
let vault_update_id = match &document.message {
|
||||
WebSocketServerMessage::VaultUpdate(u) => Some(u.document.vault_update_id),
|
||||
WebSocketServerMessage::CursorPositions(_) => None,
|
||||
};
|
||||
let is_deleted = match &document.message {
|
||||
WebSocketServerMessage::VaultUpdate(u) => Some(u.document.is_deleted),
|
||||
WebSocketServerMessage::CursorPositions(_) => None,
|
||||
};
|
||||
let mut tx_map = self
|
||||
.tx
|
||||
.lock()
|
||||
.expect("broadcasts.tx mutex poisoned — a previous holder panicked");
|
||||
let count_before_prune = tx_map
|
||||
.get(&vault)
|
||||
.map_or(0, tokio::sync::broadcast::Sender::receiver_count);
|
||||
let pruned = Self::prune_inactive_vaults(&mut tx_map);
|
||||
let pruned_self = pruned.contains(&vault);
|
||||
|
||||
if tx.receiver_count() == 0 {
|
||||
let sender = tx_map
|
||||
.entry(vault.clone())
|
||||
.or_insert_with(|| broadcast::channel(self.broadcast_channel_capacity).0);
|
||||
|
||||
let count_before_send = sender.receiver_count();
|
||||
|
||||
if count_before_send == 0 {
|
||||
info!(
|
||||
"[BCAST] send_document_update vault={vault} vuid={vault_update_id:?} is_deleted={is_deleted:?} count_before_prune={count_before_prune} pruned_self={pruned_self} count_before_send=0 SKIPPED"
|
||||
);
|
||||
debug!("Skipping broadcast, no clients connected for vault `{vault}`");
|
||||
return;
|
||||
}
|
||||
|
||||
let result = tx
|
||||
.send(document)
|
||||
.context("Cannot broadcast server message to websocket listeners")
|
||||
.map_err(server_error);
|
||||
|
||||
if result.is_err() {
|
||||
warn!("Failed to send message: {result:?}");
|
||||
let send_result = sender.send(document);
|
||||
match &send_result {
|
||||
Ok(n) => info!(
|
||||
"[BCAST] send_document_update vault={vault} vuid={vault_update_id:?} is_deleted={is_deleted:?} count_before_prune={count_before_prune} pruned_self={pruned_self} count_before_send={count_before_send} SENT delivered_to={n}"
|
||||
),
|
||||
Err(e) => warn!(
|
||||
"[BCAST] send_document_update vault={vault} vuid={vault_update_id:?} is_deleted={is_deleted:?} count_before_prune={count_before_prune} pruned_self={pruned_self} count_before_send={count_before_send} FAILED err={e}"
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_or_create(
|
||||
&self,
|
||||
vault: VaultId,
|
||||
) -> broadcast::Sender<WebSocketServerMessageWithOrigin> {
|
||||
let mut tx = self.tx.lock().await;
|
||||
|
||||
tx.entry(vault)
|
||||
.or_insert_with(|| broadcast::channel(self.max_clients_per_vault).0.clone())
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ pub struct WebSocketHandshake {
|
|||
pub token: String,
|
||||
pub device_id: DeviceId,
|
||||
|
||||
#[ts(as = "Option<i32>")]
|
||||
#[ts(type = "number | null")]
|
||||
pub last_seen_vault_update_id: Option<VaultUpdateId>,
|
||||
}
|
||||
|
||||
|
|
@ -22,13 +22,14 @@ pub struct CursorPositionFromClient {
|
|||
}
|
||||
|
||||
#[derive(TS, Serialize, Deserialize, Clone, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct DocumentWithCursors {
|
||||
// It's None in case the document is dirty.
|
||||
// We still want to sync the cursor to mark
|
||||
// that it exists and can be client-side
|
||||
// interpolated. However, the actual
|
||||
// position is meaningless.
|
||||
#[ts(as = "Option<u32>")]
|
||||
#[ts(type = "number | null")]
|
||||
pub vault_update_id: Option<VaultUpdateId>,
|
||||
|
||||
pub document_id: DocumentId,
|
||||
|
|
@ -57,11 +58,19 @@ pub struct CursorPositionFromServer {
|
|||
pub clients: Vec<ClientCursors>,
|
||||
}
|
||||
|
||||
// One committed version. Non-delete updates are broadcast to every
|
||||
// connected client *except* the device that authored them — that
|
||||
// device already has the new state via its HTTP response. Deletes are
|
||||
// broadcast to every client including the author: the author keeps
|
||||
// the document in its sync queue until this receipt arrives so a late
|
||||
// remote update can't sneak in between the HTTP response and the
|
||||
// queue cleanup. The server also emits these one-at-a-time to catch
|
||||
// up a freshly-connected client on versions committed while it was
|
||||
// offline, in ascending `vault_update_id` order.
|
||||
#[derive(TS, Serialize, Clone, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct WebSocketVaultUpdate {
|
||||
pub documents: Vec<DocumentVersionWithoutContent>,
|
||||
pub is_initial_sync: bool,
|
||||
pub document: DocumentVersionWithoutContent,
|
||||
}
|
||||
|
||||
#[derive(TS, Deserialize, Clone, Debug)]
|
||||
|
|
@ -80,6 +89,10 @@ pub enum WebSocketServerMessage {
|
|||
CursorPositions(CursorPositionFromServer),
|
||||
}
|
||||
|
||||
/// Broadcast envelope carrying the message plus the device that produced
|
||||
/// it. The per-recipient send task compares `origin_device_id` against
|
||||
/// its own device id to fill in `originates_from_self` before the message
|
||||
/// is serialized on the wire.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct WebSocketServerMessageWithOrigin {
|
||||
pub origin_device_id: Option<DeviceId>,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ use crate::{
|
|||
database::models::{DocumentVersionWithoutContent, VaultId, VaultUpdateId},
|
||||
},
|
||||
config::user_config::User,
|
||||
errors::{SyncServerError, server_error, unauthenticated_error},
|
||||
errors::{SyncServerError, client_error, server_error, unauthenticated_error},
|
||||
server::auth::auth,
|
||||
};
|
||||
|
||||
|
|
@ -26,7 +26,7 @@ pub fn get_authenticated_handshake(
|
|||
if let Some(Message::Text(message)) = message {
|
||||
let message: WebSocketClientMessage = serde_json::from_str(&message)
|
||||
.context("Failed to parse message")
|
||||
.map_err(server_error)?;
|
||||
.map_err(client_error)?;
|
||||
|
||||
match message {
|
||||
WebSocketClientMessage::Handshake(handshake) => {
|
||||
|
|
@ -44,21 +44,29 @@ pub fn get_authenticated_handshake(
|
|||
}
|
||||
}
|
||||
|
||||
/// Stream the documents the client missed while offline, bounded above
|
||||
/// by `up_to_vault_update_id` so the catch-up is a stable snapshot at
|
||||
/// exactly that cursor. The WebSocket handshake atomically subscribes
|
||||
/// to the broadcast channel and snapshots this cursor under the per-
|
||||
/// vault send lock; commits past the cursor are then delivered solely
|
||||
/// through the broadcast channel (filtered by the same cursor on the
|
||||
/// receive side), so every committed update is delivered exactly once.
|
||||
pub async fn get_unseen_documents(
|
||||
state: &AppState,
|
||||
vault_id: &VaultId,
|
||||
last_seen_vault_update_id: Option<VaultUpdateId>,
|
||||
up_to_vault_update_id: VaultUpdateId,
|
||||
) -> Result<Vec<DocumentVersionWithoutContent>, SyncServerError> {
|
||||
if let Some(update_id) = last_seen_vault_update_id {
|
||||
state
|
||||
.database
|
||||
.get_latest_documents_since(vault_id, update_id, None)
|
||||
.get_latest_documents_since(vault_id, update_id, Some(up_to_vault_update_id), None)
|
||||
.await
|
||||
.map_err(server_error)
|
||||
} else {
|
||||
state
|
||||
.database
|
||||
.get_latest_documents(vault_id, None)
|
||||
.get_latest_documents(vault_id, Some(up_to_vault_update_id), None)
|
||||
.await
|
||||
.map_err(server_error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,15 +1,3 @@
|
|||
use anyhow::Context;
|
||||
use axum::{
|
||||
extract::{
|
||||
Path, State,
|
||||
ws::{Message, WebSocket, WebSocketUpgrade},
|
||||
},
|
||||
response::Response,
|
||||
};
|
||||
use futures::stream::StreamExt;
|
||||
use log::{debug, info};
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::{
|
||||
app_state::{
|
||||
AppState,
|
||||
|
|
@ -24,9 +12,35 @@ use crate::{
|
|||
},
|
||||
},
|
||||
},
|
||||
consts::{
|
||||
HANDSHAKE_TIMEOUT, MAX_CURSOR_DOCUMENTS, MAX_CURSORS_PER_DOCUMENT, MAX_RELATIVE_PATH_LEN,
|
||||
},
|
||||
errors::{SyncServerError, client_error, server_error},
|
||||
utils::normalize::normalize,
|
||||
};
|
||||
use anyhow::Context;
|
||||
use axum::{
|
||||
extract::{
|
||||
Path, State,
|
||||
ws::{Message, WebSocket, WebSocketUpgrade},
|
||||
},
|
||||
response::Response,
|
||||
};
|
||||
use futures::sink::SinkExt;
|
||||
use futures::stream::StreamExt;
|
||||
use log::{debug, info, warn};
|
||||
use serde::Deserialize;
|
||||
|
||||
/// Tracks a pending (not yet authenticated) WebSocket connection.
|
||||
/// Decrements the counter when dropped, ensuring cleanup even if
|
||||
/// the upgrade never completes or auth fails.
|
||||
struct PendingWsGuard(std::sync::Arc<std::sync::atomic::AtomicUsize>);
|
||||
|
||||
impl Drop for PendingWsGuard {
|
||||
fn drop(&mut self) {
|
||||
self.0.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct WebSocketPathParams {
|
||||
|
|
@ -39,13 +53,31 @@ pub async fn websocket_handler(
|
|||
Path(WebSocketPathParams { vault_id }): Path<WebSocketPathParams>,
|
||||
State(state): State<AppState>,
|
||||
) -> Result<Response, SyncServerError> {
|
||||
Ok(ws.on_upgrade(move |socket| websocket_wrapped(state, socket, vault_id)))
|
||||
let current = state
|
||||
.pending_ws_connections
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
if current >= state.config.server.max_pending_websocket_connections {
|
||||
state
|
||||
.pending_ws_connections
|
||||
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
|
||||
return Err(client_error(anyhow::anyhow!(
|
||||
"Too many pending WebSocket connections"
|
||||
)));
|
||||
}
|
||||
|
||||
let guard = PendingWsGuard(state.pending_ws_connections.clone());
|
||||
Ok(ws.on_upgrade(move |socket| websocket_wrapped(state, socket, vault_id, guard)))
|
||||
}
|
||||
|
||||
async fn websocket_wrapped(state: AppState, stream: WebSocket, vault_id: VaultId) {
|
||||
async fn websocket_wrapped(
|
||||
state: AppState,
|
||||
stream: WebSocket,
|
||||
vault_id: VaultId,
|
||||
pending_guard: PendingWsGuard,
|
||||
) {
|
||||
info!("WebSocket connection opened on vault `{vault_id}`");
|
||||
|
||||
let result = websocket(state, stream, vault_id.clone()).await;
|
||||
let result = websocket(state, stream, vault_id.clone(), pending_guard).await;
|
||||
|
||||
if let Err(err) = result {
|
||||
debug!("WebSocket connection error on vault `{vault_id}`: {err}");
|
||||
|
|
@ -57,39 +89,112 @@ async fn websocket(
|
|||
state: AppState,
|
||||
stream: WebSocket,
|
||||
vault_id: VaultId,
|
||||
pending_guard: PendingWsGuard,
|
||||
) -> Result<(), SyncServerError> {
|
||||
let (mut sender, mut websocket_receiver) = stream.split();
|
||||
|
||||
let authed_handshake = get_authenticated_handshake(
|
||||
&state,
|
||||
&vault_id,
|
||||
websocket_receiver
|
||||
.next()
|
||||
let handshake_msg = tokio::time::timeout(HANDSHAKE_TIMEOUT, websocket_receiver.next())
|
||||
.await
|
||||
.map_err(|_| client_error(anyhow::anyhow!("WebSocket handshake timed out")))?
|
||||
.transpose()
|
||||
.unwrap_or_default(),
|
||||
)?;
|
||||
.map_err(|e| client_error(anyhow::anyhow!("WebSocket error during handshake: {e}")))?;
|
||||
|
||||
let authed_handshake = get_authenticated_handshake(&state, &vault_id, handshake_msg)?;
|
||||
|
||||
info!(
|
||||
"WebSocket handshake successful for vault `{vault_id}` for `{}`",
|
||||
authed_handshake.handshake.device_id
|
||||
);
|
||||
|
||||
let mut broadcast_receiver = state.broadcasts.get_receiver(vault_id.clone()).await;
|
||||
// Auth complete — no longer a pending connection.
|
||||
drop(pending_guard);
|
||||
|
||||
send_update_over_websocket(
|
||||
&WebSocketServerMessage::VaultUpdate(WebSocketVaultUpdate {
|
||||
documents: get_unseen_documents(
|
||||
let max_clients = state.config.server.max_clients_per_vault;
|
||||
|
||||
// Atomic subscribe + cursor snapshot, serialized against in-flight
|
||||
// broadcasts:
|
||||
//
|
||||
// 1. Acquire the per-vault broadcast send lock. While we hold it,
|
||||
// no `send_document_update` can run, so no broadcast can fire
|
||||
// between our subscribe and our cursor snapshot.
|
||||
// 2. Subscribe to the broadcast channel (now we'll see every
|
||||
// broadcast that fires after we drop the send guard).
|
||||
// 3. Snapshot `cursor = max committed vault_update_id`. Because
|
||||
// `insert_document_version` holds the same send lock from
|
||||
// *before* the commit through *after* the broadcast, every doc
|
||||
// visible at this cursor has either (a) already had its
|
||||
// broadcast delivered to all then-existing subscribers — and we
|
||||
// weren't one of them, so we'll catch it via the snapshot — or
|
||||
// (b) had its broadcast contend on the lock we're holding, and
|
||||
// will be delivered to us as soon as we drop the guard, with
|
||||
// `vault_update_id > cursor`.
|
||||
// 4. Drop the send guard so writers can resume broadcasting.
|
||||
// 5. Stream the catch-up bounded by the cursor — i.e. only docs
|
||||
// with `vault_update_id <= cursor` — exactly once.
|
||||
// 6. The send task forwards broadcasts but filters to
|
||||
// `vault_update_id > cursor`, so a doc that's both in the
|
||||
// catch-up and in a contended-then-released broadcast is
|
||||
// delivered exactly once (via the catch-up).
|
||||
let send_guard = state.broadcasts.acquire_send_lock(&vault_id).await;
|
||||
let mut broadcast_receiver = match state.broadcasts.get_receiver(vault_id.clone(), max_clients)
|
||||
{
|
||||
Ok(receiver) => receiver,
|
||||
Err(err) => {
|
||||
drop(send_guard);
|
||||
warn!(
|
||||
"Vault `{vault_id}` has reached the maximum number of clients ({max_clients}), rejecting connection from `{}`",
|
||||
authed_handshake.handshake.device_id
|
||||
);
|
||||
if let Err(e) = sender
|
||||
.send(Message::Close(Some(axum::extract::ws::CloseFrame {
|
||||
code: 4000,
|
||||
reason: format!(
|
||||
"Vault has reached the maximum number of clients ({max_clients})"
|
||||
)
|
||||
.into(),
|
||||
})))
|
||||
.await
|
||||
{
|
||||
warn!("Failed to send WebSocket close frame: {e}");
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
let cursor = state
|
||||
.database
|
||||
.get_max_update_id_in_vault(&vault_id, None)
|
||||
.await
|
||||
.map_err(server_error)?;
|
||||
drop(send_guard);
|
||||
|
||||
// Catch-up on versions committed while this client was offline,
|
||||
// streamed one-at-a-time in ascending `vault_update_id` order, up
|
||||
// to the snapshot cursor.
|
||||
let unseen_documents = get_unseen_documents(
|
||||
&state,
|
||||
&vault_id,
|
||||
authed_handshake.handshake.last_seen_vault_update_id,
|
||||
cursor,
|
||||
)
|
||||
.await?,
|
||||
is_initial_sync: true,
|
||||
}),
|
||||
.await?;
|
||||
let unseen_summary: Vec<(i64, bool, String)> = unseen_documents
|
||||
.iter()
|
||||
.map(|d| (d.vault_update_id, d.is_deleted, d.relative_path.clone()))
|
||||
.collect();
|
||||
info!(
|
||||
"[CATCHUP] vault={vault_id} device={} last_seen={:?} cursor={cursor} unseen_count={} unseen={:?}",
|
||||
authed_handshake.handshake.device_id,
|
||||
authed_handshake.handshake.last_seen_vault_update_id,
|
||||
unseen_summary.len(),
|
||||
unseen_summary
|
||||
);
|
||||
for document in unseen_documents {
|
||||
send_update_over_websocket(
|
||||
&WebSocketServerMessage::VaultUpdate(WebSocketVaultUpdate { document }),
|
||||
&mut sender,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
send_update_over_websocket(
|
||||
&WebSocketServerMessage::CursorPositions(CursorPositionFromServer {
|
||||
|
|
@ -101,25 +206,58 @@ async fn websocket(
|
|||
|
||||
let device_id = authed_handshake.handshake.device_id.clone();
|
||||
let mut send_task = tokio::spawn(async move {
|
||||
while let Ok(update) = broadcast_receiver.recv().await {
|
||||
loop {
|
||||
match broadcast_receiver.recv().await {
|
||||
Ok(update) => {
|
||||
// Drop messages this device authored because the HTTP
|
||||
// response already carried authoritative state back.
|
||||
// Delete broadcasts are sent without an origin so the
|
||||
// author also receives them — that's the receipt the
|
||||
// client needs to drop the doc from its sync queue.
|
||||
if Some(&device_id) == update.origin_device_id.as_ref() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Filter out vault updates already covered by the
|
||||
// catch-up snapshot. The handshake atomically
|
||||
// subscribed and snapshotted `cursor` under the
|
||||
// broadcast send lock, so any broadcast with
|
||||
// `vault_update_id <= cursor` is one that contended
|
||||
// on the lock during our subscribe — its row is
|
||||
// already in the catch-up stream and re-delivering
|
||||
// it via this channel would duplicate the message.
|
||||
// Cursor messages aren't versioned and are always
|
||||
// forwarded.
|
||||
if let WebSocketServerMessage::VaultUpdate(WebSocketVaultUpdate { document }) =
|
||||
&update.message
|
||||
&& document.vault_update_id <= cursor
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let message = match update.message {
|
||||
WebSocketServerMessage::CursorPositions(CursorPositionFromServer { clients }) => {
|
||||
WebSocketServerMessage::CursorPositions(CursorPositionFromServer {
|
||||
clients,
|
||||
}) => WebSocketServerMessage::CursorPositions(CursorPositionFromServer {
|
||||
clients: clients
|
||||
.into_iter()
|
||||
.filter(|client| client.device_id != device_id)
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
}),
|
||||
WebSocketServerMessage::VaultUpdate(_) => update.message,
|
||||
};
|
||||
|
||||
send_update_over_websocket(&message, &mut sender).await?;
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
|
||||
warn!(
|
||||
"WebSocket receiver lagged, dropped {n} messages — disconnecting client to force full resync"
|
||||
);
|
||||
break;
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
|
||||
}
|
||||
}
|
||||
|
||||
Ok::<(), SyncServerError>(())
|
||||
});
|
||||
|
|
@ -128,10 +266,12 @@ async fn websocket(
|
|||
let vault_id_clone = vault_id.clone();
|
||||
let cursor_manager = state.cursors.clone();
|
||||
let mut receive_task = tokio::spawn(async move {
|
||||
while let Some(Ok(Message::Text(message))) = websocket_receiver.next().await {
|
||||
while let Some(msg) = websocket_receiver.next().await {
|
||||
match msg {
|
||||
Ok(Message::Text(message)) => {
|
||||
let message: WebSocketClientMessage = serde_json::from_str(&message)
|
||||
.context("Failed to parse WebSocket message from client")
|
||||
.map_err(server_error)?;
|
||||
.map_err(client_error)?;
|
||||
|
||||
match message {
|
||||
WebSocketClientMessage::Handshake(_) => {
|
||||
|
|
@ -140,54 +280,94 @@ async fn websocket(
|
|||
)));
|
||||
}
|
||||
WebSocketClientMessage::CursorPositions(cursors) => {
|
||||
let docs = cursors.documents_with_cursors;
|
||||
if docs.len() > MAX_CURSOR_DOCUMENTS {
|
||||
warn!(
|
||||
"Cursor update rejected: {} documents exceeds limit of {MAX_CURSOR_DOCUMENTS}",
|
||||
docs.len()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let valid = docs.iter().all(|doc| {
|
||||
doc.cursors.len() <= MAX_CURSORS_PER_DOCUMENT
|
||||
&& doc.relative_path.len() <= MAX_RELATIVE_PATH_LEN
|
||||
});
|
||||
if !valid {
|
||||
warn!(
|
||||
"Cursor update rejected: a document exceeds cursor or path length limits"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
cursor_manager
|
||||
.update_cursors(
|
||||
vault_id_clone.clone(),
|
||||
authed_handshake.user.name.clone(),
|
||||
&device_id,
|
||||
cursors.documents_with_cursors,
|
||||
docs,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Message::Close(_)) => break,
|
||||
Ok(Message::Binary(_)) => {
|
||||
warn!("Received unexpected binary WebSocket message, ignoring");
|
||||
}
|
||||
Ok(_) => {} // Ping/Pong frames handled by axum
|
||||
Err(e) => {
|
||||
debug!("WebSocket receive error: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok::<(), SyncServerError>(())
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
_ = &mut send_task => receive_task.abort(),
|
||||
_ = &mut receive_task => send_task.abort(),
|
||||
let result: Result<(), SyncServerError> = tokio::select! {
|
||||
send_result = &mut send_task => {
|
||||
receive_task.abort();
|
||||
let _ = receive_task.await;
|
||||
match send_result {
|
||||
Err(e) => Err(server_error(
|
||||
anyhow::Error::from(e).context("WebSocket send task failed"),
|
||||
)),
|
||||
Ok(inner) => inner,
|
||||
}
|
||||
},
|
||||
receive_result = &mut receive_task => {
|
||||
send_task.abort();
|
||||
let _ = send_task.await;
|
||||
match receive_result {
|
||||
Err(e) => Err(server_error(
|
||||
anyhow::Error::from(e).context("WebSocket receive task failed"),
|
||||
)),
|
||||
Ok(inner) => inner,
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let result: Result<(), SyncServerError> = (async {
|
||||
send_task
|
||||
.await
|
||||
.context("WebSocket send task failed")
|
||||
.map_err(client_error)
|
||||
.and_then(|err| err)?;
|
||||
|
||||
receive_task
|
||||
.await
|
||||
.context("WebSocket receive task failed")
|
||||
.map_err(client_error)
|
||||
.and_then(|err| err)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await;
|
||||
|
||||
state
|
||||
.cursors
|
||||
.remove_cursors_of_device(&vault_id, &authed_handshake.handshake.device_id)
|
||||
.await;
|
||||
|
||||
if result.is_err() {
|
||||
match &result {
|
||||
Ok(()) => {
|
||||
info!(
|
||||
"WebSocket disconnected on vault `{vault_id}` for `{}`",
|
||||
authed_handshake.handshake.device_id
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"WebSocket error on vault `{vault_id}` for `{}`: {err}",
|
||||
authed_handshake.handshake.device_id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue