diff --git a/backend/config-e2e.yml b/backend/config-e2e.yml index 04fe344a..a49ab287 100644 --- a/backend/config-e2e.yml +++ b/backend/config-e2e.yml @@ -1,6 +1,6 @@ database: databases_directory_path: databases - max_connections: 12 + max_connections_per_vault: 12 server: host: 0.0.0.0 @@ -13,7 +13,7 @@ users: - name: admin token: test-token-change-me vaults: - all: true + allow_access_to_all: true - name: test token: other-test-token diff --git a/backend/sync_server/src/app_state/database.rs b/backend/sync_server/src/app_state/database.rs index fa7f35b0..2c2cfced 100644 --- a/backend/sync_server/src/app_state/database.rs +++ b/backend/sync_server/src/app_state/database.rs @@ -73,7 +73,7 @@ impl Database { .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal); let pool = SqlitePoolOptions::new() - .max_connections(config.max_connections) + .max_connections(config.max_connections_per_vault) .test_before_acquire(true) .connect_with(connection_options) .await diff --git a/backend/sync_server/src/config/database_config.rs b/backend/sync_server/src/config/database_config.rs index b3d2fad7..ef26a09d 100644 --- a/backend/sync_server/src/config/database_config.rs +++ b/backend/sync_server/src/config/database_config.rs @@ -3,15 +3,15 @@ use std::path::PathBuf; use log::debug; use serde::{Deserialize, Serialize}; -use crate::consts::{DEFAULT_DATABASES_DIRECTORY_PATH, DEFAULT_MAX_CONNECTIONS}; +use crate::consts::{DEFAULT_DATABASES_DIRECTORY_PATH, DEFAULT_MAX_CONNECTIONS_PER_VAULT}; #[derive(Debug, Deserialize, Serialize, Clone)] pub struct DatabaseConfig { #[serde(default = "default_databases_directory_path")] pub databases_directory_path: PathBuf, - #[serde(default = "default_max_connections")] - pub max_connections: u32, + #[serde(default = "default_max_connections_per_vault")] + pub max_connections_per_vault: u32, } fn default_databases_directory_path() -> PathBuf { @@ -19,16 +19,16 @@ fn default_databases_directory_path() -> PathBuf { PathBuf::from(DEFAULT_DATABASES_DIRECTORY_PATH) } -fn default_max_connections() -> u32 { - debug!("Using default max connections: {DEFAULT_MAX_CONNECTIONS}"); - DEFAULT_MAX_CONNECTIONS +fn default_max_connections_per_vault() -> u32 { + debug!("Using default max connections: {DEFAULT_MAX_CONNECTIONS_PER_VAULT}"); + DEFAULT_MAX_CONNECTIONS_PER_VAULT } impl Default for DatabaseConfig { fn default() -> Self { Self { databases_directory_path: default_databases_directory_path(), - max_connections: default_max_connections(), + max_connections_per_vault: default_max_connections_per_vault(), } } } diff --git a/backend/sync_server/src/config/user_config.rs b/backend/sync_server/src/config/user_config.rs index c3afca14..4ee7c72d 100644 --- a/backend/sync_server/src/config/user_config.rs +++ b/backend/sync_server/src/config/user_config.rs @@ -1,6 +1,10 @@ +use std::default; + use rand::{Rng as _, distributions::Alphanumeric, thread_rng}; use serde::{Deserialize, Serialize}; +use crate::app_state::database::models::VaultId; + #[derive(Debug, Deserialize, Serialize, Clone)] pub struct UserConfig { #[serde(default = "default_users")] @@ -17,6 +21,7 @@ impl UserConfig { pub struct User { pub name: String, pub token: String, + pub vault_access: VaultAccess, } impl Default for UserConfig { @@ -27,10 +32,25 @@ impl Default for UserConfig { } } +#[derive(Debug, Deserialize, Serialize, Clone, Default)] +#[serde(rename_all = "snake_case", tag = "type")] +pub enum VaultAccess { + #[default] + AllowAccessToAll, + + AllowList(AllowListedVaults), +} + +#[derive(Debug, Deserialize, Serialize, Clone, Default)] +pub struct AllowListedVaults { + pub allowed: Vec, +} + fn default_users() -> Vec { vec![User { name: "admin".to_owned(), token: get_random_token(), + vault_access: VaultAccess::default(), }] } diff --git a/backend/sync_server/src/consts.rs b/backend/sync_server/src/consts.rs index 2d3bec55..1453f25a 100644 --- a/backend/sync_server/src/consts.rs +++ b/backend/sync_server/src/consts.rs @@ -2,6 +2,6 @@ pub const DEFAULT_CONFIG_PATH: &str = "config.yml"; pub const DEFAULT_DATABASES_DIRECTORY_PATH: &str = "databases"; pub const DEFAULT_HOST: &str = "127.0.0.1"; pub const DEFAULT_PORT: u16 = 3000; -pub const DEFAULT_MAX_CONNECTIONS: u32 = 12; +pub const DEFAULT_MAX_CONNECTIONS_PER_VAULT: u32 = 12; pub const DEFAULT_MAX_BODY_SIZE_MB: usize = 4096; pub const DEFAULT_MAX_CLIENTS_PER_VAULT: usize = 256; diff --git a/backend/sync_server/src/errors.rs b/backend/sync_server/src/errors.rs index 5aec9c32..31538107 100644 --- a/backend/sync_server/src/errors.rs +++ b/backend/sync_server/src/errors.rs @@ -24,7 +24,7 @@ pub enum SyncServerError { NotFound(#[source] anyhow::Error), #[error("Unauthorized: {0}")] - Unauthorized(#[source] anyhow::Error), + Unauthenticated(#[source] anyhow::Error), #[error("Permission denied error: {0}")] PermissionDeniedError(#[source] anyhow::Error), @@ -37,7 +37,7 @@ impl SyncServerError { | Self::ClientError(error) | Self::ServerError(error) | Self::NotFound(error) - | Self::Unauthorized(error) + | Self::Unauthenticated(error) | Self::PermissionDeniedError(error) => error.into(), } } @@ -53,7 +53,7 @@ impl IntoResponse for SyncServerError { } Self::ClientError(_) => (StatusCode::BAD_REQUEST, body).into_response(), Self::NotFound(_) => (StatusCode::NOT_FOUND, body).into_response(), - Self::Unauthorized(_) => (StatusCode::UNAUTHORIZED, body).into_response(), + Self::Unauthenticated(_) => (StatusCode::UNAUTHORIZED, body).into_response(), Self::PermissionDeniedError(_) => (StatusCode::FORBIDDEN, body).into_response(), } } @@ -100,17 +100,16 @@ pub fn client_error(error: anyhow::Error) -> SyncServerError { } pub fn not_found_error(error: anyhow::Error) -> SyncServerError { - info!("Not found error: {:?}", error); + info!("Not found: {:?}", error); SyncServerError::NotFound(error) } -pub fn unauthorized_error(error: anyhow::Error) -> SyncServerError { - info!("Unauthorized error: {:?}", error); - SyncServerError::Unauthorized(error) +pub fn unauthenticated_error(error: anyhow::Error) -> SyncServerError { + info!("Unauthenticated user: {:?}", error); + SyncServerError::Unauthenticated(error) } -#[allow(dead_code)] pub fn permission_denied_error(error: anyhow::Error) -> SyncServerError { - info!("Permission denied error: {:?}", error); + info!("Permission denied: {:?}", error); SyncServerError::PermissionDeniedError(error) } diff --git a/backend/sync_server/src/server.rs b/backend/sync_server/src/server.rs index 90bd8ff3..4bc85c0f 100644 --- a/backend/sync_server/src/server.rs +++ b/backend/sync_server/src/server.rs @@ -61,7 +61,7 @@ pub async fn create_server(config_path: Option) -> Result<()> { let mut api = create_open_api(); let app = ApiRouter::new() - .api_route("/ping", get(ping::ping)) + .api_route("/vaults/:vault_id/ping", get(ping::ping)) .api_route( "/vaults/:vault_id/documents", get(fetch_latest_documents::fetch_latest_documents), diff --git a/backend/sync_server/src/server/auth.rs b/backend/sync_server/src/server/auth.rs index ae20e187..06bfe5db 100644 --- a/backend/sync_server/src/server/auth.rs +++ b/backend/sync_server/src/server/auth.rs @@ -1,15 +1,26 @@ use crate::{ - app_state::AppState, - config::user_config::User, - errors::{SyncServerError, unauthorized_error}, + app_state::{AppState, database::models::VaultId}, + config::user_config::{AllowListedVaults, User, VaultAccess}, + errors::{SyncServerError, permission_denied_error, unauthenticated_error}, }; // TODO: turn this into a middleware -pub fn auth(app_state: &AppState, token: &str) -> Result { - app_state +pub fn auth(app_state: &AppState, token: &str, vault: &VaultId) -> Result { + let user = app_state .config .users .get_user(token) .cloned() - .ok_or_else(|| unauthorized_error(anyhow::anyhow!("Invalid token"))) + .ok_or_else(|| unauthenticated_error(anyhow::anyhow!("Invalid token")))?; + + if match user.vault_access { + VaultAccess::AllowAccessToAll => true, + VaultAccess::AllowList(AllowListedVaults { ref allowed }) => allowed.contains(vault), + } { + Ok(user) + } else { + Err(permission_denied_error(anyhow::anyhow!( + "Permission denied for vault `{vault}`" + ))) + } } diff --git a/backend/sync_server/src/server/create_document.rs b/backend/sync_server/src/server/create_document.rs index 826b37c6..25919384 100644 --- a/backend/sync_server/src/server/create_document.rs +++ b/backend/sync_server/src/server/create_document.rs @@ -87,7 +87,7 @@ async fn internal_create_document( relative_path: String, content: Vec, ) -> Result, SyncServerError> { - auth(&state, auth_header.token())?; + auth(&state, auth_header.token(), &vault_id)?; let mut transaction = state .database diff --git a/backend/sync_server/src/server/delete_document.rs b/backend/sync_server/src/server/delete_document.rs index 10fbca3c..82955676 100644 --- a/backend/sync_server/src/server/delete_document.rs +++ b/backend/sync_server/src/server/delete_document.rs @@ -37,7 +37,7 @@ pub async fn delete_document( State(state): State, Json(request): Json, ) -> Result, SyncServerError> { - auth(&state, auth_header.token())?; + auth(&state, auth_header.token(), &vault_id)?; let mut transaction = state .database diff --git a/backend/sync_server/src/server/fetch_document_version.rs b/backend/sync_server/src/server/fetch_document_version.rs index aab06c85..87900696 100644 --- a/backend/sync_server/src/server/fetch_document_version.rs +++ b/backend/sync_server/src/server/fetch_document_version.rs @@ -35,7 +35,7 @@ pub async fn fetch_document_version( }): Path, State(state): State, ) -> Result, SyncServerError> { - auth(&state, auth_header.token())?; + auth(&state, auth_header.token(), &vault_id)?; let result = state .database diff --git a/backend/sync_server/src/server/fetch_document_version_content.rs b/backend/sync_server/src/server/fetch_document_version_content.rs index a2504ba1..24eddf40 100644 --- a/backend/sync_server/src/server/fetch_document_version_content.rs +++ b/backend/sync_server/src/server/fetch_document_version_content.rs @@ -37,7 +37,7 @@ pub async fn fetch_document_version_content( }): Path, State(state): State, ) -> Result { - auth(&state, auth_header.token())?; + auth(&state, auth_header.token(), &vault_id)?; let result = state .database diff --git a/backend/sync_server/src/server/fetch_latest_document_version.rs b/backend/sync_server/src/server/fetch_latest_document_version.rs index ec777f30..5ccfa4e9 100644 --- a/backend/sync_server/src/server/fetch_latest_document_version.rs +++ b/backend/sync_server/src/server/fetch_latest_document_version.rs @@ -33,7 +33,7 @@ pub async fn fetch_latest_document_version( }): Path, State(state): State, ) -> Result, SyncServerError> { - auth(&state, auth_header.token())?; + auth(&state, auth_header.token(), &vault_id)?; let latest_version = state .database diff --git a/backend/sync_server/src/server/fetch_latest_documents.rs b/backend/sync_server/src/server/fetch_latest_documents.rs index 2b4dc841..4b62a2f8 100644 --- a/backend/sync_server/src/server/fetch_latest_documents.rs +++ b/backend/sync_server/src/server/fetch_latest_documents.rs @@ -35,7 +35,7 @@ pub async fn fetch_latest_documents( Query(QueryParams { since_update_id }): Query, State(state): State, ) -> Result, SyncServerError> { - auth(&state, auth_header.token())?; + auth(&state, auth_header.token(), &vault_id)?; let documents = if let Some(since_update_id) = since_update_id { state diff --git a/backend/sync_server/src/server/ping.rs b/backend/sync_server/src/server/ping.rs index 1fe75ee6..38dc2037 100644 --- a/backend/sync_server/src/server/ping.rs +++ b/backend/sync_server/src/server/ping.rs @@ -1,19 +1,34 @@ -use axum::{Json, extract::State}; +use axum::{ + Json, + extract::{Path, State}, +}; use axum_extra::{ TypedHeader, headers::{Authorization, authorization::Bearer}, }; +use schemars::JsonSchema; +use serde::Deserialize; use super::{auth::auth, responses::PingResponse}; -use crate::{app_state::AppState, errors::SyncServerError}; +use crate::{ + app_state::{AppState, database::models::VaultId}, + errors::SyncServerError, +}; + +// This is required for aide to infer the path parameter types and names +#[derive(Deserialize, JsonSchema)] +pub struct PingPathParams { + vault_id: VaultId, +} #[axum::debug_handler] pub async fn ping( maybe_auth_header: Option>>, + Path(PingPathParams { vault_id }): Path, State(state): State, ) -> Result, SyncServerError> { - let is_authenticated = - maybe_auth_header.is_some_and(|auth_header| auth(&state, auth_header.token()).is_ok()); + let is_authenticated = maybe_auth_header + .is_some_and(|auth_header| auth(&state, auth_header.token(), &vault_id).is_ok()); Ok(Json(PingResponse { server_version: env!("CARGO_PKG_VERSION").to_owned(), diff --git a/backend/sync_server/src/server/update_document.rs b/backend/sync_server/src/server/update_document.rs index 0448ddb7..5bb39b70 100644 --- a/backend/sync_server/src/server/update_document.rs +++ b/backend/sync_server/src/server/update_document.rs @@ -92,7 +92,7 @@ async fn internal_update_document( relative_path: String, content: Vec, ) -> Result, SyncServerError> { - auth(&state, auth_header.token())?; + auth(&state, auth_header.token(), &vault_id)?; // No need for a transaction as document versions are immutable let parent_document = state diff --git a/backend/sync_server/src/server/websocket.rs b/backend/sync_server/src/server/websocket.rs index 30125f41..d672b944 100644 --- a/backend/sync_server/src/server/websocket.rs +++ b/backend/sync_server/src/server/websocket.rs @@ -20,7 +20,7 @@ use crate::{ AppState, database::models::{DocumentVersionWithoutContent, VaultId, VaultUpdateId}, }, - errors::{SyncServerError, server_error, unauthorized_error}, + errors::{SyncServerError, server_error, unauthenticated_error}, }; // This is required for aide to infer the path parameter types and names @@ -73,9 +73,9 @@ async fn websocket( let (mut sender, mut receiver) = stream.split(); if let Some(Ok(Message::Text(token))) = receiver.next().await { - auth(&state, &token)?; + auth(&state, &token, &vault_id)?; } else { - return Err(unauthorized_error(anyhow::anyhow!( + return Err(unauthenticated_error(anyhow::anyhow!( "Failed to authenticate" ))); }