Add proper shutdown, rate limits, config validation, cors config, fix dangling cursors, cache regex, merge created texts
This commit is contained in:
parent
4763bc9d04
commit
e15b0f9903
28 changed files with 1277 additions and 464 deletions
|
|
@ -1,3 +1,4 @@
|
|||
use anyhow::Context as _;
|
||||
use axum::{
|
||||
Extension, Json,
|
||||
extract::{Path, State},
|
||||
|
|
@ -5,18 +6,21 @@ use axum::{
|
|||
use axum_extra::TypedHeader;
|
||||
use axum_typed_multipart::TypedMultipart;
|
||||
use log::{debug, info};
|
||||
use reconcile_text::{BuiltinTokenizer, reconcile};
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::{device_id_header::DeviceIdHeader, requests::CreateDocumentVersion};
|
||||
use crate::{
|
||||
app_state::{
|
||||
AppState,
|
||||
database::models::{DocumentVersionWithoutContent, StoredDocumentVersion, VaultId},
|
||||
database::models::{StoredDocumentVersion, VaultId},
|
||||
},
|
||||
config::user_config::User,
|
||||
errors::{SyncServerError, client_error, server_error},
|
||||
errors::{SyncServerError, client_error, server_error, write_transaction_error},
|
||||
server::{responses::DocumentUpdateResponse, update_document},
|
||||
utils::{
|
||||
find_first_available_path::find_first_available_path, normalize::normalize,
|
||||
find_first_available_path::find_first_available_path, is_binary::is_binary,
|
||||
is_file_type_mergable::is_file_type_mergable, normalize::normalize,
|
||||
sanitize_path::sanitize_path,
|
||||
},
|
||||
};
|
||||
|
|
@ -30,48 +34,75 @@ pub struct CreateDocumentPathParams {
|
|||
/// Create a new document in case a document with the same doesn't exist
|
||||
/// already. If a document with the same path exists, a new version is created
|
||||
/// with their content merged.
|
||||
///
|
||||
/// Text content must be UTF-8 encoded. Clients are responsible for
|
||||
/// transcoding other encodings (e.g. UTF-16) to UTF-8 before sending.
|
||||
#[axum::debug_handler]
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn create_document(
|
||||
Path(CreateDocumentPathParams { vault_id }): Path<CreateDocumentPathParams>,
|
||||
Extension(user): Extension<User>,
|
||||
TypedHeader(device_id): TypedHeader<DeviceIdHeader>,
|
||||
State(state): State<AppState>,
|
||||
TypedMultipart(request): TypedMultipart<CreateDocumentVersion>,
|
||||
) -> Result<Json<DocumentVersionWithoutContent>, SyncServerError> {
|
||||
) -> Result<Json<DocumentUpdateResponse>, SyncServerError> {
|
||||
debug!("Creating document in vault `{vault_id}`");
|
||||
|
||||
let mut transaction = state
|
||||
.database
|
||||
.create_write_transaction(&vault_id)
|
||||
.await
|
||||
.map_err(server_error)?;
|
||||
.map_err(write_transaction_error)?;
|
||||
|
||||
let document_id = match request.document_id {
|
||||
Some(document_id) => {
|
||||
let existing_version = state
|
||||
.database
|
||||
.get_latest_document(&vault_id, &document_id, Some(&mut transaction))
|
||||
.await
|
||||
.map_err(server_error)?;
|
||||
let sanitized_relative_path = sanitize_path(&request.relative_path).map_err(client_error)?;
|
||||
let new_content = request.content.contents.to_vec();
|
||||
|
||||
if existing_version.is_some() {
|
||||
return Err(client_error(anyhow::anyhow!(
|
||||
"Document with the same ID `{document_id}` already exists"
|
||||
)));
|
||||
}
|
||||
|
||||
document_id
|
||||
}
|
||||
None => uuid::Uuid::new_v4(),
|
||||
};
|
||||
|
||||
let last_update_id = state
|
||||
let latest_version = state
|
||||
.database
|
||||
.get_max_update_id_in_vault(&vault_id, Some(&mut transaction))
|
||||
.get_latest_non_deleted_document_by_path(
|
||||
&vault_id,
|
||||
&sanitized_relative_path,
|
||||
Some(&mut *transaction),
|
||||
)
|
||||
.await
|
||||
.map_err(server_error)?;
|
||||
|
||||
if let Some(latest_version) = latest_version {
|
||||
let is_mergeable_text = is_file_type_mergable(
|
||||
&sanitized_relative_path,
|
||||
&state.config.server.mergeable_file_extensions,
|
||||
) && !is_binary(&latest_version.content)
|
||||
&& !is_binary(&new_content);
|
||||
|
||||
if is_mergeable_text || new_content == latest_version.content {
|
||||
return update_document::update_document(
|
||||
&sanitized_relative_path,
|
||||
Vec::new(),
|
||||
vault_id,
|
||||
latest_version.document_id,
|
||||
&request.relative_path,
|
||||
new_content,
|
||||
user,
|
||||
device_id,
|
||||
state,
|
||||
transaction,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// For non-mergeable (binary) files with different content, don't
|
||||
// merge, create a separate document at a deconflicted path so
|
||||
// neither client's data is silently overwritten.
|
||||
}
|
||||
|
||||
let document_id = uuid::Uuid::new_v4();
|
||||
|
||||
let last_update_id = state
|
||||
.database
|
||||
.get_max_update_id_in_vault(&vault_id, Some(&mut *transaction))
|
||||
.await
|
||||
.map_err(server_error)?;
|
||||
|
||||
let sanitized_relative_path = sanitize_path(&request.relative_path);
|
||||
let deduped_path = find_first_available_path(
|
||||
&vault_id,
|
||||
&sanitized_relative_path,
|
||||
|
|
@ -91,7 +122,7 @@ pub async fn create_document(
|
|||
vault_update_id: last_update_id + 1,
|
||||
document_id,
|
||||
relative_path: deduped_path,
|
||||
content: request.content.contents.to_vec(),
|
||||
content: new_content,
|
||||
updated_date: chrono::Utc::now(),
|
||||
is_deleted: false,
|
||||
user_id: user.name,
|
||||
|
|
@ -105,5 +136,7 @@ pub async fn create_document(
|
|||
.await
|
||||
.map_err(server_error)?;
|
||||
|
||||
Ok(Json(new_version.into()))
|
||||
Ok(Json(DocumentUpdateResponse::FastForwardUpdate(
|
||||
new_version.into(),
|
||||
)))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use anyhow::Context;
|
||||
use anyhow::{Context, anyhow};
|
||||
use axum::{
|
||||
Extension, Json,
|
||||
extract::{Path, State},
|
||||
|
|
@ -16,7 +16,7 @@ use crate::{
|
|||
},
|
||||
},
|
||||
config::user_config::User,
|
||||
errors::{SyncServerError, server_error},
|
||||
errors::{SyncServerError, client_error, not_found_error, server_error, write_transaction_error},
|
||||
utils::{normalize::normalize, sanitize_path::sanitize_path},
|
||||
};
|
||||
|
||||
|
|
@ -37,7 +37,7 @@ pub async fn delete_document(
|
|||
Extension(user): Extension<User>,
|
||||
TypedHeader(device_id): TypedHeader<DeviceIdHeader>,
|
||||
State(state): State<AppState>,
|
||||
Json(request): Json<DeleteDocumentVersion>,
|
||||
Json(_request): Json<DeleteDocumentVersion>,
|
||||
) -> Result<Json<DocumentVersionWithoutContent>, SyncServerError> {
|
||||
debug!("Deleting document `{document_id}` in vault `{vault_id}`");
|
||||
|
||||
|
|
@ -45,7 +45,7 @@ pub async fn delete_document(
|
|||
.database
|
||||
.create_write_transaction(&vault_id)
|
||||
.await
|
||||
.map_err(server_error)?;
|
||||
.map_err(write_transaction_error)?;
|
||||
|
||||
let last_update_id = state
|
||||
.database
|
||||
|
|
@ -77,7 +77,7 @@ pub async fn delete_document(
|
|||
let new_version = StoredDocumentVersion {
|
||||
vault_update_id: last_update_id + 1,
|
||||
document_id,
|
||||
relative_path: sanitize_path(&request.relative_path),
|
||||
relative_path: sanitize_path(&request.relative_path).map_err(client_error)?,
|
||||
content: latest_content, // copy the content from the latest version
|
||||
updated_date: chrono::Utc::now(),
|
||||
is_deleted: true,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ use crate::{
|
|||
AppState,
|
||||
database::models::{DocumentId, DocumentVersion, VaultId, VaultUpdateId},
|
||||
},
|
||||
errors::{SyncServerError, not_found_error, server_error},
|
||||
errors::{SyncServerError, client_error, not_found_error, server_error},
|
||||
utils::normalize::normalize,
|
||||
};
|
||||
|
||||
|
|
@ -52,7 +52,7 @@ pub async fn fetch_document_version(
|
|||
)?;
|
||||
|
||||
if result.document_id != document_id {
|
||||
return Err(not_found_error(anyhow!(
|
||||
return Err(client_error(anyhow!(
|
||||
"Document with document id `{document_id}` does not have a version with id \
|
||||
`{vault_update_id}`",
|
||||
)));
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ use crate::{
|
|||
AppState,
|
||||
database::models::{DocumentId, VaultId, VaultUpdateId},
|
||||
},
|
||||
errors::{SyncServerError, not_found_error, server_error},
|
||||
errors::{SyncServerError, client_error, not_found_error, server_error},
|
||||
utils::normalize::normalize,
|
||||
};
|
||||
|
||||
|
|
@ -52,7 +52,7 @@ pub async fn fetch_document_version_content(
|
|||
)?;
|
||||
|
||||
if result.document_id != document_id {
|
||||
return Err(not_found_error(anyhow!(
|
||||
return Err(client_error(anyhow!(
|
||||
"Document with document id `{document_id}` does not have a version with id \
|
||||
`{vault_update_id}`",
|
||||
)));
|
||||
|
|
|
|||
|
|
@ -1,25 +1,37 @@
|
|||
use std::sync::{
|
||||
Arc,
|
||||
atomic::{AtomicU64, Ordering},
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{Arc, Mutex},
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
|
||||
use axum_extra::{
|
||||
TypedHeader,
|
||||
headers::{Authorization, authorization::Bearer},
|
||||
};
|
||||
|
||||
/// Simple token-bucket rate limiter that refills every second.
|
||||
/// Per-user token-bucket rate limiter. Each bearer token gets its own bucket
|
||||
/// that refills to `max_per_second` tokens every second.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RateLimiter {
|
||||
inner: Arc<TokenBucket>,
|
||||
max_per_second: u64,
|
||||
buckets: Arc<Mutex<HashMap<String, Arc<TokenBucket>>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TokenBucket {
|
||||
tokens: AtomicU64,
|
||||
state: Mutex<BucketState>,
|
||||
max_tokens: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct BucketState {
|
||||
tokens: u64,
|
||||
last_refill: Instant,
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
/// Create a new rate limiter. Spawns a background task that refills tokens
|
||||
/// every second.
|
||||
/// Create a new per-user rate limiter.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
|
|
@ -27,44 +39,62 @@ impl RateLimiter {
|
|||
pub fn new(max_per_second: u64) -> Self {
|
||||
assert!(
|
||||
max_per_second > 0,
|
||||
"max_per_second must be > 0 (use 0 in config to disable rate limiting entirely)"
|
||||
"max_per_second must be > 0 (set rate_limit_per_user_per_second to null in config to disable)"
|
||||
);
|
||||
|
||||
let bucket = Arc::new(TokenBucket {
|
||||
tokens: AtomicU64::new(max_per_second),
|
||||
max_tokens: max_per_second,
|
||||
});
|
||||
|
||||
let bucket_clone = bucket.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(1));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
bucket_clone
|
||||
.tokens
|
||||
.store(bucket_clone.max_tokens, Ordering::Release);
|
||||
}
|
||||
});
|
||||
|
||||
Self { inner: bucket }
|
||||
Self {
|
||||
max_per_second,
|
||||
buckets: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
fn try_acquire(&self) -> bool {
|
||||
self.inner
|
||||
.tokens
|
||||
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
|
||||
if current > 0 { Some(current - 1) } else { None }
|
||||
fn get_or_create_bucket(&self, token: &str) -> Arc<TokenBucket> {
|
||||
self.buckets
|
||||
.lock()
|
||||
.expect("rate limiter lock poisoned")
|
||||
.entry(token.to_owned())
|
||||
.or_insert_with(|| {
|
||||
Arc::new(TokenBucket {
|
||||
state: Mutex::new(BucketState {
|
||||
tokens: self.max_per_second,
|
||||
last_refill: Instant::now(),
|
||||
}),
|
||||
max_tokens: self.max_per_second,
|
||||
})
|
||||
})
|
||||
.is_ok()
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenBucket {
|
||||
fn try_acquire(&self) -> bool {
|
||||
let mut state = self.state.lock().expect("token bucket lock poisoned");
|
||||
let now = Instant::now();
|
||||
if now.duration_since(state.last_refill).as_secs() >= 1 {
|
||||
state.tokens = self.max_tokens;
|
||||
state.last_refill = now;
|
||||
}
|
||||
if state.tokens > 0 {
|
||||
state.tokens -= 1;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn rate_limit_middleware(
|
||||
axum::extract::State(limiter): axum::extract::State<RateLimiter>,
|
||||
auth_header: Option<TypedHeader<Authorization<Bearer>>>,
|
||||
req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
if limiter.try_acquire() {
|
||||
let Some(TypedHeader(auth)) = auth_header else {
|
||||
return Ok(next.run(req).await);
|
||||
};
|
||||
|
||||
let bucket = limiter.get_or_create_bucket(auth.token());
|
||||
if bucket.try_acquire() {
|
||||
Ok(next.run(req).await)
|
||||
} else {
|
||||
Err(StatusCode::TOO_MANY_REQUESTS)
|
||||
|
|
|
|||
|
|
@ -14,8 +14,6 @@ pub struct CreateDocumentVersion {
|
|||
#[ts(as = "Vec<u8>")]
|
||||
#[form_data(limit = "unlimited")]
|
||||
pub content: FieldData<Bytes>,
|
||||
|
||||
pub idempotency_key: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, TryFromMultipart)]
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ use axum::{
|
|||
};
|
||||
use axum_extra::TypedHeader;
|
||||
use axum_typed_multipart::TypedMultipart;
|
||||
use futures::io::Write;
|
||||
use log::{debug, info};
|
||||
use reconcile_text::{BuiltinTokenizer, EditedText, reconcile};
|
||||
use serde::Deserialize;
|
||||
|
|
@ -16,10 +17,15 @@ use super::{
|
|||
use crate::{
|
||||
app_state::{
|
||||
AppState,
|
||||
database::models::{DocumentId, StoredDocumentVersion, VaultId, VaultUpdateId},
|
||||
database::{
|
||||
WriteTransaction,
|
||||
models::{DocumentId, StoredDocumentVersion, VaultId, VaultUpdateId},
|
||||
},
|
||||
},
|
||||
config::user_config::User,
|
||||
errors::{SyncServerError, client_error, not_found_error, server_error},
|
||||
errors::{
|
||||
SyncServerError, client_error, not_found_error, server_error, write_transaction_error,
|
||||
},
|
||||
server::requests::UpdateBinaryDocumentVersion,
|
||||
utils::{
|
||||
find_first_available_path::find_first_available_path, is_binary::is_binary,
|
||||
|
|
@ -46,18 +52,27 @@ pub async fn update_binary(
|
|||
State(state): State<AppState>,
|
||||
TypedMultipart(request): TypedMultipart<UpdateBinaryDocumentVersion>,
|
||||
) -> Result<Json<DocumentUpdateResponse>, SyncServerError> {
|
||||
let parent_document = get_parent_document(&state, &vault_id, request.parent_version_id).await?;
|
||||
let parent_document =
|
||||
get_parent_document(&state, &vault_id, &document_id, request.parent_version_id).await?;
|
||||
let content = request.content.contents.to_vec();
|
||||
|
||||
let mut transaction = state
|
||||
.database
|
||||
.create_write_transaction(&vault_id)
|
||||
.await
|
||||
.map_err(write_transaction_error)?;
|
||||
|
||||
update_document(
|
||||
parent_document,
|
||||
&parent_document.relative_path,
|
||||
parent_document.content,
|
||||
vault_id,
|
||||
document_id,
|
||||
&request.relative_path,
|
||||
content,
|
||||
user,
|
||||
device_id,
|
||||
state,
|
||||
&request.relative_path,
|
||||
content,
|
||||
transaction,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
|
@ -74,28 +89,36 @@ pub async fn update_text(
|
|||
State(state): State<AppState>,
|
||||
Json(request): Json<UpdateTextDocumentVersion>,
|
||||
) -> Result<Json<DocumentUpdateResponse>, SyncServerError> {
|
||||
let parent_document = get_parent_document(&state, &vault_id, request.parent_version_id).await?;
|
||||
let parent_document =
|
||||
get_parent_document(&state, &vault_id, &document_id, request.parent_version_id).await?;
|
||||
|
||||
let edited_text = EditedText::from_diff(
|
||||
str::from_utf8(&parent_document.content)
|
||||
.expect("parent must be valid UTF-8 because it's a text document"),
|
||||
request.content,
|
||||
&*BuiltinTokenizer::Word,
|
||||
)
|
||||
.context("Failed to apply given diff to parent document")
|
||||
.map_err(client_error)?;
|
||||
let parent_text = str::from_utf8(&parent_document.content)
|
||||
.context("Parent version contains binary content; use putBinary instead of putText")
|
||||
.map_err(client_error)?;
|
||||
|
||||
let edited_text = EditedText::from_diff(parent_text, request.content, &*BuiltinTokenizer::Word)
|
||||
.context("Failed to apply given diff to parent document")
|
||||
.map_err(client_error)?;
|
||||
|
||||
let content = edited_text.apply().text().into_bytes();
|
||||
|
||||
let mut transaction = state
|
||||
.database
|
||||
.create_write_transaction(&vault_id)
|
||||
.await
|
||||
.map_err(write_transaction_error)?;
|
||||
|
||||
update_document(
|
||||
parent_document,
|
||||
&parent_document.relative_path,
|
||||
parent_document.content,
|
||||
vault_id,
|
||||
document_id,
|
||||
&request.relative_path,
|
||||
content,
|
||||
user,
|
||||
device_id,
|
||||
state,
|
||||
&request.relative_path,
|
||||
content,
|
||||
transaction,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
|
@ -103,9 +126,10 @@ pub async fn update_text(
|
|||
async fn get_parent_document(
|
||||
state: &AppState,
|
||||
vault_id: &VaultId,
|
||||
document_id: &DocumentId,
|
||||
parent_version_id: VaultUpdateId,
|
||||
) -> Result<StoredDocumentVersion, SyncServerError> {
|
||||
state
|
||||
let parent = state
|
||||
.database
|
||||
.get_document_version(vault_id, parent_version_id, None)
|
||||
.await
|
||||
|
|
@ -117,29 +141,33 @@ async fn get_parent_document(
|
|||
)))
|
||||
},
|
||||
Ok,
|
||||
)
|
||||
)?;
|
||||
|
||||
if &parent.document_id != document_id {
|
||||
return Err(client_error(anyhow!(
|
||||
"Parent version `{parent_version_id}` does not belong to document `{document_id}`"
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(parent)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
|
||||
async fn update_document(
|
||||
parent_document: StoredDocumentVersion,
|
||||
pub async fn update_document(
|
||||
parent_relative_path: &str,
|
||||
parent_content: Vec<u8>,
|
||||
vault_id: VaultId,
|
||||
document_id: DocumentId,
|
||||
relative_path: &str,
|
||||
content: Vec<u8>,
|
||||
user: User,
|
||||
device_id: DeviceIdHeader,
|
||||
state: AppState,
|
||||
relative_path: &str,
|
||||
content: Vec<u8>,
|
||||
mut transaction: WriteTransaction,
|
||||
) -> Result<Json<DocumentUpdateResponse>, SyncServerError> {
|
||||
debug!("Updating document `{document_id}` in vault `{vault_id}`");
|
||||
|
||||
let sanitized_relative_path = sanitize_path(relative_path);
|
||||
|
||||
let mut transaction = state
|
||||
.database
|
||||
.create_write_transaction(&vault_id)
|
||||
.await
|
||||
.map_err(server_error)?;
|
||||
let sanitized_relative_path = sanitize_path(relative_path).map_err(client_error)?;
|
||||
|
||||
let last_update_id = state
|
||||
.database
|
||||
|
|
@ -195,35 +223,44 @@ async fn update_document(
|
|||
let are_all_participants_mergable = is_file_type_mergable(
|
||||
&sanitized_relative_path,
|
||||
&state.config.server.mergeable_file_extensions,
|
||||
) && !is_binary(&parent_document.content)
|
||||
) && !is_binary(&parent_content)
|
||||
&& !is_binary(&latest_version.content)
|
||||
&& !is_binary(&content);
|
||||
|
||||
let merged_content = if are_all_participants_mergable {
|
||||
let (merged_content, is_different_from_request_content) = if are_all_participants_mergable {
|
||||
info!("Merging changes for document `{document_id}` in vault `{vault_id}`");
|
||||
reconcile(
|
||||
str::from_utf8(&parent_document.content)
|
||||
.expect("parent must be valid UTF-8 because it's not binary"),
|
||||
&str::from_utf8(&latest_version.content)
|
||||
.expect("latest_version must be valid UTF-8 because it's not binary")
|
||||
.into(),
|
||||
&str::from_utf8(&content)
|
||||
.expect("content must be valid UTF-8 because it's not binary")
|
||||
.into(),
|
||||
let parent_text = str::from_utf8(&parent_content)
|
||||
.context("Parent document content is not valid UTF-8")
|
||||
.map_err(client_error)?;
|
||||
let latest_text = str::from_utf8(&latest_version.content)
|
||||
.context("Latest version content is not valid UTF-8")
|
||||
.map_err(client_error)?;
|
||||
let new_text = str::from_utf8(&content)
|
||||
.context("New content is not valid UTF-8")
|
||||
.map_err(client_error)?;
|
||||
let merged = reconcile(
|
||||
parent_text,
|
||||
&latest_text.into(),
|
||||
&new_text.into(),
|
||||
&*BuiltinTokenizer::Word,
|
||||
)
|
||||
.apply()
|
||||
.text()
|
||||
.into_bytes()
|
||||
.into_bytes();
|
||||
let is_different = merged != content;
|
||||
(merged, is_different)
|
||||
} else {
|
||||
content.clone()
|
||||
(content, false) // false means that the client doesn't need to refetch the file as we can ensure the remote and local versions are the same as LWW is the merging method for binary files
|
||||
};
|
||||
|
||||
let is_different_from_request_content = merged_content != content;
|
||||
|
||||
// We can only update the relative path if we're the first one to do so
|
||||
let new_relative_path = if parent_document.relative_path == latest_version.relative_path
|
||||
&& latest_version.relative_path != sanitized_relative_path
|
||||
// Rename resolution: only apply the client's rename if the document's path
|
||||
// hasn't changed since this client's parent version. Check the parent
|
||||
// version's path against the latest version's path. If they differ, another
|
||||
// client already renamed the document — keep the latest path (first rename
|
||||
// wins). Content changes from both clients are still merged correctly via
|
||||
// the 3-way reconcile above, independent of which rename wins.
|
||||
let new_relative_path = if parent_relative_path == latest_version.relative_path
|
||||
&& sanitized_relative_path != latest_version.relative_path
|
||||
{
|
||||
let new_path = find_first_available_path(
|
||||
&vault_id,
|
||||
|
|
|
|||
|
|
@ -6,10 +6,10 @@ use axum::{
|
|||
},
|
||||
response::Response,
|
||||
};
|
||||
use futures::sink::SinkExt;
|
||||
use futures::stream::StreamExt;
|
||||
use log::{debug, info};
|
||||
use log::{debug, info, warn};
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::{
|
||||
app_state::{
|
||||
AppState,
|
||||
|
|
@ -24,10 +24,26 @@ use crate::{
|
|||
},
|
||||
},
|
||||
},
|
||||
consts::{
|
||||
HANDSHAKE_TIMEOUT, MAX_CURSORS_PER_DOCUMENT, MAX_CURSOR_DOCUMENTS,
|
||||
MAX_RELATIVE_PATH_LEN,
|
||||
},
|
||||
errors::{SyncServerError, client_error, server_error},
|
||||
utils::normalize::normalize,
|
||||
};
|
||||
|
||||
/// Tracks a pending (not yet authenticated) WebSocket connection.
|
||||
/// Decrements the counter when dropped, ensuring cleanup even if
|
||||
/// the upgrade never completes or auth fails.
|
||||
struct PendingWsGuard(std::sync::Arc<std::sync::atomic::AtomicUsize>);
|
||||
|
||||
impl Drop for PendingWsGuard {
|
||||
fn drop(&mut self) {
|
||||
self.0
|
||||
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct WebSocketPathParams {
|
||||
#[serde(deserialize_with = "normalize")]
|
||||
|
|
@ -39,13 +55,31 @@ pub async fn websocket_handler(
|
|||
Path(WebSocketPathParams { vault_id }): Path<WebSocketPathParams>,
|
||||
State(state): State<AppState>,
|
||||
) -> Result<Response, SyncServerError> {
|
||||
Ok(ws.on_upgrade(move |socket| websocket_wrapped(state, socket, vault_id)))
|
||||
let current = state
|
||||
.pending_ws_connections
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
if current >= state.config.server.max_pending_websocket_connections {
|
||||
state
|
||||
.pending_ws_connections
|
||||
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
|
||||
return Err(client_error(anyhow::anyhow!(
|
||||
"Too many pending WebSocket connections"
|
||||
)));
|
||||
}
|
||||
|
||||
let guard = PendingWsGuard(state.pending_ws_connections.clone());
|
||||
Ok(ws.on_upgrade(move |socket| websocket_wrapped(state, socket, vault_id, guard)))
|
||||
}
|
||||
|
||||
async fn websocket_wrapped(state: AppState, stream: WebSocket, vault_id: VaultId) {
|
||||
async fn websocket_wrapped(
|
||||
state: AppState,
|
||||
stream: WebSocket,
|
||||
vault_id: VaultId,
|
||||
pending_guard: PendingWsGuard,
|
||||
) {
|
||||
info!("WebSocket connection opened on vault `{vault_id}`");
|
||||
|
||||
let result = websocket(state, stream, vault_id.clone()).await;
|
||||
let result = websocket(state, stream, vault_id.clone(), pending_guard).await;
|
||||
|
||||
if let Err(err) = result {
|
||||
debug!("WebSocket connection error on vault `{vault_id}`: {err}");
|
||||
|
|
@ -57,25 +91,53 @@ async fn websocket(
|
|||
state: AppState,
|
||||
stream: WebSocket,
|
||||
vault_id: VaultId,
|
||||
pending_guard: PendingWsGuard,
|
||||
) -> Result<(), SyncServerError> {
|
||||
let (mut sender, mut websocket_receiver) = stream.split();
|
||||
|
||||
let authed_handshake = get_authenticated_handshake(
|
||||
&state,
|
||||
&vault_id,
|
||||
websocket_receiver
|
||||
.next()
|
||||
.await
|
||||
.transpose()
|
||||
.unwrap_or_default(),
|
||||
)?;
|
||||
let handshake_msg = tokio::time::timeout(HANDSHAKE_TIMEOUT, websocket_receiver.next())
|
||||
.await
|
||||
.map_err(|_| client_error(anyhow::anyhow!("WebSocket handshake timed out")))?
|
||||
.transpose()
|
||||
.map_err(|e| client_error(anyhow::anyhow!("WebSocket error during handshake: {e}")))?;
|
||||
|
||||
let authed_handshake = get_authenticated_handshake(&state, &vault_id, handshake_msg)?;
|
||||
|
||||
info!(
|
||||
"WebSocket handshake successful for vault `{vault_id}` for `{}`",
|
||||
authed_handshake.handshake.device_id
|
||||
);
|
||||
|
||||
let mut broadcast_receiver = state.broadcasts.get_receiver(vault_id.clone()).await;
|
||||
// Auth complete — no longer a pending connection.
|
||||
drop(pending_guard);
|
||||
|
||||
let max_clients = state.config.server.max_clients_per_vault;
|
||||
let mut broadcast_receiver = match state
|
||||
.broadcasts
|
||||
.get_receiver(vault_id.clone(), max_clients)
|
||||
.await
|
||||
{
|
||||
Ok(receiver) => receiver,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"Vault `{vault_id}` has reached the maximum number of clients ({max_clients}), rejecting connection from `{}`",
|
||||
authed_handshake.handshake.device_id
|
||||
);
|
||||
if let Err(e) = sender
|
||||
.send(Message::Close(Some(axum::extract::ws::CloseFrame {
|
||||
code: 4000,
|
||||
reason: format!(
|
||||
"Vault has reached the maximum number of clients ({max_clients})"
|
||||
)
|
||||
.into(),
|
||||
})))
|
||||
.await
|
||||
{
|
||||
warn!("Failed to send WebSocket close frame: {e}");
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
send_update_over_websocket(
|
||||
&WebSocketServerMessage::VaultUpdate(WebSocketVaultUpdate {
|
||||
|
|
@ -101,24 +163,35 @@ async fn websocket(
|
|||
|
||||
let device_id = authed_handshake.handshake.device_id.clone();
|
||||
let mut send_task = tokio::spawn(async move {
|
||||
while let Ok(update) = broadcast_receiver.recv().await {
|
||||
if Some(&device_id) == update.origin_device_id.as_ref() {
|
||||
continue;
|
||||
}
|
||||
loop {
|
||||
match broadcast_receiver.recv().await {
|
||||
Ok(update) => {
|
||||
if Some(&device_id) == update.origin_device_id.as_ref() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let message = match update.message {
|
||||
WebSocketServerMessage::CursorPositions(CursorPositionFromServer { clients }) => {
|
||||
WebSocketServerMessage::CursorPositions(CursorPositionFromServer {
|
||||
clients: clients
|
||||
.into_iter()
|
||||
.filter(|client| client.device_id != device_id)
|
||||
.collect(),
|
||||
})
|
||||
let message = match update.message {
|
||||
WebSocketServerMessage::CursorPositions(CursorPositionFromServer {
|
||||
clients,
|
||||
}) => WebSocketServerMessage::CursorPositions(CursorPositionFromServer {
|
||||
clients: clients
|
||||
.into_iter()
|
||||
.filter(|client| client.device_id != device_id)
|
||||
.collect(),
|
||||
}),
|
||||
WebSocketServerMessage::VaultUpdate(_) => update.message,
|
||||
};
|
||||
|
||||
send_update_over_websocket(&message, &mut sender).await?;
|
||||
}
|
||||
WebSocketServerMessage::VaultUpdate(_) => update.message,
|
||||
};
|
||||
|
||||
send_update_over_websocket(&message, &mut sender).await?;
|
||||
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
|
||||
warn!(
|
||||
"WebSocket receiver lagged, dropped {n} messages — disconnecting client to force full resync"
|
||||
);
|
||||
break;
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
|
||||
}
|
||||
}
|
||||
|
||||
Ok::<(), SyncServerError>(())
|
||||
|
|
@ -128,26 +201,57 @@ async fn websocket(
|
|||
let vault_id_clone = vault_id.clone();
|
||||
let cursor_manager = state.cursors.clone();
|
||||
let mut receive_task = tokio::spawn(async move {
|
||||
while let Some(Ok(Message::Text(message))) = websocket_receiver.next().await {
|
||||
let message: WebSocketClientMessage = serde_json::from_str(&message)
|
||||
.context("Failed to parse WebSocket message from client")
|
||||
.map_err(server_error)?;
|
||||
while let Some(msg) = websocket_receiver.next().await {
|
||||
match msg {
|
||||
Ok(Message::Text(message)) => {
|
||||
let message: WebSocketClientMessage = serde_json::from_str(&message)
|
||||
.context("Failed to parse WebSocket message from client")
|
||||
.map_err(client_error)?;
|
||||
|
||||
match message {
|
||||
WebSocketClientMessage::Handshake(_) => {
|
||||
return Err(client_error(anyhow::anyhow!(
|
||||
"Unexpected handshake message"
|
||||
)));
|
||||
match message {
|
||||
WebSocketClientMessage::Handshake(_) => {
|
||||
return Err(client_error(anyhow::anyhow!(
|
||||
"Unexpected handshake message"
|
||||
)));
|
||||
}
|
||||
WebSocketClientMessage::CursorPositions(cursors) => {
|
||||
let docs = cursors.documents_with_cursors;
|
||||
if docs.len() > MAX_CURSOR_DOCUMENTS {
|
||||
warn!(
|
||||
"Cursor update rejected: {} documents exceeds limit of {MAX_CURSOR_DOCUMENTS}",
|
||||
docs.len()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let valid = docs.iter().all(|doc| {
|
||||
doc.cursors.len() <= MAX_CURSORS_PER_DOCUMENT
|
||||
&& doc.relative_path.len() <= MAX_RELATIVE_PATH_LEN
|
||||
});
|
||||
if !valid {
|
||||
warn!("Cursor update rejected: a document exceeds cursor or path length limits");
|
||||
continue;
|
||||
}
|
||||
|
||||
cursor_manager
|
||||
.update_cursors(
|
||||
vault_id_clone.clone(),
|
||||
authed_handshake.user.name.clone(),
|
||||
&device_id,
|
||||
docs,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
WebSocketClientMessage::CursorPositions(cursors) => {
|
||||
cursor_manager
|
||||
.update_cursors(
|
||||
vault_id_clone.clone(),
|
||||
authed_handshake.user.name.clone(),
|
||||
&device_id,
|
||||
cursors.documents_with_cursors,
|
||||
)
|
||||
.await;
|
||||
Ok(Message::Close(_)) => break,
|
||||
Ok(Message::Binary(_)) => {
|
||||
warn!("Received unexpected binary WebSocket message, ignoring");
|
||||
}
|
||||
Ok(_) => {} // Ping/Pong frames handled by axum
|
||||
Err(e) => {
|
||||
debug!("WebSocket receive error: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -155,38 +259,47 @@ async fn websocket(
|
|||
Ok::<(), SyncServerError>(())
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
_ = &mut send_task => receive_task.abort(),
|
||||
_ = &mut receive_task => send_task.abort(),
|
||||
let result: Result<(), SyncServerError> = tokio::select! {
|
||||
send_result = &mut send_task => {
|
||||
receive_task.abort();
|
||||
let _ = receive_task.await;
|
||||
match send_result {
|
||||
Err(e) => Err(server_error(
|
||||
anyhow::Error::from(e).context("WebSocket send task failed"),
|
||||
)),
|
||||
Ok(inner) => inner,
|
||||
}
|
||||
},
|
||||
receive_result = &mut receive_task => {
|
||||
send_task.abort();
|
||||
let _ = send_task.await;
|
||||
match receive_result {
|
||||
Err(e) => Err(server_error(
|
||||
anyhow::Error::from(e).context("WebSocket receive task failed"),
|
||||
)),
|
||||
Ok(inner) => inner,
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let result: Result<(), SyncServerError> = (async {
|
||||
send_task
|
||||
.await
|
||||
.context("WebSocket send task failed")
|
||||
.map_err(client_error)
|
||||
.and_then(|err| err)?;
|
||||
|
||||
receive_task
|
||||
.await
|
||||
.context("WebSocket receive task failed")
|
||||
.map_err(client_error)
|
||||
.and_then(|err| err)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await;
|
||||
|
||||
state
|
||||
.cursors
|
||||
.remove_cursors_of_device(&vault_id, &authed_handshake.handshake.device_id)
|
||||
.await;
|
||||
|
||||
if result.is_err() {
|
||||
info!(
|
||||
"WebSocket disconnected on vault `{vault_id}` for `{}`",
|
||||
authed_handshake.handshake.device_id
|
||||
);
|
||||
match &result {
|
||||
Ok(()) => {
|
||||
info!(
|
||||
"WebSocket disconnected on vault `{vault_id}` for `{}`",
|
||||
authed_handshake.handshake.device_id
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"WebSocket error on vault `{vault_id}` for `{}`: {err}",
|
||||
authed_handshake.handshake.device_id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue