Add WS endpoint
This commit is contained in:
parent
0320308f1a
commit
63a4948b87
2 changed files with 160 additions and 13 deletions
|
|
@ -1,3 +1,16 @@
|
|||
mod auth;
|
||||
mod create_document;
|
||||
mod delete_document;
|
||||
mod fetch_document_version;
|
||||
mod fetch_document_version_content;
|
||||
mod fetch_latest_document_version;
|
||||
mod fetch_latest_documents;
|
||||
mod ping;
|
||||
mod requests;
|
||||
mod responses;
|
||||
mod update_document;
|
||||
mod websocket;
|
||||
|
||||
use std::{ffi::OsString, sync::Arc};
|
||||
|
||||
use aide::{
|
||||
|
|
@ -10,7 +23,6 @@ use aide::{
|
|||
transform::TransformOpenApi,
|
||||
};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use app_state::AppState;
|
||||
use axum::{
|
||||
Extension, Json,
|
||||
extract::{DefaultBodyLimit, Request},
|
||||
|
|
@ -32,21 +44,10 @@ use tower_http::{
|
|||
use tracing::{Level, info_span};
|
||||
|
||||
use crate::{
|
||||
app_state::AppState,
|
||||
config::server_config::ServerConfig,
|
||||
errors::{SerializedError, not_found_error},
|
||||
};
|
||||
mod app_state;
|
||||
mod auth;
|
||||
mod create_document;
|
||||
mod delete_document;
|
||||
mod fetch_document_version;
|
||||
mod fetch_document_version_content;
|
||||
mod fetch_latest_document_version;
|
||||
mod fetch_latest_documents;
|
||||
mod ping;
|
||||
mod requests;
|
||||
mod responses;
|
||||
mod update_document;
|
||||
|
||||
pub async fn create_server(config_path: Option<OsString>) -> Result<()> {
|
||||
aide::r#gen::on_error(|err| error!("{err}"));
|
||||
|
|
@ -65,6 +66,7 @@ pub async fn create_server(config_path: Option<OsString>) -> Result<()> {
|
|||
"/vaults/:vault_id/documents",
|
||||
get(fetch_latest_documents::fetch_latest_documents),
|
||||
)
|
||||
.route("/vaults/:vault_id/ws", get(websocket::websocket_handler))
|
||||
.api_route(
|
||||
"/vaults/:vault_id/documents",
|
||||
post(create_document::create_document_multipart),
|
||||
|
|
|
|||
145
backend/sync_server/src/server/websocket.rs
Normal file
145
backend/sync_server/src/server/websocket.rs
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
use anyhow::Context;
|
||||
use axum::{
|
||||
extract::{
|
||||
Path, Query, State,
|
||||
ws::{Message, WebSocket, WebSocketUpgrade},
|
||||
},
|
||||
response::Response,
|
||||
};
|
||||
use futures::{
|
||||
sink::SinkExt,
|
||||
stream::{SplitSink, StreamExt},
|
||||
};
|
||||
use log::{error, info, warn};
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::auth::auth;
|
||||
use crate::{
|
||||
app_state::{
|
||||
AppState,
|
||||
database::models::{DocumentVersionWithoutContent, VaultId, VaultUpdateId},
|
||||
},
|
||||
errors::{SyncServerError, server_error, unauthorized_error},
|
||||
};
|
||||
|
||||
// This is required for aide to infer the path parameter types and names
|
||||
#[derive(Deserialize, JsonSchema)]
|
||||
pub struct WebsocketPathParams {
|
||||
vault_id: VaultId,
|
||||
}
|
||||
|
||||
// This is required for aide to infer the path parameter types and names
|
||||
#[derive(Deserialize, JsonSchema)]
|
||||
pub struct QueryParams {
|
||||
since_update_id: Option<VaultUpdateId>,
|
||||
}
|
||||
|
||||
pub async fn websocket_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
Path(WebsocketPathParams { vault_id }): Path<WebsocketPathParams>,
|
||||
Query(QueryParams { since_update_id }): Query<QueryParams>,
|
||||
State(state): State<AppState>,
|
||||
) -> Result<Response, SyncServerError> {
|
||||
Ok(ws.on_upgrade(move |socket| websocket_wrapped(state, socket, vault_id, since_update_id)))
|
||||
}
|
||||
|
||||
async fn websocket_wrapped(
|
||||
state: AppState,
|
||||
stream: WebSocket,
|
||||
vault_id: VaultId,
|
||||
since_update_id: Option<VaultUpdateId>,
|
||||
) {
|
||||
info!("Websocket connection opened on vault '{}'", vault_id);
|
||||
|
||||
let result = websocket(state, stream, vault_id.clone(), since_update_id).await;
|
||||
|
||||
if let Err(err) = result {
|
||||
error!(
|
||||
"Websocket connection error on vault '{}': {}",
|
||||
vault_id, err
|
||||
);
|
||||
}
|
||||
|
||||
warn!("Websocket connection closed on vault '{}'", vault_id);
|
||||
}
|
||||
|
||||
async fn websocket(
|
||||
state: AppState,
|
||||
stream: WebSocket,
|
||||
vault_id: VaultId,
|
||||
since_update_id: Option<VaultUpdateId>,
|
||||
) -> Result<(), SyncServerError> {
|
||||
let (mut sender, mut receiver) = stream.split();
|
||||
|
||||
if let Some(Ok(Message::Text(token))) = receiver.next().await {
|
||||
auth(&state, &token)?;
|
||||
} else {
|
||||
return Err(unauthorized_error(anyhow::anyhow!(
|
||||
"Failed to authenticate"
|
||||
)));
|
||||
}
|
||||
|
||||
let mut rx = state.broadcasts.get_receiver(vault_id.clone()).await;
|
||||
|
||||
let documents = if let Some(since_update_id) = since_update_id {
|
||||
state
|
||||
.database
|
||||
.get_latest_documents_since(&vault_id, since_update_id, None)
|
||||
.await
|
||||
.map_err(server_error)
|
||||
} else {
|
||||
state
|
||||
.database
|
||||
.get_latest_documents(&vault_id, None)
|
||||
.await
|
||||
.map_err(server_error)
|
||||
}?;
|
||||
|
||||
for document in documents {
|
||||
send_document_over_websocket(document, &mut sender).await?;
|
||||
}
|
||||
|
||||
let mut send_task = tokio::spawn(async move {
|
||||
while let Ok(update) = rx.recv().await {
|
||||
send_document_over_websocket(update, &mut sender).await?;
|
||||
}
|
||||
|
||||
Ok::<(), SyncServerError>(())
|
||||
});
|
||||
|
||||
let mut recv_task = tokio::spawn(async move {
|
||||
while let Some(Ok(Message::Text(text))) = receiver.next().await {
|
||||
info!("Received message: {}", text);
|
||||
// Add username before message.
|
||||
// let _ = tx.send(format!("{name}: {text}"));
|
||||
}
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
_ = &mut send_task => recv_task.abort(),
|
||||
_ = &mut recv_task => send_task.abort(),
|
||||
};
|
||||
|
||||
send_task
|
||||
.await
|
||||
.context("Websocket send task failed")
|
||||
.map_err(server_error)??;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_document_over_websocket(
|
||||
document: DocumentVersionWithoutContent,
|
||||
sender: &mut SplitSink<WebSocket, Message>,
|
||||
) -> Result<(), SyncServerError> {
|
||||
let serialized_update = serde_json::to_string(&document)
|
||||
.context("Failed to serialize update")
|
||||
.map_err(server_error)?;
|
||||
|
||||
sender
|
||||
.send(Message::Text(serialized_update))
|
||||
.await
|
||||
.context("Failed to send message over websocket")
|
||||
.map_err(server_error)
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue