264 lines
7.2 KiB
Rust
264 lines
7.2 KiB
Rust
use std::time::{Duration, Instant};
|
|
|
|
use anyhow::{anyhow, bail, Context};
|
|
use rand::RngExt;
|
|
use serde_json::Value;
|
|
use tokio::time::sleep;
|
|
use tracing::warn;
|
|
|
|
use crate::pocketbase::get_superuser_token;
|
|
use crate::state::AppState;
|
|
|
|
const LOCK_COLLECTION: &str = "checkout_locks";
|
|
const LOCK_ACQUIRE_TIMEOUT_SECS: u64 = 10;
|
|
const LOCK_RETRY_DELAY_MS: u64 = 100;
|
|
|
|
pub struct PocketBaseLock {
|
|
client: reqwest::Client,
|
|
pb_url: String,
|
|
token: String,
|
|
record_id: Option<String>,
|
|
name: String,
|
|
}
|
|
|
|
struct ExistingLock {
|
|
id: String,
|
|
expires_at_unix: u64,
|
|
}
|
|
|
|
pub async fn acquire_pocketbase_lock(
|
|
state: &AppState,
|
|
name: &str,
|
|
ttl_secs: u64,
|
|
) -> anyhow::Result<PocketBaseLock> {
|
|
validate_lock_name(name)?;
|
|
|
|
let token = get_superuser_token(state).await?;
|
|
let pb_url = state.pocketbase_url.trim_end_matches('/').to_string();
|
|
let owner = random_owner();
|
|
let deadline = Instant::now() + Duration::from_secs(LOCK_ACQUIRE_TIMEOUT_SECS);
|
|
|
|
loop {
|
|
let now = now_unix_secs();
|
|
if let Some(record_id) =
|
|
try_create_lock(state, &pb_url, &token, name, &owner, now + ttl_secs).await?
|
|
{
|
|
return Ok(PocketBaseLock {
|
|
client: state.http_client.clone(),
|
|
pb_url,
|
|
token,
|
|
record_id: Some(record_id),
|
|
name: name.to_string(),
|
|
});
|
|
}
|
|
|
|
if let Some(existing) = find_lock(state, &pb_url, &token, name).await? {
|
|
if existing.expires_at_unix <= now {
|
|
if let Err(err) = delete_lock_record(state, &pb_url, &token, &existing.id).await {
|
|
warn!(
|
|
lock_name = name,
|
|
lock_id = %existing.id,
|
|
"Failed to delete stale PocketBase lock: {err}"
|
|
);
|
|
}
|
|
continue;
|
|
}
|
|
}
|
|
|
|
if Instant::now() >= deadline {
|
|
bail!("Timed out acquiring PocketBase lock '{name}'");
|
|
}
|
|
|
|
sleep(Duration::from_millis(LOCK_RETRY_DELAY_MS)).await;
|
|
}
|
|
}
|
|
|
|
impl PocketBaseLock {
|
|
pub async fn release(mut self) -> anyhow::Result<()> {
|
|
let Some(record_id) = self.record_id.take() else {
|
|
return Ok(());
|
|
};
|
|
release_lock_record(&self.client, &self.pb_url, &self.token, &record_id)
|
|
.await
|
|
.with_context(|| format!("Failed to release PocketBase lock '{}'", self.name))
|
|
}
|
|
}
|
|
|
|
impl Drop for PocketBaseLock {
|
|
fn drop(&mut self) {
|
|
let Some(record_id) = self.record_id.take() else {
|
|
return;
|
|
};
|
|
|
|
let client = self.client.clone();
|
|
let pb_url = self.pb_url.clone();
|
|
let token = self.token.clone();
|
|
let name = self.name.clone();
|
|
tokio::spawn(async move {
|
|
if let Err(err) = release_lock_record(&client, &pb_url, &token, &record_id).await {
|
|
warn!(
|
|
lock_name = %name,
|
|
lock_id = %record_id,
|
|
"Failed to release PocketBase lock on drop: {err}"
|
|
);
|
|
}
|
|
});
|
|
}
|
|
}
|
|
|
|
async fn try_create_lock(
|
|
state: &AppState,
|
|
pb_url: &str,
|
|
token: &str,
|
|
name: &str,
|
|
owner: &str,
|
|
expires_at_unix: u64,
|
|
) -> anyhow::Result<Option<String>> {
|
|
let url = format!("{pb_url}/api/collections/{LOCK_COLLECTION}/records");
|
|
let resp = state
|
|
.http_client
|
|
.post(&url)
|
|
.header("Authorization", format!("Bearer {token}"))
|
|
.json(&serde_json::json!({
|
|
"name": name,
|
|
"owner": owner,
|
|
"expires_at_unix": expires_at_unix,
|
|
}))
|
|
.send()
|
|
.await?;
|
|
|
|
if resp.status().is_success() {
|
|
let body: Value = resp.json().await?;
|
|
return body["id"]
|
|
.as_str()
|
|
.map(str::to_string)
|
|
.map(Some)
|
|
.ok_or_else(|| anyhow!("PocketBase lock record missing id"));
|
|
}
|
|
|
|
let status = resp.status();
|
|
let text = resp.text().await.unwrap_or_default();
|
|
if status.is_client_error() {
|
|
return Ok(None);
|
|
}
|
|
|
|
Err(anyhow!("PocketBase lock create failed ({status}): {text}"))
|
|
}
|
|
|
|
async fn find_lock(
|
|
state: &AppState,
|
|
pb_url: &str,
|
|
token: &str,
|
|
name: &str,
|
|
) -> anyhow::Result<Option<ExistingLock>> {
|
|
let filter = format!("name=\"{}\"", name);
|
|
let url = format!(
|
|
"{pb_url}/api/collections/{LOCK_COLLECTION}/records?filter={}&perPage=1",
|
|
urlencoding::encode(&filter)
|
|
);
|
|
let resp = state
|
|
.http_client
|
|
.get(&url)
|
|
.header("Authorization", format!("Bearer {token}"))
|
|
.send()
|
|
.await?;
|
|
|
|
ensure_success_ref(&resp).await?;
|
|
|
|
let body: Value = resp.json().await?;
|
|
let Some(item) = body["items"].as_array().and_then(|items| items.first()) else {
|
|
return Ok(None);
|
|
};
|
|
let id = item["id"]
|
|
.as_str()
|
|
.ok_or_else(|| anyhow!("PocketBase lock missing id"))?
|
|
.to_string();
|
|
let expires_at_unix = number_field(item, "expires_at_unix").unwrap_or(0);
|
|
|
|
Ok(Some(ExistingLock {
|
|
id,
|
|
expires_at_unix,
|
|
}))
|
|
}
|
|
|
|
async fn delete_lock_record(
|
|
state: &AppState,
|
|
pb_url: &str,
|
|
token: &str,
|
|
record_id: &str,
|
|
) -> anyhow::Result<()> {
|
|
release_lock_record(&state.http_client, pb_url, token, record_id).await
|
|
}
|
|
|
|
async fn release_lock_record(
|
|
client: &reqwest::Client,
|
|
pb_url: &str,
|
|
token: &str,
|
|
record_id: &str,
|
|
) -> anyhow::Result<()> {
|
|
let url = format!("{pb_url}/api/collections/{LOCK_COLLECTION}/records/{record_id}");
|
|
let resp = client
|
|
.delete(&url)
|
|
.header("Authorization", format!("Bearer {token}"))
|
|
.send()
|
|
.await?;
|
|
|
|
if resp.status().is_success() || resp.status() == reqwest::StatusCode::NOT_FOUND {
|
|
return Ok(());
|
|
}
|
|
|
|
let status = resp.status();
|
|
let text = resp.text().await.unwrap_or_default();
|
|
Err(anyhow!("PocketBase lock delete failed ({status}): {text}"))
|
|
}
|
|
|
|
fn validate_lock_name(name: &str) -> anyhow::Result<()> {
|
|
if name.is_empty() || name.len() > 80 {
|
|
bail!("invalid PocketBase lock name length");
|
|
}
|
|
if !name
|
|
.bytes()
|
|
.all(|b| b.is_ascii_alphanumeric() || b == b':' || b == b'_' || b == b'-')
|
|
{
|
|
bail!("invalid PocketBase lock name characters");
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn random_owner() -> String {
|
|
let mut rng = rand::rng();
|
|
(0..24)
|
|
.map(|_| {
|
|
let idx: u8 = rng.random_range(0..36);
|
|
if idx < 10 {
|
|
(b'0' + idx) as char
|
|
} else {
|
|
(b'a' + idx - 10) as char
|
|
}
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
fn now_unix_secs() -> u64 {
|
|
std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap_or_default()
|
|
.as_secs()
|
|
}
|
|
|
|
fn number_field(value: &Value, field: &str) -> Option<u64> {
|
|
value[field].as_u64().or_else(|| {
|
|
value[field]
|
|
.as_f64()
|
|
.filter(|n| n.is_finite() && *n >= 0.0 && n.fract() == 0.0)
|
|
.map(|n| n as u64)
|
|
})
|
|
}
|
|
|
|
async fn ensure_success_ref(resp: &reqwest::Response) -> anyhow::Result<()> {
|
|
if resp.status().is_success() {
|
|
return Ok(());
|
|
}
|
|
|
|
Err(anyhow!("upstream returned {}", resp.status()))
|
|
}
|