Add WS endpoint

This commit is contained in:
Andras Schmelczer 2025-03-25 22:26:32 +00:00
parent 0320308f1a
commit 63a4948b87
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
2 changed files with 160 additions and 13 deletions

View file

@ -1,3 +1,16 @@
mod auth;
mod create_document;
mod delete_document;
mod fetch_document_version;
mod fetch_document_version_content;
mod fetch_latest_document_version;
mod fetch_latest_documents;
mod ping;
mod requests;
mod responses;
mod update_document;
mod websocket;
use std::{ffi::OsString, sync::Arc};
use aide::{
@ -10,7 +23,6 @@ use aide::{
transform::TransformOpenApi,
};
use anyhow::{Context as _, Result, anyhow};
use app_state::AppState;
use axum::{
Extension, Json,
extract::{DefaultBodyLimit, Request},
@ -32,21 +44,10 @@ use tower_http::{
use tracing::{Level, info_span};
use crate::{
app_state::AppState,
config::server_config::ServerConfig,
errors::{SerializedError, not_found_error},
};
mod app_state;
mod auth;
mod create_document;
mod delete_document;
mod fetch_document_version;
mod fetch_document_version_content;
mod fetch_latest_document_version;
mod fetch_latest_documents;
mod ping;
mod requests;
mod responses;
mod update_document;
pub async fn create_server(config_path: Option<OsString>) -> Result<()> {
aide::r#gen::on_error(|err| error!("{err}"));
@ -65,6 +66,7 @@ pub async fn create_server(config_path: Option<OsString>) -> Result<()> {
"/vaults/:vault_id/documents",
get(fetch_latest_documents::fetch_latest_documents),
)
.route("/vaults/:vault_id/ws", get(websocket::websocket_handler))
.api_route(
"/vaults/:vault_id/documents",
post(create_document::create_document_multipart),

View file

@ -0,0 +1,145 @@
use anyhow::Context;
use axum::{
extract::{
Path, Query, State,
ws::{Message, WebSocket, WebSocketUpgrade},
},
response::Response,
};
use futures::{
sink::SinkExt,
stream::{SplitSink, StreamExt},
};
use log::{error, info, warn};
use schemars::JsonSchema;
use serde::Deserialize;
use super::auth::auth;
use crate::{
app_state::{
AppState,
database::models::{DocumentVersionWithoutContent, VaultId, VaultUpdateId},
},
errors::{SyncServerError, server_error, unauthorized_error},
};
// This is required for aide to infer the path parameter types and names
#[derive(Deserialize, JsonSchema)]
pub struct WebsocketPathParams {
vault_id: VaultId,
}
// This is required for aide to infer the path parameter types and names
#[derive(Deserialize, JsonSchema)]
pub struct QueryParams {
since_update_id: Option<VaultUpdateId>,
}
pub async fn websocket_handler(
ws: WebSocketUpgrade,
Path(WebsocketPathParams { vault_id }): Path<WebsocketPathParams>,
Query(QueryParams { since_update_id }): Query<QueryParams>,
State(state): State<AppState>,
) -> Result<Response, SyncServerError> {
Ok(ws.on_upgrade(move |socket| websocket_wrapped(state, socket, vault_id, since_update_id)))
}
async fn websocket_wrapped(
state: AppState,
stream: WebSocket,
vault_id: VaultId,
since_update_id: Option<VaultUpdateId>,
) {
info!("Websocket connection opened on vault '{}'", vault_id);
let result = websocket(state, stream, vault_id.clone(), since_update_id).await;
if let Err(err) = result {
error!(
"Websocket connection error on vault '{}': {}",
vault_id, err
);
}
warn!("Websocket connection closed on vault '{}'", vault_id);
}
async fn websocket(
state: AppState,
stream: WebSocket,
vault_id: VaultId,
since_update_id: Option<VaultUpdateId>,
) -> Result<(), SyncServerError> {
let (mut sender, mut receiver) = stream.split();
if let Some(Ok(Message::Text(token))) = receiver.next().await {
auth(&state, &token)?;
} else {
return Err(unauthorized_error(anyhow::anyhow!(
"Failed to authenticate"
)));
}
let mut rx = state.broadcasts.get_receiver(vault_id.clone()).await;
let documents = if let Some(since_update_id) = since_update_id {
state
.database
.get_latest_documents_since(&vault_id, since_update_id, None)
.await
.map_err(server_error)
} else {
state
.database
.get_latest_documents(&vault_id, None)
.await
.map_err(server_error)
}?;
for document in documents {
send_document_over_websocket(document, &mut sender).await?;
}
let mut send_task = tokio::spawn(async move {
while let Ok(update) = rx.recv().await {
send_document_over_websocket(update, &mut sender).await?;
}
Ok::<(), SyncServerError>(())
});
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(Message::Text(text))) = receiver.next().await {
info!("Received message: {}", text);
// Add username before message.
// let _ = tx.send(format!("{name}: {text}"));
}
});
tokio::select! {
_ = &mut send_task => recv_task.abort(),
_ = &mut recv_task => send_task.abort(),
};
send_task
.await
.context("Websocket send task failed")
.map_err(server_error)??;
Ok(())
}
async fn send_document_over_websocket(
document: DocumentVersionWithoutContent,
sender: &mut SplitSink<WebSocket, Message>,
) -> Result<(), SyncServerError> {
let serialized_update = serde_json::to_string(&document)
.context("Failed to serialize update")
.map_err(server_error)?;
sender
.send(Message::Text(serialized_update))
.await
.context("Failed to send message over websocket")
.map_err(server_error)
}