use std::sync::Arc; use axum::extract::{Path, State}; use axum::http::{header, StatusCode}; use axum::response::{Html, IntoResponse, Response}; use axum::Extension; use axum::Json; use rand::RngExt; use serde::{Deserialize, Serialize}; use tracing::warn; use url::form_urlencoded; use crate::auth::OptionalUser; use crate::licensing::{is_valid_share_bounds, share_bounds_from_params, ShareBounds}; use crate::pocketbase::get_superuser_token; use crate::state::SharedState; const CODE_LEN: usize = 8; const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyz0123456789"; const MAX_QUERY_LEN: usize = 4096; const MAX_QUERY_PAIRS: usize = 80; const MAX_PARAM_KEY_LEN: usize = 64; const MAX_PARAM_VALUE_LEN: usize = 512; fn generate_code() -> String { let mut rng = rand::rng(); (0..CODE_LEN) .map(|_| CHARSET[rng.random_range(0..CHARSET.len())] as char) .collect() } #[derive(Deserialize)] pub struct ShortenRequest { params: String, } #[derive(Serialize)] pub struct ShortenResponse { code: String, url: String, } #[derive(Serialize)] struct PbRecord { code: String, params: String, #[serde(skip_serializing_if = "Option::is_none")] created_by: Option, click_count: u64, #[serde(skip_serializing_if = "Option::is_none")] share_south: Option, #[serde(skip_serializing_if = "Option::is_none")] share_west: Option, #[serde(skip_serializing_if = "Option::is_none")] share_north: Option, #[serde(skip_serializing_if = "Option::is_none")] share_east: Option, } #[derive(Serialize)] struct ShareLinkListItem { code: String, url: String, og_image_url: String, params: String, click_count: u64, created: String, } #[derive(Serialize)] struct ShareLinksResponse { links: Vec, } fn json_number_as_u64(value: &serde_json::Value) -> u64 { value .as_u64() .or_else(|| { value .as_f64() .filter(|n| n.is_finite() && *n > 0.0) .map(|n| n as u64) }) .unwrap_or(0) } fn sanitized_query_params(params: &str, keep_share: bool) -> Result { let params = params.trim_start_matches('?'); if params.len() > MAX_QUERY_LEN { return Err("query string is too long"); } let mut pairs = Vec::new(); for (idx, (key, value)) in form_urlencoded::parse(params.as_bytes()).enumerate() { if idx >= MAX_QUERY_PAIRS { return Err("query string has too many parameters"); } if key == "share" && !keep_share { continue; } if !is_allowed_param_key(&key) { return Err("query string contains an unsupported parameter"); } if key.len() > MAX_PARAM_KEY_LEN || value.len() > MAX_PARAM_VALUE_LEN { return Err("query parameter is too long"); } if key.chars().any(char::is_control) || value.chars().any(char::is_control) { return Err("query parameter contains control characters"); } pairs.push((key.into_owned(), value.into_owned())); } let mut out = form_urlencoded::Serializer::new(String::new()); for (key, value) in pairs { out.append_pair(&key, &value); } Ok(out.finish()) } fn is_allowed_param_key(key: &str) -> bool { matches!( key, "lat" | "lon" | "zoom" | "filter" | "school" | "crime" | "voteShare" | "ethnicity" | "amenityDistance" | "transportDistance" | "amenityCount2km" | "amenityCount5km" | "poi" | "tab" | "pc" | "tt" | "share" ) } fn escape_attr(value: &str) -> String { value .replace('&', "&") .replace('"', """) .replace('\'', "'") .replace('<', "<") .replace('>', ">") } fn user_can_create_share_grant(user: &OptionalUser) -> bool { user.0 .as_ref() .is_some_and(|u| u.is_admin || u.subscription == "licensed") } fn share_fields( bounds: Option, ) -> (Option, Option, Option, Option) { match bounds { Some(bounds) => ( Some(bounds.south), Some(bounds.west), Some(bounds.north), Some(bounds.east), ), None => (None, None, None, None), } } fn record_share_bounds(item: &serde_json::Value) -> Option { let bounds = ShareBounds { south: item.get("share_south")?.as_f64()?, west: item.get("share_west")?.as_f64()?, north: item.get("share_north")?.as_f64()?, east: item.get("share_east")?.as_f64()?, }; is_valid_share_bounds(bounds).then_some(bounds) } fn dashboard_redirect_url(params: &str, code: &str, include_share: bool) -> String { match (params.is_empty(), include_share) { (true, false) => "/dashboard".to_string(), (true, true) => format!("/dashboard?share={code}"), (false, false) => format!("/dashboard?{params}"), (false, true) => format!("/dashboard?{params}&share={code}"), } } fn og_image_url(public_url: &str, params: &str) -> String { if params.is_empty() { format!("{}/api/screenshot?og=1", public_url.trim_end_matches('/')) } else { format!( "{}/api/screenshot?og=1&{params}", public_url.trim_end_matches('/') ) } } pub async fn post_shorten( State(shared): State>, Extension(user): Extension, Json(req): Json, ) -> Response { let state = shared.load_state(); let pb_url = state.pocketbase_url.trim_end_matches('/'); let can_create_share_grant = user_can_create_share_grant(&user); let params = match sanitized_query_params(&req.params, !can_create_share_grant) { Ok(params) => params, Err(reason) => { warn!("Rejected short URL params: {reason}"); return (StatusCode::BAD_REQUEST, reason).into_response(); } }; let token = match get_superuser_token(&state).await { Ok(t) => t, Err(err) => { warn!("PocketBase superuser auth failed: {err}"); return StatusCode::BAD_GATEWAY.into_response(); } }; let code = generate_code(); let share_bounds = if can_create_share_grant { share_bounds_from_params(¶ms) } else { None }; let (share_south, share_west, share_north, share_east) = share_fields(share_bounds); let record = PbRecord { code: code.clone(), params, created_by: user.0.as_ref().map(|u| u.id.clone()), click_count: 0, share_south, share_west, share_north, share_east, }; let res = state .http_client .post(format!("{pb_url}/api/collections/short_urls/records")) .header("Authorization", format!("Bearer {token}")) .json(&record) .send() .await; match res { Ok(resp) if resp.status().is_success() => { let body = ShortenResponse { url: format!("/s/{code}"), code, }; Json(body).into_response() } Ok(resp) => { let status = resp.status(); let text = resp.text().await.unwrap_or_default(); warn!("PocketBase create failed ({status}): {text}"); StatusCode::BAD_GATEWAY.into_response() } Err(err) => { warn!("PocketBase request error: {err}"); StatusCode::BAD_GATEWAY.into_response() } } } pub async fn get_share_links( State(shared): State>, Extension(user): Extension, ) -> Response { let state = shared.load_state(); let user = match user.0 { Some(u) => u, None => return StatusCode::UNAUTHORIZED.into_response(), }; let pb_url = state.pocketbase_url.trim_end_matches('/'); let token = match get_superuser_token(&state).await { Ok(t) => t, Err(err) => { warn!("PocketBase superuser auth failed: {err}"); return StatusCode::BAD_GATEWAY.into_response(); } }; let filter = format!("created_by=\"{}\"", user.id); let url = format!( "{pb_url}/api/collections/short_urls/records?sort=-created&perPage=200&filter={}", urlencoding::encode(&filter) ); let res = match state .http_client .get(&url) .header("Authorization", format!("Bearer {token}")) .send() .await { Ok(r) => r, Err(err) => { warn!("Failed to list share links: {err}"); return StatusCode::BAD_GATEWAY.into_response(); } }; if !res.status().is_success() { let status = res.status(); let text = res.text().await.unwrap_or_default(); warn!("PocketBase list share links failed ({status}): {text}"); return StatusCode::BAD_GATEWAY.into_response(); } let body: serde_json::Value = match res.json().await { Ok(v) => v, Err(err) => { warn!("Failed to parse share links response: {err}"); return StatusCode::BAD_GATEWAY.into_response(); } }; let public_url = state.public_url.trim_end_matches('/'); let links: Vec = body["items"] .as_array() .map(|arr| { arr.iter() .map(|item| { let code = item["code"].as_str().unwrap_or("").to_string(); let params = item["params"].as_str().unwrap_or("").to_string(); ShareLinkListItem { url: format!("{public_url}/s/{code}"), code, og_image_url: og_image_url(public_url, ¶ms), params, click_count: json_number_as_u64(&item["click_count"]), created: item["created"].as_str().unwrap_or("").to_string(), } }) .collect() }) .unwrap_or_default(); Json(ShareLinksResponse { links }).into_response() } pub async fn get_short_url( State(shared): State>, Path(code): Path, ) -> Response { let state = shared.load_state(); if code.is_empty() || code.len() > 20 || !code.bytes().all(|b| b.is_ascii_alphanumeric()) { return StatusCode::BAD_REQUEST.into_response(); } let pb_url = state.pocketbase_url.trim_end_matches('/'); let token = match get_superuser_token(&state).await { Ok(t) => t, Err(err) => { warn!("PocketBase superuser auth failed: {err}"); return StatusCode::BAD_GATEWAY.into_response(); } }; let filter = format!("code=\"{code}\""); let url = format!( "{pb_url}/api/collections/short_urls/records?filter={}&perPage=1", urlencoding::encode(&filter) ); let res = state .http_client .get(&url) .header("Authorization", format!("Bearer {token}")) .send() .await; match res { Ok(resp) if resp.status().is_success() => { let json: serde_json::Value = match resp.json().await { Ok(v) => v, Err(err) => { warn!("Failed to parse PocketBase response: {err}"); return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } }; let item = json["items"].as_array().and_then(|items| items.first()); match item.and_then(|item| item["params"].as_str().map(|params| (item, params))) { Some((item, params)) => { let record_id = item["id"].as_str().unwrap_or("").to_string(); let next_click_count = json_number_as_u64(&item["click_count"]).saturating_add(1); let params = match sanitized_query_params(params, true) { Ok(params) => params, Err(reason) => { warn!("Stored short URL params rejected for {code}: {reason}"); return StatusCode::BAD_REQUEST.into_response(); } }; if !record_id.is_empty() { let update_url = format!("{pb_url}/api/collections/short_urls/records/{record_id}"); match state .http_client .patch(&update_url) .header("Authorization", format!("Bearer {token}")) .json(&serde_json::json!({ "click_count": next_click_count })) .send() .await { Ok(update_resp) if update_resp.status().is_success() => {} Ok(update_resp) => { let status = update_resp.status(); let text = update_resp.text().await.unwrap_or_default(); warn!("PocketBase click count update failed ({status}): {text}"); } Err(err) => warn!("PocketBase click count update failed: {err}"), } } let redirect_url = dashboard_redirect_url(¶ms, &code, record_share_bounds(item).is_some()); let og_image_url = og_image_url(&state.public_url, ¶ms); let og_url = format!("{}/s/{code}", state.public_url.trim_end_matches('/')); let og_title = "Perfect Postcode | Every neighbourhood in England"; let og_description = "Explore property prices, energy ratings, crime stats, school ratings, and more across England on one interactive map."; let redirect_url = escape_attr(&redirect_url); let og_image_url = escape_attr(&og_image_url); let og_url = escape_attr(&og_url); let og_title = escape_attr(og_title); let og_description = escape_attr(og_description); let html = format!( r#" {og_title} "# ); ( [ (header::CACHE_CONTROL, "no-store"), ( header::CONTENT_SECURITY_POLICY, "default-src 'none'; img-src https: data:; base-uri 'none'; form-action 'none'", ), ], Html(html), ) .into_response() } None => StatusCode::NOT_FOUND.into_response(), } } Ok(resp) => { let status = resp.status(); warn!("PocketBase lookup failed ({status})"); StatusCode::BAD_GATEWAY.into_response() } Err(err) => { warn!("PocketBase request error: {err}"); StatusCode::BAD_GATEWAY.into_response() } } } #[cfg(test)] mod tests { use super::*; #[test] fn sanitizes_short_url_params_and_drops_share() { let params = sanitized_query_params( "lat=51.5&lon=-0.1&zoom=12&filter=price%3A1%3A2&share=oldcode", false, ) .unwrap(); assert_eq!(params, "lat=51.5&lon=-0.1&zoom=12&filter=price%3A1%3A2"); } #[test] fn rejects_html_in_unsupported_params() { assert!(sanitized_query_params("lat=51&x=%22%3E%3Cscript%3E", false).is_err()); } #[test] fn can_preserve_existing_share_grant() { let params = sanitized_query_params("lat=51.5&lon=-0.1&zoom=12&share=oldcode", true).unwrap(); assert_eq!(params, "lat=51.5&lon=-0.1&zoom=12&share=oldcode"); } #[test] fn escapes_html_attributes() { assert_eq!(escape_attr(r#""'><&"#), ""'><&"); } }