Return user name for cursors rather than device

This commit is contained in:
Andras Schmelczer 2025-06-07 22:13:07 +01:00
parent 14db4bf240
commit 02f32e894a
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
5 changed files with 27 additions and 12 deletions

View file

@ -34,6 +34,7 @@ impl Cursors {
pub async fn update_cursors( pub async fn update_cursors(
&self, &self,
vault_id: VaultId, vault_id: VaultId,
user_name: String,
device_id: &DeviceId, device_id: &DeviceId,
document_to_cursors: HashMap<String, Vec<CursorSpan>>, document_to_cursors: HashMap<String, Vec<CursorSpan>>,
) { ) {
@ -43,6 +44,7 @@ impl Cursors {
all_device_cursors.retain(|c| &c.client_cursors.device_id != device_id); all_device_cursors.retain(|c| &c.client_cursors.device_id != device_id);
all_device_cursors.push(ClientCursorsWithTimeToLive::new(ClientCursors { all_device_cursors.push(ClientCursorsWithTimeToLive::new(ClientCursors {
user_name,
device_id: device_id.to_string(), device_id: device_id.to_string(),
cursors: document_to_cursors, cursors: document_to_cursors,
})); }));

View file

@ -31,6 +31,7 @@ pub struct CursorPositionFromClient {
#[derive(TS, Serialize, Clone, Debug)] #[derive(TS, Serialize, Clone, Debug)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct ClientCursors { pub struct ClientCursors {
pub user_name: String,
pub device_id: DeviceId, pub device_id: DeviceId,
pub cursors: HashMap<String, Vec<CursorSpan>>, pub cursors: HashMap<String, Vec<CursorSpan>>,
} }

View file

@ -8,15 +8,21 @@ use crate::{
AppState, AppState,
database::models::{DocumentVersionWithoutContent, VaultId, VaultUpdateId}, database::models::{DocumentVersionWithoutContent, VaultId, VaultUpdateId},
}, },
config::user_config::User,
errors::{SyncServerError, server_error, unauthenticated_error}, errors::{SyncServerError, server_error, unauthenticated_error},
server::auth::auth, server::auth::auth,
}; };
pub struct AuthenticatedWebSocketHandshake {
pub handshake: WebSocketHandshake,
pub user: User,
}
pub fn get_authenticated_handshake( pub fn get_authenticated_handshake(
state: &AppState, state: &AppState,
vault_id: &VaultId, vault_id: &VaultId,
message: Option<Message>, message: Option<Message>,
) -> Result<WebSocketHandshake, SyncServerError> { ) -> Result<AuthenticatedWebSocketHandshake, SyncServerError> {
if let Some(Message::Text(message)) = message { if let Some(Message::Text(message)) = message {
let message: WebSocketClientMessage = serde_json::from_str(&message) let message: WebSocketClientMessage = serde_json::from_str(&message)
.context("Failed to parse message") .context("Failed to parse message")
@ -24,8 +30,8 @@ pub fn get_authenticated_handshake(
match message { match message {
WebSocketClientMessage::Handshake(handshake) => { WebSocketClientMessage::Handshake(handshake) => {
auth(state, handshake.token.trim(), vault_id)?; let user = auth(state, handshake.token.trim(), vault_id)?;
Ok(handshake) Ok(AuthenticatedWebSocketHandshake { handshake, user })
} }
WebSocketClientMessage::CursorPositions(_) => Err(unauthenticated_error( WebSocketClientMessage::CursorPositions(_) => Err(unauthenticated_error(
anyhow::anyhow!("Expected a handshake message"), anyhow::anyhow!("Expected a handshake message"),

View file

@ -52,6 +52,7 @@ async fn websocket_wrapped(state: AppState, stream: WebSocket, vault_id: VaultId
} }
} }
#[allow(clippy::too_many_lines)]
async fn websocket( async fn websocket(
state: AppState, state: AppState,
stream: WebSocket, stream: WebSocket,
@ -59,7 +60,7 @@ async fn websocket(
) -> Result<(), SyncServerError> { ) -> Result<(), SyncServerError> {
let (mut sender, mut websocket_receiver) = stream.split(); let (mut sender, mut websocket_receiver) = stream.split();
let handshake = get_authenticated_handshake( let authed_handshake = get_authenticated_handshake(
&state, &state,
&vault_id, &vault_id,
websocket_receiver websocket_receiver
@ -71,15 +72,19 @@ async fn websocket(
info!( info!(
"WebSocket handshake successful for vault '{vault_id}' for '{}'", "WebSocket handshake successful for vault '{vault_id}' for '{}'",
handshake.device_id authed_handshake.handshake.device_id
); );
let mut broadcast_receiver = state.broadcasts.get_receiver(vault_id.clone()).await; let mut broadcast_receiver = state.broadcasts.get_receiver(vault_id.clone()).await;
send_update_over_websocket( send_update_over_websocket(
&WebSocketServerMessage::VaultUpdate(WebSocketVaultUpdate { &WebSocketServerMessage::VaultUpdate(WebSocketVaultUpdate {
documents: get_unseen_documents(&state, &vault_id, handshake.last_seen_vault_update_id) documents: get_unseen_documents(
.await?, &state,
&vault_id,
authed_handshake.handshake.last_seen_vault_update_id,
)
.await?,
is_initial_sync: true, is_initial_sync: true,
}), }),
&mut sender, &mut sender,
@ -94,7 +99,7 @@ async fn websocket(
) )
.await?; .await?;
let device_id = handshake.device_id.clone(); let device_id = authed_handshake.handshake.device_id.clone();
let mut send_task = tokio::spawn(async move { let mut send_task = tokio::spawn(async move {
while let Ok(update) = broadcast_receiver.recv().await { while let Ok(update) = broadcast_receiver.recv().await {
if Some(&device_id) == update.origin_device_id.as_ref() { if Some(&device_id) == update.origin_device_id.as_ref() {
@ -107,7 +112,7 @@ async fn websocket(
Ok::<(), SyncServerError>(()) Ok::<(), SyncServerError>(())
}); });
let device_id = handshake.device_id.clone(); let device_id = authed_handshake.handshake.device_id.clone();
let vault_id_clone = vault_id.clone(); let vault_id_clone = vault_id.clone();
let cursor_manager = state.cursors.clone(); let cursor_manager = state.cursors.clone();
let mut receive_task = tokio::spawn(async move { let mut receive_task = tokio::spawn(async move {
@ -126,6 +131,7 @@ async fn websocket(
cursor_manager cursor_manager
.update_cursors( .update_cursors(
vault_id_clone.clone(), vault_id_clone.clone(),
authed_handshake.user.name.clone(),
&device_id, &device_id,
cursors.document_to_cursors, cursors.document_to_cursors,
) )
@ -161,13 +167,13 @@ async fn websocket(
state state
.cursors .cursors
.remove_cursors_of_device(&vault_id, &handshake.device_id) .remove_cursors_of_device(&vault_id, &authed_handshake.handshake.device_id)
.await; .await;
if result.is_err() { if result.is_err() {
info!( info!(
"WebSocket disconnected on vault '{vault_id}' for '{}'", "WebSocket disconnected on vault '{vault_id}' for '{}'",
handshake.device_id authed_handshake.handshake.device_id
); );
} }

View file

@ -1,4 +1,4 @@
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. // This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
import type { CursorSpan } from "./CursorSpan"; import type { CursorSpan } from "./CursorSpan";
export type ClientCursors = { deviceId: string, cursors: { [key in string]?: Array<CursorSpan> }, }; export type ClientCursors = { userName: string, deviceId: string, cursors: { [key in string]?: Array<CursorSpan> }, };