Add vault-level access control

This commit is contained in:
Andras Schmelczer 2025-03-29 12:25:15 +00:00
parent a8c813b9a7
commit b3e98d32b6
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
17 changed files with 86 additions and 41 deletions

View file

@ -1,6 +1,6 @@
database:
databases_directory_path: databases
max_connections: 12
max_connections_per_vault: 12
server:
host: 0.0.0.0
@ -13,7 +13,7 @@ users:
- name: admin
token: test-token-change-me
vaults:
all: true
allow_access_to_all: true
- name: test
token: other-test-token

View file

@ -73,7 +73,7 @@ impl Database {
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal);
let pool = SqlitePoolOptions::new()
.max_connections(config.max_connections)
.max_connections(config.max_connections_per_vault)
.test_before_acquire(true)
.connect_with(connection_options)
.await

View file

@ -3,15 +3,15 @@ use std::path::PathBuf;
use log::debug;
use serde::{Deserialize, Serialize};
use crate::consts::{DEFAULT_DATABASES_DIRECTORY_PATH, DEFAULT_MAX_CONNECTIONS};
use crate::consts::{DEFAULT_DATABASES_DIRECTORY_PATH, DEFAULT_MAX_CONNECTIONS_PER_VAULT};
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct DatabaseConfig {
#[serde(default = "default_databases_directory_path")]
pub databases_directory_path: PathBuf,
#[serde(default = "default_max_connections")]
pub max_connections: u32,
#[serde(default = "default_max_connections_per_vault")]
pub max_connections_per_vault: u32,
}
fn default_databases_directory_path() -> PathBuf {
@ -19,16 +19,16 @@ fn default_databases_directory_path() -> PathBuf {
PathBuf::from(DEFAULT_DATABASES_DIRECTORY_PATH)
}
fn default_max_connections() -> u32 {
debug!("Using default max connections: {DEFAULT_MAX_CONNECTIONS}");
DEFAULT_MAX_CONNECTIONS
fn default_max_connections_per_vault() -> u32 {
debug!("Using default max connections: {DEFAULT_MAX_CONNECTIONS_PER_VAULT}");
DEFAULT_MAX_CONNECTIONS_PER_VAULT
}
impl Default for DatabaseConfig {
fn default() -> Self {
Self {
databases_directory_path: default_databases_directory_path(),
max_connections: default_max_connections(),
max_connections_per_vault: default_max_connections_per_vault(),
}
}
}

View file

@ -1,6 +1,10 @@
use std::default;
use rand::{Rng as _, distributions::Alphanumeric, thread_rng};
use serde::{Deserialize, Serialize};
use crate::app_state::database::models::VaultId;
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct UserConfig {
#[serde(default = "default_users")]
@ -17,6 +21,7 @@ impl UserConfig {
pub struct User {
pub name: String,
pub token: String,
pub vault_access: VaultAccess,
}
impl Default for UserConfig {
@ -27,10 +32,25 @@ impl Default for UserConfig {
}
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum VaultAccess {
#[default]
AllowAccessToAll,
AllowList(AllowListedVaults),
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
pub struct AllowListedVaults {
pub allowed: Vec<VaultId>,
}
fn default_users() -> Vec<User> {
vec![User {
name: "admin".to_owned(),
token: get_random_token(),
vault_access: VaultAccess::default(),
}]
}

View file

@ -2,6 +2,6 @@ pub const DEFAULT_CONFIG_PATH: &str = "config.yml";
pub const DEFAULT_DATABASES_DIRECTORY_PATH: &str = "databases";
pub const DEFAULT_HOST: &str = "127.0.0.1";
pub const DEFAULT_PORT: u16 = 3000;
pub const DEFAULT_MAX_CONNECTIONS: u32 = 12;
pub const DEFAULT_MAX_CONNECTIONS_PER_VAULT: u32 = 12;
pub const DEFAULT_MAX_BODY_SIZE_MB: usize = 4096;
pub const DEFAULT_MAX_CLIENTS_PER_VAULT: usize = 256;

View file

@ -24,7 +24,7 @@ pub enum SyncServerError {
NotFound(#[source] anyhow::Error),
#[error("Unauthorized: {0}")]
Unauthorized(#[source] anyhow::Error),
Unauthenticated(#[source] anyhow::Error),
#[error("Permission denied error: {0}")]
PermissionDeniedError(#[source] anyhow::Error),
@ -37,7 +37,7 @@ impl SyncServerError {
| Self::ClientError(error)
| Self::ServerError(error)
| Self::NotFound(error)
| Self::Unauthorized(error)
| Self::Unauthenticated(error)
| Self::PermissionDeniedError(error) => error.into(),
}
}
@ -53,7 +53,7 @@ impl IntoResponse for SyncServerError {
}
Self::ClientError(_) => (StatusCode::BAD_REQUEST, body).into_response(),
Self::NotFound(_) => (StatusCode::NOT_FOUND, body).into_response(),
Self::Unauthorized(_) => (StatusCode::UNAUTHORIZED, body).into_response(),
Self::Unauthenticated(_) => (StatusCode::UNAUTHORIZED, body).into_response(),
Self::PermissionDeniedError(_) => (StatusCode::FORBIDDEN, body).into_response(),
}
}
@ -100,17 +100,16 @@ pub fn client_error(error: anyhow::Error) -> SyncServerError {
}
pub fn not_found_error(error: anyhow::Error) -> SyncServerError {
info!("Not found error: {:?}", error);
info!("Not found: {:?}", error);
SyncServerError::NotFound(error)
}
pub fn unauthorized_error(error: anyhow::Error) -> SyncServerError {
info!("Unauthorized error: {:?}", error);
SyncServerError::Unauthorized(error)
pub fn unauthenticated_error(error: anyhow::Error) -> SyncServerError {
info!("Unauthenticated user: {:?}", error);
SyncServerError::Unauthenticated(error)
}
#[allow(dead_code)]
pub fn permission_denied_error(error: anyhow::Error) -> SyncServerError {
info!("Permission denied error: {:?}", error);
info!("Permission denied: {:?}", error);
SyncServerError::PermissionDeniedError(error)
}

View file

@ -61,7 +61,7 @@ pub async fn create_server(config_path: Option<OsString>) -> Result<()> {
let mut api = create_open_api();
let app = ApiRouter::new()
.api_route("/ping", get(ping::ping))
.api_route("/vaults/:vault_id/ping", get(ping::ping))
.api_route(
"/vaults/:vault_id/documents",
get(fetch_latest_documents::fetch_latest_documents),

View file

@ -1,15 +1,26 @@
use crate::{
app_state::AppState,
config::user_config::User,
errors::{SyncServerError, unauthorized_error},
app_state::{AppState, database::models::VaultId},
config::user_config::{AllowListedVaults, User, VaultAccess},
errors::{SyncServerError, permission_denied_error, unauthenticated_error},
};
// TODO: turn this into a middleware
pub fn auth(app_state: &AppState, token: &str) -> Result<User, SyncServerError> {
app_state
pub fn auth(app_state: &AppState, token: &str, vault: &VaultId) -> Result<User, SyncServerError> {
let user = app_state
.config
.users
.get_user(token)
.cloned()
.ok_or_else(|| unauthorized_error(anyhow::anyhow!("Invalid token")))
.ok_or_else(|| unauthenticated_error(anyhow::anyhow!("Invalid token")))?;
if match user.vault_access {
VaultAccess::AllowAccessToAll => true,
VaultAccess::AllowList(AllowListedVaults { ref allowed }) => allowed.contains(vault),
} {
Ok(user)
} else {
Err(permission_denied_error(anyhow::anyhow!(
"Permission denied for vault `{vault}`"
)))
}
}

View file

@ -87,7 +87,7 @@ async fn internal_create_document(
relative_path: String,
content: Vec<u8>,
) -> Result<Json<DocumentVersionWithoutContent>, SyncServerError> {
auth(&state, auth_header.token())?;
auth(&state, auth_header.token(), &vault_id)?;
let mut transaction = state
.database

View file

@ -37,7 +37,7 @@ pub async fn delete_document(
State(state): State<AppState>,
Json(request): Json<DeleteDocumentVersion>,
) -> Result<Json<DocumentVersionWithoutContent>, SyncServerError> {
auth(&state, auth_header.token())?;
auth(&state, auth_header.token(), &vault_id)?;
let mut transaction = state
.database

View file

@ -35,7 +35,7 @@ pub async fn fetch_document_version(
}): Path<FetchDocumentVersionPathParams>,
State(state): State<AppState>,
) -> Result<Json<DocumentVersion>, SyncServerError> {
auth(&state, auth_header.token())?;
auth(&state, auth_header.token(), &vault_id)?;
let result = state
.database

View file

@ -37,7 +37,7 @@ pub async fn fetch_document_version_content(
}): Path<FetchDocumentVersionContentPathParams>,
State(state): State<AppState>,
) -> Result<Bytes, SyncServerError> {
auth(&state, auth_header.token())?;
auth(&state, auth_header.token(), &vault_id)?;
let result = state
.database

View file

@ -33,7 +33,7 @@ pub async fn fetch_latest_document_version(
}): Path<FetchLatestDocumentVersionPathParams>,
State(state): State<AppState>,
) -> Result<Json<DocumentVersion>, SyncServerError> {
auth(&state, auth_header.token())?;
auth(&state, auth_header.token(), &vault_id)?;
let latest_version = state
.database

View file

@ -35,7 +35,7 @@ pub async fn fetch_latest_documents(
Query(QueryParams { since_update_id }): Query<QueryParams>,
State(state): State<AppState>,
) -> Result<Json<FetchLatestDocumentsResponse>, SyncServerError> {
auth(&state, auth_header.token())?;
auth(&state, auth_header.token(), &vault_id)?;
let documents = if let Some(since_update_id) = since_update_id {
state

View file

@ -1,19 +1,34 @@
use axum::{Json, extract::State};
use axum::{
Json,
extract::{Path, State},
};
use axum_extra::{
TypedHeader,
headers::{Authorization, authorization::Bearer},
};
use schemars::JsonSchema;
use serde::Deserialize;
use super::{auth::auth, responses::PingResponse};
use crate::{app_state::AppState, errors::SyncServerError};
use crate::{
app_state::{AppState, database::models::VaultId},
errors::SyncServerError,
};
// This is required for aide to infer the path parameter types and names
#[derive(Deserialize, JsonSchema)]
pub struct PingPathParams {
vault_id: VaultId,
}
#[axum::debug_handler]
pub async fn ping(
maybe_auth_header: Option<TypedHeader<Authorization<Bearer>>>,
Path(PingPathParams { vault_id }): Path<PingPathParams>,
State(state): State<AppState>,
) -> Result<Json<PingResponse>, SyncServerError> {
let is_authenticated =
maybe_auth_header.is_some_and(|auth_header| auth(&state, auth_header.token()).is_ok());
let is_authenticated = maybe_auth_header
.is_some_and(|auth_header| auth(&state, auth_header.token(), &vault_id).is_ok());
Ok(Json(PingResponse {
server_version: env!("CARGO_PKG_VERSION").to_owned(),

View file

@ -92,7 +92,7 @@ async fn internal_update_document(
relative_path: String,
content: Vec<u8>,
) -> Result<Json<DocumentUpdateResponse>, SyncServerError> {
auth(&state, auth_header.token())?;
auth(&state, auth_header.token(), &vault_id)?;
// No need for a transaction as document versions are immutable
let parent_document = state

View file

@ -20,7 +20,7 @@ use crate::{
AppState,
database::models::{DocumentVersionWithoutContent, VaultId, VaultUpdateId},
},
errors::{SyncServerError, server_error, unauthorized_error},
errors::{SyncServerError, server_error, unauthenticated_error},
};
// This is required for aide to infer the path parameter types and names
@ -73,9 +73,9 @@ async fn websocket(
let (mut sender, mut receiver) = stream.split();
if let Some(Ok(Message::Text(token))) = receiver.next().await {
auth(&state, &token)?;
auth(&state, &token, &vault_id)?;
} else {
return Err(unauthorized_error(anyhow::anyhow!(
return Err(unauthenticated_error(anyhow::anyhow!(
"Failed to authenticate"
)));
}