split: server REST endpoints + rate limiting

server.rs router rewrite, auth.rs, device_id_header.rs, requests.rs,
responses.rs, plus per-endpoint changes: create/update/delete_document,
fetch_document_version{,_content,s}, fetch_latest_documents, index.rs.
Adds: fetch_vault_history, list_vaults, rate_limit (new files).
This commit is contained in:
Andras Schmelczer 2026-05-08 21:35:41 +01:00
parent 2d5edc6ec5
commit 4ba439b874
16 changed files with 838 additions and 202 deletions

View file

@ -4,27 +4,30 @@ mod delete_document;
mod device_id_header;
mod fetch_document_version;
mod fetch_document_version_content;
mod fetch_document_versions;
mod fetch_latest_document_version;
mod fetch_latest_documents;
mod fetch_vault_history;
mod index;
mod list_vaults;
mod ping;
mod rate_limit;
mod requests;
mod responses;
mod update_document;
mod websocket;
use anyhow::{Context as _, Result, anyhow};
use anyhow::{Context as _, Result};
use auth::auth_middleware;
use axum::{
Router,
extract::{DefaultBodyLimit, Request},
http::{self, HeaderValue, Method},
middleware,
response::IntoResponse,
routing::{IntoMakeService, delete, get, post, put},
};
use device_id_header::DEVICE_ID_HEADER_NAME;
use log::info;
use log::{info, warn};
use tokio::signal;
use tower_http::{
LatencyUnit,
@ -41,7 +44,7 @@ use tracing::{Level, info_span};
use crate::{
app_state::AppState,
config::{Config, server_config::ServerConfig},
errors::{client_error, not_found_error},
consts::GRACEFUL_SHUTDOWN_TIMEOUT,
};
pub async fn create_server(config: Config) -> Result<()> {
@ -51,26 +54,33 @@ pub async fn create_server(config: Config) -> Result<()> {
let server_config = app_state.config.server.clone();
let app = Router::new()
let mut app = Router::new()
.nest("/", get_authed_routes(app_state.clone()))
.route("/", get(index::index))
.route("/assets/*path", get(index::spa_assets))
.route("/vaults", get(list_vaults::list_vaults))
.route("/vaults/:vault_id/ping", get(ping::ping))
.route("/vaults/:vault_id/ws", get(websocket::websocket_handler))
.fallback(index::spa_fallback);
let cors_layer = build_cors_layer(&server_config).context("Invalid CORS configuration")?;
if let Some(rate_limit) = server_config.rate_limit_per_user_per_second {
info!("Rate limiting enabled: {rate_limit} requests/second per user");
let limiter = rate_limit::RateLimiter::new(rate_limit);
app = app.layer(middleware::from_fn_with_state(
limiter,
rate_limit::rate_limit_middleware,
));
}
let app = app
.layer(DefaultBodyLimit::disable())
.layer(RequestBodyLimitLayer::new(
app_state.config.server.max_body_size_mb * 1024 * 1024,
))
.layer(TimeoutLayer::new(server_config.response_timeout))
.layer(
CorsLayer::new()
.allow_origin("*".parse::<HeaderValue>().expect("Failed to parse origin"))
.allow_headers([
http::header::CONTENT_TYPE,
http::header::AUTHORIZATION,
DEVICE_ID_HEADER_NAME.clone(),
])
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE]),
)
.layer(cors_layer)
.layer(
TraceLayer::new_for_http()
.make_span_with(|request: &Request<_>| {
@ -90,12 +100,39 @@ pub async fn create_server(config: Config) -> Result<()> {
.on_eos(DefaultOnEos::new())
.on_failure(DefaultOnFailure::new().level(Level::ERROR)),
)
.with_state(app_state)
.fallback(handle_404)
.fallback(handle_405)
.with_state(app_state.clone())
.into_make_service();
start_server(app, &server_config).await
start_server(app, &server_config, app_state).await
}
fn build_cors_layer(server_config: &ServerConfig) -> Result<CorsLayer> {
let origins = &server_config.allowed_origins;
let cors = if origins.len() == 1 && origins[0] == "*" {
info!("CORS: allowing all origins");
let header: HeaderValue = "*"
.parse()
.context("Failed to parse wildcard CORS origin")?;
CorsLayer::new().allow_origin(header)
} else {
let parsed: Vec<HeaderValue> = origins
.iter()
.map(|o| {
o.parse::<HeaderValue>()
.with_context(|| format!("Failed to parse CORS origin: `{o}`"))
})
.collect::<Result<Vec<_>>>()?;
CorsLayer::new().allow_origin(parsed)
};
Ok(cors
.allow_headers([
http::header::CONTENT_TYPE,
http::header::AUTHORIZATION,
DEVICE_ID_HEADER_NAME.clone(),
])
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE]))
}
fn get_authed_routes(app_state: AppState) -> Router<AppState> {
@ -120,6 +157,10 @@ fn get_authed_routes(app_state: AppState) -> Router<AppState> {
"/vaults/:vault_id/documents/:document_id/text",
put(update_document::update_text),
)
.route(
"/vaults/:vault_id/documents/:document_id/versions",
get(fetch_document_versions::fetch_document_versions),
)
.route(
"/vaults/:vault_id/documents/:document_id/versions/:vault_update_id",
get(fetch_document_version::fetch_document_version),
@ -132,10 +173,18 @@ fn get_authed_routes(app_state: AppState) -> Router<AppState> {
"/vaults/:vault_id/documents/:document_id",
delete(delete_document::delete_document),
)
.route(
"/vaults/:vault_id/history",
get(fetch_vault_history::fetch_vault_history),
)
.layer(middleware::from_fn_with_state(app_state, auth_middleware))
}
async fn start_server(app: IntoMakeService<axum::Router>, config: &ServerConfig) -> Result<()> {
async fn start_server(
app: IntoMakeService<axum::Router>,
config: &ServerConfig,
app_state: AppState,
) -> Result<()> {
let address = format!("{}:{}", config.host, config.port);
let listener = tokio::net::TcpListener::bind(address.clone())
.await
@ -148,26 +197,46 @@ async fn start_server(app: IntoMakeService<axum::Router>, config: &ServerConfig)
.context("Failed to get local address")?
);
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.tcp_nodelay(true)
.await
.context("Failed to start server")
let mut shutdown_rx = app_state.subscribe_shutdown();
let server = axum::serve(listener, app)
.with_graceful_shutdown(async move {
shutdown_signal().await;
app_state.shutdown();
})
.tcp_nodelay(true);
tokio::select! {
result = server => result.context("Failed to start server"),
() = async {
let _ = shutdown_rx.changed().await;
info!(
"Shutdown signal received, waiting up to {}s for in-flight requests to complete...",
GRACEFUL_SHUTDOWN_TIMEOUT.as_secs()
);
tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT).await;
warn!("Graceful shutdown timed out, forcing exit");
} => Ok(()),
}
}
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
if let Err(e) = signal::ctrl_c().await {
log::error!("Failed to install Ctrl+C handler: {e}");
}
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
match signal::unix::signal(signal::unix::SignalKind::terminate()) {
Ok(mut signal) => {
signal.recv().await;
}
Err(e) => {
log::error!("Failed to install SIGTERM handler: {e}");
}
}
};
#[cfg(not(unix))]
@ -178,11 +247,3 @@ async fn shutdown_signal() {
() = terminate => {},
}
}
async fn handle_404() -> impl IntoResponse {
not_found_error(anyhow!("Page not found"))
}
async fn handle_405() -> impl IntoResponse {
client_error(anyhow!("Method not allowed"))
}

