Add proper shutdown, rate limits, config validation, cors config, fix dangling cursors, cache regex, merge created texts
This commit is contained in:
parent
4763bc9d04
commit
e15b0f9903
28 changed files with 1277 additions and 464 deletions
|
|
@ -6,10 +6,10 @@ use axum::{
|
|||
},
|
||||
response::Response,
|
||||
};
|
||||
use futures::sink::SinkExt;
|
||||
use futures::stream::StreamExt;
|
||||
use log::{debug, info};
|
||||
use log::{debug, info, warn};
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::{
|
||||
app_state::{
|
||||
AppState,
|
||||
|
|
@ -24,10 +24,26 @@ use crate::{
|
|||
},
|
||||
},
|
||||
},
|
||||
consts::{
|
||||
HANDSHAKE_TIMEOUT, MAX_CURSORS_PER_DOCUMENT, MAX_CURSOR_DOCUMENTS,
|
||||
MAX_RELATIVE_PATH_LEN,
|
||||
},
|
||||
errors::{SyncServerError, client_error, server_error},
|
||||
utils::normalize::normalize,
|
||||
};
|
||||
|
||||
/// Tracks a pending (not yet authenticated) WebSocket connection.
|
||||
/// Decrements the counter when dropped, ensuring cleanup even if
|
||||
/// the upgrade never completes or auth fails.
|
||||
struct PendingWsGuard(std::sync::Arc<std::sync::atomic::AtomicUsize>);
|
||||
|
||||
impl Drop for PendingWsGuard {
|
||||
fn drop(&mut self) {
|
||||
self.0
|
||||
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct WebSocketPathParams {
|
||||
#[serde(deserialize_with = "normalize")]
|
||||
|
|
@ -39,13 +55,31 @@ pub async fn websocket_handler(
|
|||
Path(WebSocketPathParams { vault_id }): Path<WebSocketPathParams>,
|
||||
State(state): State<AppState>,
|
||||
) -> Result<Response, SyncServerError> {
|
||||
Ok(ws.on_upgrade(move |socket| websocket_wrapped(state, socket, vault_id)))
|
||||
let current = state
|
||||
.pending_ws_connections
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
if current >= state.config.server.max_pending_websocket_connections {
|
||||
state
|
||||
.pending_ws_connections
|
||||
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
|
||||
return Err(client_error(anyhow::anyhow!(
|
||||
"Too many pending WebSocket connections"
|
||||
)));
|
||||
}
|
||||
|
||||
let guard = PendingWsGuard(state.pending_ws_connections.clone());
|
||||
Ok(ws.on_upgrade(move |socket| websocket_wrapped(state, socket, vault_id, guard)))
|
||||
}
|
||||
|
||||
async fn websocket_wrapped(state: AppState, stream: WebSocket, vault_id: VaultId) {
|
||||
async fn websocket_wrapped(
|
||||
state: AppState,
|
||||
stream: WebSocket,
|
||||
vault_id: VaultId,
|
||||
pending_guard: PendingWsGuard,
|
||||
) {
|
||||
info!("WebSocket connection opened on vault `{vault_id}`");
|
||||
|
||||
let result = websocket(state, stream, vault_id.clone()).await;
|
||||
let result = websocket(state, stream, vault_id.clone(), pending_guard).await;
|
||||
|
||||
if let Err(err) = result {
|
||||
debug!("WebSocket connection error on vault `{vault_id}`: {err}");
|
||||
|
|
@ -57,25 +91,53 @@ async fn websocket(
|
|||
state: AppState,
|
||||
stream: WebSocket,
|
||||
vault_id: VaultId,
|
||||
pending_guard: PendingWsGuard,
|
||||
) -> Result<(), SyncServerError> {
|
||||
let (mut sender, mut websocket_receiver) = stream.split();
|
||||
|
||||
let authed_handshake = get_authenticated_handshake(
|
||||
&state,
|
||||
&vault_id,
|
||||
websocket_receiver
|
||||
.next()
|
||||
.await
|
||||
.transpose()
|
||||
.unwrap_or_default(),
|
||||
)?;
|
||||
let handshake_msg = tokio::time::timeout(HANDSHAKE_TIMEOUT, websocket_receiver.next())
|
||||
.await
|
||||
.map_err(|_| client_error(anyhow::anyhow!("WebSocket handshake timed out")))?
|
||||
.transpose()
|
||||
.map_err(|e| client_error(anyhow::anyhow!("WebSocket error during handshake: {e}")))?;
|
||||
|
||||
let authed_handshake = get_authenticated_handshake(&state, &vault_id, handshake_msg)?;
|
||||
|
||||
info!(
|
||||
"WebSocket handshake successful for vault `{vault_id}` for `{}`",
|
||||
authed_handshake.handshake.device_id
|
||||
);
|
||||
|
||||
let mut broadcast_receiver = state.broadcasts.get_receiver(vault_id.clone()).await;
|
||||
// Auth complete — no longer a pending connection.
|
||||
drop(pending_guard);
|
||||
|
||||
let max_clients = state.config.server.max_clients_per_vault;
|
||||
let mut broadcast_receiver = match state
|
||||
.broadcasts
|
||||
.get_receiver(vault_id.clone(), max_clients)
|
||||
.await
|
||||
{
|
||||
Ok(receiver) => receiver,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"Vault `{vault_id}` has reached the maximum number of clients ({max_clients}), rejecting connection from `{}`",
|
||||
authed_handshake.handshake.device_id
|
||||
);
|
||||
if let Err(e) = sender
|
||||
.send(Message::Close(Some(axum::extract::ws::CloseFrame {
|
||||
code: 4000,
|
||||
reason: format!(
|
||||
"Vault has reached the maximum number of clients ({max_clients})"
|
||||
)
|
||||
.into(),
|
||||
})))
|
||||
.await
|
||||
{
|
||||
warn!("Failed to send WebSocket close frame: {e}");
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
send_update_over_websocket(
|
||||
&WebSocketServerMessage::VaultUpdate(WebSocketVaultUpdate {
|
||||
|
|
@ -101,24 +163,35 @@ async fn websocket(
|
|||
|
||||
let device_id = authed_handshake.handshake.device_id.clone();
|
||||
let mut send_task = tokio::spawn(async move {
|
||||
while let Ok(update) = broadcast_receiver.recv().await {
|
||||
if Some(&device_id) == update.origin_device_id.as_ref() {
|
||||
continue;
|
||||
}
|
||||
loop {
|
||||
match broadcast_receiver.recv().await {
|
||||
Ok(update) => {
|
||||
if Some(&device_id) == update.origin_device_id.as_ref() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let message = match update.message {
|
||||
WebSocketServerMessage::CursorPositions(CursorPositionFromServer { clients }) => {
|
||||
WebSocketServerMessage::CursorPositions(CursorPositionFromServer {
|
||||
clients: clients
|
||||
.into_iter()
|
||||
.filter(|client| client.device_id != device_id)
|
||||
.collect(),
|
||||
})
|
||||
let message = match update.message {
|
||||
WebSocketServerMessage::CursorPositions(CursorPositionFromServer {
|
||||
clients,
|
||||
}) => WebSocketServerMessage::CursorPositions(CursorPositionFromServer {
|
||||
clients: clients
|
||||
.into_iter()
|
||||
.filter(|client| client.device_id != device_id)
|
||||
.collect(),
|
||||
}),
|
||||
WebSocketServerMessage::VaultUpdate(_) => update.message,
|
||||
};
|
||||
|
||||
send_update_over_websocket(&message, &mut sender).await?;
|
||||
}
|
||||
WebSocketServerMessage::VaultUpdate(_) => update.message,
|
||||
};
|
||||
|
||||
send_update_over_websocket(&message, &mut sender).await?;
|
||||
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
|
||||
warn!(
|
||||
"WebSocket receiver lagged, dropped {n} messages — disconnecting client to force full resync"
|
||||
);
|
||||
break;
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
|
||||
}
|
||||
}
|
||||
|
||||
Ok::<(), SyncServerError>(())
|
||||
|
|
@ -128,26 +201,57 @@ async fn websocket(
|
|||
let vault_id_clone = vault_id.clone();
|
||||
let cursor_manager = state.cursors.clone();
|
||||
let mut receive_task = tokio::spawn(async move {
|
||||
while let Some(Ok(Message::Text(message))) = websocket_receiver.next().await {
|
||||
let message: WebSocketClientMessage = serde_json::from_str(&message)
|
||||
.context("Failed to parse WebSocket message from client")
|
||||
.map_err(server_error)?;
|
||||
while let Some(msg) = websocket_receiver.next().await {
|
||||
match msg {
|
||||
Ok(Message::Text(message)) => {
|
||||
let message: WebSocketClientMessage = serde_json::from_str(&message)
|
||||
.context("Failed to parse WebSocket message from client")
|
||||
.map_err(client_error)?;
|
||||
|
||||
match message {
|
||||
WebSocketClientMessage::Handshake(_) => {
|
||||
return Err(client_error(anyhow::anyhow!(
|
||||
"Unexpected handshake message"
|
||||
)));
|
||||
match message {
|
||||
WebSocketClientMessage::Handshake(_) => {
|
||||
return Err(client_error(anyhow::anyhow!(
|
||||
"Unexpected handshake message"
|
||||
)));
|
||||
}
|
||||
WebSocketClientMessage::CursorPositions(cursors) => {
|
||||
let docs = cursors.documents_with_cursors;
|
||||
if docs.len() > MAX_CURSOR_DOCUMENTS {
|
||||
warn!(
|
||||
"Cursor update rejected: {} documents exceeds limit of {MAX_CURSOR_DOCUMENTS}",
|
||||
docs.len()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let valid = docs.iter().all(|doc| {
|
||||
doc.cursors.len() <= MAX_CURSORS_PER_DOCUMENT
|
||||
&& doc.relative_path.len() <= MAX_RELATIVE_PATH_LEN
|
||||
});
|
||||
if !valid {
|
||||
warn!("Cursor update rejected: a document exceeds cursor or path length limits");
|
||||
continue;
|
||||
}
|
||||
|
||||
cursor_manager
|
||||
.update_cursors(
|
||||
vault_id_clone.clone(),
|
||||
authed_handshake.user.name.clone(),
|
||||
&device_id,
|
||||
docs,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
WebSocketClientMessage::CursorPositions(cursors) => {
|
||||
cursor_manager
|
||||
.update_cursors(
|
||||
vault_id_clone.clone(),
|
||||
authed_handshake.user.name.clone(),
|
||||
&device_id,
|
||||
cursors.documents_with_cursors,
|
||||
)
|
||||
.await;
|
||||
Ok(Message::Close(_)) => break,
|
||||
Ok(Message::Binary(_)) => {
|
||||
warn!("Received unexpected binary WebSocket message, ignoring");
|
||||
}
|
||||
Ok(_) => {} // Ping/Pong frames handled by axum
|
||||
Err(e) => {
|
||||
debug!("WebSocket receive error: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -155,38 +259,47 @@ async fn websocket(
|
|||
Ok::<(), SyncServerError>(())
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
_ = &mut send_task => receive_task.abort(),
|
||||
_ = &mut receive_task => send_task.abort(),
|
||||
let result: Result<(), SyncServerError> = tokio::select! {
|
||||
send_result = &mut send_task => {
|
||||
receive_task.abort();
|
||||
let _ = receive_task.await;
|
||||
match send_result {
|
||||
Err(e) => Err(server_error(
|
||||
anyhow::Error::from(e).context("WebSocket send task failed"),
|
||||
)),
|
||||
Ok(inner) => inner,
|
||||
}
|
||||
},
|
||||
receive_result = &mut receive_task => {
|
||||
send_task.abort();
|
||||
let _ = send_task.await;
|
||||
match receive_result {
|
||||
Err(e) => Err(server_error(
|
||||
anyhow::Error::from(e).context("WebSocket receive task failed"),
|
||||
)),
|
||||
Ok(inner) => inner,
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let result: Result<(), SyncServerError> = (async {
|
||||
send_task
|
||||
.await
|
||||
.context("WebSocket send task failed")
|
||||
.map_err(client_error)
|
||||
.and_then(|err| err)?;
|
||||
|
||||
receive_task
|
||||
.await
|
||||
.context("WebSocket receive task failed")
|
||||
.map_err(client_error)
|
||||
.and_then(|err| err)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await;
|
||||
|
||||
state
|
||||
.cursors
|
||||
.remove_cursors_of_device(&vault_id, &authed_handshake.handshake.device_id)
|
||||
.await;
|
||||
|
||||
if result.is_err() {
|
||||
info!(
|
||||
"WebSocket disconnected on vault `{vault_id}` for `{}`",
|
||||
authed_handshake.handshake.device_id
|
||||
);
|
||||
match &result {
|
||||
Ok(()) => {
|
||||
info!(
|
||||
"WebSocket disconnected on vault `{vault_id}` for `{}`",
|
||||
authed_handshake.handshake.device_id
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"WebSocket error on vault `{vault_id}` for `{}`: {err}",
|
||||
authed_handshake.handshake.device_id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue