This commit is contained in:
Andras Schmelczer 2026-05-14 20:42:48 +01:00
parent 273d7a83ee
commit 084117cea8
48 changed files with 2283 additions and 890 deletions

View file

@ -4,7 +4,7 @@ use std::sync::Arc;
use std::time::Duration;
use axum::extract::{Query, State};
use axum::http::{header, HeaderMap, StatusCode};
use axum::http::{header, HeaderMap, StatusCode, Uri};
use axum::response::IntoResponse;
use axum::Extension;
use rust_xlsxwriter::{Format, FormatAlign, FormatBorder, Image, Url, Workbook};
@ -16,11 +16,14 @@ use crate::auth::OptionalUser;
use crate::consts::NAN_U16;
use crate::data::{PostcodePoiMetrics, QuantRef};
use crate::features;
use crate::licensing::check_license_bounds;
use crate::licensing::{check_license_bounds, resolve_share_code};
use crate::parsing::{
parse_field_indices_with_poi, parse_filters_with_poi, require_bounds, row_passes_filters,
row_passes_poi_filters,
};
use crate::routes::travel_time::{
load_travel_data, parse_optional_travel, row_passes_travel_filters,
};
use crate::routes::{fetch_screenshot_bytes, FeatureInfo};
use crate::state::SharedState;
@ -29,11 +32,20 @@ const EXPORT_SCREENSHOT_TIMEOUT_SECS: u64 = 12;
/// Height (in pixels) reserved for the screenshot row
const IMAGE_ROW_HEIGHT: f64 = 225.0;
/// Hard cap on the bounding-box area (in degrees²) that may be exported.
/// All of England fits inside ~6° × ~10° ≈ 60 deg². Anything substantially
/// larger is rejected to keep aggregation bounded for non-licensed users
/// who supply share grants outside their expected region, and to avoid
/// minutes-long requests that fan out to millions of rows.
const MAX_EXPORT_BBOX_AREA_DEG2: f64 = 80.0;
#[derive(Deserialize)]
pub struct ExportParams {
bounds: Option<String>,
filters: Option<String>,
travel: Option<String>,
fields: Option<String>,
share: Option<String>,
}
/// Per-postcode accumulator for export aggregation (mean for numeric, mode for enum).
@ -125,6 +137,8 @@ fn build_frontend_params(
center_lon: f64,
zoom: f64,
filters_str: Option<&str>,
travel_params: &[String],
share: Option<&str>,
) -> String {
let mut parts = vec![
format!("lat={:.4}", center_lat),
@ -140,20 +154,53 @@ fn build_frontend_params(
}
}
}
for entry in travel_params {
if !entry.is_empty() {
parts.push(format!("tt={}", urlencoding::encode(entry.trim())));
}
}
if let Some(share) = share.filter(|value| !value.is_empty()) {
parts.push(format!("share={}", urlencoding::encode(share)));
}
parts.join("&")
}
fn collect_travel_state_params(query: Option<&str>) -> Vec<String> {
query
.into_iter()
.flat_map(|qs| url::form_urlencoded::parse(qs.as_bytes()))
.filter_map(|(key, value)| {
if key == "tt" && !value.is_empty() {
Some(value.into_owned())
} else {
None
}
})
.collect()
}
pub async fn get_export(
State(shared): State<Arc<SharedState>>,
headers: HeaderMap,
Extension(user): Extension<OptionalUser>,
uri: Uri,
Query(params): Query<ExportParams>,
) -> Result<impl IntoResponse, axum::response::Response> {
let state = shared.load_state();
let (south, west, north, east) =
require_bounds(params.bounds).map_err(IntoResponse::into_response)?;
check_license_bounds(&user.0, (south, west, north, east), None)?;
let area_deg2 = (north - south).max(0.0) * (east - west).max(0.0);
if area_deg2 > MAX_EXPORT_BBOX_AREA_DEG2 {
return Err((
StatusCode::BAD_REQUEST,
"Export area is too large; zoom in further before exporting",
)
.into_response());
}
let share_bounds = resolve_share_code(&state, params.share.as_deref()).await;
check_license_bounds(&user.0, (south, west, north, east), share_bounds)?;
let quant = state.data.quant_ref();
let poi_quant = state.data.poi_metrics.quant_ref();
@ -168,7 +215,14 @@ pub async fn get_export(
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let has_poi_filters = !parsed_poi_filters.is_empty();
let filters_str = params.filters;
let travel_entries = parse_optional_travel(params.travel.as_deref())
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let has_travel_filters = travel_entries
.iter()
.any(|entry| entry.filter_min.is_some() && entry.filter_max.is_some());
let travel_state_params = collect_travel_state_params(uri.query());
let fields_str = params.fields;
let share_code = params.share;
let public_url = state.public_url.clone();
@ -181,8 +235,14 @@ pub async fn get_export(
} else {
12.0
};
let frontend_params =
build_frontend_params(center_lat, center_lon, zoom, filters_str.as_deref());
let frontend_params = build_frontend_params(
center_lat,
center_lon,
zoom,
filters_str.as_deref(),
&travel_state_params,
share_code.as_deref(),
);
// Fetch screenshot (async, before spawn_blocking)
let auth_header = headers.get(header::AUTHORIZATION);
@ -235,14 +295,17 @@ pub async fn get_export(
let enum_values = &state.data.enum_values;
let postcode_data = &state.postcode_data;
let poi_metrics = &state.data.poi_metrics;
let travel_data = load_travel_data(&state.travel_time_store, &travel_entries)?;
let poi_offset = num_features;
let total_export_features = num_features + poi_metrics.num_features();
let (pc_interner, pc_keys) = state.data.postcode_parts();
// Build set of enum feature indices for quick lookup
let enum_indices: FxHashMap<usize, ()> = enum_values.keys().map(|&idx| (idx, ())).collect();
// Group rows by postcode
let mut postcode_rows: FxHashMap<usize, Vec<usize>> = FxHashMap::default();
// Aggregate directly by postcode so large requests don't retain every
// matching property row before sampling the exported postcodes.
let mut postcode_aggs: FxHashMap<usize, PostcodeExportAgg> = FxHashMap::default();
state
.grid
.for_each_in_bounds(south, west, north, east, |row_idx| {
@ -260,31 +323,31 @@ pub async fn get_export(
{
return;
}
let postcode = state.data.postcode(row);
let postcode = pc_interner.resolve(&pc_keys[row]);
if has_travel_filters
&& !row_passes_travel_filters(postcode, &travel_entries, &travel_data)
{
return;
}
if let Some(&pc_idx) = postcode_data.postcode_to_idx.get(postcode) {
postcode_rows.entry(pc_idx).or_default().push(row);
postcode_aggs
.entry(pc_idx)
.or_insert_with(|| PostcodeExportAgg::new(total_export_features))
.add_row(
feature_data,
row,
num_features,
&enum_indices,
&quant,
poi_metrics,
);
}
});
// Aggregate per postcode
let mut postcode_aggs: Vec<(usize, PostcodeExportAgg)> =
Vec::with_capacity(postcode_rows.len());
for (pc_idx, rows) in postcode_rows {
let mut agg = PostcodeExportAgg::new(total_export_features);
for &row in &rows {
agg.add_row(
feature_data,
row,
num_features,
&enum_indices,
&quant,
poi_metrics,
);
}
if agg.count > 0 {
postcode_aggs.push((pc_idx, agg));
}
}
let mut postcode_aggs: Vec<(usize, PostcodeExportAgg)> = postcode_aggs
.into_iter()
.filter(|(_, agg)| agg.count > 0)
.collect();
// Sort by property count descending
postcode_aggs.sort_unstable_by_key(|agg| std::cmp::Reverse(agg.1.count));
@ -460,7 +523,11 @@ pub async fn get_export(
.set_align(FormatAlign::Left);
// Dashboard URL
let dashboard_url = format!("{}/?{}", public_url, frontend_params);
let dashboard_url = format!(
"{}/dashboard?{}",
public_url.trim_end_matches('/'),
frontend_params
);
// Sheet 1: "Selected" (filter features only) with link + screenshot
// Sheet 2: "All Data" (all features)
@ -680,3 +747,42 @@ pub async fn get_export(
bytes,
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn collect_travel_state_params_accepts_single_tt_param() {
let entry = "transit:bank-tube-station:Bank%20tube%20station:0:52";
let query = format!("bounds=1,2,3,4&tt={}", urlencoding::encode(entry));
assert_eq!(collect_travel_state_params(Some(&query)), vec![entry]);
}
#[test]
fn collect_travel_state_params_preserves_repeated_tt_params() {
let bank = "transit:bank-tube-station:Bank%20tube%20station:0:52";
let kings_cross = "transit:kings-cross:Kings%20Cross:b:0:30";
let query = format!(
"tt={}&filter=Price%3A0%3A100&tt={}",
urlencoding::encode(bank),
urlencoding::encode(kings_cross)
);
assert_eq!(
collect_travel_state_params(Some(&query)),
vec![bank, kings_cross]
);
}
#[test]
fn export_query_deserializes_when_tt_is_a_single_string() {
let uri: Uri = "/api/export?bounds=1,2,3,4&tt=transit%3Abank%3ABank%2520station%3A0%3A52"
.parse()
.unwrap();
let Query(params) = Query::<ExportParams>::try_from_uri(&uri).unwrap();
assert_eq!(params.bounds.as_deref(), Some("1,2,3,4"));
}
}

View file

@ -10,7 +10,8 @@ use tracing::{info, warn};
use crate::auth::{OptionalUser, PocketBaseUser};
use crate::checkout_sessions::{
active_referral_checkout_user, start_license_checkout, CheckoutStart,
active_referral_checkout_user, grant_license_with_pricing_lock, start_license_checkout,
CheckoutStart,
};
use crate::pocketbase::get_superuser_token;
use crate::pocketbase_locks::acquire_pocketbase_lock;
@ -107,6 +108,25 @@ fn validate_invite_code(code: &str) -> Result<(), &'static str> {
Ok(())
}
/// Sanitize the inviter's display name returned to anonymous clients.
/// The value comes from the inviter's email local-part stored in PocketBase;
/// we don't trust it, so strip control chars and HTML-meaningful characters
/// and cap the length. Returns None if nothing usable remains.
fn sanitize_invited_by(raw: &str) -> Option<String> {
const MAX_LEN: usize = 40;
let cleaned: String = raw
.chars()
.filter(|c| !c.is_control() && !matches!(*c, '<' | '>' | '"' | '\'' | '&' | '\\'))
.take(MAX_LEN)
.collect();
let trimmed = cleaned.trim();
if trimmed.is_empty() {
None
} else {
Some(trimmed.to_string())
}
}
fn generate_invite_code() -> String {
use rand::RngExt;
let mut rng = rand::rng();
@ -131,6 +151,48 @@ fn current_unix_secs_string() -> String {
.to_string()
}
/// Fetch the live `is_admin` flag for a user, bypassing any cached token
/// claims. Returns Err with an HTTP response if PocketBase is unreachable
/// or returns an unexpected payload — the caller should propagate that.
async fn verify_is_admin(
state: &AppState,
pb_url: &str,
token: &str,
user_id: &str,
) -> Result<bool, Response> {
if user_id.is_empty()
|| user_id.len() > 32
|| !user_id.bytes().all(|b| b.is_ascii_alphanumeric())
{
return Err(StatusCode::FORBIDDEN.into_response());
}
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
let resp = match state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await
{
Ok(r) => r,
Err(err) => {
warn!("Failed to verify is_admin: {err}");
return Err(StatusCode::BAD_GATEWAY.into_response());
}
};
if !resp.status().is_success() {
return Err(StatusCode::BAD_GATEWAY.into_response());
}
let body: serde_json::Value = match resp.json().await {
Ok(v) => v,
Err(err) => {
warn!("Failed to parse user record for is_admin verify: {err}");
return Err(StatusCode::BAD_GATEWAY.into_response());
}
};
Ok(body["is_admin"].as_bool().unwrap_or(false))
}
async fn lookup_unused_invite(
state: &AppState,
pb_url: &str,
@ -217,35 +279,16 @@ async fn mark_invite_used(
async fn grant_license_for_invite(
state: &AppState,
pb_url: &str,
token: &str,
_pb_url: &str,
_token: &str,
user_id: &str,
) -> Result<(), Response> {
let update_url = format!("{pb_url}/api/collections/users/records/{user_id}");
let resp = match state
.http_client
.patch(&update_url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({ "subscription": "licensed" }))
.send()
grant_license_with_pricing_lock(state, user_id)
.await
{
Ok(resp) => resp,
Err(err) => {
.map_err(|err| {
warn!("Failed to update user subscription for admin invite: {err}");
return Err(StatusCode::BAD_GATEWAY.into_response());
}
};
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
warn!("PocketBase user subscription update failed ({status}): {text}");
return Err(StatusCode::BAD_GATEWAY.into_response());
}
state.token_cache.invalidate_by_user_id(user_id);
Ok(())
StatusCode::BAD_GATEWAY.into_response()
})
}
async fn create_referral_checkout(
@ -289,12 +332,32 @@ pub async fn post_invites(
None => return StatusCode::UNAUTHORIZED.into_response(),
};
let invite_type = if user.is_admin {
match body.invite_type.as_deref() {
Some("referral") => "referral",
_ => "admin",
// Cached token claims could be stale or, in the worst case, tampered with
// upstream of us. For admin-only actions, re-fetch the live record from
// PocketBase and trust only that.
let wants_admin_invite =
user.is_admin && !matches!(body.invite_type.as_deref(), Some("referral"));
let pb_url = state.pocketbase_url.trim_end_matches('/');
let token = match get_superuser_token(&state).await {
Ok(t) => t,
Err(err) => {
warn!("Failed to auth as PocketBase superuser: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
} else if user.subscription == "licensed" {
};
let invite_type = if wants_admin_invite {
match verify_is_admin(&state, pb_url, &token, &user.id).await {
Ok(true) => "admin",
Ok(false) => {
warn!(user_id = %user.id, "is_admin claim rejected by live PB lookup");
return (StatusCode::FORBIDDEN, "Not authorised").into_response();
}
Err(response) => return response,
}
} else if user.is_admin || user.subscription == "licensed" {
"referral"
} else {
return (
@ -305,15 +368,6 @@ pub async fn post_invites(
};
let code = generate_invite_code();
let pb_url = state.pocketbase_url.trim_end_matches('/');
let token = match get_superuser_token(&state).await {
Ok(t) => t,
Err(err) => {
warn!("Failed to auth as PocketBase superuser: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
};
let create_url = format!("{pb_url}/api/collections/invites/records");
let res = state
@ -429,7 +483,7 @@ pub async fn get_invite(
let used = !used_by.is_empty();
let created_by = invite["created_by"].as_str().unwrap_or("");
// Look up inviter's name (email local part)
// Look up inviter's name (email local part) — sanitized before returning.
let invited_by = if !created_by.is_empty() {
let user_url = format!("{pb_url}/api/collections/users/records/{created_by}");
match state
@ -444,7 +498,7 @@ pub async fn get_invite(
user_body["email"]
.as_str()
.and_then(|e| e.split('@').next())
.map(String::from)
.and_then(sanitize_invited_by)
}
_ => None,
}
@ -565,11 +619,11 @@ pub async fn post_redeem_invite(
};
if invite_type == "admin" {
if let Err(response) = mark_invite_used(&state, pb_url, &token, invite_id, &user.id).await {
if let Err(response) = grant_license_for_invite(&state, pb_url, &token, &user.id).await {
return response;
}
if let Err(response) = grant_license_for_invite(&state, pb_url, &token, &user.id).await {
if let Err(response) = mark_invite_used(&state, pb_url, &token, invite_id, &user.id).await {
return response;
}

View file

@ -3,16 +3,24 @@ use std::sync::Arc;
use axum::extract::{Path, State};
use axum::http::{header, StatusCode};
use axum::response::{Html, IntoResponse, Response};
use axum::Extension;
use axum::Json;
use rand::RngExt;
use serde::{Deserialize, Serialize};
use tracing::warn;
use url::form_urlencoded;
use crate::auth::OptionalUser;
use crate::licensing::{is_valid_share_bounds, share_bounds_from_params, ShareBounds};
use crate::pocketbase::get_superuser_token;
use crate::state::SharedState;
const CODE_LEN: usize = 8;
const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyz0123456789";
const MAX_QUERY_LEN: usize = 4096;
const MAX_QUERY_PAIRS: usize = 80;
const MAX_PARAM_KEY_LEN: usize = 64;
const MAX_PARAM_VALUE_LEN: usize = 512;
fn generate_code() -> String {
let mut rng = rand::rng();
@ -36,15 +44,178 @@ pub struct ShortenResponse {
struct PbRecord {
code: String,
params: String,
#[serde(skip_serializing_if = "Option::is_none")]
created_by: Option<String>,
click_count: u64,
#[serde(skip_serializing_if = "Option::is_none")]
share_south: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
share_west: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
share_north: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
share_east: Option<f64>,
}
#[derive(Serialize)]
struct ShareLinkListItem {
code: String,
url: String,
og_image_url: String,
params: String,
click_count: u64,
created: String,
}
#[derive(Serialize)]
struct ShareLinksResponse {
links: Vec<ShareLinkListItem>,
}
fn json_number_as_u64(value: &serde_json::Value) -> u64 {
value
.as_u64()
.or_else(|| {
value
.as_f64()
.filter(|n| n.is_finite() && *n > 0.0)
.map(|n| n as u64)
})
.unwrap_or(0)
}
fn sanitized_query_params(params: &str, keep_share: bool) -> Result<String, &'static str> {
let params = params.trim_start_matches('?');
if params.len() > MAX_QUERY_LEN {
return Err("query string is too long");
}
let mut pairs = Vec::new();
for (idx, (key, value)) in form_urlencoded::parse(params.as_bytes()).enumerate() {
if idx >= MAX_QUERY_PAIRS {
return Err("query string has too many parameters");
}
if key == "share" && !keep_share {
continue;
}
if !is_allowed_param_key(&key) {
return Err("query string contains an unsupported parameter");
}
if key.len() > MAX_PARAM_KEY_LEN || value.len() > MAX_PARAM_VALUE_LEN {
return Err("query parameter is too long");
}
if key.chars().any(char::is_control) || value.chars().any(char::is_control) {
return Err("query parameter contains control characters");
}
pairs.push((key.into_owned(), value.into_owned()));
}
let mut out = form_urlencoded::Serializer::new(String::new());
for (key, value) in pairs {
out.append_pair(&key, &value);
}
Ok(out.finish())
}
fn is_allowed_param_key(key: &str) -> bool {
matches!(
key,
"lat"
| "lon"
| "zoom"
| "filter"
| "school"
| "crime"
| "voteShare"
| "ethnicity"
| "amenityDistance"
| "transportDistance"
| "amenityCount2km"
| "amenityCount5km"
| "poi"
| "tab"
| "pc"
| "tt"
| "share"
)
}
fn escape_attr(value: &str) -> String {
value
.replace('&', "&amp;")
.replace('"', "&quot;")
.replace('\'', "&#39;")
.replace('<', "&lt;")
.replace('>', "&gt;")
}
fn user_can_create_share_grant(user: &OptionalUser) -> bool {
user.0
.as_ref()
.is_some_and(|u| u.is_admin || u.subscription == "licensed")
}
fn share_fields(
bounds: Option<ShareBounds>,
) -> (Option<f64>, Option<f64>, Option<f64>, Option<f64>) {
match bounds {
Some(bounds) => (
Some(bounds.south),
Some(bounds.west),
Some(bounds.north),
Some(bounds.east),
),
None => (None, None, None, None),
}
}
fn record_share_bounds(item: &serde_json::Value) -> Option<ShareBounds> {
let bounds = ShareBounds {
south: item.get("share_south")?.as_f64()?,
west: item.get("share_west")?.as_f64()?,
north: item.get("share_north")?.as_f64()?,
east: item.get("share_east")?.as_f64()?,
};
is_valid_share_bounds(bounds).then_some(bounds)
}
fn dashboard_redirect_url(params: &str, code: &str, include_share: bool) -> String {
match (params.is_empty(), include_share) {
(true, false) => "/dashboard".to_string(),
(true, true) => format!("/dashboard?share={code}"),
(false, false) => format!("/dashboard?{params}"),
(false, true) => format!("/dashboard?{params}&share={code}"),
}
}
fn og_image_url(public_url: &str, params: &str) -> String {
if params.is_empty() {
format!("{}/api/screenshot?og=1", public_url.trim_end_matches('/'))
} else {
format!(
"{}/api/screenshot?og=1&{params}",
public_url.trim_end_matches('/')
)
}
}
pub async fn post_shorten(
State(shared): State<Arc<SharedState>>,
Extension(user): Extension<OptionalUser>,
Json(req): Json<ShortenRequest>,
) -> Response {
let state = shared.load_state();
let pb_url = state.pocketbase_url.trim_end_matches('/');
let can_create_share_grant = user_can_create_share_grant(&user);
let params = match sanitized_query_params(&req.params, !can_create_share_grant) {
Ok(params) => params,
Err(reason) => {
warn!("Rejected short URL params: {reason}");
return (StatusCode::BAD_REQUEST, reason).into_response();
}
};
let token = match get_superuser_token(&state).await {
Ok(t) => t,
Err(err) => {
@ -54,10 +225,22 @@ pub async fn post_shorten(
};
let code = generate_code();
let share_bounds = if can_create_share_grant {
share_bounds_from_params(&params)
} else {
None
};
let (share_south, share_west, share_north, share_east) = share_fields(share_bounds);
let record = PbRecord {
code: code.clone(),
params: req.params,
params,
created_by: user.0.as_ref().map(|u| u.id.clone()),
click_count: 0,
share_south,
share_west,
share_north,
share_east,
};
let res = state
@ -89,6 +272,85 @@ pub async fn post_shorten(
}
}
pub async fn get_share_links(
State(shared): State<Arc<SharedState>>,
Extension(user): Extension<OptionalUser>,
) -> Response {
let state = shared.load_state();
let user = match user.0 {
Some(u) => u,
None => return StatusCode::UNAUTHORIZED.into_response(),
};
let pb_url = state.pocketbase_url.trim_end_matches('/');
let token = match get_superuser_token(&state).await {
Ok(t) => t,
Err(err) => {
warn!("PocketBase superuser auth failed: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
};
let filter = format!("created_by=\"{}\"", user.id);
let url = format!(
"{pb_url}/api/collections/short_urls/records?sort=-created&perPage=200&filter={}",
urlencoding::encode(&filter)
);
let res = match state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await
{
Ok(r) => r,
Err(err) => {
warn!("Failed to list share links: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
};
if !res.status().is_success() {
let status = res.status();
let text = res.text().await.unwrap_or_default();
warn!("PocketBase list share links failed ({status}): {text}");
return StatusCode::BAD_GATEWAY.into_response();
}
let body: serde_json::Value = match res.json().await {
Ok(v) => v,
Err(err) => {
warn!("Failed to parse share links response: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
};
let public_url = state.public_url.trim_end_matches('/');
let links: Vec<ShareLinkListItem> = body["items"]
.as_array()
.map(|arr| {
arr.iter()
.map(|item| {
let code = item["code"].as_str().unwrap_or("").to_string();
let params = item["params"].as_str().unwrap_or("").to_string();
ShareLinkListItem {
url: format!("{public_url}/s/{code}"),
code,
og_image_url: og_image_url(public_url, &params),
params,
click_count: json_number_as_u64(&item["click_count"]),
created: item["created"].as_str().unwrap_or("").to_string(),
}
})
.collect()
})
.unwrap_or_default();
Json(ShareLinksResponse { links }).into_response()
}
pub async fn get_short_url(
State(shared): State<Arc<SharedState>>,
Path(code): Path<String>,
@ -132,22 +394,51 @@ pub async fn get_short_url(
}
};
let params = json["items"]
.as_array()
.and_then(|items| items.first())
.and_then(|item| item["params"].as_str());
let item = json["items"].as_array().and_then(|items| items.first());
match params {
Some(params) => {
let redirect_url = if params.is_empty() {
format!("/dashboard?share={code}")
} else {
format!("/dashboard?{params}&share={code}")
match item.and_then(|item| item["params"].as_str().map(|params| (item, params))) {
Some((item, params)) => {
let record_id = item["id"].as_str().unwrap_or("").to_string();
let next_click_count =
json_number_as_u64(&item["click_count"]).saturating_add(1);
let params = match sanitized_query_params(params, true) {
Ok(params) => params,
Err(reason) => {
warn!("Stored short URL params rejected for {code}: {reason}");
return StatusCode::BAD_REQUEST.into_response();
}
};
let og_image_url = format!("{}/api/screenshot?og=1&{params}", state.public_url);
let og_url = format!("{}/s/{code}", state.public_url);
if !record_id.is_empty() {
let update_url =
format!("{pb_url}/api/collections/short_urls/records/{record_id}");
match state
.http_client
.patch(&update_url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({ "click_count": next_click_count }))
.send()
.await
{
Ok(update_resp) if update_resp.status().is_success() => {}
Ok(update_resp) => {
let status = update_resp.status();
let text = update_resp.text().await.unwrap_or_default();
warn!("PocketBase click count update failed ({status}): {text}");
}
Err(err) => warn!("PocketBase click count update failed: {err}"),
}
}
let redirect_url =
dashboard_redirect_url(&params, &code, record_share_bounds(item).is_some());
let og_image_url = og_image_url(&state.public_url, &params);
let og_url = format!("{}/s/{code}", state.public_url.trim_end_matches('/'));
let og_title = "Perfect Postcode | Every neighbourhood in England";
let og_description = "Explore property prices, energy ratings, crime stats, school ratings, and more across England on one interactive map.";
let redirect_url = escape_attr(&redirect_url);
let og_image_url = escape_attr(&og_image_url);
let og_url = escape_attr(&og_url);
let og_title = escape_attr(og_title);
let og_description = escape_attr(og_description);
let html = format!(
r#"<!DOCTYPE html>
@ -168,7 +459,13 @@ pub async fn get_short_url(
</head><body></body></html>"#
);
(
[(header::CACHE_CONTROL, "public, max-age=86400")],
[
(header::CACHE_CONTROL, "no-store"),
(
header::CONTENT_SECURITY_POLICY,
"default-src 'none'; img-src https: data:; base-uri 'none'; form-action 'none'",
),
],
Html(html),
)
.into_response()
@ -187,3 +484,37 @@ pub async fn get_short_url(
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sanitizes_short_url_params_and_drops_share() {
let params = sanitized_query_params(
"lat=51.5&lon=-0.1&zoom=12&filter=price%3A1%3A2&share=oldcode",
false,
)
.unwrap();
assert_eq!(params, "lat=51.5&lon=-0.1&zoom=12&filter=price%3A1%3A2");
}
#[test]
fn rejects_html_in_unsupported_params() {
assert!(sanitized_query_params("lat=51&x=%22%3E%3Cscript%3E", false).is_err());
}
#[test]
fn can_preserve_existing_share_grant() {
let params =
sanitized_query_params("lat=51.5&lon=-0.1&zoom=12&share=oldcode", true).unwrap();
assert_eq!(params, "lat=51.5&lon=-0.1&zoom=12&share=oldcode");
}
#[test]
fn escapes_html_attributes() {
assert_eq!(escape_attr(r#""'><&"#), "&quot;&#39;&gt;&lt;&amp;");
}
}

View file

@ -9,7 +9,7 @@ use sha2::Sha256;
use tracing::{info, warn};
use crate::checkout_sessions::{
grant_license, mark_checkout_completed, mark_referral_invite_used, verify_checkout_completion,
complete_verified_checkout, reverse_license_for_payment_intent, verify_checkout_completion,
CheckoutCompletion,
};
use crate::state::SharedState;
@ -54,16 +54,52 @@ fn verify_signature(payload: &[u8], sig_header: &str, secret: &str) -> bool {
signed_payload.push(b'.');
signed_payload.extend_from_slice(payload);
signatures.into_iter().any(|sig_hex| {
// Verify every candidate signature without short-circuiting, so the total
// time taken doesn't depend on which (if any) signature matched.
let mut matched = false;
for sig_hex in signatures {
let Ok(sig_bytes) = hex::decode(sig_hex) else {
return false;
continue;
};
let Ok(mut mac) = HmacSha256::new_from_slice(secret.as_bytes()) else {
return false;
continue;
};
mac.update(&signed_payload);
mac.verify_slice(&sig_bytes).is_ok()
})
// verify_slice itself is constant-time.
if mac.verify_slice(&sig_bytes).is_ok() {
matched = true;
}
}
matched
}
fn payment_intent_id_from_object(object: &serde_json::Value) -> Option<&str> {
object["payment_intent"]
.as_str()
.filter(|id| is_safe_stripe_id(id))
}
fn is_safe_stripe_id(id: &str) -> bool {
!id.is_empty()
&& id.len() <= 128
&& id
.bytes()
.all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'-')
}
fn reversal_event_is_actionable(event_type: &str, object: &serde_json::Value) -> bool {
match event_type {
"charge.refunded" => {
object["refunded"].as_bool().unwrap_or(false)
|| object["amount_refunded"].as_u64().unwrap_or(0) > 0
}
"charge.refund.updated" | "refund.created" | "refund.updated" => {
matches!(object["status"].as_str(), Some("succeeded"))
}
"charge.dispute.created" | "charge.dispute.funds_withdrawn" => true,
"charge.dispute.closed" => matches!(object["status"].as_str(), Some("lost")),
_ => false,
}
}
/// Handle Stripe webhook events.
@ -109,40 +145,11 @@ pub async fn post_stripe_webhook(
let session = &event["data"]["object"];
match verify_checkout_completion(&state, session).await {
Ok(CheckoutCompletion::Grant(checkout)) => {
if let Err(err) = mark_referral_invite_used(
&state,
&checkout.referral_invite_id,
&checkout.user_id,
)
.await
{
if let Err(err) = complete_verified_checkout(&state, &checkout).await {
warn!(
user_id = %checkout.user_id,
reservation_id = %checkout.reservation_id,
referral_invite_id = %checkout.referral_invite_id,
"Failed to mark referral invite used after Stripe checkout: {err:?}"
);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
if let Err(err) = grant_license(&state, &checkout.user_id).await {
warn!(
user_id = %checkout.user_id,
reservation_id = %checkout.reservation_id,
"Failed to grant license after Stripe checkout: {err:?}"
);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
if let Err(err) = mark_checkout_completed(
&state,
&checkout.reservation_id,
checkout.paid_amount_pence,
)
.await
{
warn!(
user_id = %checkout.user_id,
reservation_id = %checkout.reservation_id,
"Failed to mark checkout completed after license grant: {err:?}"
"Failed to complete verified Stripe checkout: {err:?}"
);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
@ -163,6 +170,52 @@ pub async fn post_stripe_webhook(
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
}
} else if matches!(
event_type,
"charge.refunded"
| "charge.refund.updated"
| "refund.created"
| "refund.updated"
| "charge.dispute.created"
| "charge.dispute.closed"
| "charge.dispute.funds_withdrawn"
) {
let object = &event["data"]["object"];
let Some(payment_intent_id) = payment_intent_id_from_object(object) else {
warn!(
event_id,
event_type, "Stripe reversal event missing payment intent id"
);
return StatusCode::OK.into_response();
};
if !reversal_event_is_actionable(event_type, object) {
info!(
payment_intent_id,
event_type, "Ignoring non-final Stripe reversal event"
);
return StatusCode::OK.into_response();
}
match reverse_license_for_payment_intent(&state, payment_intent_id, event_type).await {
Ok(Some(user_id)) => {
info!(
user_id,
payment_intent_id, event_type, "Processed Stripe payment reversal event"
);
}
Ok(None) => {
warn!(
payment_intent_id,
event_type, "Stripe reversal event had no matching checkout reservation"
);
}
Err(err) => {
warn!(
payment_intent_id,
event_type, "Failed to process Stripe payment reversal event: {err:?}"
);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
}
}
StatusCode::OK.into_response()