View file

@ -9,7 +9,7 @@ use axum_extra::{
TypedHeader,
headers::{Authorization, authorization::Bearer},
};
use log::info;
use log::{debug, info};
use crate::{
app_state::{AppState, database::models::VaultId},
@ -21,10 +21,12 @@ use crate::{
pub async fn auth_middleware(
State(state): State<AppState>,
Path(path_params): Path<HashMap<String, String>>,
TypedHeader(auth_header): TypedHeader<Authorization<Bearer>>,
auth_header: Option<TypedHeader<Authorization<Bearer>>>,
mut req: Request,
next: Next,
) -> Result<Response, SyncServerError> {
let auth_header = auth_header
.ok_or_else(|| unauthenticated_error(anyhow::anyhow!("Missing Authorization header")))?;
let token = auth_header.token().trim();
let vault_id = normalize_string(
path_params
@ -39,20 +41,24 @@ pub async fn auth_middleware(
Ok(next.run(req).await)
}
pub fn auth(state: &AppState, token: &str, vault_id: &VaultId) -> Result<User, SyncServerError> {
let user = state
pub fn authenticate(state: &AppState, token: &str) -> Result<User, SyncServerError> {
state
.config
.users
.get_user(token)
.cloned()
.ok_or_else(|| unauthenticated_error(anyhow::anyhow!("Invalid token")))?;
.ok_or_else(|| unauthenticated_error(anyhow::anyhow!("Invalid token")))
}
pub fn auth(state: &AppState, token: &str, vault_id: &VaultId) -> Result<User, SyncServerError> {
let user = authenticate(state, token)?;
if match user.vault_access {
VaultAccess::AllowAccessToAll => true,
VaultAccess::AllowList(AllowListedVaults { ref allowed }) => allowed.contains(vault_id),
} {
info!(
"User `{}` is authenticated and is authorised to access to vault `{vault_id}`",
debug!(
"User `{}` is authenticated and is authorised to access vault `{vault_id}`",
user.name
);

View file

@ -11,12 +11,14 @@ use super::{device_id_header::DeviceIdHeader, requests::CreateDocumentVersion};
use crate::{
app_state::{
AppState,
database::models::{DocumentVersionWithoutContent, StoredDocumentVersion, VaultId},
database::models::{StoredDocumentVersion, VaultId},
},
config::user_config::User,
errors::{SyncServerError, client_error, server_error},
errors::{SyncServerError, client_error, server_error, write_transaction_error},
server::{responses::DocumentUpdateResponse, update_document},
utils::{
find_first_available_path::find_first_available_path, normalize::normalize,
find_first_available_path::find_first_available_path, is_binary::is_binary,
is_file_type_mergable::is_file_type_mergable, normalize::normalize,
sanitize_path::sanitize_path,
},
};
@ -30,48 +32,137 @@ pub struct CreateDocumentPathParams {
/// Create a new document in case a document with the same doesn't exist
/// already. If a document with the same path exists, a new version is created
/// with their content merged.
///
/// Text content must be UTF-8 encoded. Clients are responsible for
/// transcoding other encodings (e.g. UTF-16) to UTF-8 before sending.
#[axum::debug_handler]
#[allow(clippy::too_many_lines)]
pub async fn create_document(
Path(CreateDocumentPathParams { vault_id }): Path<CreateDocumentPathParams>,
Extension(user): Extension<User>,
TypedHeader(device_id): TypedHeader<DeviceIdHeader>,
State(state): State<AppState>,
TypedMultipart(request): TypedMultipart<CreateDocumentVersion>,
) -> Result<Json<DocumentVersionWithoutContent>, SyncServerError> {
) -> Result<Json<DocumentUpdateResponse>, SyncServerError> {
debug!("Creating document in vault `{vault_id}`");
let mut transaction = state
.database
.create_write_transaction(&vault_id)
.await
.map_err(server_error)?;
.map_err(write_transaction_error)?;
let document_id = match request.document_id {
Some(document_id) => {
let existing_version = state
.database
.get_latest_document(&vault_id, &document_id, Some(&mut transaction))
.await
.map_err(server_error)?;
let sanitized_relative_path = sanitize_path(&request.relative_path).map_err(client_error)?;
let new_content = request.content.contents.to_vec();
if existing_version.is_some() {
return Err(client_error(anyhow::anyhow!(
"Document with the same ID `{document_id}` already exists"
)));
}
document_id
}
None => uuid::Uuid::new_v4(),
};
let last_update_id = state
let latest_version = state
.database
.get_max_update_id_in_vault(&vault_id, Some(&mut transaction))
.get_latest_non_deleted_document_by_path(
&vault_id,
&sanitized_relative_path,
Some(&mut *transaction),
)
.await
.map_err(server_error)?;
if let Some(latest_version) = latest_version {
// Only merge with an existing document the client couldn't have
// known about: its creation is newer than the client's last seen
// vault update to avoid creating cycles by merging two documents into one.
// This could happen if both clients know of document A at path P1,
// but client 2 moves it to P2 while client 1 creates a new document at P2,
// then client 1 would merge its new document with the moved version of A at P2
// that client 2 resulting in two files (P1 and P2) with the same doc id (A).
if latest_version.creation_vault_update_id > request.last_seen_vault_update_id
&& latest_version.creation_vault_update_id == latest_version.vault_update_id
// can't allow merging with a moved document as that could create a cycle
{
let is_mergeable_text = is_file_type_mergable(
&sanitized_relative_path,
&state.config.server.mergeable_file_extensions,
) && !is_binary(&latest_version.content)
&& !is_binary(&new_content);
if is_mergeable_text || new_content == latest_version.content {
return update_document::update_document(
&sanitized_relative_path,
Vec::new(),
vault_id,
latest_version.document_id,
Some(&request.relative_path),
new_content,
user,
device_id,
state,
transaction,
)
.await;
}
// For non-mergeable (binary) files with different content, don't
// merge, create a separate document at a deconflicted path so
// neither client's data is silently overwritten.
}
}
// Lost-create + local rename recovery. If this device has a doc
// the requesting client hasn't seen yet (its create succeeded
// server-side but the response was discarded — e.g. a sync
// reset mid-flight) and the new request carries the same content
// at a different path (the user renamed the file before the
// retry), bind the retry to that existing doc instead of
// creating a duplicate. The dedup is scoped tightly:
// - same `device_id` (only this client's own lost create),
// - `creation_vault_update_id > last_seen` (client never saw
// this doc, so it can't be deliberately creating another
// copy with matching content),
// - `creation == latest` (the doc has only its create version,
// nobody else has touched it; safe to relocate),
// - exact content match.
// Outside that window we fall through to the normal deconflict
// path, so legitimate "this device created a duplicate of an
// already-acknowledged file" flows still produce a new doc.
if let Some(lost_create) = state
.database
.find_unseen_lost_create_by_device_and_content(
&vault_id,
&device_id.0,
request.last_seen_vault_update_id,
&new_content,
Some(&mut *transaction),
)
.await
.map_err(server_error)?
{
info!(
"Lost-create recovery: binding retry at `{sanitized_relative_path}` to existing doc {} (was at `{}`) in vault `{vault_id}` for device `{}`",
lost_create.document_id,
lost_create.relative_path,
device_id.0
);
return update_document::update_document(
&sanitized_relative_path,
Vec::new(),
vault_id,
lost_create.document_id,
Some(&request.relative_path),
new_content,
user,
device_id,
state,
transaction,
)
.await;
}
let document_id = uuid::Uuid::new_v4();
let last_update_id = state
.database
.get_max_update_id_in_vault(&vault_id, Some(&mut *transaction))
.await
.map_err(server_error)?;
let sanitized_relative_path = sanitize_path(&request.relative_path);
let deduped_path = find_first_available_path(
&vault_id,
&sanitized_relative_path,
@ -87,11 +178,13 @@ pub async fn create_document(
);
}
let new_vault_update_id = last_update_id + 1;
let new_version = StoredDocumentVersion {
vault_update_id: last_update_id + 1,
vault_update_id: new_vault_update_id,
creation_vault_update_id: new_vault_update_id,
document_id,
relative_path: deduped_path,
content: request.content.contents.to_vec(),
content: new_content,
updated_date: chrono::Utc::now(),
is_deleted: false,
user_id: user.name,
@ -101,9 +194,11 @@ pub async fn create_document(
state
.database
.insert_document_version(&vault_id, &new_version, Some(transaction))
.insert_document_version(&vault_id, &new_version, transaction)
.await
.map_err(server_error)?;
Ok(Json(new_version.into()))
Ok(Json(DocumentUpdateResponse::FastForwardUpdate(
new_version.into(),
)))
}

View file

@ -1,4 +1,4 @@
use anyhow::Context;
use anyhow::{Context, anyhow};
use axum::{
Extension, Json,
extract::{Path, State},
@ -7,7 +7,7 @@ use axum_extra::TypedHeader;
use log::{debug, info};
use serde::Deserialize;
use super::{device_id_header::DeviceIdHeader, requests::DeleteDocumentVersion};
use super::device_id_header::DeviceIdHeader;
use crate::{
app_state::{
AppState,
@ -16,8 +16,8 @@ use crate::{
},
},
config::user_config::User,
errors::{SyncServerError, server_error},
utils::{normalize::normalize, sanitize_path::sanitize_path},
errors::{SyncServerError, not_found_error, server_error, write_transaction_error},
utils::normalize::normalize,
};
#[derive(Deserialize)]
@ -37,7 +37,6 @@ pub async fn delete_document(
Extension(user): Extension<User>,
TypedHeader(device_id): TypedHeader<DeviceIdHeader>,
State(state): State<AppState>,
Json(request): Json<DeleteDocumentVersion>,
) -> Result<Json<DocumentVersionWithoutContent>, SyncServerError> {
debug!("Deleting document `{document_id}` in vault `{vault_id}`");
@ -45,7 +44,7 @@ pub async fn delete_document(
.database
.create_write_transaction(&vault_id)
.await
.map_err(server_error)?;
.map_err(write_transaction_error)?;
let last_update_id = state
.database
@ -59,9 +58,18 @@ pub async fn delete_document(
.await
.map_err(server_error)?;
if let Some(latest_version) = &latest_version
&& latest_version.is_deleted
{
let Some(latest_version) = latest_version else {
transaction
.rollback()
.await
.context("Failed to roll back transaction")
.map_err(server_error)?;
return Err(not_found_error(anyhow!(
"Document `{document_id}` not found in vault `{vault_id}`"
)));
};
if latest_version.is_deleted {
transaction
.rollback()
.await
@ -69,15 +77,19 @@ pub async fn delete_document(
.map_err(server_error)?;
info!("Document `{document_id}` has already been deleted",);
return Ok(Json(latest_version.clone().into()));
return Ok(Json(latest_version.into()));
}
let latest_content = latest_version.map_or_else(Vec::new, |version| version.content); // in case the document has never existed before deleting it
let new_vault_update_id = last_update_id + 1;
let latest_relative_path = latest_version.relative_path;
let latest_content = latest_version.content;
let creation_vault_update_id = latest_version.creation_vault_update_id;
let new_version = StoredDocumentVersion {
vault_update_id: last_update_id + 1,
vault_update_id: new_vault_update_id,
creation_vault_update_id,
document_id,
relative_path: sanitize_path(&request.relative_path),
relative_path: latest_relative_path,
content: latest_content, // copy the content from the latest version
updated_date: chrono::Utc::now(),
is_deleted: true,
@ -88,7 +100,7 @@ pub async fn delete_document(
state
.database
.insert_document_version(&vault_id, &new_version, Some(transaction))
.insert_document_version(&vault_id, &new_version, transaction)
.await
.map_err(server_error)?;

View file

@ -16,20 +16,31 @@ impl Header for DeviceIdHeader {
{
let value = values.next().ok_or_else(headers::Error::invalid)?;
Ok(DeviceIdHeader(
value
.to_str()
.map_err(|_| headers::Error::invalid())?
.to_owned(),
))
let s = value.to_str().map_err(|_| headers::Error::invalid())?;
if s.is_empty() || s.len() > 256 {
return Err(headers::Error::invalid());
}
// Only allow safe characters to prevent log injection and similar attacks.
// Covers UUIDs, user-agent strings like "vault-link/1.0 (12345; linux)",
// and human-readable device names.
if !s
.chars()
.all(|c| c.is_ascii_alphanumeric() || "-_./ ();:@+,".contains(c))
{
return Err(headers::Error::invalid());
}
Ok(DeviceIdHeader(s.to_owned()))
}
fn encode<E>(&self, values: &mut E)
where
E: Extend<HeaderValue>,
{
let value = HeaderValue::from_static(Box::leak(self.0.clone().into_boxed_str()));
values.extend(std::iter::once(value));
if let Ok(value) = HeaderValue::from_str(&self.0) {
values.extend(std::iter::once(value));
}
}
}

View file

@ -11,7 +11,7 @@ use crate::{
AppState,
database::models::{DocumentId, DocumentVersion, VaultId, VaultUpdateId},
},
errors::{SyncServerError, not_found_error, server_error},
errors::{SyncServerError, client_error, not_found_error, server_error},
utils::normalize::normalize,
};
@ -52,7 +52,7 @@ pub async fn fetch_document_version(
)?;
if result.document_id != document_id {
return Err(not_found_error(anyhow!(
return Err(client_error(anyhow!(
"Document with document id `{document_id}` does not have a version with id \
`{vault_update_id}`",
)));

View file

@ -11,7 +11,7 @@ use crate::{
AppState,
database::models::{DocumentId, VaultId, VaultUpdateId},
},
errors::{SyncServerError, not_found_error, server_error},
errors::{SyncServerError, client_error, not_found_error, server_error},
utils::normalize::normalize,
};
@ -52,7 +52,7 @@ pub async fn fetch_document_version_content(
)?;
if result.document_id != document_id {
return Err(not_found_error(anyhow!(
return Err(client_error(anyhow!(
"Document with document id `{document_id}` does not have a version with id \
`{vault_update_id}`",
)));

View file

@ -0,0 +1,42 @@
use axum::{
Json,
extract::{Path, State},
};
use log::debug;
use serde::Deserialize;
use crate::{
app_state::{
AppState,
database::models::{DocumentId, DocumentVersionWithoutContent, VaultId},
},
errors::{SyncServerError, server_error},
utils::normalize::normalize,
};
#[derive(Deserialize)]
pub struct FetchDocumentVersionsPathParams {
#[serde(deserialize_with = "normalize")]
vault_id: VaultId,
document_id: DocumentId,
}
#[axum::debug_handler]
pub async fn fetch_document_versions(
Path(FetchDocumentVersionsPathParams {
vault_id,
document_id,
}): Path<FetchDocumentVersionsPathParams>,
State(state): State<AppState>,
) -> Result<Json<Vec<DocumentVersionWithoutContent>>, SyncServerError> {
debug!("Fetching all versions for document `{document_id}` in vault `{vault_id}`");
let versions = state
.database
.get_document_versions(&vault_id, &document_id, None)
.await
.map_err(server_error)?;
Ok(Json(versions))
}

View file

@ -37,13 +37,13 @@ pub async fn fetch_latest_documents(
let documents = if let Some(since_update_id) = since_update_id {
state
.database
.get_latest_documents_since(&vault_id, since_update_id, None)
.get_latest_documents_since(&vault_id, since_update_id, None, None)
.await
.map_err(server_error)
} else {
state
.database
.get_latest_documents(&vault_id, None)
.get_latest_documents(&vault_id, None, None)
.await
.map_err(server_error)
}?;

View file

@ -0,0 +1,70 @@
use axum::{
Json,
extract::{Path, Query, State},
};
use log::debug;
use serde::Deserialize;
use super::responses::VaultHistoryResponse;
use crate::{
app_state::{
AppState,
database::models::{VaultId, VaultUpdateId},
},
errors::{SyncServerError, client_error, server_error},
utils::normalize::normalize,
};
const DEFAULT_LIMIT: i64 = 50;
const MAX_LIMIT: i64 = 500;
#[derive(Deserialize)]
pub struct FetchVaultHistoryPathParams {
#[serde(deserialize_with = "normalize")]
vault_id: VaultId,
}
#[derive(Deserialize)]
pub struct QueryParams {
limit: Option<i64>,
before_update_id: Option<VaultUpdateId>,
}
#[axum::debug_handler]
pub async fn fetch_vault_history(
Path(FetchVaultHistoryPathParams { vault_id }): Path<FetchVaultHistoryPathParams>,
Query(QueryParams {
limit,
before_update_id,
}): Query<QueryParams>,
State(state): State<AppState>,
) -> Result<Json<VaultHistoryResponse>, SyncServerError> {
if let Some(id) = before_update_id
&& id <= 0
{
return Err(client_error(anyhow::anyhow!(
"before_update_id must be a positive integer"
)));
}
let limit = limit.unwrap_or(DEFAULT_LIMIT).clamp(1, MAX_LIMIT);
debug!(
"Fetching vault history for vault `{vault_id}` (limit={limit}, before={before_update_id:?})"
);
// Fetch one extra row to determine if there are more results
let mut versions = state
.database
.get_vault_history(&vault_id, limit + 1, before_update_id, None)
.await
.map_err(server_error)?;
#[allow(clippy::cast_sign_loss)] // limit is clamped to [1, 500] above
let has_more = versions.len() > limit as usize;
if has_more {
versions.pop();
}
Ok(Json(VaultHistoryResponse { versions, has_more }))
}

View file

@ -1,7 +1,77 @@
use axum::response::{Html, IntoResponse};
use axum::{
body::Body,
extract::{Path, State},
http::{StatusCode, header},
response::{Html, IntoResponse, Response},
};
use log::warn;
use rust_embed::Embed;
pub async fn index() -> impl IntoResponse {
const HTML_CONTENT: &str = include_str!("./assets/index.html");
let html_content = HTML_CONTENT;
Html(html_content)
use crate::app_state::AppState;
#[derive(Embed)]
#[folder = "../frontend/history-ui/dist/"]
struct HistoryUiAssets;
pub async fn index(State(_state): State<AppState>) -> impl IntoResponse {
if let Some(content) = HistoryUiAssets::get("index.html") {
Html(
std::str::from_utf8(content.data.as_ref())
.inspect_err(|e| warn!("Embedded index.html is not valid UTF-8: {e}"))
.unwrap_or("<h1>VaultLink</h1>")
.to_owned(),
)
.into_response()
} else {
warn!("No embedded index.html found — history UI may not have been built");
Html("<h1>VaultLink server</h1>".to_owned()).into_response()
}
}
pub async fn spa_assets(Path(path): Path<String>) -> impl IntoResponse {
// The route is /assets/*path so path is relative to assets/.
// The embedded files include the assets/ prefix from the dist directory.
let full_path = format!("assets/{path}");
if let Some(content) = HistoryUiAssets::get(&full_path) {
let mime = mime_guess::from_path(&full_path).first_or_octet_stream();
return Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, mime.as_ref())
.body(Body::from(content.data.to_vec()))
.unwrap_or_else(|_| {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap_or_else(|_| Response::new(Body::empty()))
});
}
// Asset paths must match an embedded file — no SPA fallback.
// Serving index.html here would return 200 with text/html for missing
// .css/.js files, causing the browser to silently ignore the content.
Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from("Not found"))
.unwrap_or_else(|_| Response::new(Body::from("Not found")))
}
/// SPA fallback for production: serves index.html for client-side routes
/// (e.g. `/documents/123`).
pub async fn spa_fallback() -> impl IntoResponse {
match HistoryUiAssets::get("index.html") {
Some(content) => Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/html")
.body(Body::from(content.data.to_vec()))
.unwrap_or_else(|_| {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap_or_else(|_| Response::new(Body::empty()))
}),
None => Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from("Not found"))
.unwrap_or_else(|_| Response::new(Body::from("Not found"))),
}
}

View file

@ -0,0 +1,82 @@
use axum::{
Json,
extract::{Query, State},
};
use axum_extra::{
TypedHeader,
headers::{Authorization, authorization::Bearer},
};
use log::debug;
use serde::Deserialize;
use super::{
auth::authenticate,
responses::{ListVaultsResponse, VaultInfo},
};
use crate::{
app_state::AppState,
config::user_config::{AllowListedVaults, VaultAccess},
errors::{SyncServerError, server_error, unauthenticated_error},
};
const DEFAULT_LIMIT: usize = 50;
const MAX_LIMIT: usize = 200;
#[derive(Deserialize)]
pub struct QueryParams {
limit: Option<usize>,
after: Option<String>,
}
#[axum::debug_handler]
pub async fn list_vaults(
auth_header: Option<TypedHeader<Authorization<Bearer>>>,
Query(QueryParams { limit, after }): Query<QueryParams>,
State(state): State<AppState>,
) -> Result<Json<ListVaultsResponse>, SyncServerError> {
let auth_header = auth_header
.ok_or_else(|| unauthenticated_error(anyhow::anyhow!("Missing Authorization header")))?;
let user = authenticate(&state, auth_header.token().trim())?;
debug!("User `{}` listing accessible vaults", user.name);
let existing_vaults = state.database.list_vaults().await.map_err(server_error)?;
let mut accessible: Vec<String> = match user.vault_access {
VaultAccess::AllowAccessToAll => existing_vaults,
VaultAccess::AllowList(AllowListedVaults { ref allowed }) => existing_vaults
.into_iter()
.filter(|v| allowed.contains(v))
.collect(),
};
// Cursor-based pagination: skip vaults up to and including `after`
if let Some(ref cursor) = after {
accessible.retain(|v| v.as_str() > cursor.as_str());
}
let limit = limit.unwrap_or(DEFAULT_LIMIT).clamp(1, MAX_LIMIT);
let has_more = accessible.len() > limit;
accessible.truncate(limit);
let mut vaults = Vec::with_capacity(accessible.len());
for name in accessible {
let stats = state
.database
.get_vault_stats(&name)
.await
.map_err(server_error)?;
vaults.push(VaultInfo {
name,
document_count: stats.document_count,
created_at: stats.created_at,
});
}
Ok(Json(ListVaultsResponse {
vaults,
has_more,
user_name: user.name,
}))
}

View file

@ -0,0 +1,102 @@
use std::{
collections::HashMap,
sync::{Arc, Mutex},
time::Instant,
};
use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
use axum_extra::{
TypedHeader,
headers::{Authorization, authorization::Bearer},
};
/// Per-user token-bucket rate limiter. Each bearer token gets its own bucket
/// that refills to `max_per_second` tokens every second.
#[derive(Clone, Debug)]
pub struct RateLimiter {
max_per_second: u64,
buckets: Arc<Mutex<HashMap<String, Arc<TokenBucket>>>>,
}
#[derive(Debug)]
struct TokenBucket {
state: Mutex<BucketState>,
max_tokens: u64,
}
#[derive(Debug)]
struct BucketState {
tokens: u64,
last_refill: Instant,
}
impl RateLimiter {
/// Create a new per-user rate limiter.
///
/// # Panics
///
/// Panics if `max_per_second` is 0.
pub fn new(max_per_second: u64) -> Self {
assert!(
max_per_second > 0,
"max_per_second must be > 0 (set rate_limit_per_user_per_second to null in config to disable)"
);
Self {
max_per_second,
buckets: Arc::new(Mutex::new(HashMap::new())),
}
}
fn get_or_create_bucket(&self, token: &str) -> Arc<TokenBucket> {
self.buckets
.lock()
.expect("rate limiter lock poisoned")
.entry(token.to_owned())
.or_insert_with(|| {
Arc::new(TokenBucket {
state: Mutex::new(BucketState {
tokens: self.max_per_second,
last_refill: Instant::now(),
}),
max_tokens: self.max_per_second,
})
})
.clone()
}
}
impl TokenBucket {
fn try_acquire(&self) -> bool {
let mut state = self.state.lock().expect("token bucket lock poisoned");
let now = Instant::now();
if now.duration_since(state.last_refill).as_secs() >= 1 {
state.tokens = self.max_tokens;
state.last_refill = now;
}
if state.tokens > 0 {
state.tokens -= 1;
true
} else {
false
}
}
}
pub async fn rate_limit_middleware(
axum::extract::State(limiter): axum::extract::State<RateLimiter>,
auth_header: Option<TypedHeader<Authorization<Bearer>>>,
req: Request,
next: Next,
) -> Result<Response, StatusCode> {
let Some(TypedHeader(auth)) = auth_header else {
return Ok(next.run(req).await);
};
let bucket = limiter.get_or_create_bucket(auth.token());
if bucket.try_acquire() {
Ok(next.run(req).await)
} else {
Err(StatusCode::TOO_MANY_REQUESTS)
}
}

View file

@ -4,18 +4,16 @@ use reconcile_text::NumberOrText;
use serde::{self, Deserialize};
use ts_rs::TS;
use crate::app_state::database::models::{DocumentId, VaultUpdateId};
use crate::app_state::database::models::VaultUpdateId;
#[derive(TS, Debug, TryFromMultipart)]
#[ts(export)]
pub struct CreateDocumentVersion {
/// The client can decide the document id (if it wishes to) in order
/// to help with syncing. If the client does not provide a document id,
/// the server will generate one. If the client provides a document id
/// it must not already exist in the database.
pub document_id: Option<DocumentId>,
pub relative_path: String,
#[ts(type = "number")]
pub last_seen_vault_update_id: VaultUpdateId,
#[ts(as = "Vec<u8>")]
#[form_data(limit = "unlimited")]
pub content: FieldData<Bytes>,
@ -24,7 +22,9 @@ pub struct CreateDocumentVersion {
#[derive(Debug, TryFromMultipart)]
pub struct UpdateBinaryDocumentVersion {
pub parent_version_id: VaultUpdateId,
pub relative_path: String,
// None on a content-only edit; Some on a user rename. When None,
// the server keeps the document at its current path.
pub relative_path: Option<String>,
#[form_data(limit = "unlimited")]
pub content: FieldData<Bytes>,
@ -34,18 +34,13 @@ pub struct UpdateBinaryDocumentVersion {
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct UpdateTextDocumentVersion {
#[ts(as = "i32")]
#[ts(type = "number")]
pub parent_version_id: VaultUpdateId,
pub relative_path: String,
// None on a content-only edit; Some on a user rename. When None,
// the server keeps the document at its current path.
pub relative_path: Option<String>,
#[ts(type = "Array<number | string>")]
pub content: Vec<NumberOrText>,
}
#[derive(TS, Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct DeleteDocumentVersion {
pub relative_path: String,
}

View file

@ -1,3 +1,4 @@
use chrono::{DateTime, Utc};
use serde::{self, Serialize};
use ts_rs::TS;
@ -36,7 +37,36 @@ pub struct FetchLatestDocumentsResponse {
pub last_update_id: VaultUpdateId,
}
/// Response to an update document request.
/// Response to a vault history request (paginated).
#[derive(TS, Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct VaultHistoryResponse {
pub versions: Vec<DocumentVersionWithoutContent>,
pub has_more: bool,
}
/// Summary of a single vault returned by the list-vaults endpoint.
#[derive(TS, Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct VaultInfo {
pub name: String,
pub document_count: u32,
pub created_at: Option<DateTime<Utc>>,
}
/// Response to listing vaults accessible to the authenticated user.
#[derive(TS, Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct ListVaultsResponse {
pub vaults: Vec<VaultInfo>,
pub has_more: bool,
pub user_name: String,
}
/// Response to a create/update document request.
#[derive(TS, Debug, Clone, Serialize)]
#[serde(tag = "type")]
#[ts(export)]

View file

@ -16,10 +16,15 @@ use super::{
use crate::{
app_state::{
AppState,
database::models::{DocumentId, StoredDocumentVersion, VaultId, VaultUpdateId},
database::{
WriteTransaction,
models::{DocumentId, StoredDocumentVersion, VaultId, VaultUpdateId},
},
},
config::user_config::User,
errors::{SyncServerError, client_error, not_found_error, server_error},
errors::{
SyncServerError, client_error, not_found_error, server_error, write_transaction_error,
},
server::requests::UpdateBinaryDocumentVersion,
utils::{
find_first_available_path::find_first_available_path, is_binary::is_binary,
@ -46,18 +51,27 @@ pub async fn update_binary(
State(state): State<AppState>,
TypedMultipart(request): TypedMultipart<UpdateBinaryDocumentVersion>,
) -> Result<Json<DocumentUpdateResponse>, SyncServerError> {
let parent_document = get_parent_document(&state, &vault_id, request.parent_version_id).await?;
let parent_document =
get_parent_document(&state, &vault_id, &document_id, request.parent_version_id).await?;
let content = request.content.contents.to_vec();
let transaction = state
.database
.create_write_transaction(&vault_id)
.await
.map_err(write_transaction_error)?;
update_document(
parent_document,
&parent_document.relative_path,
parent_document.content,
vault_id,
document_id,
request.relative_path.as_deref(),
content,
user,
device_id,
state,
&request.relative_path,
content,
transaction,
)
.await
}
@ -74,28 +88,36 @@ pub async fn update_text(
State(state): State<AppState>,
Json(request): Json<UpdateTextDocumentVersion>,
) -> Result<Json<DocumentUpdateResponse>, SyncServerError> {
let parent_document = get_parent_document(&state, &vault_id, request.parent_version_id).await?;
let parent_document =
get_parent_document(&state, &vault_id, &document_id, request.parent_version_id).await?;
let edited_text = EditedText::from_diff(
str::from_utf8(&parent_document.content)
.expect("parent must be valid UTF-8 because it's a text document"),
request.content,
&*BuiltinTokenizer::Word,
)
.context("Failed to apply given diff to parent document")
.map_err(client_error)?;
let parent_text = str::from_utf8(&parent_document.content)
.context("Parent version contains binary content; use putBinary instead of putText")
.map_err(client_error)?;
let edited_text = EditedText::from_diff(parent_text, request.content, &*BuiltinTokenizer::Word)
.context("Failed to apply given diff to parent document")
.map_err(client_error)?;
let content = edited_text.apply().text().into_bytes();
let transaction = state
.database
.create_write_transaction(&vault_id)
.await
.map_err(write_transaction_error)?;
update_document(
parent_document,
&parent_document.relative_path,
parent_document.content,
vault_id,
document_id,
request.relative_path.as_deref(),
content,
user,
device_id,
state,
&request.relative_path,
content,
transaction,
)
.await
}
@ -103,9 +125,10 @@ pub async fn update_text(
async fn get_parent_document(
state: &AppState,
vault_id: &VaultId,
document_id: &DocumentId,
parent_version_id: VaultUpdateId,
) -> Result<StoredDocumentVersion, SyncServerError> {
state
let parent = state
.database
.get_document_version(vault_id, parent_version_id, None)
.await
@ -117,29 +140,36 @@ async fn get_parent_document(
)))
},
Ok,
)
)?;
if &parent.document_id != document_id {
return Err(client_error(anyhow!(
"Parent version `{parent_version_id}` does not belong to document `{document_id}`"
)));
}
Ok(parent)
}
#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
async fn update_document(
parent_document: StoredDocumentVersion,
pub async fn update_document(
parent_relative_path: &str,
parent_content: Vec<u8>,
vault_id: VaultId,
document_id: DocumentId,
relative_path: Option<&str>,
content: Vec<u8>,
user: User,
device_id: DeviceIdHeader,
state: AppState,
relative_path: &str,
content: Vec<u8>,
mut transaction: WriteTransaction,
) -> Result<Json<DocumentUpdateResponse>, SyncServerError> {
debug!("Updating document `{document_id}` in vault `{vault_id}`");
let sanitized_relative_path = sanitize_path(relative_path);
let mut transaction = state
.database
.create_write_transaction(&vault_id)
.await
.map_err(server_error)?;
let sanitized_relative_path = relative_path
.map(sanitize_path)
.transpose()
.map_err(client_error)?;
let last_update_id = state
.database
@ -175,9 +205,12 @@ async fn update_document(
}
// Return the latest version if the content and path are the same as the latest
// version
if content == latest_version.content && sanitized_relative_path == latest_version.relative_path
{
// version. A missing relative_path means "keep current path", so the path
// is implicitly unchanged.
let path_unchanged = sanitized_relative_path
.as_deref()
.is_none_or(|p| p == latest_version.relative_path);
if content == latest_version.content && path_unchanged {
info!(
"Document content is the same as the latest version for `{document_id}`, skipping update"
);
@ -192,62 +225,89 @@ async fn update_document(
)));
}
// For mergability, use whichever path the new version will live at — the
// requested rename target if the client sent one, otherwise the existing
// server-side path.
let mergable_check_path = sanitized_relative_path
.as_deref()
.unwrap_or(&latest_version.relative_path);
let are_all_participants_mergable = is_file_type_mergable(
&sanitized_relative_path,
mergable_check_path,
&state.config.server.mergeable_file_extensions,
) && !is_binary(&parent_document.content)
) && !is_binary(&parent_content)
&& !is_binary(&latest_version.content)
&& !is_binary(&content);
let merged_content = if are_all_participants_mergable {
let (merged_content, is_different_from_request_content) = if are_all_participants_mergable {
info!("Merging changes for document `{document_id}` in vault `{vault_id}`");
reconcile(
str::from_utf8(&parent_document.content)
.expect("parent must be valid UTF-8 because it's not binary"),
&str::from_utf8(&latest_version.content)
.expect("latest_version must be valid UTF-8 because it's not binary")
.into(),
&str::from_utf8(&content)
.expect("content must be valid UTF-8 because it's not binary")
.into(),
&*BuiltinTokenizer::Word,
)
.apply()
.text()
.into_bytes()
let parent_text = str::from_utf8(&parent_content)
.context("Parent document content is not valid UTF-8")
.map_err(client_error)?;
let latest_text = str::from_utf8(&latest_version.content)
.context("Latest version content is not valid UTF-8")
.map_err(client_error)?;
let new_text = str::from_utf8(&content)
.context("New content is not valid UTF-8")
.map_err(client_error)?;
let parent_owned = parent_text.to_owned();
let latest_owned = latest_text.to_owned();
let new_owned = new_text.to_owned();
let content_clone = content.clone();
let (merged, is_different) = tokio::task::spawn_blocking(move || {
let merged = reconcile(
&parent_owned,
&latest_owned.into(),
&new_owned.into(),
&*BuiltinTokenizer::Word,
)
.apply()
.text()
.into_bytes();
let is_different = merged != content_clone;
(merged, is_different)
})
.await
.map_err(|e| server_error(anyhow::anyhow!("Reconcile task failed: {e}")))?;
(merged, is_different)
} else {
content.clone()
(content, false) // false means that the client doesn't need to refetch the file as we can ensure the remote and local versions are the same as LWW is the merging method for binary files
};
let is_different_from_request_content = merged_content != content;
// Rename resolution: only apply the client's rename if (a) the client
// requested one (`sanitized_relative_path` is `Some`) and (b) the
// document's path hasn't changed since this client's parent version.
// If the parent and latest paths differ, another client already renamed
// the document — keep the latest path (first rename wins). Content
// changes from both clients are still merged correctly via the 3-way
// reconcile above, independent of which rename wins. A missing
// relative_path means "keep current path" (content-only edit).
let new_relative_path = match sanitized_relative_path.as_deref() {
Some(requested)
if parent_relative_path == latest_version.relative_path
&& requested != latest_version.relative_path =>
{
let new_path =
find_first_available_path(&vault_id, requested, &state.database, &mut transaction)
.await
.map_err(server_error)?;
// We can only update the relative path if we're the first one to do so
let new_relative_path = if parent_document.relative_path == latest_version.relative_path
&& latest_version.relative_path != sanitized_relative_path
{
let new_path = find_first_available_path(
&vault_id,
&sanitized_relative_path,
&state.database,
&mut transaction,
)
.await
.map_err(server_error)?;
if new_path != requested {
info!(
"Document already exists at new location: `{requested}` when trying to update it in vault `{vault_id}`, deconflicting by creating at `{new_path}`"
);
}
if new_path != sanitized_relative_path {
info!(
"Document already exists at new location: `{sanitized_relative_path}` when trying to update it in vault `{vault_id}`, deconflicting by creating at `{new_path}`"
);
new_path
}
new_path
} else {
latest_version.relative_path.clone()
_ => latest_version.relative_path.clone(),
};
let new_version = StoredDocumentVersion {
document_id,
vault_update_id: last_update_id + 1,
creation_vault_update_id: latest_version.creation_vault_update_id,
relative_path: new_relative_path,
content: merged_content,
updated_date: chrono::Utc::now(),
@ -259,7 +319,7 @@ async fn update_document(
state
.database
.insert_document_version(&vault_id, &new_version, Some(transaction))
.insert_document_version(&vault_id, &new_version, transaction)
.await
.map_err(server_error)?;