This commit is contained in:
Andras Schmelczer 2026-03-21 12:47:39 +00:00
parent 8f2f5e4fa9
commit a20264bcaf
112 changed files with 12567 additions and 2694 deletions

433
sync-server/Cargo.lock generated
View file

@ -337,10 +337,11 @@ checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b"
[[package]]
name = "cc"
version = "1.2.2"
version = "1.2.57"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f34d93e62b03caf570cccc334cbc6c2fceca82f39211051345108adcba3eebdc"
checksum = "7a0dd1ca384932ff3641c8718a02769f1698e7563dc6974ffd03346116310423"
dependencies = [
"find-msvc-tools",
"shlex",
]
@ -350,6 +351,12 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268"
[[package]]
name = "cfg_aliases"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "chrono"
version = "0.4.41"
@ -624,6 +631,12 @@ version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4"
[[package]]
name = "find-msvc-tools"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582"
[[package]]
name = "flume"
version = "0.11.1"
@ -773,8 +786,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
"wasm-bindgen",
]
[[package]]
@ -957,6 +972,24 @@ dependencies = [
"pin-project-lite",
"smallvec",
"tokio",
"want",
]
[[package]]
name = "hyper-rustls"
version = "0.27.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58"
dependencies = [
"http",
"hyper",
"hyper-util",
"rustls",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tower-service",
"webpki-roots 1.0.6",
]
[[package]]
@ -966,13 +999,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4"
dependencies = [
"bytes",
"futures-channel",
"futures-util",
"http",
"http-body",
"hyper",
"pin-project-lite",
"socket2 0.5.10",
"tokio",
"tower-service",
"tracing",
]
[[package]]
@ -1153,6 +1189,12 @@ dependencies = [
"hashbrown",
]
[[package]]
name = "ipnet"
version = "2.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2"
[[package]]
name = "is_terminal_polyfill"
version = "1.70.1"
@ -1272,6 +1314,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "miniz_oxide"
version = "0.8.0"
@ -1505,6 +1557,58 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "quinn"
version = "0.11.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef"
dependencies = [
"bytes",
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash",
"rustls",
"socket2 0.5.10",
"thiserror 2.0.18",
"tokio",
"tracing",
]
[[package]]
name = "quinn-proto"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d"
dependencies = [
"bytes",
"getrandom 0.2.15",
"rand 0.8.5",
"ring",
"rustc-hash",
"rustls",
"rustls-pki-types",
"slab",
"thiserror 2.0.18",
"tinyvec",
"tracing",
"web-time",
]
[[package]]
name = "quinn-udp"
version = "0.5.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd"
dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2 0.6.0",
"tracing",
"windows-sys 0.59.0",
]
[[package]]
name = "quote"
version = "1.0.37"
@ -1582,12 +1686,12 @@ dependencies = [
[[package]]
name = "reconcile-text"
version = "0.8.0"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "599cf9539996a2a19e501110404c59ba62f4974009f8fb864a8b7151c15ee5a5"
checksum = "52e0cf361887ea64c479ca871c1170dda761f84e122f2616b5579906a38d7557"
dependencies = [
"serde",
"thiserror 2.0.17",
"thiserror 2.0.18",
]
[[package]]
@ -1628,6 +1732,63 @@ version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "reqwest"
version = "0.12.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da"
dependencies = [
"base64 0.22.1",
"bytes",
"futures-core",
"futures-util",
"http",
"http-body",
"http-body-util",
"hyper",
"hyper-rustls",
"hyper-util",
"ipnet",
"js-sys",
"log",
"mime",
"once_cell",
"percent-encoding",
"pin-project-lite",
"quinn",
"rustls",
"rustls-pemfile",
"rustls-pki-types",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tokio-rustls",
"tower",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
"webpki-roots 0.26.11",
"windows-registry",
]
[[package]]
name = "ring"
version = "0.17.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7"
dependencies = [
"cc",
"cfg-if",
"getrandom 0.2.15",
"libc",
"untrusted",
"windows-sys 0.52.0",
]
[[package]]
name = "rsa"
version = "0.9.7"
@ -1648,12 +1809,52 @@ dependencies = [
"zeroize",
]
[[package]]
name = "rust-embed"
version = "8.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04113cb9355a377d83f06ef1f0a45b8ab8cd7d8b1288160717d66df5c7988d27"
dependencies = [
"rust-embed-impl",
"rust-embed-utils",
"walkdir",
]
[[package]]
name = "rust-embed-impl"
version = "8.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da0902e4c7c8e997159ab384e6d0fc91c221375f6894346ae107f47dd0f3ccaa"
dependencies = [
"proc-macro2",
"quote",
"rust-embed-utils",
"syn 2.0.90",
"walkdir",
]
[[package]]
name = "rust-embed-utils"
version = "8.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5bcdef0be6fe7f6fa333b1073c949729274b05f123a0ad7efcb8efd878e5c3b1"
dependencies = [
"sha2",
"walkdir",
]
[[package]]
name = "rustc-demangle"
version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
[[package]]
name = "rustix"
version = "0.38.41"
@ -1667,6 +1868,50 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "rustls"
version = "0.23.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4"
dependencies = [
"once_cell",
"ring",
"rustls-pki-types",
"rustls-webpki",
"subtle",
"zeroize",
]
[[package]]
name = "rustls-pemfile"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50"
dependencies = [
"rustls-pki-types",
]
[[package]]
name = "rustls-pki-types"
version = "1.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd"
dependencies = [
"web-time",
"zeroize",
]
[[package]]
name = "rustls-webpki"
version = "0.103.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53"
dependencies = [
"ring",
"rustls-pki-types",
"untrusted",
]
[[package]]
name = "rustversion"
version = "1.0.18"
@ -1679,6 +1924,15 @@ version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
[[package]]
name = "same-file"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
dependencies = [
"winapi-util",
]
[[package]]
name = "sanitize-filename"
version = "0.6.0"
@ -1846,6 +2100,16 @@ dependencies = [
"serde",
]
[[package]]
name = "socket2"
version = "0.5.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678"
dependencies = [
"libc",
"windows-sys 0.52.0",
]
[[package]]
name = "socket2"
version = "0.6.0"
@ -1916,7 +2180,7 @@ dependencies = [
"serde_json",
"sha2",
"smallvec",
"thiserror 2.0.17",
"thiserror 2.0.18",
"tokio",
"tokio-stream",
"tracing",
@ -2000,7 +2264,7 @@ dependencies = [
"smallvec",
"sqlx-core",
"stringprep",
"thiserror 2.0.17",
"thiserror 2.0.18",
"tracing",
"uuid",
"whoami",
@ -2039,7 +2303,7 @@ dependencies = [
"smallvec",
"sqlx-core",
"stringprep",
"thiserror 2.0.17",
"thiserror 2.0.18",
"tracing",
"uuid",
"whoami",
@ -2065,7 +2329,7 @@ dependencies = [
"serde",
"serde_urlencoded",
"sqlx-core",
"thiserror 2.0.17",
"thiserror 2.0.18",
"tracing",
"url",
"uuid",
@ -2136,15 +2400,18 @@ dependencies = [
"futures",
"humantime-serde",
"log",
"mime_guess",
"rand 0.9.0",
"reconcile-text",
"regex",
"reqwest",
"rust-embed",
"sanitize-filename",
"serde",
"serde_json",
"serde_yaml",
"sqlx",
"thiserror 2.0.17",
"thiserror 2.0.18",
"tokio",
"tower-http",
"tracing",
@ -2158,6 +2425,9 @@ name = "sync_wrapper"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263"
dependencies = [
"futures-core",
]
[[package]]
name = "synstructure"
@ -2203,11 +2473,11 @@ dependencies = [
[[package]]
name = "thiserror"
version = "2.0.17"
version = "2.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8"
checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4"
dependencies = [
"thiserror-impl 2.0.17",
"thiserror-impl 2.0.18",
]
[[package]]
@ -2223,9 +2493,9 @@ dependencies = [
[[package]]
name = "thiserror-impl"
version = "2.0.17"
version = "2.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913"
checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5"
dependencies = [
"proc-macro2",
"quote",
@ -2276,10 +2546,9 @@ dependencies = [
"bytes",
"libc",
"mio",
"parking_lot",
"pin-project-lite",
"signal-hook-registry",
"socket2",
"socket2 0.6.0",
"tokio-macros",
"windows-sys 0.61.2",
]
@ -2295,6 +2564,16 @@ dependencies = [
"syn 2.0.90",
]
[[package]]
name = "tokio-rustls"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61"
dependencies = [
"rustls",
"tokio",
]
[[package]]
name = "tokio-stream"
version = "0.1.17"
@ -2426,6 +2705,12 @@ dependencies = [
"tracing-log",
]
[[package]]
name = "try-lock"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]]
name = "ts-rs"
version = "10.1.0"
@ -2434,7 +2719,7 @@ checksum = "e640d9b0964e9d39df633548591090ab92f7a4567bc31d3891af23471a3365c6"
dependencies = [
"chrono",
"lazy_static",
"thiserror 2.0.17",
"thiserror 2.0.18",
"ts-rs-macros",
"uuid",
]
@ -2481,6 +2766,12 @@ version = "0.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f720def6ce1ee2fc44d40ac9ed6d3a59c361c80a75a7aa8e75bb9baed31cf2ea"
[[package]]
name = "unicase"
version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142"
[[package]]
name = "unicode-bidi"
version = "0.3.17"
@ -2514,6 +2805,12 @@ version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861"
[[package]]
name = "untrusted"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]]
name = "url"
version = "2.5.4"
@ -2577,6 +2874,25 @@ version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
[[package]]
name = "walkdir"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b"
dependencies = [
"same-file",
"winapi-util",
]
[[package]]
name = "want"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e"
dependencies = [
"try-lock",
]
[[package]]
name = "wasi"
version = "0.11.0+wasi-snapshot-preview1"
@ -2623,6 +2939,19 @@ dependencies = [
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-futures"
version = "0.4.49"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38176d9b44ea84e9184eff0bc34cc167ed044f816accfe5922e54d84cf48eca2"
dependencies = [
"cfg-if",
"js-sys",
"once_cell",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.99"
@ -2652,6 +2981,44 @@ version = "0.2.99"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "943aab3fdaaa029a6e0271b35ea10b72b943135afe9bffca82384098ad0e06a6"
[[package]]
name = "web-sys"
version = "0.3.76"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04dd7223427d52553d3702c004d3b2fe07c148165faa56313cb00211e31c12bc"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "web-time"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "webpki-roots"
version = "0.26.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9"
dependencies = [
"webpki-roots 1.0.6",
]
[[package]]
name = "webpki-roots"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22cfaf3c063993ff62e73cb4311efde4db1efb31ab78a3e5c457939ad5cc0bed"
dependencies = [
"rustls-pki-types",
]
[[package]]
name = "whoami"
version = "1.5.2"
@ -2692,6 +3059,36 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
[[package]]
name = "windows-registry"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0"
dependencies = [
"windows-result",
"windows-strings",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-result"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e"
dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-strings"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10"
dependencies = [
"windows-result",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-sys"
version = "0.48.0"

View file

@ -1,6 +1,6 @@
[package]
name = "sync_server"
rust-version = "1.92.0"
rust-version = "1.94.0"
authors = ["Andras Schmelczer <andras@schmelczer.dev>"]
edition = "2024"
license = "MIT"
@ -10,7 +10,7 @@ version = "0.14.0"
[dependencies]
serde = { version = "1.0.219", default-features = false, features = ["derive"] }
thiserror = { version = "2.0.12", default-features = false }
tokio = { version = "1.48.0", features = ["full"]}
tokio = { version = "1.48.0", features = ["macros", "rt-multi-thread", "sync", "time", "net", "fs", "signal"]}
uuid = { version = "1.16.0", features = ["v4", "serde"] }
log = { version = "0.4.28" }
anyhow = { version = "1.0.100", features = ["backtrace"] }
@ -33,7 +33,10 @@ serde_json = "1.0.140"
bimap = "0.6.3"
ts-rs = { version = "10.1", features = ["uuid-impl", "chrono-impl"] }
base64 = "0.22.1"
reconcile-text = { version = "0.8.0", features = ["serde"] }
reconcile-text = { version = "0.11.0", features = ["serde"] }
rust-embed = "8.5"
mime_guess = "2.0"
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] }
[profile.release]
codegen-units = 1

View file

@ -1,5 +1,16 @@
// generated by `sqlx migrate build-script`
fn main() {
// trigger recompilation when a new migration is added
println!("cargo:rerun-if-changed=migrations");
// Ensure the history-ui dist directory exists so rust-embed can compile
// even when the frontend hasn't been built yet.
let dist_path = std::path::Path::new("../frontend/history-ui/dist");
if !dist_path.exists() {
std::fs::create_dir_all(dist_path).expect("Failed to create history-ui dist directory");
std::fs::write(
dist_path.join("index.html"),
"<!DOCTYPE html><html><body><p>Run <code>npm run build -w history-ui</code> first.</p></body></html>",
)
.expect("Failed to write placeholder index.html");
}
}

View file

@ -1,12 +1,14 @@
database:
databases_directory_path: databases
max_connections_per_vault: 12
max_connections_per_vault: 64
cursor_timeout: 1m
server:
host: 0.0.0.0
port: 3000
port: 3010
max_body_size_mb: 512
max_clients_per_vault: 256
broadcast_channel_capacity: 1024
dev_proxy_url: "http://localhost:5173"
response_timeout: 30m
mergeable_file_extensions:
- md

View file

@ -1,5 +1,5 @@
[toolchain]
channel = "1.92.0"
channel = "1.94.0"
targets = [
"x86_64-unknown-linux-gnu",
"x86_64-unknown-linux-musl",

View file

@ -2,6 +2,11 @@ pub mod cursors;
pub mod database;
pub mod websocket;
use std::sync::{
Arc,
atomic::AtomicUsize,
};
use anyhow::Result;
use cursors::Cursors;
use database::Database;
@ -15,21 +20,34 @@ pub struct AppState {
pub database: Database,
pub cursors: Cursors,
pub broadcasts: Broadcasts,
/// Tracks WebSocket connections that have upgraded but not yet completed
/// the authentication handshake.
pub pending_ws_connections: Arc<AtomicUsize>,
/// Send on this channel to stop background tasks (cursor cleanup,
/// idle-pool cleanup). Held by `AppState` so dropping it also
/// triggers shutdown.
#[allow(dead_code)]
shutdown_tx: Arc<tokio::sync::watch::Sender<()>>,
}
impl AppState {
pub async fn try_new(config: Config) -> Result<Self> {
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(());
let broadcasts = Broadcasts::new(&config.server);
let database = Database::try_new(&config.database, &broadcasts).await?;
let database =
Database::try_new(&config.database, &broadcasts, shutdown_rx.clone()).await?;
let cursors: Cursors = Cursors::new(&config.database, &broadcasts);
Cursors::start_background_task(cursors.clone());
Cursors::start_background_task(cursors.clone(), shutdown_rx);
Ok(Self {
config,
database,
cursors,
broadcasts,
pending_ws_connections: Arc::new(AtomicUsize::new(0)),
shutdown_tx: Arc::new(shutdown_tx),
})
}
}

View file

@ -42,7 +42,9 @@ impl Cursors {
) {
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
let all_device_cursors = vault_to_cursors.entry(vault_id).or_insert_with(Vec::new);
let all_device_cursors = vault_to_cursors
.entry(vault_id.clone())
.or_insert_with(Vec::new);
all_device_cursors.retain(|c| &c.client_cursors.device_id != device_id);
all_device_cursors.push(ClientCursorsWithTimeToLive::new(ClientCursors {
@ -51,8 +53,11 @@ impl Cursors {
documents_with_cursors: document_to_cursors,
}));
drop(vault_to_cursors); // Explicitly drop the lock before broadcasting to avoid deadlock
self.broadcast_cursors().await;
// IMPORTANT: Drop the lock BEFORE calling broadcast_cursors_for_vault,
// which re-acquires the same lock internally. Holding the lock here
// while calling broadcast would cause a deadlock.
drop(vault_to_cursors);
self.broadcast_cursors_for_vault(&vault_id).await;
}
pub async fn get_cursors(&self, vault_id: &VaultId) -> Vec<ClientCursors> {
@ -69,45 +74,83 @@ impl Cursors {
.unwrap_or_default()
}
pub fn start_background_task(self) {
pub fn start_background_task(self, mut shutdown: tokio::sync::watch::Receiver<()>) {
tokio::spawn(async move {
loop {
self.remove_expired_cursors().await;
tokio::time::sleep(Duration::from_secs(1)).await;
tokio::select! {
() = tokio::time::sleep(Duration::from_secs(1)) => {
self.remove_expired_cursors().await;
}
Ok(()) = shutdown.changed() => break,
}
}
});
}
async fn remove_expired_cursors(&self) {
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
let changed_vaults: Vec<VaultId> = {
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
for (_vault_id, cursors) in vault_to_cursors.iter_mut() {
cursors.retain(|cursor| !cursor.is_expired(self.config.cursor_timeout));
let mut changed = Vec::new();
for (vault_id, cursors) in vault_to_cursors.iter_mut() {
let before = cursors.len();
cursors.retain(|cursor| !cursor.is_expired(self.config.cursor_timeout));
if cursors.len() != before {
changed.push(vault_id.clone());
}
}
// Remove empty vault entries to prevent unbounded growth
vault_to_cursors.retain(|_, cursors| !cursors.is_empty());
changed
};
for vault_id in &changed_vaults {
self.broadcast_cursors_for_vault(vault_id).await;
}
}
async fn broadcast_cursors(&self) {
let vault_to_cursors = self.vault_to_cursors.lock().await;
async fn broadcast_cursors_for_vault(&self, vault_id: &VaultId) {
let client_cursors: Vec<ClientCursors> = {
let vault_to_cursors = self.vault_to_cursors.lock().await;
vault_to_cursors
.get(vault_id)
.map(|cursors| cursors.iter().map(|c| c.client_cursors.clone()).collect())
.unwrap_or_default()
};
for (vault_id, cursors) in vault_to_cursors.iter() {
self.broadcasts
.send_document_update(
vault_id.clone(),
WebSocketServerMessageWithOrigin::new(WebSocketServerMessage::CursorPositions(
CursorPositionFromServer {
clients: cursors.iter().map(|c| c.client_cursors.clone()).collect(),
},
)),
)
.await;
}
self.broadcasts
.send_document_update(
vault_id.clone(),
WebSocketServerMessageWithOrigin::new(WebSocketServerMessage::CursorPositions(
CursorPositionFromServer {
clients: client_cursors,
},
)),
)
.await;
}
pub async fn remove_cursors_of_device(&self, vault_id: &str, device_id: &str) {
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
pub async fn remove_cursors_of_device(&self, vault_id: &VaultId, device_id: &DeviceId) {
let changed = {
let mut vault_to_cursors = self.vault_to_cursors.lock().await;
if let Some(cursors) = vault_to_cursors.get_mut(vault_id) {
cursors.retain(|c| c.client_cursors.device_id != device_id);
if let Some(cursors) = vault_to_cursors.get_mut(vault_id) {
let before = cursors.len();
cursors.retain(|c| c.client_cursors.device_id != *device_id);
let changed = cursors.len() != before;
if cursors.is_empty() {
vault_to_cursors.remove(vault_id);
}
changed
} else {
false
}
};
if changed {
self.broadcast_cursors_for_vault(vault_id).await;
}
}
}

View file

@ -6,14 +6,29 @@ use log::info;
use models::{
DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId, VaultUpdateId,
};
use sqlx::{ConnectOptions, sqlite::SqliteConnectOptions, types::chrono::Utc};
use sqlx::{ConnectOptions, Connection, sqlite::SqliteConnectOptions, types::chrono::Utc};
pub mod models;
use sqlx::{Pool, Sqlite, sqlite::SqlitePoolOptions};
use tokio::sync::Mutex;
use sqlx::{
Pool, Sqlite, pool::PoolConnection, sqlite::SqliteConnection, sqlite::SqlitePoolOptions,
};
use tokio::sync::{Mutex, OnceCell};
use tokio::time::Instant;
use uuid::fmt::Hyphenated;
/// Row struct for vault history queries (used by `sqlx::query_as!`)
#[derive(Debug)]
struct VaultHistoryRow {
vault_update_id: models::VaultUpdateId,
document_id: models::DocumentId,
relative_path: String,
updated_date: chrono::DateTime<chrono::Utc>,
is_deleted: bool,
user_id: String,
device_id: String,
content_size: Option<u64>,
}
use super::websocket::{
broadcasts::Broadcasts,
models::{WebSocketServerMessage, WebSocketServerMessageWithOrigin, WebSocketVaultUpdate},
@ -21,32 +36,98 @@ use super::websocket::{
use crate::config::database_config::DatabaseConfig;
use crate::consts::IDLE_POOL_TIMEOUT;
#[derive(Clone)]
struct PoolWithTimestamp {
pool: Pool<Sqlite>,
last_accessed: Instant,
}
impl std::fmt::Debug for PoolWithTimestamp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PoolWithTimestamp")
.field("pool", &"Pool<Sqlite>")
.field("last_accessed", &self.last_accessed)
.finish()
}
#[derive(Debug)]
struct VaultPool {
cell: Arc<OnceCell<Pool<Sqlite>>>,
last_accessed: Mutex<Instant>,
}
#[derive(Clone, Debug)]
pub struct Database {
config: DatabaseConfig,
broadcasts: Broadcasts,
connection_pools: Arc<Mutex<HashMap<VaultId, PoolWithTimestamp>>>,
connection_pools: Arc<Mutex<HashMap<VaultId, Arc<VaultPool>>>>,
}
pub type Transaction<'a> = sqlx::Transaction<'a, Sqlite>;
/// A write transaction backed by a raw `BEGIN IMMEDIATE` instead of sqlx's
/// savepoint-based `Transaction`. This avoids the savepoint mismatch caused
/// by the old `END; BEGIN IMMEDIATE;` workaround.
pub struct WriteTransaction {
conn: Option<PoolConnection<Sqlite>>,
}
impl WriteTransaction {
async fn new(pool: &Pool<Sqlite>) -> Result<Self> {
let mut conn = pool
.acquire()
.await
.context("Cannot acquire connection for write transaction")?;
sqlx::query("BEGIN IMMEDIATE")
.execute(&mut *conn)
.await
.context("Cannot begin immediate transaction")?;
Ok(Self { conn: Some(conn) })
}
pub async fn commit(mut self) -> Result<()> {
if let Some(mut conn) = self.conn.take() {
sqlx::query("COMMIT")
.execute(&mut *conn)
.await
.context("Failed to commit transaction")?;
}
Ok(())
}
pub async fn rollback(mut self) -> Result<()> {
if let Some(mut conn) = self.conn.take() {
sqlx::query("ROLLBACK")
.execute(&mut *conn)
.await
.context("Failed to rollback transaction")?;
}
Ok(())
}
}
impl Drop for WriteTransaction {
fn drop(&mut self) {
if self.conn.is_some() {
// The connection is returned to the pool with an open transaction.
// The pool's `before_acquire` hook issues a ROLLBACK before
// handing it to the next consumer, so no async work is needed
// here. If the pool is being shut down, SQLite itself rolls back
// uncommitted transactions when the connection closes.
log::warn!("WriteTransaction dropped without commit or rollback");
}
}
}
impl std::ops::Deref for WriteTransaction {
type Target = SqliteConnection;
fn deref(&self) -> &Self::Target {
self.conn
.as_ref()
.expect("BUG: WriteTransaction dereferenced after being consumed")
.deref()
}
}
impl std::ops::DerefMut for WriteTransaction {
fn deref_mut(&mut self) -> &mut Self::Target {
self.conn
.as_mut()
.expect("BUG: WriteTransaction dereferenced after being consumed")
.deref_mut()
}
}
impl Database {
pub async fn try_new(config: &DatabaseConfig, broadcasts: &Broadcasts) -> Result<Self> {
pub async fn try_new(
config: &DatabaseConfig,
broadcasts: &Broadcasts,
shutdown: tokio::sync::watch::Receiver<()>,
) -> Result<Self> {
tokio::fs::create_dir_all(&config.databases_directory_path)
.await
.with_context(|| {
@ -71,13 +152,17 @@ impl Database {
.trim_end_matches(".sqlite")
.to_owned();
Self::validate_vault_id(&vault)?;
let pool = Self::create_vault_database(config, &vault).await?;
let cell = Arc::new(OnceCell::new());
cell.set(pool).expect("cell is new");
connection_pools.insert(
vault.clone(),
PoolWithTimestamp {
pool,
last_accessed: Instant::now(),
},
Arc::new(VaultPool {
cell,
last_accessed: Mutex::new(Instant::now()),
}),
);
}
info!("Database migrations applied");
@ -88,8 +173,7 @@ impl Database {
broadcasts: broadcasts.clone(),
};
// Start background task to cleanup idle connection pools
database.start_idle_pool_cleanup();
database.start_idle_pool_cleanup(shutdown);
Ok(database)
}
@ -102,91 +186,128 @@ impl Database {
.databases_directory_path
.join(format!("{vault}.sqlite"));
let connection_options = SqliteConnectOptions::new()
// Database-level PRAGMAs (auto_vacuum, journal_mode) require a write
// lock and persist across connections. Set them once with a dedicated
// init connection so pool connections never need the write lock just to
// open.
let init_options = SqliteConnectOptions::new()
.filename(file_name.clone())
.create_if_missing(true)
.auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental)
.busy_timeout(Duration::from_secs(30))
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal);
// Run migrations on a dedicated connection, NOT through the pool.
// The pool's `before_acquire` hook issues ROLLBACK on every checkout,
// which can roll back the migration's bookkeeping transaction (the
// _sqlx_migrations INSERT) while the DDL (ALTER TABLE) has already
// auto-committed — leaving the migration in a dirty state.
//
// Uses `run_direct` instead of `run` because `run` takes
// `impl Acquire<'_>`, whose lifetime bound prevents the enclosing
// future from satisfying the `Send` requirement of axum handlers.
let mut init_conn = sqlx::SqliteConnection::connect_with(&init_options).await?;
sqlx::migrate!("src/app_state/database/migrations")
.run_direct(&mut init_conn)
.await
.context("Cannot run pending migrations")?;
drop(init_conn);
// Pool connections only set per-connection PRAGMAs that don't require a
// write lock. journal_mode = WAL is a no-op on an already-WAL database.
let pool_options = SqliteConnectOptions::new()
.filename(file_name.clone())
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
.busy_timeout(Duration::from_secs(30))
.log_slow_statements(log::LevelFilter::Warn, Duration::from_secs(30));
let pool = SqlitePoolOptions::new()
.max_connections(config.max_connections_per_vault)
.acquire_slow_threshold(Duration::from_secs(30))
.test_before_acquire(true)
.connect_with(connection_options)
.before_acquire(|conn, _meta| {
Box::pin(async move {
// Ensure the connection has no leftover open transaction
// (e.g. from a WriteTransaction that was dropped without
// commit/rollback). ROLLBACK is a harmless no-op if no
// transaction is active.
if let Err(e) = sqlx::query("ROLLBACK").execute(&mut *conn).await {
// "cannot rollback - no transaction is active" is the
// common case (connection returned cleanly). Only
// unexpected errors deserve attention.
log::debug!("before_acquire ROLLBACK failed: {e}");
}
Ok(true)
})
})
.connect_with(pool_options)
.await
.with_context(|| format!("Cannot open database at `{}`", file_name.display()))?;
Self::run_migrations(&pool).await?;
Ok(pool)
}
async fn run_migrations(pool: &Pool<Sqlite>) -> Result<()> {
sqlx::migrate!("src/app_state/database/migrations")
.run(pool)
.await
.context("Cannot check for pending migrations")
fn validate_vault_id(vault: &VaultId) -> Result<()> {
if vault.is_empty() {
anyhow::bail!("Vault ID must not be empty");
}
if vault.contains('/')
|| vault.contains('\\')
|| vault.contains("..")
|| vault.contains('\0')
{
anyhow::bail!(
"Invalid vault ID: must not contain path separators, '..', or null bytes"
);
}
Ok(())
}
async fn get_connection_pool(&self, vault: &VaultId) -> Result<Pool<Sqlite>> {
// First, check if the pool exists without holding the lock during creation
{
Self::validate_vault_id(vault)?;
// Get or create the VaultPool entry. The global lock is held only
// long enough for a HashMap lookup/insert — never across
// create_vault_database.
let vault_pool = {
let mut pools = self.connection_pools.lock().await;
if let Some(pool_with_timestamp) = pools.get_mut(vault) {
pool_with_timestamp.last_accessed = Instant::now();
return Ok(pool_with_timestamp.pool.clone());
}
}
pools
.entry(vault.clone())
.or_insert_with(|| {
Arc::new(VaultPool {
cell: Arc::new(OnceCell::new()),
last_accessed: Mutex::new(Instant::now()),
})
})
.clone()
};
// Create the pool outside of the lock to avoid blocking other vaults
// Note: This may result in multiple pools being created for the same vault
// under high concurrency, but only one will be kept
let new_pool = Self::create_vault_database(&self.config, vault).await?;
// Re-acquire lock and insert (or use existing if another task created it)
let mut pools = self.connection_pools.lock().await;
let pool_with_timestamp = pools
.entry(vault.clone())
.or_insert_with(|| PoolWithTimestamp {
pool: new_pool.clone(),
last_accessed: Instant::now(),
});
pool_with_timestamp.last_accessed = Instant::now();
Ok(pool_with_timestamp.pool.clone())
}
/// Attempting to write from this transaction might result in a
/// database locked error. Use this transaction for read-only operations.
pub async fn create_readonly_transaction(
&self,
vault: &VaultId,
) -> Result<Transaction<'static>> {
self.get_connection_pool(vault)
.await?
.begin()
.await
.context("Cannot create transaction")
}
pub async fn create_write_transaction(&self, vault: &VaultId) -> Result<Transaction<'static>> {
let mut transaction = self.create_readonly_transaction(vault).await?;
// sqlx doesn't support immediate transactions for sqlite: https://github.com/launchbadge/sqlx/issues/481
sqlx::query!("END; BEGIN IMMEDIATE;")
.execute(&mut *transaction)
// OnceCell::get_or_try_init guarantees exactly-once
// initialization: concurrent callers for the same vault wait
// here; callers for other vaults are not blocked.
let config = self.config.clone();
let vault_clone = vault.clone();
let pool = vault_pool
.cell
.get_or_try_init(|| async {
Self::create_vault_database(&config, &vault_clone).await
})
.await?;
Ok(transaction)
*vault_pool.last_accessed.lock().await = Instant::now();
Ok(pool.clone())
}
pub async fn create_write_transaction(&self, vault: &VaultId) -> Result<WriteTransaction> {
let pool = self.get_connection_pool(vault).await?;
WriteTransaction::new(&pool).await
}
/// Return the latest state of all documents in the vault
pub async fn get_latest_documents(
&self,
vault: &VaultId,
transaction: Option<&mut Transaction<'_>>,
connection: Option<&mut SqliteConnection>,
) -> Result<Vec<DocumentVersionWithoutContent>> {
let query = sqlx::query!(
r#"
@ -204,8 +325,8 @@ impl Database {
"#,
);
if let Some(transaction) = transaction {
query.fetch_all(&mut **transaction).await
if let Some(conn) = connection {
query.fetch_all(&mut *conn).await
} else {
query
.fetch_all(&self.get_connection_pool(vault).await?)
@ -222,9 +343,7 @@ impl Database {
is_deleted: row.is_deleted,
user_id: row.user_id,
device_id: row.device_id,
content_size: row
.content_size
.expect("Content size can't be null but sqlx can't infer it"),
content_size: row.content_size.unwrap_or(0),
})
.collect()
})
@ -236,7 +355,7 @@ impl Database {
&self,
vault: &VaultId,
vault_update_id: VaultUpdateId,
transaction: Option<&mut Transaction<'_>>,
connection: Option<&mut SqliteConnection>,
) -> Result<Vec<DocumentVersionWithoutContent>> {
let query = sqlx::query!(
r#"
@ -256,8 +375,8 @@ impl Database {
vault_update_id
);
if let Some(transaction) = transaction {
query.fetch_all(&mut **transaction).await
if let Some(conn) = connection {
query.fetch_all(&mut *conn).await
} else {
query
.fetch_all(&self.get_connection_pool(vault).await?)
@ -276,9 +395,7 @@ impl Database {
is_deleted: row.is_deleted,
user_id: row.user_id,
device_id: row.device_id,
content_size: row
.content_size
.expect("Content size can't be null but sqlx can't infer it"),
content_size: row.content_size.unwrap_or(0),
})
.collect()
})
@ -287,7 +404,7 @@ impl Database {
pub async fn get_max_update_id_in_vault(
&self,
vault: &VaultId,
transaction: Option<&mut Transaction<'_>>,
connection: Option<&mut SqliteConnection>,
) -> Result<i64> {
let query = sqlx::query!(
r#"
@ -296,8 +413,8 @@ impl Database {
"#,
);
if let Some(transaction) = transaction {
query.fetch_one(&mut **transaction).await
if let Some(conn) = connection {
query.fetch_one(&mut *conn).await
} else {
query
.fetch_one(&self.get_connection_pool(vault).await?)
@ -311,7 +428,7 @@ impl Database {
&self,
vault: &VaultId,
relative_path: &str,
transaction: Option<&mut Transaction<'_>>,
connection: Option<&mut SqliteConnection>,
) -> Result<Option<StoredDocumentVersion>> {
let query = sqlx::query_as!(
StoredDocumentVersion,
@ -337,8 +454,8 @@ impl Database {
relative_path
);
if let Some(transaction) = transaction {
query.fetch_optional(&mut **transaction).await
if let Some(conn) = connection {
query.fetch_optional(&mut *conn).await
} else {
query
.fetch_optional(&self.get_connection_pool(vault).await?)
@ -351,7 +468,7 @@ impl Database {
&self,
vault: &VaultId,
document_id: &DocumentId,
transaction: Option<&mut Transaction<'_>>,
connection: Option<&mut SqliteConnection>,
) -> Result<Option<StoredDocumentVersion>> {
let document_id = document_id.as_hyphenated();
let query = sqlx::query_as!(
@ -374,8 +491,8 @@ impl Database {
document_id
);
if let Some(transaction) = transaction {
query.fetch_optional(&mut **transaction).await
if let Some(conn) = connection {
query.fetch_optional(&mut *conn).await
} else {
query
.fetch_optional(&self.get_connection_pool(vault).await?)
@ -388,7 +505,7 @@ impl Database {
&self,
vault: &VaultId,
vault_update_id: VaultUpdateId,
transaction: Option<&mut Transaction<'_>>,
connection: Option<&mut SqliteConnection>,
) -> Result<Option<StoredDocumentVersion>> {
let query = sqlx::query_as!(
StoredDocumentVersion,
@ -409,8 +526,8 @@ impl Database {
vault_update_id
);
if let Some(transaction) = transaction {
query.fetch_optional(&mut **transaction).await
if let Some(conn) = connection {
query.fetch_optional(&mut *conn).await
} else {
query
.fetch_optional(&self.get_connection_pool(vault).await?)
@ -424,7 +541,7 @@ impl Database {
&self,
vault_id: &VaultId,
version: &StoredDocumentVersion,
transaction: Option<Transaction<'_>>,
transaction: Option<WriteTransaction>,
) -> Result<()> {
let document_id = version.document_id.as_hyphenated();
let query = sqlx::query!(
@ -438,9 +555,10 @@ impl Database {
is_deleted,
user_id,
device_id,
idempotency_key
idempotency_key,
has_been_merged
)
values (?, ?, ?, ?, ?, ?, ?, ?, ?)
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"#,
version.vault_update_id,
document_id,
@ -450,7 +568,8 @@ impl Database {
version.is_deleted,
version.user_id,
version.device_id,
version.idempotency_key
version.idempotency_key,
version.has_been_merged
);
if let Some(mut transaction) = transaction {
@ -490,32 +609,36 @@ impl Database {
&self,
vault: &VaultId,
idempotency_key: &str,
transaction: Option<&mut Transaction<'_>>,
connection: Option<&mut SqliteConnection>,
) -> Result<Option<StoredDocumentVersion>> {
// Start from the `documents` table (which has an index on
// `idempotency_key`) to find the document_id, then join to
// `latest_document_versions` for the latest state.
let query = sqlx::query_as!(
StoredDocumentVersion,
r#"
select
d.vault_update_id,
d.document_id as "document_id: Hyphenated",
d.relative_path,
d.updated_date as "updated_date: chrono::DateTime<Utc>",
d.content,
d.is_deleted,
d.user_id,
d.device_id,
d.has_been_merged,
d.idempotency_key
from latest_document_versions d
inner join documents d2 on d.document_id = d2.document_id
where d2.idempotency_key = ?
ldv.vault_update_id,
ldv.document_id as "document_id: Hyphenated",
ldv.relative_path,
ldv.updated_date as "updated_date: chrono::DateTime<Utc>",
ldv.content,
ldv.is_deleted,
ldv.user_id,
ldv.device_id,
ldv.has_been_merged,
ldv.idempotency_key
from documents d
inner join latest_document_versions ldv on d.document_id = ldv.document_id
where d.idempotency_key = ?
order by ldv.vault_update_id desc
limit 1
"#,
idempotency_key
);
if let Some(transaction) = transaction {
query.fetch_optional(&mut **transaction).await
if let Some(conn) = connection {
query.fetch_optional(&mut *conn).await
} else {
query
.fetch_optional(&self.get_connection_pool(vault).await?)
@ -524,39 +647,192 @@ impl Database {
.context("Cannot fetch document by idempotency key")
}
/// Return all versions (without content) of a specific document, ordered by `vault_update_id`
pub async fn get_document_versions(
&self,
vault: &VaultId,
document_id: &DocumentId,
connection: Option<&mut SqliteConnection>,
) -> Result<Vec<DocumentVersionWithoutContent>> {
let document_id = document_id.as_hyphenated();
let query = sqlx::query!(
r#"
select
vault_update_id,
document_id as "document_id: Hyphenated",
relative_path,
updated_date as "updated_date: chrono::DateTime<Utc>",
is_deleted,
user_id,
device_id,
length(content) as "content_size: u64"
from documents
where document_id = ?
order by vault_update_id
"#,
document_id,
);
if let Some(conn) = connection {
query.fetch_all(&mut *conn).await
} else {
query
.fetch_all(&self.get_connection_pool(vault).await?)
.await
}
.with_context(|| format!("Cannot fetch document versions for document `{document_id}`"))
.map(|rows| {
rows.into_iter()
.map(|row| DocumentVersionWithoutContent {
vault_update_id: row.vault_update_id,
document_id: row.document_id.into(),
relative_path: row.relative_path,
updated_date: row.updated_date,
is_deleted: row.is_deleted,
user_id: row.user_id,
device_id: row.device_id,
content_size: row.content_size.unwrap_or(0),
})
.collect()
})
}
/// Return all versions across all documents, paginated, ordered by `vault_update_id` DESC
pub async fn get_vault_history(
&self,
vault: &VaultId,
limit: i64,
before_update_id: Option<VaultUpdateId>,
connection: Option<&mut SqliteConnection>,
) -> Result<Vec<DocumentVersionWithoutContent>> {
let map_row = |row: VaultHistoryRow| DocumentVersionWithoutContent {
vault_update_id: row.vault_update_id,
document_id: row.document_id,
relative_path: row.relative_path,
updated_date: row.updated_date,
is_deleted: row.is_deleted,
user_id: row.user_id,
device_id: row.device_id,
content_size: row.content_size.unwrap_or(0),
};
if let Some(before) = before_update_id {
let query = sqlx::query_as!(
VaultHistoryRow,
r#"
select
vault_update_id,
document_id as "document_id: Hyphenated",
relative_path,
updated_date as "updated_date: chrono::DateTime<Utc>",
is_deleted,
user_id,
device_id,
length(content) as "content_size: u64"
from documents
where vault_update_id < ?
order by vault_update_id desc
limit ?
"#,
before,
limit,
);
let rows = if let Some(conn) = connection {
query.fetch_all(&mut *conn).await
} else {
query
.fetch_all(&self.get_connection_pool(vault).await?)
.await
}
.context("Cannot fetch vault history")?;
Ok(rows.into_iter().map(map_row).collect())
} else {
let query = sqlx::query_as!(
VaultHistoryRow,
r#"
select
vault_update_id,
document_id as "document_id: Hyphenated",
relative_path,
updated_date as "updated_date: chrono::DateTime<Utc>",
is_deleted,
user_id,
device_id,
length(content) as "content_size: u64"
from documents
order by vault_update_id desc
limit ?
"#,
limit,
);
let rows = if let Some(conn) = connection {
query.fetch_all(&mut *conn).await
} else {
query
.fetch_all(&self.get_connection_pool(vault).await?)
.await
}
.context("Cannot fetch vault history")?;
Ok(rows.into_iter().map(map_row).collect())
}
}
/// Cleanup idle connection pools that haven't been accessed in more than 5 minutes
async fn cleanup_idle_pools(&self) {
let mut pools = self.connection_pools.lock().await;
let now = Instant::now();
// Collect idle vaults and remove them from the map while holding
// the lock briefly. Close pools OUTSIDE the lock so that
// pool.close().await doesn't block other get_connection_pool calls.
let idle_pools: Vec<(VaultId, Arc<VaultPool>)> = {
let mut pools = self.connection_pools.lock().await;
let now = Instant::now();
// Collect vaults to remove
let vaults_to_remove: Vec<VaultId> = pools
.iter()
.filter(|(_, pool_with_timestamp)| {
now.duration_since(pool_with_timestamp.last_accessed) > IDLE_POOL_TIMEOUT
})
.map(|(vault_id, _)| vault_id.clone())
.collect();
let vaults_to_remove: Vec<VaultId> = pools
.iter()
.filter(|(_, vp)| {
// If the lock is contested, the pool is actively used — not idle.
let Ok(last) = vp.last_accessed.try_lock() else {
return false;
};
now.duration_since(*last) > IDLE_POOL_TIMEOUT
})
.map(|(vault_id, _)| vault_id.clone())
.collect();
// Close and remove idle pools
for vault_id in &vaults_to_remove {
if let Some(pool_with_timestamp) = pools.remove(vault_id) {
vaults_to_remove
.into_iter()
.filter_map(|id| pools.remove(&id).map(|vp| (id, vp)))
.collect()
};
for (vault_id, vault_pool) in idle_pools {
if let Some(pool) = vault_pool.cell.get() {
info!("Closing idle database connection pool for vault `{vault_id}`");
pool_with_timestamp.pool.close().await;
pool.close().await;
}
}
}
/// Start a background task that periodically cleans up idle connection pools
fn start_idle_pool_cleanup(&self) {
fn start_idle_pool_cleanup(&self, mut shutdown: tokio::sync::watch::Receiver<()>) {
let database = self.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60)); // Check every minute
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
interval.tick().await;
database.cleanup_idle_pools().await;
tokio::select! {
_ = interval.tick() => {
database.cleanup_idle_pools().await;
}
_ = shutdown.changed() => {
info!("Idle pool cleanup task shutting down");
break;
}
}
}
});
}

View file

@ -0,0 +1,2 @@
CREATE UNIQUE INDEX IF NOT EXISTS idx_documents_idempotency_key
ON documents (idempotency_key) WHERE idempotency_key IS NOT NULL AND is_deleted = 0;

View file

@ -0,0 +1,2 @@
CREATE INDEX IF NOT EXISTS idx_documents_document_id
ON documents (document_id, vault_update_id);

View file

@ -25,6 +25,8 @@ pub struct StoredDocumentVersion {
pub idempotency_key: Option<String>,
}
/// Equality is based solely on `vault_update_id` (the primary key).
/// Two rows with the same PK are the same database record.
impl PartialEq<Self> for StoredDocumentVersion {
fn eq(&self, other: &Self) -> bool {
self.vault_update_id == other.vault_update_id
@ -34,7 +36,7 @@ impl PartialEq<Self> for StoredDocumentVersion {
#[derive(TS, Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct DocumentVersionWithoutContent {
#[ts(as = "i32")]
#[ts(type = "number")]
pub vault_update_id: VaultUpdateId,
pub document_id: DocumentId,
@ -44,7 +46,7 @@ pub struct DocumentVersionWithoutContent {
pub user_id: UserId,
pub device_id: DeviceId,
#[ts(as = "i32")]
#[ts(type = "number")]
pub content_size: u64,
}
@ -66,7 +68,7 @@ impl From<StoredDocumentVersion> for DocumentVersionWithoutContent {
#[derive(TS, Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct DocumentVersion {
#[ts(as = "i32")]
#[ts(type = "number")]
pub vault_update_id: VaultUpdateId,
pub document_id: DocumentId,

View file

@ -1,24 +1,21 @@
use std::{collections::HashMap, sync::Arc};
use anyhow::Context;
use log::{debug, warn};
use tokio::sync::{Mutex, broadcast};
use super::models::WebSocketServerMessageWithOrigin;
use crate::{
app_state::database::models::VaultId, config::server_config::ServerConfig, errors::server_error,
};
use crate::{app_state::database::models::VaultId, config::server_config::ServerConfig};
#[derive(Debug, Clone)]
pub struct Broadcasts {
max_clients_per_vault: usize,
broadcast_channel_capacity: usize,
tx: Arc<Mutex<HashMap<VaultId, broadcast::Sender<WebSocketServerMessageWithOrigin>>>>,
}
impl Broadcasts {
pub fn new(server_config: &ServerConfig) -> Self {
Self {
max_clients_per_vault: server_config.max_clients_per_vault,
broadcast_channel_capacity: server_config.broadcast_channel_capacity,
tx: Arc::new(Mutex::new(HashMap::new())),
}
}
@ -26,10 +23,25 @@ impl Broadcasts {
pub async fn get_receiver(
&self,
vault: VaultId,
) -> broadcast::Receiver<WebSocketServerMessageWithOrigin> {
let tx = self.get_or_create(vault).await;
max_clients: usize,
) -> Result<broadcast::Receiver<WebSocketServerMessageWithOrigin>, crate::errors::SyncServerError>
{
let mut tx_map = self.tx.lock().await;
tx.subscribe()
// Prune senders for vaults with no active receivers
tx_map.retain(|_, sender| sender.receiver_count() > 0);
let sender = tx_map
.entry(vault)
.or_insert_with(|| broadcast::channel(self.broadcast_channel_capacity).0);
if sender.receiver_count() >= max_clients {
return Err(crate::errors::client_error(anyhow::anyhow!(
"Vault has reached the maximum number of clients ({max_clients})"
)));
}
Ok(sender.subscribe())
}
/// Notify all clients (who are subscribed to the vault) about an update.
@ -39,31 +51,22 @@ impl Broadcasts {
vault: VaultId,
document: WebSocketServerMessageWithOrigin,
) {
let tx = self.get_or_create(vault.clone()).await;
let mut tx_map = self.tx.lock().await;
if tx.receiver_count() == 0 {
// Prune senders for vaults with no active receivers
tx_map.retain(|_, sender| sender.receiver_count() > 0);
let sender = tx_map
.entry(vault.clone())
.or_insert_with(|| broadcast::channel(self.broadcast_channel_capacity).0);
if sender.receiver_count() == 0 {
debug!("Skipping broadcast, no clients connected for vault `{vault}`");
return;
}
let result = tx
.send(document)
.context("Cannot broadcast server message to websocket listeners")
.map_err(server_error);
if result.is_err() {
warn!("Failed to send message: {result:?}");
if let Err(e) = sender.send(document) {
warn!("Failed to broadcast to vault `{vault}`: {e}");
}
}
async fn get_or_create(
&self,
vault: VaultId,
) -> broadcast::Sender<WebSocketServerMessageWithOrigin> {
let mut tx = self.tx.lock().await;
tx.entry(vault)
.or_insert_with(|| broadcast::channel(self.max_clients_per_vault).0.clone())
.clone()
}
}

View file

@ -11,7 +11,7 @@ pub struct WebSocketHandshake {
pub token: String,
pub device_id: DeviceId,
#[ts(as = "Option<i32>")]
#[ts(type = "number | null")]
pub last_seen_vault_update_id: Option<VaultUpdateId>,
}
@ -28,7 +28,7 @@ pub struct DocumentWithCursors {
// that it exists and can be client-side
// interpolated. However, the actual
// position is meaningless.
#[ts(as = "Option<u32>")]
#[ts(type = "number | null")]
pub vault_update_id: Option<VaultUpdateId>,
pub document_id: DocumentId,
@ -70,6 +70,7 @@ pub struct WebSocketVaultUpdate {
pub enum WebSocketClientMessage {
Handshake(WebSocketHandshake),
CursorPositions(CursorPositionFromClient),
Ping {},
}
#[derive(TS, Serialize, Clone, Debug)]

View file

@ -9,7 +9,7 @@ use crate::{
database::models::{DocumentVersionWithoutContent, VaultId, VaultUpdateId},
},
config::user_config::User,
errors::{SyncServerError, server_error, unauthenticated_error},
errors::{SyncServerError, client_error, server_error, unauthenticated_error},
server::auth::auth,
};
@ -26,16 +26,16 @@ pub fn get_authenticated_handshake(
if let Some(Message::Text(message)) = message {
let message: WebSocketClientMessage = serde_json::from_str(&message)
.context("Failed to parse message")
.map_err(server_error)?;
.map_err(client_error)?;
match message {
WebSocketClientMessage::Handshake(handshake) => {
let user = auth(state, handshake.token.trim(), vault_id)?;
Ok(AuthenticatedWebSocketHandshake { handshake, user })
}
WebSocketClientMessage::CursorPositions(_) => Err(unauthenticated_error(
anyhow::anyhow!("Expected a handshake message"),
)),
WebSocketClientMessage::CursorPositions(_) | WebSocketClientMessage::Ping {} => Err(
unauthenticated_error(anyhow::anyhow!("Expected a handshake message")),
),
}
} else {
Err(unauthenticated_error(anyhow::anyhow!(

View file

@ -28,23 +28,20 @@ pub struct Config {
impl Config {
pub async fn read_or_create(path: &Path) -> Result<Self> {
let config = if path.exists() {
info!(
"Loading configuration from `{}`",
path.canonicalize().unwrap().display()
);
Self::load_from_file(path).await?
let display_path = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
if path.exists() {
info!("Loading configuration from `{}`", display_path.display());
Self::load_from_file(path).await
} else {
Self::default()
};
config.write(path).await?;
info!(
"Updated configuration at `{}`",
path.canonicalize().unwrap().display()
);
Ok(config)
let config = Self::default();
config.write(path).await?;
info!(
"Created default configuration at `{}`",
display_path.display()
);
Ok(config)
}
}
pub async fn load_from_file(path: &Path) -> Result<Self> {

View file

@ -1,10 +1,13 @@
use anyhow::{Result, ensure};
use log::debug;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use crate::consts::{
DEFAULT_HOST, DEFAULT_MAX_BODY_SIZE_MB, DEFAULT_MAX_CLIENTS_PER_VAULT,
DEFAULT_MERGEABLE_FILE_EXTENSIONS, DEFAULT_PORT, DEFAULT_RESPONSE_TIMEOUT_SECONDS,
DEFAULT_ALLOWED_ORIGINS, DEFAULT_BROADCAST_CHANNEL_CAPACITY, DEFAULT_HOST,
DEFAULT_MAX_BODY_SIZE_MB, DEFAULT_MAX_CLIENTS_PER_VAULT, DEFAULT_MAX_PENDING_WS_CONNECTIONS,
DEFAULT_MERGEABLE_FILE_EXTENSIONS, DEFAULT_PORT, DEFAULT_RATE_LIMIT_PER_SECOND,
DEFAULT_RESPONSE_TIMEOUT_SECONDS,
};
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
@ -21,11 +24,68 @@ pub struct ServerConfig {
#[serde(default = "default_max_clients_per_vault")]
pub max_clients_per_vault: usize,
#[serde(default = "default_broadcast_channel_capacity")]
pub broadcast_channel_capacity: usize,
#[serde(default = "default_response_timeout", with = "humantime_serde")]
pub response_timeout: Duration,
#[serde(default = "default_mergeable_file_extensions")]
pub mergeable_file_extensions: Vec<String>,
/// Maximum requests per second (0 = disabled).
#[serde(default = "default_rate_limit_per_second")]
pub rate_limit_per_second: u64,
/// Allowed CORS origins. Default: `["*"]` (allow all).
#[serde(default = "default_allowed_origins")]
pub allowed_origins: Vec<String>,
/// Maximum concurrent unauthenticated WebSocket connections waiting for
/// handshake. Limits resource consumption from clients that connect but
/// never authenticate.
#[serde(default = "default_max_pending_websocket_connections")]
pub max_pending_websocket_connections: usize,
/// When set, proxies all UI requests (index, assets, Vite HMR) to this
/// URL instead of serving embedded assets. Typically
/// `http://localhost:5173` for the Vite dev server.
#[serde(default)]
pub dev_proxy_url: Option<String>,
}
impl ServerConfig {
pub fn validate(&self) -> Result<()> {
ensure!(
!self.response_timeout.is_zero(),
"response_timeout must be greater than 0"
);
ensure!(
self.max_body_size_mb > 0,
"max_body_size_mb must be greater than 0"
);
ensure!(
self.max_clients_per_vault > 0,
"max_clients_per_vault must be greater than 0"
);
ensure!(
self.broadcast_channel_capacity > 0,
"broadcast_channel_capacity must be greater than 0"
);
ensure!(
self.max_pending_websocket_connections > 0,
"max_pending_websocket_connections must be greater than 0"
);
ensure!(
self.max_clients_per_vault <= 10_000,
"max_clients_per_vault must be at most 10000"
);
ensure!(
self.broadcast_channel_capacity <= 1_000_000,
"broadcast_channel_capacity must be at most 1000000"
);
Ok(())
}
}
fn default_host() -> String {
@ -48,6 +108,11 @@ fn default_max_clients_per_vault() -> usize {
DEFAULT_MAX_CLIENTS_PER_VAULT
}
fn default_broadcast_channel_capacity() -> usize {
debug!("Using default broadcast channel capacity: {DEFAULT_BROADCAST_CHANNEL_CAPACITY}");
DEFAULT_BROADCAST_CHANNEL_CAPACITY
}
fn default_response_timeout() -> Duration {
debug!("Using default response timeout: {DEFAULT_RESPONSE_TIMEOUT_SECONDS:?}");
DEFAULT_RESPONSE_TIMEOUT_SECONDS
@ -60,3 +125,23 @@ fn default_mergeable_file_extensions() -> Vec<String> {
.map(|s| (*s).to_owned())
.collect()
}
fn default_rate_limit_per_second() -> u64 {
debug!("Using default rate limit per second: {DEFAULT_RATE_LIMIT_PER_SECOND}");
DEFAULT_RATE_LIMIT_PER_SECOND
}
fn default_allowed_origins() -> Vec<String> {
debug!("Using default allowed origins: {DEFAULT_ALLOWED_ORIGINS:?}");
DEFAULT_ALLOWED_ORIGINS
.iter()
.map(|s| (*s).to_owned())
.collect()
}
fn default_max_pending_websocket_connections() -> usize {
debug!(
"Using default max pending WebSocket connections: {DEFAULT_MAX_PENDING_WS_CONNECTIONS}"
);
DEFAULT_MAX_PENDING_WS_CONNECTIONS
}

View file

@ -19,10 +19,19 @@ where
let mut user_token_map = BiHashMap::new();
for user in &users {
if let Some(existing_name) = user_token_map.get_by_right(&user.token) {
let redacted = if user.token.len() > 6 {
format!(
"{}...{}",
&user.token[..3],
&user.token[user.token.len() - 3..]
)
} else {
"***".to_owned()
};
return Err(D::Error::custom(format!(
"Duplicate user token found: `{}` for users `{}` and `{}`. User tokens must be \
unique.",
user.token, existing_name, user.name
"Duplicate user token found: `{redacted}` for users `{}` and `{}`. User tokens \
must be unique.",
existing_name, user.name
)));
}
@ -41,10 +50,23 @@ where
impl UserConfig {
pub fn get_user(&self, token: &str) -> Option<&User> {
self.user_configs.iter().find(|u| u.token == token)
self.user_configs
.iter()
.find(|u| constant_time_eq(u.token.as_bytes(), token.as_bytes()))
}
}
/// Constant-time byte comparison to prevent timing attacks on token lookups.
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
a.iter()
.zip(b.iter())
.fold(0u8, |acc, (x, y)| acc | (x ^ y))
== 0
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct User {
pub name: String,

View file

@ -13,12 +13,21 @@ pub const DEFAULT_PORT: u16 = 3000;
pub const DEFAULT_MAX_BODY_SIZE_MB: usize = 4096;
pub const DEFAULT_RESPONSE_TIMEOUT_SECONDS: Duration = Duration::from_mins(30);
pub const DEFAULT_MAX_CLIENTS_PER_VAULT: usize = 256;
pub const DEFAULT_BROADCAST_CHANNEL_CAPACITY: usize = 4096;
pub const DEFAULT_MAX_PENDING_WS_CONNECTIONS: usize = 128;
pub const DEFAULT_LOG_DIRECTORY: &str = "logs";
pub const DEFAULT_LOG_ROTATION_INTERVAL: Duration = Duration::from_hours(24);
pub const IDLE_POOL_TIMEOUT: Duration = Duration::from_mins(5);
pub const GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
pub const DEFAULT_LOG_LEVEL: LogLevel = LogLevel::Info;
pub const DEFAULT_MERGEABLE_FILE_EXTENSIONS: &[&str] = &["md", "txt"];
/// 0 means rate limiting is disabled.
pub const DEFAULT_RATE_LIMIT_PER_SECOND: u64 = 0;
/// Default: allow all origins.
pub const DEFAULT_ALLOWED_ORIGINS: &[&str] = &["*"];
pub const SUPPORTED_API_VERSION: u32 = 3;

View file

@ -5,7 +5,7 @@ use axum::{
http::StatusCode,
response::{IntoResponse, Response},
};
use log::debug;
use log::{debug, error, warn};
use serde::Serialize;
use thiserror::Error;
use ts_rs::TS;
@ -69,7 +69,19 @@ impl Display for SerializedError {
impl IntoResponse for SyncServerError {
fn into_response(self) -> Response {
let body = Json(self.serialize());
let serialized = self.serialize();
match &self {
Self::InitError(_) | Self::ServerError(_) => {
error!("{serialized}");
}
Self::ClientError(_) | Self::NotFound(_) => {
warn!("{serialized}");
}
Self::Unauthenticated(_) | Self::PermissionDeniedError(_) => {}
}
let body = Json(serialized);
match self {
Self::InitError(_) | Self::ServerError(_) => {

View file

@ -41,7 +41,15 @@ async fn main() -> ExitCode {
}
};
let mut result = set_up_logging(&args, &config.logging);
let mut result = config
.server
.validate()
.context("Invalid server configuration")
.map_err(init_error);
if result.is_ok() {
result = set_up_logging(&args, &config.logging);
}
if result.is_ok() {
result = start_server(config).await;

View file

@ -4,28 +4,31 @@ mod delete_document;
mod device_id_header;
mod fetch_document_version;
mod fetch_document_version_content;
mod fetch_document_versions;
mod fetch_latest_document_version;
mod fetch_latest_documents;
mod fetch_vault_history;
mod index;
mod ping;
mod rate_limit;
mod requests;
mod resolve_keys;
mod responses;
mod restore_document_version;
mod update_document;
mod websocket;
use anyhow::{Context as _, Result, anyhow};
use anyhow::{Context as _, Result};
use auth::auth_middleware;
use axum::{
Router,
extract::{DefaultBodyLimit, Request},
http::{self, HeaderValue, Method},
middleware,
response::IntoResponse,
routing::{IntoMakeService, delete, get, post, put},
};
use device_id_header::DEVICE_ID_HEADER_NAME;
use log::info;
use log::{info, warn};
use tokio::signal;
use tower_http::{
LatencyUnit,
@ -42,7 +45,7 @@ use tracing::{Level, info_span};
use crate::{
app_state::AppState,
config::{Config, server_config::ServerConfig},
errors::{client_error, not_found_error},
consts::GRACEFUL_SHUTDOWN_TIMEOUT,
};
pub async fn create_server(config: Config) -> Result<()> {
@ -52,26 +55,42 @@ pub async fn create_server(config: Config) -> Result<()> {
let server_config = app_state.config.server.clone();
let app = Router::new()
let mut app = Router::new()
.nest("/", get_authed_routes(app_state.clone()))
.route("/", get(index::index))
.route("/assets/*path", get(index::spa_assets))
.route("/vaults/:vault_id/ping", get(ping::ping))
.route("/vaults/:vault_id/ws", get(websocket::websocket_handler))
.route("/vaults/:vault_id/ws", get(websocket::websocket_handler));
if app_state.config.server.dev_proxy_url.is_some() {
info!(
"Dev proxy enabled → {}",
app_state.config.server.dev_proxy_url.as_deref().unwrap()
);
app = app.fallback(index::vite_proxy);
}
let cors_layer = build_cors_layer(&server_config).context("Invalid CORS configuration")?;
if server_config.rate_limit_per_second > 0 {
info!(
"Rate limiting enabled: {} requests/second",
server_config.rate_limit_per_second
);
let limiter = rate_limit::RateLimiter::new(server_config.rate_limit_per_second);
app = app.layer(middleware::from_fn_with_state(
limiter,
rate_limit::rate_limit_middleware,
));
}
let app = app
.layer(DefaultBodyLimit::disable())
.layer(RequestBodyLimitLayer::new(
app_state.config.server.max_body_size_mb * 1024 * 1024,
))
.layer(TimeoutLayer::new(server_config.response_timeout))
.layer(
CorsLayer::new()
.allow_origin("*".parse::<HeaderValue>().expect("Failed to parse origin"))
.allow_headers([
http::header::CONTENT_TYPE,
http::header::AUTHORIZATION,
DEVICE_ID_HEADER_NAME.clone(),
])
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE]),
)
.layer(cors_layer)
.layer(
TraceLayer::new_for_http()
.make_span_with(|request: &Request<_>| {
@ -92,13 +111,40 @@ pub async fn create_server(config: Config) -> Result<()> {
.on_failure(DefaultOnFailure::new().level(Level::ERROR)),
)
.with_state(app_state)
.fallback(handle_404)
.fallback(handle_405)
.into_make_service();
start_server(app, &server_config).await
}
fn build_cors_layer(server_config: &ServerConfig) -> Result<CorsLayer> {
let origins = &server_config.allowed_origins;
let cors = if origins.len() == 1 && origins[0] == "*" {
info!("CORS: allowing all origins (wildcard)");
let header: HeaderValue = "*"
.parse()
.context("Failed to parse wildcard CORS origin")?;
CorsLayer::new().allow_origin(header)
} else {
let parsed: Vec<HeaderValue> = origins
.iter()
.map(|o| {
o.parse::<HeaderValue>()
.with_context(|| format!("Failed to parse CORS origin: `{o}`"))
})
.collect::<Result<Vec<_>>>()?;
CorsLayer::new().allow_origin(parsed)
};
Ok(cors
.allow_headers([
http::header::CONTENT_TYPE,
http::header::AUTHORIZATION,
DEVICE_ID_HEADER_NAME.clone(),
])
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE]))
}
fn get_authed_routes(app_state: AppState) -> Router<AppState> {
Router::new()
.route(
@ -125,6 +171,10 @@ fn get_authed_routes(app_state: AppState) -> Router<AppState> {
"/vaults/:vault_id/documents/:document_id/text",
put(update_document::update_text),
)
.route(
"/vaults/:vault_id/documents/:document_id/versions",
get(fetch_document_versions::fetch_document_versions),
)
.route(
"/vaults/:vault_id/documents/:document_id/versions/:vault_update_id",
get(fetch_document_version::fetch_document_version),
@ -137,6 +187,14 @@ fn get_authed_routes(app_state: AppState) -> Router<AppState> {
"/vaults/:vault_id/documents/:document_id",
delete(delete_document::delete_document),
)
.route(
"/vaults/:vault_id/documents/:document_id/restore",
post(restore_document_version::restore_document_version),
)
.route(
"/vaults/:vault_id/history",
get(fetch_vault_history::fetch_vault_history),
)
.layer(middleware::from_fn_with_state(app_state, auth_middleware))
}
@ -153,26 +211,46 @@ async fn start_server(app: IntoMakeService<axum::Router>, config: &ServerConfig)
.context("Failed to get local address")?
);
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.tcp_nodelay(true)
.await
.context("Failed to start server")
let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(false);
let server = axum::serve(listener, app)
.with_graceful_shutdown(async move {
shutdown_signal().await;
let _ = shutdown_tx.send(true);
})
.tcp_nodelay(true);
tokio::select! {
result = server => result.context("Failed to start server"),
() = async {
let _ = shutdown_rx.changed().await;
info!(
"Shutdown signal received, waiting up to {}s for in-flight requests to complete...",
GRACEFUL_SHUTDOWN_TIMEOUT.as_secs()
);
tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT).await;
warn!("Graceful shutdown timed out, forcing exit");
} => Ok(()),
}
}
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
if let Err(e) = signal::ctrl_c().await {
log::error!("Failed to install Ctrl+C handler: {e}");
}
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
match signal::unix::signal(signal::unix::SignalKind::terminate()) {
Ok(mut signal) => {
signal.recv().await;
}
Err(e) => {
log::error!("Failed to install SIGTERM handler: {e}");
}
}
};
#[cfg(not(unix))]
@ -183,11 +261,3 @@ async fn shutdown_signal() {
() = terminate => {},
}
}
async fn handle_404() -> impl IntoResponse {
not_found_error(anyhow!("Page not found"))
}
async fn handle_405() -> impl IntoResponse {
client_error(anyhow!("Method not allowed"))
}

View file

@ -9,7 +9,7 @@ use axum_extra::{
TypedHeader,
headers::{Authorization, authorization::Bearer},
};
use log::info;
use log::{debug, info};
use crate::{
app_state::{AppState, database::models::VaultId},
@ -21,10 +21,12 @@ use crate::{
pub async fn auth_middleware(
State(state): State<AppState>,
Path(path_params): Path<HashMap<String, String>>,
TypedHeader(auth_header): TypedHeader<Authorization<Bearer>>,
auth_header: Option<TypedHeader<Authorization<Bearer>>>,
mut req: Request,
next: Next,
) -> Result<Response, SyncServerError> {
let auth_header = auth_header
.ok_or_else(|| unauthenticated_error(anyhow::anyhow!("Missing Authorization header")))?;
let token = auth_header.token().trim();
let vault_id = normalize_string(
path_params
@ -51,8 +53,8 @@ pub fn auth(state: &AppState, token: &str, vault_id: &VaultId) -> Result<User, S
VaultAccess::AllowAccessToAll => true,
VaultAccess::AllowList(AllowListedVaults { ref allowed }) => allowed.contains(vault_id),
} {
info!(
"User `{}` is authenticated and is authorised to access to vault `{vault_id}`",
debug!(
"User `{}` is authenticated and is authorised to access vault `{vault_id}`",
user.name
);

View file

@ -1,4 +1,3 @@
use anyhow::Context;
use axum::{
Extension, Json,
extract::{Path, State},
@ -16,9 +15,13 @@ use crate::{
},
config::user_config::User,
errors::{SyncServerError, server_error},
server::{responses::DocumentUpdateResponse, update_document::merge_with_stored_version},
server::{
responses::DocumentUpdateResponse,
update_document::{MergeInput, merge_with_stored_version},
},
utils::{
find_first_available_path::find_first_available_path, normalize::normalize,
dedup_paths::get_base_path, find_first_available_path::find_first_available_path,
is_binary::is_binary, is_file_type_mergable::is_file_type_mergable, normalize::normalize,
sanitize_path::sanitize_path,
},
};
@ -32,7 +35,11 @@ pub struct CreateDocumentPathParams {
/// Create a new document in case a document with the same doesn't exist
/// already. If a document with the same path exists, a new version is created
/// with their content merged.
///
/// Text content must be UTF-8 encoded. Clients are responsible for
/// transcoding other encodings (e.g. UTF-16) to UTF-8 before sending.
#[axum::debug_handler]
#[allow(clippy::too_many_lines)]
pub async fn create_document(
Path(CreateDocumentPathParams { vault_id }): Path<CreateDocumentPathParams>,
Extension(user): Extension<User>,
@ -51,62 +58,133 @@ pub async fn create_document(
if let Some(ref idempotency_key) = request.idempotency_key {
let existing = state
.database
.get_document_by_idempotency_key(&vault_id, idempotency_key, Some(&mut transaction))
.get_document_by_idempotency_key(&vault_id, idempotency_key, Some(&mut *transaction))
.await
.map_err(server_error)?;
if let Some(existing) = existing {
info!(
"Found existing document with idempotency key `{idempotency_key}`, returning existing document"
);
transaction
.rollback()
.await
.context("Failed to roll back transaction")
.map_err(server_error)?;
return Ok(Json(DocumentUpdateResponse::FastForwardUpdate(
existing.into(),
)));
if existing.is_deleted {
// The document was created (storing the key) and later deleted.
// Don't return the deleted version — it would cause the client
// to delete its local file. Instead, fall through to normal
// create so the client's content is preserved as a new document.
// The unique index excludes deleted rows (WHERE is_deleted = 0),
// so keeping the key does NOT cause a constraint violation —
// the new non-deleted version can safely reuse the same key.
info!(
"Idempotency key `{idempotency_key}` matches a deleted document, ignoring and creating fresh"
);
} else {
// Return the LATEST version of the document, not the version
// that originally stored the key. The document may have been
// modified by other clients since the key was stored, and
// returning a stale version would cause the client to cache
// incorrect content, breaking subsequent diffs.
let latest = state
.database
.get_latest_document(&vault_id, &existing.document_id, Some(&mut *transaction))
.await
.map_err(server_error)?
.unwrap_or(existing);
info!(
"Found existing document with idempotency key `{idempotency_key}`, returning latest version"
);
transaction.rollback().await.map_err(server_error)?;
return Ok(Json(DocumentUpdateResponse::FastForwardUpdate(
latest.into(),
)));
}
}
}
let sanitized_relative_path = sanitize_path(&request.relative_path);
if sanitized_relative_path.is_empty() {
transaction.rollback().await.map_err(server_error)?;
return Err(crate::errors::client_error(anyhow::anyhow!(
"Relative path is empty after sanitization"
)));
}
let new_content = request.content.contents.to_vec();
let latest_version = state
.database
.get_latest_non_deleted_document_by_path(
&vault_id,
&sanitized_relative_path,
Some(&mut transaction),
Some(&mut *transaction),
)
.await
.map_err(server_error)?;
if let Some(latest_version) = latest_version {
info!(
"Document already exists at new location: `{sanitized_relative_path}` when trying to create it in vault `{vault_id}`, merging into existing document"
);
let is_mergeable_text = is_file_type_mergable(
&sanitized_relative_path,
&state.config.server.mergeable_file_extensions,
) && !is_binary(&latest_version.content)
&& !is_binary(&new_content);
return merge_with_stored_version(
&sanitized_relative_path,
&latest_version.content.clone(),
latest_version,
vault_id,
user,
device_id,
state,
&sanitized_relative_path,
request.content.contents.to_vec(),
transaction,
request.idempotency_key,
)
.await;
if is_mergeable_text || new_content == latest_version.content {
return merge_with_stored_version(
MergeInput {
parent_content: &[],
new_content,
idempotency_key: request.idempotency_key,
},
latest_version,
vault_id,
user,
device_id,
state,
transaction,
)
.await;
}
// For non-mergeable (binary) files with different content, don't
// merge — create a separate document at a deconflicted path so
// neither client's data is silently overwritten.
}
// For creates at deconflicted paths (e.g., "file (2).bin"), the client's
// ensureClearPath renamed a local file before uploading. Check if the
// base path (e.g., "file.bin") has a document with identical content.
// If so, merge with it instead of creating a duplicate document.
let base_path = get_base_path(&sanitized_relative_path);
if base_path != sanitized_relative_path {
let base_doc = state
.database
.get_latest_non_deleted_document_by_path(&vault_id, &base_path, Some(&mut *transaction))
.await
.map_err(server_error)?;
if let Some(base_doc) = base_doc
&& new_content == base_doc.content
{
info!(
"Create at deconflicted path `{sanitized_relative_path}` has identical content to document at base path `{base_path}`, merging"
);
return merge_with_stored_version(
MergeInput {
parent_content: &[],
new_content,
idempotency_key: request.idempotency_key,
},
base_doc,
vault_id,
user,
device_id,
state,
transaction,
)
.await;
}
}
let document_id = uuid::Uuid::new_v4();
let last_update_id = state
.database
.get_max_update_id_in_vault(&vault_id, Some(&mut transaction))
.get_max_update_id_in_vault(&vault_id, Some(&mut *transaction))
.await
.map_err(server_error)?;
@ -129,7 +207,7 @@ pub async fn create_document(
vault_update_id: last_update_id + 1,
document_id,
relative_path: deduped_path,
content: request.content.contents.to_vec(),
content: new_content,
updated_date: chrono::Utc::now(),
is_deleted: false,
user_id: user.name,

View file

@ -1,4 +1,4 @@
use anyhow::Context;
use anyhow::{Context, anyhow};
use axum::{
Extension, Json,
extract::{Path, State},
@ -16,8 +16,8 @@ use crate::{
},
},
config::user_config::User,
errors::{SyncServerError, server_error},
utils::{normalize::normalize, sanitize_path::sanitize_path},
errors::{SyncServerError, not_found_error, server_error},
utils::normalize::normalize,
};
#[derive(Deserialize)]
@ -37,7 +37,7 @@ pub async fn delete_document(
Extension(user): Extension<User>,
TypedHeader(device_id): TypedHeader<DeviceIdHeader>,
State(state): State<AppState>,
Json(request): Json<DeleteDocumentVersion>,
Json(_request): Json<DeleteDocumentVersion>,
) -> Result<Json<DocumentVersionWithoutContent>, SyncServerError> {
debug!("Deleting document `{document_id}` in vault `{vault_id}`");
@ -59,6 +59,18 @@ pub async fn delete_document(
.await
.map_err(server_error)?;
if latest_version.is_none() {
transaction
.rollback()
.await
.context("Failed to roll back transaction")
.map_err(server_error)?;
return Err(not_found_error(anyhow!(
"Document `{document_id}` not found in vault `{vault_id}`"
)));
}
if let Some(latest_version) = &latest_version
&& latest_version.is_deleted
{
@ -72,13 +84,14 @@ pub async fn delete_document(
return Ok(Json(latest_version.clone().into()));
}
let latest_content = latest_version.map_or_else(Vec::new, |version| version.content); // in case the document has never existed before deleting it
// latest_version is guaranteed to be Some and not deleted at this point
let latest_version = latest_version.expect("checked above: not None and not deleted");
let new_version = StoredDocumentVersion {
vault_update_id: last_update_id + 1,
document_id,
relative_path: sanitize_path(&request.relative_path),
content: latest_content, // copy the content from the latest version
relative_path: latest_version.relative_path,
content: latest_version.content,
updated_date: chrono::Utc::now(),
is_deleted: true,
user_id: user.name,

View file

@ -16,20 +16,31 @@ impl Header for DeviceIdHeader {
{
let value = values.next().ok_or_else(headers::Error::invalid)?;
Ok(DeviceIdHeader(
value
.to_str()
.map_err(|_| headers::Error::invalid())?
.to_owned(),
))
let s = value.to_str().map_err(|_| headers::Error::invalid())?;
if s.is_empty() || s.len() > 256 {
return Err(headers::Error::invalid());
}
// Only allow safe characters to prevent log injection and similar attacks.
// Covers UUIDs, user-agent strings like "vault-link/1.0 (12345; linux)",
// and human-readable device names.
if !s
.chars()
.all(|c| c.is_ascii_alphanumeric() || "-_./ ();:@+,".contains(c))
{
return Err(headers::Error::invalid());
}
Ok(DeviceIdHeader(s.to_owned()))
}
fn encode<E>(&self, values: &mut E)
where
E: Extend<HeaderValue>,
{
let value = HeaderValue::from_static(Box::leak(self.0.clone().into_boxed_str()));
values.extend(std::iter::once(value));
if let Ok(value) = HeaderValue::from_str(&self.0) {
values.extend(std::iter::once(value));
}
}
}

View file

@ -11,7 +11,7 @@ use crate::{
AppState,
database::models::{DocumentId, DocumentVersion, VaultId, VaultUpdateId},
},
errors::{SyncServerError, not_found_error, server_error},
errors::{SyncServerError, client_error, not_found_error, server_error},
utils::normalize::normalize,
};
@ -52,7 +52,7 @@ pub async fn fetch_document_version(
)?;
if result.document_id != document_id {
return Err(not_found_error(anyhow!(
return Err(client_error(anyhow!(
"Document with document id `{document_id}` does not have a version with id \
`{vault_update_id}`",
)));

View file

@ -11,7 +11,7 @@ use crate::{
AppState,
database::models::{DocumentId, VaultId, VaultUpdateId},
},
errors::{SyncServerError, not_found_error, server_error},
errors::{SyncServerError, client_error, not_found_error, server_error},
utils::normalize::normalize,
};
@ -52,7 +52,7 @@ pub async fn fetch_document_version_content(
)?;
if result.document_id != document_id {
return Err(not_found_error(anyhow!(
return Err(client_error(anyhow!(
"Document with document id `{document_id}` does not have a version with id \
`{vault_update_id}`",
)));

View file

@ -0,0 +1,42 @@
use axum::{
Json,
extract::{Path, State},
};
use log::debug;
use serde::Deserialize;
use crate::{
app_state::{
AppState,
database::models::{DocumentId, DocumentVersionWithoutContent, VaultId},
},
errors::{SyncServerError, server_error},
utils::normalize::normalize,
};
#[derive(Deserialize)]
pub struct FetchDocumentVersionsPathParams {
#[serde(deserialize_with = "normalize")]
vault_id: VaultId,
document_id: DocumentId,
}
#[axum::debug_handler]
pub async fn fetch_document_versions(
Path(FetchDocumentVersionsPathParams {
vault_id,
document_id,
}): Path<FetchDocumentVersionsPathParams>,
State(state): State<AppState>,
) -> Result<Json<Vec<DocumentVersionWithoutContent>>, SyncServerError> {
debug!("Fetching all versions for document `{document_id}` in vault `{vault_id}`");
let versions = state
.database
.get_document_versions(&vault_id, &document_id, None)
.await
.map_err(server_error)?;
Ok(Json(versions))
}

View file

@ -0,0 +1,70 @@
use axum::{
Json,
extract::{Path, Query, State},
};
use log::debug;
use serde::Deserialize;
use super::responses::VaultHistoryResponse;
use crate::{
app_state::{
AppState,
database::models::{VaultId, VaultUpdateId},
},
errors::{SyncServerError, client_error, server_error},
utils::normalize::normalize,
};
const DEFAULT_LIMIT: i64 = 50;
const MAX_LIMIT: i64 = 500;
#[derive(Deserialize)]
pub struct FetchVaultHistoryPathParams {
#[serde(deserialize_with = "normalize")]
vault_id: VaultId,
}
#[derive(Deserialize)]
pub struct QueryParams {
limit: Option<i64>,
before_update_id: Option<VaultUpdateId>,
}
#[axum::debug_handler]
pub async fn fetch_vault_history(
Path(FetchVaultHistoryPathParams { vault_id }): Path<FetchVaultHistoryPathParams>,
Query(QueryParams {
limit,
before_update_id,
}): Query<QueryParams>,
State(state): State<AppState>,
) -> Result<Json<VaultHistoryResponse>, SyncServerError> {
if let Some(id) = before_update_id
&& id <= 0
{
return Err(client_error(anyhow::anyhow!(
"before_update_id must be a positive integer"
)));
}
let limit = limit.unwrap_or(DEFAULT_LIMIT).clamp(1, MAX_LIMIT);
debug!(
"Fetching vault history for vault `{vault_id}` (limit={limit}, before={before_update_id:?})"
);
// Fetch one extra row to determine if there are more results
let mut versions = state
.database
.get_vault_history(&vault_id, limit + 1, before_update_id, None)
.await
.map_err(server_error)?;
#[allow(clippy::cast_sign_loss)] // limit is clamped to [1, 500] above
let has_more = versions.len() > limit as usize;
if has_more {
versions.pop();
}
Ok(Json(VaultHistoryResponse { versions, has_more }))
}

View file

@ -1,7 +1,146 @@
use axum::response::{Html, IntoResponse};
use axum::{
body::Body,
extract::{Path, State},
http::{StatusCode, header},
response::{Html, IntoResponse, Response},
};
use log::warn;
use rust_embed::Embed;
pub async fn index() -> impl IntoResponse {
const HTML_CONTENT: &str = include_str!("./assets/index.html");
let html_content = HTML_CONTENT;
Html(html_content)
use crate::app_state::AppState;
#[derive(Embed)]
#[folder = "../frontend/history-ui/dist/"]
struct HistoryUiAssets;
pub async fn index(State(state): State<AppState>) -> impl IntoResponse {
if let Some(proxy_url) = &state.config.server.dev_proxy_url {
let response = proxy_request(proxy_url, "/").await;
if response.status().is_success() {
return response;
}
}
if let Some(content) = HistoryUiAssets::get("index.html") {
Html(
std::str::from_utf8(content.data.as_ref())
.inspect_err(|e| warn!("Embedded index.html is not valid UTF-8: {e}"))
.unwrap_or("<h1>VaultLink</h1>")
.to_owned(),
)
.into_response()
} else {
warn!("No embedded index.html found — history UI may not have been built");
Html("<h1>VaultLink server</h1>".to_owned()).into_response()
}
}
pub async fn spa_assets(
State(state): State<AppState>,
Path(path): Path<String>,
) -> impl IntoResponse {
if let Some(proxy_url) = &state.config.server.dev_proxy_url {
let response = proxy_request(proxy_url, &format!("/assets/{path}")).await;
if response.status().is_success() {
return response;
}
}
// The route is /assets/*path so path is relative to assets/.
// The embedded files include the assets/ prefix from the dist directory.
let full_path = format!("assets/{path}");
if let Some(content) = HistoryUiAssets::get(&full_path) {
let mime = mime_guess::from_path(&full_path).first_or_octet_stream();
return Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, mime.as_ref())
.body(Body::from(content.data.to_vec()))
.unwrap_or_else(|_| {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap_or_else(|_| Response::new(Body::empty()))
});
}
// Asset paths must match an embedded file — no SPA fallback.
// Serving index.html here would return 200 with text/html for missing
// .css/.js files, causing the browser to silently ignore the content.
Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from("Not found"))
.unwrap_or_else(|_| Response::new(Body::from("Not found")))
}
/// Proxies unmatched paths to the Vite dev server for HMR support
/// (`@vite/client`, `src/`, etc.).
pub async fn vite_proxy(
State(state): State<AppState>,
request: axum::extract::Request,
) -> impl IntoResponse {
let proxy_url = state.config.server.dev_proxy_url.as_deref().unwrap_or("");
let response = proxy_request(proxy_url, request.uri().path()).await;
if !response.status().is_success() {
return Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from("Not found"))
.unwrap_or_else(|_| Response::new(Body::from("Not found")));
}
response
}
/// SPA fallback for production: serves index.html for client-side routes
/// (e.g. `/documents/123`). Only used when the dev proxy is disabled.
pub async fn spa_fallback() -> impl IntoResponse {
match HistoryUiAssets::get("index.html") {
Some(content) => Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/html")
.body(Body::from(content.data.to_vec()))
.unwrap_or_else(|_| {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap_or_else(|_| Response::new(Body::empty()))
}),
None => Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from("Not found"))
.unwrap_or_else(|_| Response::new(Body::from("Not found"))),
}
}
static DEV_PROXY_CLIENT: std::sync::LazyLock<reqwest::Client> = std::sync::LazyLock::new(|| {
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(5))
.build()
.unwrap_or_default()
});
async fn proxy_request(proxy_url: &str, path: &str) -> Response {
let url = format!("{proxy_url}{path}");
match DEV_PROXY_CLIENT.get(&url).send().await {
Ok(resp) => {
let status =
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let mut builder = Response::builder().status(status);
for (name, value) in resp.headers() {
builder = builder.header(name.clone(), value.clone());
}
let bytes = resp.bytes().await.unwrap_or_default();
builder.body(Body::from(bytes)).unwrap_or_else(|_| {
Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Body::empty())
.unwrap_or_else(|_| Response::new(Body::empty()))
})
}
Err(_) => {
// Dev server not running — fall back to embedded assets
Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Body::empty())
.unwrap_or_else(|_| Response::new(Body::empty()))
}
}
}

View file

@ -0,0 +1,72 @@
use std::sync::{
Arc,
atomic::{AtomicU64, Ordering},
};
use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
/// Simple token-bucket rate limiter that refills every second.
#[derive(Clone, Debug)]
pub struct RateLimiter {
inner: Arc<TokenBucket>,
}
#[derive(Debug)]
struct TokenBucket {
tokens: AtomicU64,
max_tokens: u64,
}
impl RateLimiter {
/// Create a new rate limiter. Spawns a background task that refills tokens
/// every second.
///
/// # Panics
///
/// Panics if `max_per_second` is 0.
pub fn new(max_per_second: u64) -> Self {
assert!(
max_per_second > 0,
"max_per_second must be > 0 (use 0 in config to disable rate limiting entirely)"
);
let bucket = Arc::new(TokenBucket {
tokens: AtomicU64::new(max_per_second),
max_tokens: max_per_second,
});
let bucket_clone = bucket.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(1));
loop {
interval.tick().await;
bucket_clone
.tokens
.store(bucket_clone.max_tokens, Ordering::Release);
}
});
Self { inner: bucket }
}
fn try_acquire(&self) -> bool {
self.inner
.tokens
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
if current > 0 { Some(current - 1) } else { None }
})
.is_ok()
}
}
pub async fn rate_limit_middleware(
axum::extract::State(limiter): axum::extract::State<RateLimiter>,
req: Request,
next: Next,
) -> Result<Response, StatusCode> {
if limiter.try_acquire() {
Ok(next.run(req).await)
} else {
Err(StatusCode::TOO_MANY_REQUESTS)
}
}

View file

@ -31,7 +31,7 @@ pub struct UpdateBinaryDocumentVersion {
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct UpdateTextDocumentVersion {
#[ts(as = "i32")]
#[ts(type = "number")]
pub parent_version_id: VaultUpdateId,
pub relative_path: String,
@ -40,9 +40,5 @@ pub struct UpdateTextDocumentVersion {
pub content: Vec<NumberOrText>,
}
#[derive(TS, Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct DeleteDocumentVersion {
pub relative_path: String,
}
#[derive(Debug, Deserialize)]
pub struct DeleteDocumentVersion {}

View file

@ -43,6 +43,10 @@ pub async fn resolve_keys(
request.idempotency_keys.len()
);
// Each key lookup is an independent read — no write transaction needed.
// Using create_write_transaction (BEGIN IMMEDIATE) here would hold the
// SQLite write lock for the entire iteration, blocking all concurrent
// creates/updates/deletes and causing server-wide deadlocks under load.
let mut resolved = HashMap::new();
for key in &request.idempotency_keys {
@ -53,11 +57,22 @@ pub async fn resolve_keys(
.map_err(server_error)?;
if let Some(doc) = document {
resolved.insert(key.clone(), doc.document_id.to_string());
// Skip deleted documents — returning their documentId would cause
// the client to assign a stale ID to its pending doc, and the
// subsequent create retry would get a different documentId from the
// server (since create_document falls through for deleted matches),
// leaving the document permanently stuck.
if !doc.is_deleted {
resolved.insert(key.clone(), doc.document_id.to_string());
}
}
}
debug!("Resolved {}/{} idempotency keys", resolved.len(), request.idempotency_keys.len());
debug!(
"Resolved {}/{} idempotency keys",
resolved.len(),
request.idempotency_keys.len()
);
Ok(Json(ResolveKeysResponse { resolved }))
}

View file

@ -36,6 +36,15 @@ pub struct FetchLatestDocumentsResponse {
pub last_update_id: VaultUpdateId,
}
/// Response to a vault history request (paginated).
#[derive(TS, Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct VaultHistoryResponse {
pub versions: Vec<DocumentVersionWithoutContent>,
pub has_more: bool,
}
/// Response to an update document request.
#[derive(TS, Debug, Clone, Serialize)]
#[serde(tag = "type")]

View file

@ -0,0 +1,148 @@
use anyhow::anyhow;
use axum::{
Extension, Json,
extract::{Path, State},
};
use axum_extra::TypedHeader;
use log::{debug, info};
use serde::Deserialize;
use super::device_id_header::DeviceIdHeader;
use crate::{
app_state::{
AppState,
database::models::{
DocumentId, DocumentVersionWithoutContent, StoredDocumentVersion, VaultId,
VaultUpdateId,
},
},
config::user_config::User,
errors::{SyncServerError, client_error, not_found_error, server_error},
utils::{find_first_available_path::find_first_available_path, normalize::normalize},
};
#[derive(Deserialize)]
pub struct RestorePathParams {
#[serde(deserialize_with = "normalize")]
vault_id: VaultId,
document_id: DocumentId,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RestoreDocumentVersionRequest {
pub vault_update_id: VaultUpdateId,
}
#[axum::debug_handler]
pub async fn restore_document_version(
Path(RestorePathParams {
vault_id,
document_id,
}): Path<RestorePathParams>,
Extension(user): Extension<User>,
TypedHeader(device_id): TypedHeader<DeviceIdHeader>,
State(state): State<AppState>,
Json(request): Json<RestoreDocumentVersionRequest>,
) -> Result<Json<DocumentVersionWithoutContent>, SyncServerError> {
debug!(
"Restoring document `{document_id}` in vault `{vault_id}` to version `{}`",
request.vault_update_id
);
if request.vault_update_id <= 0 {
return Err(client_error(anyhow!(
"Invalid vault_update_id: `{}`",
request.vault_update_id
)));
}
let mut transaction = state
.database
.create_write_transaction(&vault_id)
.await
.map_err(server_error)?;
let target_version = state
.database
.get_document_version(&vault_id, request.vault_update_id, Some(&mut *transaction))
.await
.map_err(server_error)?
.ok_or_else(|| {
not_found_error(anyhow!("Version `{}` not found", request.vault_update_id))
})?;
if target_version.document_id != document_id {
transaction.rollback().await.map_err(server_error)?;
return Err(not_found_error(anyhow!(
"Version `{}` does not belong to document `{document_id}`",
request.vault_update_id,
)));
}
if target_version.is_deleted {
transaction.rollback().await.map_err(server_error)?;
return Err(client_error(anyhow!(
"Cannot restore to a deleted version `{}`",
request.vault_update_id,
)));
}
let existing = state
.database
.get_latest_non_deleted_document_by_path(
&vault_id,
&target_version.relative_path,
Some(&mut *transaction),
)
.await
.map_err(server_error)?;
let restore_path = if let Some(existing_doc) = &existing
&& existing_doc.document_id != document_id
{
find_first_available_path(
&vault_id,
&target_version.relative_path,
&state.database,
&mut transaction,
)
.await
.map_err(server_error)?
} else {
target_version.relative_path.clone()
};
let last_update_id = state
.database
.get_max_update_id_in_vault(&vault_id, Some(&mut *transaction))
.await
.map_err(server_error)?;
let new_version = StoredDocumentVersion {
vault_update_id: last_update_id + 1,
document_id,
relative_path: restore_path,
content: target_version.content,
updated_date: chrono::Utc::now(),
is_deleted: false,
user_id: user.name.clone(),
device_id: device_id.0.clone(),
has_been_merged: false,
idempotency_key: None,
};
state
.database
.insert_document_version(&vault_id, &new_version, Some(transaction))
.await
.map_err(server_error)?;
info!(
"Restored document `{document_id}` to version `{}` as new version `{}`",
request.vault_update_id, new_version.vault_update_id
);
Ok(Json(new_version.into()))
}

View file

@ -17,7 +17,7 @@ use crate::{
app_state::{
AppState,
database::{
Transaction,
WriteTransaction,
models::{DocumentId, StoredDocumentVersion, VaultId, VaultUpdateId},
},
},
@ -49,7 +49,8 @@ pub async fn update_binary(
State(state): State<AppState>,
TypedMultipart(request): TypedMultipart<UpdateBinaryDocumentVersion>,
) -> Result<Json<DocumentUpdateResponse>, SyncServerError> {
let parent_document = get_parent_document(&state, &vault_id, request.parent_version_id).await?;
let parent_document =
get_parent_document(&state, &vault_id, &document_id, request.parent_version_id).await?;
let content = request.content.contents.to_vec();
update_document(
@ -77,19 +78,16 @@ pub async fn update_text(
State(state): State<AppState>,
Json(request): Json<UpdateTextDocumentVersion>,
) -> Result<Json<DocumentUpdateResponse>, SyncServerError> {
let parent_document = get_parent_document(&state, &vault_id, request.parent_version_id).await?;
let parent_document =
get_parent_document(&state, &vault_id, &document_id, request.parent_version_id).await?;
let parent_content = str::from_utf8(&parent_document.content)
.context("Parent document content is not valid UTF-8")
let parent_text = str::from_utf8(&parent_document.content)
.context("Parent version contains binary content; use putBinary instead of putText")
.map_err(client_error)?;
let edited_text = EditedText::from_diff(
parent_content,
request.content,
&*BuiltinTokenizer::Word,
)
.context("Failed to apply given diff to parent document")
.map_err(client_error)?;
let edited_text = EditedText::from_diff(parent_text, request.content, &*BuiltinTokenizer::Word)
.context("Failed to apply given diff to parent document")
.map_err(client_error)?;
let content = edited_text.apply().text().into_bytes();
@ -109,9 +107,10 @@ pub async fn update_text(
async fn get_parent_document(
state: &AppState,
vault_id: &VaultId,
document_id: &DocumentId,
parent_version_id: VaultUpdateId,
) -> Result<StoredDocumentVersion, SyncServerError> {
state
let parent = state
.database
.get_document_version(vault_id, parent_version_id, None)
.await
@ -123,7 +122,15 @@ async fn get_parent_document(
)))
},
Ok,
)
)?;
if &parent.document_id != document_id {
return Err(client_error(anyhow!(
"Parent version `{parent_version_id}` does not belong to document `{document_id}`"
)));
}
Ok(parent)
}
#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
@ -141,12 +148,24 @@ async fn update_document(
let sanitized_relative_path = sanitize_path(relative_path);
if sanitized_relative_path.is_empty() {
return Err(client_error(anyhow!(
"Relative path is empty after sanitization"
)));
}
let mut transaction = state
.database
.create_write_transaction(&vault_id)
.await
.map_err(server_error)?;
let last_update_id = state
.database
.get_max_update_id_in_vault(&vault_id, Some(&mut transaction))
.await
.map_err(server_error)?;
let latest_version = state
.database
.get_latest_document(&vault_id, &document_id, Some(&mut transaction))
@ -174,43 +193,12 @@ async fn update_document(
)));
}
merge_with_stored_version(
&parent_document.relative_path,
&parent_document.content,
latest_version,
vault_id,
user,
device_id,
state,
&sanitized_relative_path,
content,
transaction,
None,
)
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn merge_with_stored_version(
parent_document_path: &str,
parent_document_content: &[u8],
latest_version: StoredDocumentVersion,
vault_id: VaultId,
user: User,
device_id: DeviceIdHeader,
state: AppState,
sanitized_relative_path: &str,
content: Vec<u8>,
mut transaction: Transaction<'_>,
idempotency_key: Option<String>,
) -> Result<Json<DocumentUpdateResponse>, SyncServerError> {
// Return the latest version if the content and path are the same as the latest
// version
if content == latest_version.content && sanitized_relative_path == latest_version.relative_path
{
info!(
"Document content is the same as the latest version for `{}`, skipping update",
latest_version.document_id
"Document content is the same as the latest version for `{document_id}`, skipping update"
);
transaction
.rollback()
@ -224,47 +212,50 @@ pub async fn merge_with_stored_version(
}
let are_all_participants_mergable = is_file_type_mergable(
sanitized_relative_path,
&sanitized_relative_path,
&state.config.server.mergeable_file_extensions,
) && !is_binary(parent_document_content)
) && !is_binary(&parent_document.content)
&& !is_binary(&latest_version.content)
&& !is_binary(&content);
let merged_content = if are_all_participants_mergable {
info!(
"Merging changes for document `{}` in vault `{vault_id}`",
latest_version.document_id
);
let parent_str = str::from_utf8(parent_document_content)
let (merged_content, is_different_from_request_content) = if are_all_participants_mergable {
info!("Merging changes for document `{document_id}` in vault `{vault_id}`");
let parent_text = str::from_utf8(&parent_document.content)
.context("Parent document content is not valid UTF-8")
.map_err(server_error)?;
let latest_str = str::from_utf8(&latest_version.content)
.map_err(client_error)?;
let latest_text = str::from_utf8(&latest_version.content)
.context("Latest version content is not valid UTF-8")
.map_err(server_error)?;
let content_str = str::from_utf8(&content)
.map_err(client_error)?;
let new_text = str::from_utf8(&content)
.context("New content is not valid UTF-8")
.map_err(server_error)?;
reconcile(
parent_str,
&latest_str.into(),
&content_str.into(),
.map_err(client_error)?;
let merged = reconcile(
parent_text,
&latest_text.into(),
&new_text.into(),
&*BuiltinTokenizer::Word,
)
.apply()
.text()
.into_bytes()
.into_bytes();
let is_different = merged != content;
(merged, is_different)
} else {
content.clone()
(content, false)
};
// We can only update the relative path if we're the first one to do so
let new_relative_path = if parent_document_path == latest_version.relative_path
&& latest_version.relative_path != sanitized_relative_path
// Rename resolution: only apply the client's rename if the document's path
// hasn't changed since this client's parent version. Check the parent
// version's path against the latest version's path. If they differ, another
// client already renamed the document — keep the latest path (first rename
// wins). Content changes from both clients are still merged correctly via
// the 3-way reconcile above, independent of which rename wins.
let new_relative_path = if parent_document.relative_path == latest_version.relative_path
&& sanitized_relative_path != latest_version.relative_path
{
let new_path = find_first_available_path(
&vault_id,
sanitized_relative_path,
&sanitized_relative_path,
&state.database,
&mut transaction,
)
@ -282,16 +273,8 @@ pub async fn merge_with_stored_version(
latest_version.relative_path.clone()
};
let last_update_id = state
.database
.get_max_update_id_in_vault(&vault_id, Some(&mut transaction))
.await
.map_err(server_error)?;
let is_different_from_request_content = merged_content != content;
let new_version = StoredDocumentVersion {
document_id: latest_version.document_id,
document_id,
vault_update_id: last_update_id + 1,
relative_path: new_relative_path,
content: merged_content,
@ -300,7 +283,114 @@ pub async fn merge_with_stored_version(
user_id: user.name,
device_id: device_id.0,
has_been_merged: are_all_participants_mergable && is_different_from_request_content,
idempotency_key,
idempotency_key: None,
};
state
.database
.insert_document_version(&vault_id, &new_version, Some(transaction))
.await
.map_err(server_error)?;
Ok(Json(if is_different_from_request_content {
DocumentUpdateResponse::MergingUpdate(new_version.into())
} else {
DocumentUpdateResponse::FastForwardUpdate(new_version.into())
}))
}
pub struct MergeInput<'a> {
pub parent_content: &'a [u8],
pub new_content: Vec<u8>,
pub idempotency_key: Option<String>,
}
#[allow(clippy::too_many_arguments)]
pub async fn merge_with_stored_version(
input: MergeInput<'_>,
latest_version: StoredDocumentVersion,
vault_id: VaultId,
user: User,
device_id: DeviceIdHeader,
state: AppState,
mut transaction: WriteTransaction,
) -> Result<Json<DocumentUpdateResponse>, SyncServerError> {
let document_id = latest_version.document_id;
let last_update_id = state
.database
.get_max_update_id_in_vault(&vault_id, Some(&mut transaction))
.await
.map_err(server_error)?;
let are_all_participants_mergable = is_file_type_mergable(
&latest_version.relative_path,
&state.config.server.mergeable_file_extensions,
) && !is_binary(input.parent_content)
&& !is_binary(&latest_version.content)
&& !is_binary(&input.new_content);
let merged_content = if are_all_participants_mergable {
info!("Merging changes for document `{document_id}` in vault `{vault_id}`");
let parent_text = str::from_utf8(input.parent_content)
.context("Parent content is not valid UTF-8")
.map_err(client_error)?;
let latest_text = str::from_utf8(&latest_version.content)
.context("Latest version content is not valid UTF-8")
.map_err(client_error)?;
let new_text = str::from_utf8(&input.new_content)
.context("New content is not valid UTF-8")
.map_err(client_error)?;
reconcile(
parent_text,
&latest_text.into(),
&new_text.into(),
&*BuiltinTokenizer::Word,
)
.apply()
.text()
.into_bytes()
} else {
input.new_content.clone()
};
let is_different_from_request_content = merged_content != input.new_content;
// When merging during create, keep the latest version's path (the existing
// document's path) rather than the requested path.
let new_relative_path = latest_version.relative_path.clone();
// Short-circuit: if content is identical AND no idempotency key to persist,
// return the existing version without inserting a new row.
if merged_content == latest_version.content
&& new_relative_path == latest_version.relative_path
&& input.idempotency_key.is_none()
{
info!(
"Merged content is the same as the latest version for `{document_id}`, skipping insert"
);
transaction
.rollback()
.await
.context("Failed to roll back transaction")
.map_err(server_error)?;
return Ok(Json(DocumentUpdateResponse::FastForwardUpdate(
latest_version.into(),
)));
}
let new_version = StoredDocumentVersion {
document_id,
vault_update_id: last_update_id + 1,
relative_path: new_relative_path,
content: merged_content,
updated_date: chrono::Utc::now(),
is_deleted: false,
user_id: user.name,
device_id: device_id.0,
has_been_merged: are_all_participants_mergable && is_different_from_request_content,
idempotency_key: input.idempotency_key,
};
state

View file

@ -6,9 +6,11 @@ use axum::{
},
response::Response,
};
use futures::sink::SinkExt;
use futures::stream::StreamExt;
use log::{debug, info, warn};
use serde::Deserialize;
use std::time::Duration;
use crate::{
app_state::{
@ -28,6 +30,20 @@ use crate::{
utils::normalize::normalize,
};
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
/// Tracks a pending (not yet authenticated) WebSocket connection.
/// Decrements the counter when dropped, ensuring cleanup even if
/// the upgrade never completes or auth fails.
struct PendingWsGuard(std::sync::Arc<std::sync::atomic::AtomicUsize>);
impl Drop for PendingWsGuard {
fn drop(&mut self) {
self.0
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
}
}
#[derive(Deserialize)]
pub struct WebSocketPathParams {
#[serde(deserialize_with = "normalize")]
@ -39,13 +55,42 @@ pub async fn websocket_handler(
Path(WebSocketPathParams { vault_id }): Path<WebSocketPathParams>,
State(state): State<AppState>,
) -> Result<Response, SyncServerError> {
Ok(ws.on_upgrade(move |socket| websocket_wrapped(state, socket, vault_id)))
// Delegating to a non-async helper avoids a known Rust issue where
// temporary borrows of `state` inside an async fn (before the move into
// `on_upgrade`) cause "Send is not general enough" compilation errors.
websocket_handler_inner(ws, vault_id, state)
}
async fn websocket_wrapped(state: AppState, stream: WebSocket, vault_id: VaultId) {
fn websocket_handler_inner(
ws: WebSocketUpgrade,
vault_id: VaultId,
state: AppState,
) -> Result<Response, SyncServerError> {
let current = state
.pending_ws_connections
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if current >= state.config.server.max_pending_websocket_connections {
state
.pending_ws_connections
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
return Err(client_error(anyhow::anyhow!(
"Too many pending WebSocket connections"
)));
}
let guard = PendingWsGuard(state.pending_ws_connections.clone());
Ok(ws.on_upgrade(move |socket| websocket_wrapped(state, socket, vault_id, guard)))
}
async fn websocket_wrapped(
state: AppState,
stream: WebSocket,
vault_id: VaultId,
pending_guard: PendingWsGuard,
) {
info!("WebSocket connection opened on vault `{vault_id}`");
let result = websocket(state, stream, vault_id.clone()).await;
let result = websocket(state, stream, vault_id.clone(), pending_guard).await;
if let Err(err) = result {
debug!("WebSocket connection error on vault `{vault_id}`: {err}");
@ -57,25 +102,53 @@ async fn websocket(
state: AppState,
stream: WebSocket,
vault_id: VaultId,
pending_guard: PendingWsGuard,
) -> Result<(), SyncServerError> {
let (mut sender, mut websocket_receiver) = stream.split();
let authed_handshake = get_authenticated_handshake(
&state,
&vault_id,
websocket_receiver
.next()
.await
.transpose()
.unwrap_or_default(),
)?;
let handshake_msg = tokio::time::timeout(HANDSHAKE_TIMEOUT, websocket_receiver.next())
.await
.map_err(|_| client_error(anyhow::anyhow!("WebSocket handshake timed out")))?
.transpose()
.map_err(|e| client_error(anyhow::anyhow!("WebSocket error during handshake: {e}")))?;
let authed_handshake = get_authenticated_handshake(&state, &vault_id, handshake_msg)?;
info!(
"WebSocket handshake successful for vault `{vault_id}` for `{}`",
authed_handshake.handshake.device_id
);
let mut broadcast_receiver = state.broadcasts.get_receiver(vault_id.clone()).await;
// Auth complete — no longer a pending connection.
drop(pending_guard);
let max_clients = state.config.server.max_clients_per_vault;
let mut broadcast_receiver = match state
.broadcasts
.get_receiver(vault_id.clone(), max_clients)
.await
{
Ok(receiver) => receiver,
Err(err) => {
warn!(
"Vault `{vault_id}` has reached the maximum number of clients ({max_clients}), rejecting connection from `{}`",
authed_handshake.handshake.device_id
);
if let Err(e) = sender
.send(Message::Close(Some(axum::extract::ws::CloseFrame {
code: 4000,
reason: format!(
"Vault has reached the maximum number of clients ({max_clients})"
)
.into(),
})))
.await
{
warn!("Failed to send WebSocket close frame: {e}");
}
return Err(err);
}
};
send_update_over_websocket(
&WebSocketServerMessage::VaultUpdate(WebSocketVaultUpdate {
@ -109,9 +182,9 @@ async fn websocket(
}
let message = match update.message {
WebSocketServerMessage::CursorPositions(
CursorPositionFromServer { clients },
) => WebSocketServerMessage::CursorPositions(CursorPositionFromServer {
WebSocketServerMessage::CursorPositions(CursorPositionFromServer {
clients,
}) => WebSocketServerMessage::CursorPositions(CursorPositionFromServer {
clients: clients
.into_iter()
.filter(|client| client.device_id != device_id)
@ -124,14 +197,11 @@ async fn websocket(
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
warn!(
"WebSocket receiver for device {device_id} lagged by {n} messages, \
disconnecting for re-sync"
"WebSocket receiver lagged, dropped {n} messages — disconnecting client to force full resync"
);
break;
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
break;
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
}
}
@ -142,26 +212,64 @@ async fn websocket(
let vault_id_clone = vault_id.clone();
let cursor_manager = state.cursors.clone();
let mut receive_task = tokio::spawn(async move {
while let Some(Ok(Message::Text(message))) = websocket_receiver.next().await {
let message: WebSocketClientMessage = serde_json::from_str(&message)
.context("Failed to parse WebSocket message from client")
.map_err(server_error)?;
while let Some(msg) = websocket_receiver.next().await {
match msg {
Ok(Message::Text(message)) => {
let message: WebSocketClientMessage = serde_json::from_str(&message)
.context("Failed to parse WebSocket message from client")
.map_err(client_error)?;
match message {
WebSocketClientMessage::Handshake(_) => {
return Err(client_error(anyhow::anyhow!(
"Unexpected handshake message"
)));
match message {
WebSocketClientMessage::Handshake(_) => {
return Err(client_error(anyhow::anyhow!(
"Unexpected handshake message"
)));
}
WebSocketClientMessage::CursorPositions(cursors) => {
const MAX_CURSOR_DOCUMENTS: usize = 1000;
const MAX_CURSORS_PER_DOCUMENT: usize = 100;
const MAX_RELATIVE_PATH_LEN: usize = 4096;
let docs = cursors.documents_with_cursors;
if docs.len() > MAX_CURSOR_DOCUMENTS {
warn!(
"Cursor update rejected: {} documents exceeds limit of {MAX_CURSOR_DOCUMENTS}",
docs.len()
);
continue;
}
let valid = docs.iter().all(|doc| {
doc.cursors.len() <= MAX_CURSORS_PER_DOCUMENT
&& doc.relative_path.len() <= MAX_RELATIVE_PATH_LEN
});
if !valid {
warn!("Cursor update rejected: a document exceeds cursor or path length limits");
continue;
}
cursor_manager
.update_cursors(
vault_id_clone.clone(),
authed_handshake.user.name.clone(),
&device_id,
docs,
)
.await;
}
WebSocketClientMessage::Ping {} => {
// Ping is a no-op for now; the variant exists for future keep-alive support.
}
}
}
WebSocketClientMessage::CursorPositions(cursors) => {
cursor_manager
.update_cursors(
vault_id_clone.clone(),
authed_handshake.user.name.clone(),
&device_id,
cursors.documents_with_cursors,
)
.await;
Ok(Message::Close(_)) => break,
Ok(Message::Binary(_)) => {
warn!("Received unexpected binary WebSocket message, ignoring");
}
Ok(_) => {} // Ping/Pong frames handled by axum
Err(e) => {
debug!("WebSocket receive error: {e}");
break;
}
}
}
@ -169,38 +277,47 @@ async fn websocket(
Ok::<(), SyncServerError>(())
});
tokio::select! {
_ = &mut send_task => receive_task.abort(),
_ = &mut receive_task => send_task.abort(),
let result: Result<(), SyncServerError> = tokio::select! {
send_result = &mut send_task => {
receive_task.abort();
let _ = receive_task.await;
match send_result {
Err(e) => Err(server_error(
anyhow::Error::from(e).context("WebSocket send task failed"),
)),
Ok(inner) => inner,
}
},
receive_result = &mut receive_task => {
send_task.abort();
let _ = send_task.await;
match receive_result {
Err(e) => Err(server_error(
anyhow::Error::from(e).context("WebSocket receive task failed"),
)),
Ok(inner) => inner,
}
},
};
let result: Result<(), SyncServerError> = (async {
send_task
.await
.context("WebSocket send task failed")
.map_err(client_error)
.and_then(|err| err)?;
receive_task
.await
.context("WebSocket receive task failed")
.map_err(client_error)
.and_then(|err| err)?;
Ok(())
})
.await;
state
.cursors
.remove_cursors_of_device(&vault_id, &authed_handshake.handshake.device_id)
.await;
if result.is_err() {
info!(
"WebSocket disconnected on vault `{vault_id}` for `{}`",
authed_handshake.handshake.device_id
);
match &result {
Ok(()) => {
info!(
"WebSocket disconnected on vault `{vault_id}` for `{}`",
authed_handshake.handshake.device_id
);
}
Err(err) => {
warn!(
"WebSocket error on vault `{vault_id}` for `{}`: {err}",
authed_handshake.handshake.device_id
);
}
}
result

View file

@ -1,3 +1,4 @@
pub mod decode_text;
pub mod dedup_paths;
pub mod find_first_available_path;
pub mod is_binary;

View file

@ -0,0 +1,41 @@
/// Decode bytes as UTF-8.
///
/// Returns `None` if the content is not valid UTF-8.
///
/// Clients are expected to transcode UTF-16 content to UTF-8 before
/// sending, so the server only needs to handle UTF-8 text and binary.
pub fn decode_text(data: &[u8]) -> Option<String> {
std::str::from_utf8(data).ok().map(String::from)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_utf8() {
assert_eq!(decode_text(b"hello"), Some("hello".to_owned()));
}
#[test]
fn test_utf8_with_bom() {
// UTF-8 BOM is valid UTF-8 — the BOM character is preserved in the string
assert_eq!(
decode_text(&[0xEF, 0xBB, 0xBF, b'h', b'i']),
Some("\u{FEFF}hi".to_owned())
);
}
#[test]
fn test_binary_returns_none() {
assert_eq!(decode_text(&[0x80, 0x81, 0x82]), None);
}
#[test]
fn test_nul_bytes_are_valid() {
assert_eq!(
decode_text(b"hello\x00world"),
Some("hello\x00world".to_owned())
);
}
}

View file

@ -1,8 +1,54 @@
use std::sync::LazyLock;
use regex::Regex;
static DEDUP_SUFFIX_REGEX: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r" \((\d+)\)$").expect("invalid regex"));
/// Strip the ` (N)` deconfliction suffix from a path, returning the base path.
/// e.g., `"binary-2 (3).bin"` → `"binary-2.bin"`, `"binary-2.bin"` → `"binary-2.bin"`
pub fn get_base_path(path: &str) -> String {
let mut path_parts = path.split('/').collect::<Vec<_>>();
let Some(file_name) = path_parts.pop() else {
return path.to_owned();
};
if file_name.is_empty() {
return path.to_owned();
}
let file_name = file_name.to_owned();
let mut directory = path_parts.join("/");
if !directory.is_empty() {
directory.push('/');
}
let is_simple_dotfile = file_name.starts_with('.') && file_name.matches('.').count() == 1;
let (stem, extension) = if is_simple_dotfile {
(file_name.clone(), String::new())
} else {
let name_parts = file_name.rsplitn(2, '.').collect::<Vec<_>>();
let mut reverse_parts = name_parts.into_iter().rev();
match (reverse_parts.next(), reverse_parts.next()) {
(Some(s), maybe_ext) => (
s.to_owned(),
maybe_ext.map(|ext| format!(".{ext}")).unwrap_or_default(),
),
_ => unreachable!("Path must have at least one part"),
}
};
let clean_stem = DEDUP_SUFFIX_REGEX.replace(&stem, "").to_string();
format!("{directory}{clean_stem}{extension}")
}
pub fn dedup_paths(path: &str) -> impl Iterator<Item = String> {
let mut path_parts = path.split('/').collect::<Vec<_>>();
let file_name = path_parts.pop().unwrap().to_owned();
let file_name = path_parts
.pop()
.filter(|s| !s.is_empty())
.unwrap_or(path)
.to_owned();
let mut directory = path_parts.join("/");
if !directory.is_empty() {
@ -29,14 +75,13 @@ pub fn dedup_paths(path: &str) -> impl Iterator<Item = String> {
}
};
let regex = Regex::new(r" \((\d+)\)$").unwrap();
let start_number = regex
let start_number = DEDUP_SUFFIX_REGEX
.captures(&stem)
.and_then(|caps| caps.get(1))
.and_then(|m| m.as_str().parse::<u32>().ok())
.unwrap_or(0);
let clean_stem = regex.replace(&stem, "").to_string();
let clean_stem = DEDUP_SUFFIX_REGEX.replace(&stem, "").to_string();
(start_number..).map(move |dedup_number| {
if dedup_number == 0 {

View file

@ -1,17 +1,26 @@
use crate::app_state::database::models::VaultId;
use crate::{app_state::database::Transaction, utils::dedup_paths::dedup_paths};
use anyhow::Result;
use crate::utils::dedup_paths::dedup_paths;
use anyhow::{Result, bail};
use log::info;
use sqlx::sqlite::SqliteConnection;
const MAX_DEDUP_ATTEMPTS: usize = 100_000;
pub async fn find_first_available_path(
vault_id: &VaultId,
sanitized_relative_path: &str,
database: &crate::app_state::database::Database,
transaction: &mut Transaction<'_>,
connection: &mut SqliteConnection,
) -> Result<String> {
for candidate in dedup_paths(sanitized_relative_path) {
for (attempt, candidate) in dedup_paths(sanitized_relative_path).enumerate() {
if attempt >= MAX_DEDUP_ATTEMPTS {
bail!(
"Could not find an available path after {MAX_DEDUP_ATTEMPTS} attempts for `{sanitized_relative_path}` in vault `{vault_id}`"
);
}
if database
.get_latest_non_deleted_document_by_path(vault_id, &candidate, Some(transaction))
.get_latest_non_deleted_document_by_path(vault_id, &candidate, Some(connection))
.await?
.is_none()
{
@ -24,5 +33,5 @@ pub async fn find_first_available_path(
);
}
unreachable!("dedup_paths produces infinite paths");
bail!("dedup_paths iterator unexpectedly exhausted");
}

View file

@ -1,16 +1,12 @@
/// Heuristically determine if the given data is a binary or a text file's
/// content.
use super::decode_text::decode_text;
/// Determine if the given data is binary (not valid UTF-8).
///
/// Only text inputs can be reconciled using the crate's functions.
/// Clients transcode UTF-16 to UTF-8 at the read boundary, so the
/// server only ever receives UTF-8 text or binary content.
#[must_use]
pub fn is_binary(data: &[u8]) -> bool {
if data.contains(&0) {
// Even though the NUL character is valid in UTF-8, it's highly suspicious in
// human-readable text.
return true;
}
std::str::from_utf8(data).is_err()
decode_text(data).is_none()
}
#[cfg(test)]
@ -19,8 +15,13 @@ mod tests {
#[test]
fn test_is_binary() {
assert!(is_binary(&[0, 159, 146, 150]));
assert!(is_binary(&[0, 12]));
assert!(is_binary(&[0x80, 0x81, 0x82]));
assert!(!is_binary(b"hello"));
}
#[test]
fn test_nul_bytes_in_utf8_are_text() {
assert!(!is_binary(b"hello\x00world"));
assert!(!is_binary(&[0, 12]));
}
}

View file

@ -6,7 +6,7 @@ use std::{
time::{Duration, SystemTime, UNIX_EPOCH},
};
use chrono::{Local, NaiveDateTime};
use chrono::NaiveDateTime;
use tracing_subscriber::fmt::MakeWriter;
#[derive(Clone)]
@ -55,7 +55,7 @@ impl RotatingFileWriter {
let timestamp_str = filename.get(prefix_len..filename.len().checked_sub(4)?)?;
let dt = NaiveDateTime::parse_from_str(timestamp_str, "%Y-%m-%d_%H-%M-%S").ok()?;
let timestamp = dt.and_local_timezone(Local).single()?;
let timestamp = dt.and_utc();
let secs: u64 = timestamp.timestamp().try_into().ok()?;
Some(UNIX_EPOCH + Duration::from_secs(secs))
@ -114,7 +114,7 @@ impl RotatingFileWriter {
}
fn rotate(inner: &mut RotatingFileWriterInner) -> io::Result<()> {
let timestamp = Local::now().format("%Y-%m-%d_%H-%M-%S");
let timestamp = chrono::Utc::now().format("%Y-%m-%d_%H-%M-%S");
let filename = format!("{}.{}.log", inner.file_prefix, timestamp);
let filepath = inner.directory.join(filename);
@ -132,8 +132,14 @@ impl RotatingFileWriter {
impl Write for RotatingFileWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut inner = self.inner.lock().unwrap();
let mut inner = self.inner.lock().unwrap_or_else(|poisoned| {
eprintln!("RotatingFileWriter mutex was poisoned, recovering");
poisoned.into_inner()
});
// Reset file handle after poison recovery so the next branch
// re-opens a valid file rather than writing to a potentially
// half-closed handle.
if inner.current_file.is_none() {
Self::open_or_create_log_file(&mut inner)?;
} else if Self::should_rotate(&inner) {
@ -148,7 +154,10 @@ impl Write for RotatingFileWriter {
}
fn flush(&mut self) -> io::Result<()> {
let mut inner = self.inner.lock().unwrap();
let mut inner = self.inner.lock().unwrap_or_else(|poisoned| {
eprintln!("RotatingFileWriter mutex was poisoned, recovering");
poisoned.into_inner()
});
if let Some(ref mut file) = inner.current_file {
file.flush()
} else {
@ -267,7 +276,7 @@ mod tests {
// Parse the expected time
let expected_dt =
NaiveDateTime::parse_from_str(timestamp_str, "%Y-%m-%d_%H-%M-%S").unwrap();
let expected_timestamp = expected_dt.and_local_timezone(Local).single().unwrap();
let expected_timestamp = expected_dt.and_utc();
let expected_duration =
Duration::from_secs(expected_timestamp.timestamp().try_into().unwrap());
let expected_next = UNIX_EPOCH + expected_duration + rotation_duration;
@ -306,7 +315,7 @@ mod tests {
// Should use the latest file (2025-10-26_14-00-00)
let expected_dt =
NaiveDateTime::parse_from_str("2025-10-26_14-00-00", "%Y-%m-%d_%H-%M-%S").unwrap();
let expected_timestamp = expected_dt.and_local_timezone(Local).single().unwrap();
let expected_timestamp = expected_dt.and_utc();
let expected_duration =
Duration::from_secs(expected_timestamp.timestamp().try_into().unwrap());
let expected_next = UNIX_EPOCH + expected_duration + rotation_duration;