Various server improvements
This commit is contained in:
parent
3fe5f49050
commit
233ce1254b
10 changed files with 177 additions and 55 deletions
|
|
@ -22,6 +22,7 @@ pub struct StoredDocumentVersion {
|
|||
pub device_id: DeviceId,
|
||||
#[allow(dead_code)] // This is for manual analysis
|
||||
pub has_been_merged: bool,
|
||||
pub idempotency_key: Option<String>,
|
||||
}
|
||||
|
||||
impl PartialEq<Self> for StoredDocumentVersion {
|
||||
|
|
@ -33,7 +34,7 @@ impl PartialEq<Self> for StoredDocumentVersion {
|
|||
#[derive(TS, Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct DocumentVersionWithoutContent {
|
||||
#[ts(as = "i32")]
|
||||
#[ts(type = "number")]
|
||||
pub vault_update_id: VaultUpdateId,
|
||||
|
||||
pub document_id: DocumentId,
|
||||
|
|
@ -43,7 +44,7 @@ pub struct DocumentVersionWithoutContent {
|
|||
pub user_id: UserId,
|
||||
pub device_id: DeviceId,
|
||||
|
||||
#[ts(as = "i32")]
|
||||
#[ts(type = "number")]
|
||||
pub content_size: u64,
|
||||
}
|
||||
|
||||
|
|
@ -65,7 +66,7 @@ impl From<StoredDocumentVersion> for DocumentVersionWithoutContent {
|
|||
#[derive(TS, Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct DocumentVersion {
|
||||
#[ts(as = "i32")]
|
||||
#[ts(type = "number")]
|
||||
pub vault_update_id: VaultUpdateId,
|
||||
|
||||
pub document_id: DocumentId,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ pub struct WebSocketHandshake {
|
|||
pub token: String,
|
||||
pub device_id: DeviceId,
|
||||
|
||||
#[ts(as = "Option<i32>")]
|
||||
#[ts(type = "number | null")]
|
||||
pub last_seen_vault_update_id: Option<VaultUpdateId>,
|
||||
}
|
||||
|
||||
|
|
@ -28,7 +28,7 @@ pub struct DocumentWithCursors {
|
|||
// that it exists and can be client-side
|
||||
// interpolated. However, the actual
|
||||
// position is meaningless.
|
||||
#[ts(as = "Option<u32>")]
|
||||
#[ts(type = "number | null")]
|
||||
pub vault_update_id: Option<VaultUpdateId>,
|
||||
|
||||
pub document_id: DocumentId,
|
||||
|
|
@ -70,6 +70,7 @@ pub struct WebSocketVaultUpdate {
|
|||
pub enum WebSocketClientMessage {
|
||||
Handshake(WebSocketHandshake),
|
||||
CursorPositions(CursorPositionFromClient),
|
||||
Ping {},
|
||||
}
|
||||
|
||||
#[derive(TS, Serialize, Clone, Debug)]
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ use crate::{
|
|||
database::models::{DocumentVersionWithoutContent, VaultId, VaultUpdateId},
|
||||
},
|
||||
config::user_config::User,
|
||||
errors::{SyncServerError, server_error, unauthenticated_error},
|
||||
errors::{SyncServerError, client_error, server_error, unauthenticated_error},
|
||||
server::auth::auth,
|
||||
};
|
||||
|
||||
|
|
@ -26,16 +26,16 @@ pub fn get_authenticated_handshake(
|
|||
if let Some(Message::Text(message)) = message {
|
||||
let message: WebSocketClientMessage = serde_json::from_str(&message)
|
||||
.context("Failed to parse message")
|
||||
.map_err(server_error)?;
|
||||
.map_err(client_error)?;
|
||||
|
||||
match message {
|
||||
WebSocketClientMessage::Handshake(handshake) => {
|
||||
let user = auth(state, handshake.token.trim(), vault_id)?;
|
||||
Ok(AuthenticatedWebSocketHandshake { handshake, user })
|
||||
}
|
||||
WebSocketClientMessage::CursorPositions(_) => Err(unauthenticated_error(
|
||||
anyhow::anyhow!("Expected a handshake message"),
|
||||
)),
|
||||
WebSocketClientMessage::CursorPositions(_) | WebSocketClientMessage::Ping {} => Err(
|
||||
unauthenticated_error(anyhow::anyhow!("Expected a handshake message")),
|
||||
),
|
||||
}
|
||||
} else {
|
||||
Err(unauthenticated_error(anyhow::anyhow!(
|
||||
|
|
|
|||
|
|
@ -28,23 +28,20 @@ pub struct Config {
|
|||
|
||||
impl Config {
|
||||
pub async fn read_or_create(path: &Path) -> Result<Self> {
|
||||
let config = if path.exists() {
|
||||
info!(
|
||||
"Loading configuration from `{}`",
|
||||
path.canonicalize().unwrap().display()
|
||||
);
|
||||
Self::load_from_file(path).await?
|
||||
let display_path = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
|
||||
|
||||
if path.exists() {
|
||||
info!("Loading configuration from `{}`", display_path.display());
|
||||
Self::load_from_file(path).await
|
||||
} else {
|
||||
Self::default()
|
||||
};
|
||||
|
||||
config.write(path).await?;
|
||||
info!(
|
||||
"Updated configuration at `{}`",
|
||||
path.canonicalize().unwrap().display()
|
||||
);
|
||||
|
||||
Ok(config)
|
||||
let config = Self::default();
|
||||
config.write(path).await?;
|
||||
info!(
|
||||
"Created default configuration at `{}`",
|
||||
display_path.display()
|
||||
);
|
||||
Ok(config)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn load_from_file(path: &Path) -> Result<Self> {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
use bimap::BiHashMap;
|
||||
use rand::{Rng, distr::Alphanumeric, rng};
|
||||
use serde::{Deserialize, Deserializer, Serialize, de::Error};
|
||||
use subtle::ConstantTimeEq;
|
||||
|
||||
use crate::app_state::database::models::VaultId;
|
||||
|
||||
|
|
@ -19,10 +20,19 @@ where
|
|||
let mut user_token_map = BiHashMap::new();
|
||||
for user in &users {
|
||||
if let Some(existing_name) = user_token_map.get_by_right(&user.token) {
|
||||
let redacted = if user.token.len() > 6 {
|
||||
format!(
|
||||
"{}...{}",
|
||||
&user.token[..3],
|
||||
&user.token[user.token.len() - 3..]
|
||||
)
|
||||
} else {
|
||||
"***".to_owned()
|
||||
};
|
||||
return Err(D::Error::custom(format!(
|
||||
"Duplicate user token found: `{}` for users `{}` and `{}`. User tokens must be \
|
||||
unique.",
|
||||
user.token, existing_name, user.name
|
||||
"Duplicate user token found: `{redacted}` for users `{}` and `{}`. User tokens \
|
||||
must be unique.",
|
||||
existing_name, user.name
|
||||
)));
|
||||
}
|
||||
|
||||
|
|
@ -41,7 +51,9 @@ where
|
|||
|
||||
impl UserConfig {
|
||||
pub fn get_user(&self, token: &str) -> Option<&User> {
|
||||
self.user_configs.iter().find(|u| u.token == token)
|
||||
self.user_configs
|
||||
.iter()
|
||||
.find(|u| u.token.as_bytes().ct_eq(token.as_bytes()).into())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ use axum::{
|
|||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use log::{debug, error};
|
||||
use log::{debug, error, warn};
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
use ts_rs::TS;
|
||||
|
|
@ -29,6 +29,9 @@ pub enum SyncServerError {
|
|||
|
||||
#[error("Permission denied error: {0}")]
|
||||
PermissionDeniedError(#[source] anyhow::Error),
|
||||
|
||||
#[error("Too many requests: {0}")]
|
||||
TooManyRequests(#[source] anyhow::Error),
|
||||
}
|
||||
|
||||
impl SyncServerError {
|
||||
|
|
@ -39,7 +42,8 @@ impl SyncServerError {
|
|||
| Self::ServerError(error)
|
||||
| Self::NotFound(error)
|
||||
| Self::Unauthenticated(error)
|
||||
| Self::PermissionDeniedError(error) => error.into(),
|
||||
| Self::PermissionDeniedError(error)
|
||||
| Self::TooManyRequests(error) => error.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -69,7 +73,22 @@ impl Display for SerializedError {
|
|||
|
||||
impl IntoResponse for SyncServerError {
|
||||
fn into_response(self) -> Response {
|
||||
let body = Json(self.serialize());
|
||||
let serialized = self.serialize();
|
||||
|
||||
match &self {
|
||||
Self::InitError(_) | Self::ServerError(_) => {
|
||||
error!("{serialized}");
|
||||
}
|
||||
Self::ClientError(_) | Self::NotFound(_) => {
|
||||
warn!("{serialized}");
|
||||
}
|
||||
Self::TooManyRequests(_) => {
|
||||
warn!("{serialized}");
|
||||
}
|
||||
Self::Unauthenticated(_) | Self::PermissionDeniedError(_) => {}
|
||||
}
|
||||
|
||||
let body = Json(serialized);
|
||||
|
||||
match self {
|
||||
Self::InitError(_) | Self::ServerError(_) => {
|
||||
|
|
@ -79,6 +98,9 @@ impl IntoResponse for SyncServerError {
|
|||
Self::NotFound(_) => (StatusCode::NOT_FOUND, body).into_response(),
|
||||
Self::Unauthenticated(_) => (StatusCode::UNAUTHORIZED, body).into_response(),
|
||||
Self::PermissionDeniedError(_) => (StatusCode::FORBIDDEN, body).into_response(),
|
||||
Self::TooManyRequests(_) => {
|
||||
(StatusCode::TOO_MANY_REQUESTS, body).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -102,6 +124,7 @@ impl From<&anyhow::Error> for SerializedError {
|
|||
SyncServerError::NotFound(_) => "NotFound",
|
||||
SyncServerError::Unauthenticated(_) => "Unauthenticated",
|
||||
SyncServerError::PermissionDeniedError(_) => "PermissionDeniedError",
|
||||
SyncServerError::TooManyRequests(_) => "TooManyRequests",
|
||||
},
|
||||
),
|
||||
message: error.to_string(),
|
||||
|
|
@ -139,3 +162,18 @@ pub fn permission_denied_error(error: anyhow::Error) -> SyncServerError {
|
|||
debug!("Permission denied: {error:?}");
|
||||
SyncServerError::PermissionDeniedError(error)
|
||||
}
|
||||
|
||||
pub fn too_many_requests_error(error: anyhow::Error) -> SyncServerError {
|
||||
debug!("Too many requests: {error:?}");
|
||||
SyncServerError::TooManyRequests(error)
|
||||
}
|
||||
|
||||
/// Maps a `create_write_transaction` error to 429 if the database is busy,
|
||||
/// or 500 for all other failures.
|
||||
pub fn write_transaction_error(error: anyhow::Error) -> SyncServerError {
|
||||
if error.downcast_ref::<crate::app_state::database::WriteBusyError>().is_some() {
|
||||
too_many_requests_error(error)
|
||||
} else {
|
||||
server_error(error)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ use axum_extra::{
|
|||
TypedHeader,
|
||||
headers::{Authorization, authorization::Bearer},
|
||||
};
|
||||
use log::info;
|
||||
use log::{debug, info};
|
||||
|
||||
use crate::{
|
||||
app_state::{AppState, database::models::VaultId},
|
||||
|
|
@ -21,10 +21,12 @@ use crate::{
|
|||
pub async fn auth_middleware(
|
||||
State(state): State<AppState>,
|
||||
Path(path_params): Path<HashMap<String, String>>,
|
||||
TypedHeader(auth_header): TypedHeader<Authorization<Bearer>>,
|
||||
auth_header: Option<TypedHeader<Authorization<Bearer>>>,
|
||||
mut req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, SyncServerError> {
|
||||
let auth_header = auth_header
|
||||
.ok_or_else(|| unauthenticated_error(anyhow::anyhow!("Missing Authorization header")))?;
|
||||
let token = auth_header.token().trim();
|
||||
let vault_id = normalize_string(
|
||||
path_params
|
||||
|
|
@ -51,8 +53,8 @@ pub fn auth(state: &AppState, token: &str, vault_id: &VaultId) -> Result<User, S
|
|||
VaultAccess::AllowAccessToAll => true,
|
||||
VaultAccess::AllowList(AllowListedVaults { ref allowed }) => allowed.contains(vault_id),
|
||||
} {
|
||||
info!(
|
||||
"User `{}` is authenticated and is authorised to access to vault `{vault_id}`",
|
||||
debug!(
|
||||
"User `{}` is authenticated and is authorised to access vault `{vault_id}`",
|
||||
user.name
|
||||
);
|
||||
|
||||
|
|
|
|||
72
sync-server/src/server/rate_limit.rs
Normal file
72
sync-server/src/server/rate_limit.rs
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
use std::sync::{
|
||||
Arc,
|
||||
atomic::{AtomicU64, Ordering},
|
||||
};
|
||||
|
||||
use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
|
||||
|
||||
/// Simple token-bucket rate limiter that refills every second.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RateLimiter {
|
||||
inner: Arc<TokenBucket>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TokenBucket {
|
||||
tokens: AtomicU64,
|
||||
max_tokens: u64,
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
/// Create a new rate limiter. Spawns a background task that refills tokens
|
||||
/// every second.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `max_per_second` is 0.
|
||||
pub fn new(max_per_second: u64) -> Self {
|
||||
assert!(
|
||||
max_per_second > 0,
|
||||
"max_per_second must be > 0 (use 0 in config to disable rate limiting entirely)"
|
||||
);
|
||||
|
||||
let bucket = Arc::new(TokenBucket {
|
||||
tokens: AtomicU64::new(max_per_second),
|
||||
max_tokens: max_per_second,
|
||||
});
|
||||
|
||||
let bucket_clone = bucket.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(1));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
bucket_clone
|
||||
.tokens
|
||||
.store(bucket_clone.max_tokens, Ordering::Release);
|
||||
}
|
||||
});
|
||||
|
||||
Self { inner: bucket }
|
||||
}
|
||||
|
||||
fn try_acquire(&self) -> bool {
|
||||
self.inner
|
||||
.tokens
|
||||
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
|
||||
if current > 0 { Some(current - 1) } else { None }
|
||||
})
|
||||
.is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn rate_limit_middleware(
|
||||
axum::extract::State(limiter): axum::extract::State<RateLimiter>,
|
||||
req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
if limiter.try_acquire() {
|
||||
Ok(next.run(req).await)
|
||||
} else {
|
||||
Err(StatusCode::TOO_MANY_REQUESTS)
|
||||
}
|
||||
}
|
||||
|
|
@ -4,21 +4,18 @@ use reconcile_text::NumberOrText;
|
|||
use serde::{self, Deserialize};
|
||||
use ts_rs::TS;
|
||||
|
||||
use crate::app_state::database::models::{DocumentId, VaultUpdateId};
|
||||
use crate::app_state::database::models::VaultUpdateId;
|
||||
|
||||
#[derive(TS, Debug, TryFromMultipart)]
|
||||
#[ts(export)]
|
||||
pub struct CreateDocumentVersion {
|
||||
/// The client can decide the document id (if it wishes to) in order
|
||||
/// to help with syncing. If the client does not provide a document id,
|
||||
/// the server will generate one. If the client provides a document id
|
||||
/// it must not already exist in the database.
|
||||
pub document_id: Option<DocumentId>,
|
||||
pub relative_path: String,
|
||||
|
||||
#[ts(as = "Vec<u8>")]
|
||||
#[form_data(limit = "unlimited")]
|
||||
pub content: FieldData<Bytes>,
|
||||
|
||||
pub idempotency_key: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, TryFromMultipart)]
|
||||
|
|
@ -34,7 +31,7 @@ pub struct UpdateBinaryDocumentVersion {
|
|||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export)]
|
||||
pub struct UpdateTextDocumentVersion {
|
||||
#[ts(as = "i32")]
|
||||
#[ts(type = "number")]
|
||||
pub parent_version_id: VaultUpdateId,
|
||||
|
||||
pub relative_path: String,
|
||||
|
|
@ -43,9 +40,5 @@ pub struct UpdateTextDocumentVersion {
|
|||
pub content: Vec<NumberOrText>,
|
||||
}
|
||||
|
||||
#[derive(TS, Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export)]
|
||||
pub struct DeleteDocumentVersion {
|
||||
pub relative_path: String,
|
||||
}
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct DeleteDocumentVersion {}
|
||||
|
|
|
|||
|
|
@ -1,25 +1,31 @@
|
|||
use crate::app_state::database::models::VaultId;
|
||||
use crate::{app_state::database::Transaction, utils::dedup_paths::dedup_paths};
|
||||
use anyhow::Result;
|
||||
use log::{debug, info};
|
||||
use crate::utils::dedup_paths::dedup_paths;
|
||||
use anyhow::{Result, bail};
|
||||
use log::info;
|
||||
use sqlx::sqlite::SqliteConnection;
|
||||
|
||||
|
||||
pub async fn find_first_available_path(
|
||||
vault_id: &VaultId,
|
||||
sanitized_relative_path: &str,
|
||||
database: &crate::app_state::database::Database,
|
||||
transaction: &mut Transaction<'_>,
|
||||
connection: &mut SqliteConnection,
|
||||
) -> Result<String> {
|
||||
info!("Finding first available path for `{sanitized_relative_path}` in vault `{vault_id}`");
|
||||
for candidate in dedup_paths(sanitized_relative_path) {
|
||||
debug!("Checking candidate path for deconflicting names: `{candidate}`");
|
||||
if database
|
||||
.get_latest_document_by_path(vault_id, &candidate, Some(transaction))
|
||||
.get_latest_non_deleted_document_by_path(vault_id, &candidate, Some(connection))
|
||||
.await?
|
||||
.is_none()
|
||||
{
|
||||
info!("Selected available path: `{candidate}`");
|
||||
return Ok(candidate);
|
||||
}
|
||||
|
||||
info!(
|
||||
"Finding first available path for `{sanitized_relative_path}` in vault `{vault_id}` as `{candidate}` is already taken"
|
||||
);
|
||||
}
|
||||
|
||||
unreachable!("dedup_paths produces infinite paths");
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue