Add WebSocket support (#12)

This commit is contained in:
Andras Schmelczer 2025-03-29 10:17:46 +00:00 committed by GitHub
parent 3d27b7f313
commit 1aad0fce31
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
68 changed files with 2578 additions and 993 deletions

View file

@ -1,23 +0,0 @@
use std::ffi::OsString;
use anyhow::Result;
use crate::{config::Config, consts::DEFAULT_CONFIG_PATH, database::Database};
#[derive(Clone, Debug)]
pub struct AppState {
pub config: Config,
pub database: Database,
}
impl AppState {
pub async fn try_new(config_path: Option<OsString>) -> Result<Self> {
let config_path = config_path.unwrap_or_else(|| OsString::from(DEFAULT_CONFIG_PATH));
let path = std::path::PathBuf::from(config_path);
let config = Config::read_or_create(&path).await?;
let database = Database::try_new(&config.database).await?;
Ok(Self { config, database })
}
}

View file

@ -1,9 +1,10 @@
use super::app_state::AppState;
use crate::{
app_state::AppState,
config::user_config::User,
errors::{SyncServerError, unauthorized_error},
};
// TODO: turn this into a middleware
pub fn auth(app_state: &AppState, token: &str) -> Result<User, SyncServerError> {
app_state
.config

View file

@ -11,12 +11,16 @@ use serde::Deserialize;
use sync_lib::base64_to_bytes;
use super::{
app_state::AppState,
auth::auth,
requests::{CreateDocumentVersion, CreateDocumentVersionMultipart},
};
use crate::{
database::models::{DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId},
app_state::{
AppState,
database::models::{
DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId,
},
},
errors::{SyncServerError, client_error, server_error},
utils::sanitize_path,
};
@ -77,7 +81,7 @@ pub async fn create_document_json(
async fn internal_create_document(
auth_header: Authorization<Bearer>,
mut state: AppState,
state: AppState,
vault_id: VaultId,
document_id: Option<DocumentId>,
relative_path: String,
@ -139,5 +143,10 @@ async fn internal_create_document(
.context("Failed to commit successful transaction")
.map_err(server_error)?;
state
.broadcasts
.send(vault_id, new_version.clone().into())
.await;
Ok(Json(new_version.into()))
}

View file

@ -8,9 +8,14 @@ use axum_jsonschema::Json;
use schemars::JsonSchema;
use serde::Deserialize;
use super::{app_state::AppState, auth::auth, requests::DeleteDocumentVersion};
use super::{auth::auth, requests::DeleteDocumentVersion};
use crate::{
database::models::{DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId},
app_state::{
AppState,
database::models::{
DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId,
},
},
errors::{SyncServerError, server_error},
utils::sanitize_path,
};
@ -29,7 +34,7 @@ pub async fn delete_document(
vault_id,
document_id,
}): Path<DeleteDocumentPathParams>,
State(mut state): State<AppState>,
State(state): State<AppState>,
Json(request): Json<DeleteDocumentVersion>,
) -> Result<Json<DocumentVersionWithoutContent>, SyncServerError> {
auth(&state, auth_header.token())?;
@ -67,5 +72,10 @@ pub async fn delete_document(
.context("Failed to commit successful transaction")
.map_err(server_error)?;
state
.broadcasts
.send(vault_id, new_version.clone().into())
.await;
Ok(Json(new_version.into()))
}

View file

@ -8,9 +8,12 @@ use axum_jsonschema::Json;
use schemars::JsonSchema;
use serde::Deserialize;
use super::{app_state::AppState, auth::auth};
use super::auth::auth;
use crate::{
database::models::{DocumentId, DocumentVersion, VaultId, VaultUpdateId},
app_state::{
AppState,
database::models::{DocumentId, DocumentVersion, VaultId, VaultUpdateId},
},
errors::{SyncServerError, not_found_error, server_error},
};
@ -30,7 +33,7 @@ pub async fn fetch_document_version(
document_id,
vault_update_id,
}): Path<FetchDocumentVersionPathParams>,
State(mut state): State<AppState>,
State(state): State<AppState>,
) -> Result<Json<DocumentVersion>, SyncServerError> {
auth(&state, auth_header.token())?;

View file

@ -10,9 +10,12 @@ use axum_extra::{
use schemars::JsonSchema;
use serde::Deserialize;
use super::{app_state::AppState, auth::auth};
use super::auth::auth;
use crate::{
database::models::{DocumentId, VaultId, VaultUpdateId},
app_state::{
AppState,
database::models::{DocumentId, VaultId, VaultUpdateId},
},
errors::{SyncServerError, not_found_error, server_error},
};
@ -32,7 +35,7 @@ pub async fn fetch_document_version_content(
document_id,
vault_update_id,
}): Path<FetchDocumentVersionContentPathParams>,
State(mut state): State<AppState>,
State(state): State<AppState>,
) -> Result<Bytes, SyncServerError> {
auth(&state, auth_header.token())?;

View file

@ -8,9 +8,12 @@ use axum_jsonschema::Json;
use schemars::JsonSchema;
use serde::Deserialize;
use super::{app_state::AppState, auth::auth};
use super::auth::auth;
use crate::{
database::models::{DocumentId, DocumentVersion, VaultId},
app_state::{
AppState,
database::models::{DocumentId, DocumentVersion, VaultId},
},
errors::{SyncServerError, not_found_error, server_error},
};
@ -28,7 +31,7 @@ pub async fn fetch_latest_document_version(
vault_id,
document_id,
}): Path<FetchLatestDocumentVersionPathParams>,
State(mut state): State<AppState>,
State(state): State<AppState>,
) -> Result<Json<DocumentVersion>, SyncServerError> {
auth(&state, auth_header.token())?;

View file

@ -7,9 +7,12 @@ use axum_jsonschema::Json;
use schemars::JsonSchema;
use serde::Deserialize;
use super::{app_state::AppState, auth::auth, responses::FetchLatestDocumentsResponse};
use super::{auth::auth, responses::FetchLatestDocumentsResponse};
use crate::{
database::models::{VaultId, VaultUpdateId},
app_state::{
AppState,
database::models::{VaultId, VaultUpdateId},
},
errors::{SyncServerError, server_error},
};
@ -30,7 +33,7 @@ pub async fn fetch_latest_documents(
TypedHeader(auth_header): TypedHeader<Authorization<Bearer>>,
Path(FetchLatestDocumentsPathParams { vault_id }): Path<FetchLatestDocumentsPathParams>,
Query(QueryParams { since_update_id }): Query<QueryParams>,
State(mut state): State<AppState>,
State(state): State<AppState>,
) -> Result<Json<FetchLatestDocumentsResponse>, SyncServerError> {
auth(&state, auth_header.token())?;

View file

@ -4,8 +4,8 @@ use axum_extra::{
headers::{Authorization, authorization::Bearer},
};
use super::{app_state::AppState, auth::auth, responses::PingResponse};
use crate::errors::SyncServerError;
use super::{auth::auth, responses::PingResponse};
use crate::{app_state::AppState, errors::SyncServerError};
#[axum::debug_handler]
pub async fn ping(

View file

@ -4,7 +4,7 @@ use axum_typed_multipart::TryFromMultipart;
use schemars::JsonSchema;
use serde::{self, Deserialize};
use crate::database::models::{DocumentId, VaultUpdateId};
use crate::app_state::database::models::{DocumentId, VaultUpdateId};
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]

View file

@ -1,7 +1,9 @@
use schemars::JsonSchema;
use serde::{self, Serialize};
use crate::database::models::{DocumentVersion, DocumentVersionWithoutContent, VaultUpdateId};
use crate::app_state::database::models::{
DocumentVersion, DocumentVersionWithoutContent, VaultUpdateId,
};
/// Response to a ping request.
#[derive(Debug, Clone, Serialize, JsonSchema)]

View file

@ -12,13 +12,15 @@ use serde::Deserialize;
use sync_lib::{base64_to_bytes, is_file_type_mergable, merge};
use super::{
app_state::AppState,
auth::auth,
requests::{UpdateDocumentVersion, UpdateDocumentVersionMultipart},
responses::DocumentUpdateResponse,
};
use crate::{
database::models::{DocumentId, StoredDocumentVersion, VaultId, VaultUpdateId},
app_state::{
AppState,
database::models::{DocumentId, StoredDocumentVersion, VaultId, VaultUpdateId},
},
errors::{SyncServerError, client_error, not_found_error, server_error},
utils::{deduped_file_paths, sanitize_path},
};
@ -83,7 +85,7 @@ pub async fn update_document_json(
#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
async fn internal_update_document(
auth_header: Authorization<Bearer>,
mut state: AppState,
state: AppState,
vault_id: VaultId,
document_id: DocumentId,
parent_version_id: VaultUpdateId,
@ -216,6 +218,11 @@ async fn internal_update_document(
.context("Failed to commit successful transaction")
.map_err(server_error)?;
state
.broadcasts
.send(vault_id, new_version.clone().into())
.await;
Ok(Json(if is_different_from_request_content {
DocumentUpdateResponse::MergingUpdate(new_version.into())
} else {

View file

@ -0,0 +1,147 @@
use anyhow::Context;
use axum::{
extract::{
Path, Query, State,
ws::{Message, WebSocket, WebSocketUpgrade},
},
response::Response,
};
use futures::{
sink::SinkExt,
stream::{SplitSink, StreamExt},
};
use log::{error, info, warn};
use schemars::JsonSchema;
use serde::Deserialize;
use super::auth::auth;
use crate::{
app_state::{
AppState,
database::models::{DocumentVersionWithoutContent, VaultId, VaultUpdateId},
},
errors::{SyncServerError, server_error, unauthorized_error},
};
// This is required for aide to infer the path parameter types and names
#[derive(Deserialize, JsonSchema)]
pub struct WebsocketPathParams {
vault_id: VaultId,
}
// This is required for aide to infer the path parameter types and names
#[derive(Deserialize, JsonSchema)]
pub struct QueryParams {
since_update_id: Option<VaultUpdateId>,
}
pub async fn websocket_handler(
ws: WebSocketUpgrade,
Path(WebsocketPathParams { vault_id }): Path<WebsocketPathParams>,
Query(QueryParams { since_update_id }): Query<QueryParams>,
State(state): State<AppState>,
) -> Result<Response, SyncServerError> {
Ok(ws.on_upgrade(move |socket| websocket_wrapped(state, socket, vault_id, since_update_id)))
}
async fn websocket_wrapped(
state: AppState,
stream: WebSocket,
vault_id: VaultId,
since_update_id: Option<VaultUpdateId>,
) {
info!("Websocket connection opened on vault '{}'", vault_id);
let result = websocket(state, stream, vault_id.clone(), since_update_id).await;
if let Err(err) = result {
error!(
"Websocket connection error on vault '{}': {}",
vault_id, err
);
}
warn!("Websocket connection closed on vault '{}'", vault_id);
}
async fn websocket(
state: AppState,
stream: WebSocket,
vault_id: VaultId,
since_update_id: Option<VaultUpdateId>,
) -> Result<(), SyncServerError> {
let (mut sender, mut receiver) = stream.split();
if let Some(Ok(Message::Text(token))) = receiver.next().await {
auth(&state, &token)?;
} else {
return Err(unauthorized_error(anyhow::anyhow!(
"Failed to authenticate"
)));
}
let mut rx = state.broadcasts.get_receiver(vault_id.clone()).await;
let documents = if let Some(since_update_id) = since_update_id {
state
.database
.get_latest_documents_since(&vault_id, since_update_id, None)
.await
.map_err(server_error)
} else {
state
.database
.get_latest_documents(&vault_id, None)
.await
.map_err(server_error)
}?;
for document in documents {
send_document_over_websocket(document, &mut sender).await?;
}
let mut send_task = tokio::spawn(async move {
while let Ok(update) = rx.recv().await {
send_document_over_websocket(update, &mut sender).await?;
}
Ok::<(), SyncServerError>(())
});
let mut recv_task =
tokio::spawn(
async move { while let Some(Ok(Message::Text(_text))) = receiver.next().await {} },
);
tokio::select! {
_ = &mut send_task => recv_task.abort(),
_ = &mut recv_task => send_task.abort(),
};
send_task
.await
.context("Websocket send task failed")
.map_err(server_error)??;
recv_task
.await
.context("Websocket receive task failed")
.map_err(server_error)?;
Ok(())
}
async fn send_document_over_websocket(
document: DocumentVersionWithoutContent,
sender: &mut SplitSink<WebSocket, Message>,
) -> Result<(), SyncServerError> {
let serialized_update = serde_json::to_string(&document)
.context("Failed to serialize update")
.map_err(server_error)?;
sender
.send(Message::Text(serialized_update))
.await
.context("Failed to send message over websocket")
.map_err(server_error)
}