This commit is contained in:
Andras Schmelczer 2026-02-15 22:39:49 +00:00
parent 03445188ea
commit 524580eb25
102 changed files with 36625 additions and 1295 deletions

View file

@ -0,0 +1,178 @@
use std::sync::Arc;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::{Extension, Json};
use serde::{Deserialize, Serialize};
use tracing::{info, warn};
use crate::auth::OptionalUser;
use crate::pocketbase::auth_superuser;
use crate::state::AppState;
use super::pricing::{count_licensed_users, price_for_count};
#[derive(Deserialize)]
pub struct CheckoutRequest {
referral_code: Option<String>,
}
#[derive(Serialize)]
struct CheckoutResponse {
url: String,
}
/// Create a Stripe Checkout session for the lifetime license (or grant for free if in free tier).
/// Requires authentication. Optionally accepts a referral code to apply a coupon.
pub async fn post_checkout(
state: Arc<AppState>,
Extension(user): Extension<OptionalUser>,
Json(req): Json<CheckoutRequest>,
) -> Response {
let user = match user.0 {
Some(u) => u,
None => return StatusCode::UNAUTHORIZED.into_response(),
};
let count = match count_licensed_users(&state).await {
Ok(c) => c,
Err(err) => {
warn!("Failed to count licensed users at checkout: {err}");
return StatusCode::SERVICE_UNAVAILABLE.into_response();
}
};
let price_pence = price_for_count(count);
let public_url = &state.public_url;
let success_url = format!("{public_url}/pricing?license_success=1");
// Free tier — grant license directly without Stripe
if price_pence == 0 {
if let Err(err) = grant_license(&state, &user.id).await {
warn!(user_id = %user.id, "Failed to grant free license: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
info!(user_id = %user.id, "Granted free early-bird license");
return Json(CheckoutResponse { url: success_url }).into_response();
}
// Paid tier — create Stripe checkout with dynamic price
let secret_key = &state.stripe_secret_key;
let cancel_url = format!("{public_url}/pricing");
let mut form_params = vec![
("mode", "payment".to_string()),
(
"line_items[0][price_data][unit_amount]",
price_pence.to_string(),
),
("line_items[0][price_data][currency]", "gbp".to_string()),
(
"line_items[0][price_data][product_data][name]",
"Perfect Postcodes Lifetime License".to_string(),
),
("line_items[0][quantity]", "1".to_string()),
("success_url", success_url),
("cancel_url", cancel_url),
("client_reference_id", user.id.clone()),
("customer_email", user.email.clone()),
];
// If a referral code is provided and valid, look it up and apply the coupon
if let Some(ref code) = req.referral_code {
if validate_referral_invite(&state, code).await {
form_params.push(("discounts[0][coupon]", state.stripe_referral_coupon_id.clone()));
info!(code = %code, "Applying referral coupon to checkout");
}
}
let res = state
.http_client
.post("https://api.stripe.com/v1/checkout/sessions")
.basic_auth(secret_key, None::<&str>)
.form(&form_params)
.send()
.await;
match res {
Ok(resp) if resp.status().is_success() => {
let body: serde_json::Value = match resp.json().await {
Ok(v) => v,
Err(err) => {
warn!("Failed to parse Stripe response: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
};
let url = body["url"].as_str().unwrap_or_default().to_string();
if url.is_empty() {
warn!("Stripe session missing URL");
return StatusCode::BAD_GATEWAY.into_response();
}
info!(user_id = %user.id, price_pence, "Created Stripe checkout session");
Json(CheckoutResponse { url }).into_response()
}
Ok(resp) => {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
warn!("Stripe checkout failed ({status}): {text}");
StatusCode::BAD_GATEWAY.into_response()
}
Err(err) => {
warn!("Stripe request error: {err}");
StatusCode::BAD_GATEWAY.into_response()
}
}
}
/// Grant a license by updating the user's subscription to "licensed" in PocketBase.
async fn grant_license(state: &AppState, user_id: &str) -> anyhow::Result<()> {
let pb_url = state.pocketbase_url.trim_end_matches('/');
let token = auth_superuser(&state.http_client, pb_url, &state.pocketbase_admin_email, &state.pocketbase_admin_password).await?;
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
let resp = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({ "subscription": "licensed" }))
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
anyhow::bail!("PocketBase update failed ({status}): {text}");
}
state.token_cache.invalidate_by_user_id(user_id);
Ok(())
}
/// Check if a referral invite code exists and is unused.
async fn validate_referral_invite(state: &AppState, code: &str) -> bool {
// Only allow alphanumeric codes to prevent PocketBase filter injection
if code.is_empty()
|| code.len() > 20
|| !code.bytes().all(|b| b.is_ascii_alphanumeric())
{
return false;
}
let pb_url = state.pocketbase_url.trim_end_matches('/');
let filter = format!(
"code=\"{}\" && invite_type=\"referral\" && used_by_id=\"\"",
code
);
let url = format!(
"{pb_url}/api/collections/invites/records?filter={}&perPage=1",
urlencoding::encode(&filter)
);
match state.http_client.get(&url).send().await {
Ok(resp) if resp.status().is_success() => {
let body: serde_json::Value = resp.json().await.unwrap_or_default();
body["totalItems"].as_u64().unwrap_or(0) > 0
}
_ => false,
}
}

View file

@ -5,11 +5,14 @@ use std::sync::Arc;
use axum::extract::Query;
use axum::http::{header, StatusCode};
use axum::response::IntoResponse;
use axum::Extension;
use rust_xlsxwriter::{Format, FormatAlign, FormatBorder, Image, Url, Workbook};
use rustc_hash::{FxHashMap, FxHashSet};
use serde::Deserialize;
use tracing::{info, warn};
use crate::auth::OptionalUser;
use crate::licensing::check_license_bounds;
use crate::parsing::{parse_field_indices, parse_filters, require_bounds, row_passes_filters};
use crate::routes::FeatureInfo;
use crate::state::AppState;
@ -150,9 +153,14 @@ async fn fetch_screenshot(
pub async fn get_export(
state: Arc<AppState>,
Extension(user): Extension<OptionalUser>,
Query(params): Query<ExportParams>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
let (south, west, north, east) = require_bounds(params.bounds)?;
) -> Result<impl IntoResponse, axum::response::Response> {
let (south, west, north, east) =
require_bounds(params.bounds).map_err(IntoResponse::into_response)?;
check_license_bounds(&user.0, (south, west, north, east))
.map_err(|(_, resp)| resp)?;
let filters_str = params.filters.clone();
let fields_str = params.fields.clone();
@ -161,7 +169,7 @@ pub async fn get_export(
&state.feature_name_to_index,
&state.data.enum_values,
)
.map_err(|err| (StatusCode::BAD_REQUEST, err))?;
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let public_url = state.public_url.clone();
@ -269,7 +277,8 @@ pub async fn get_export(
let filter_feature_names = extract_filter_feature_names(filters_str.as_deref());
let field_indices =
parse_field_indices(fields_str.as_deref(), &state.feature_name_to_index);
parse_field_indices(fields_str.as_deref(), &state.feature_name_to_index)
.map_err(|err| err.1)?;
let all_feature_indices: Vec<usize> = if let Some(ref indices) = field_indices {
indices.clone()
@ -564,8 +573,8 @@ pub async fn get_export(
Ok(buf)
})
.await
.map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?
.map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err))?;
.map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?
.map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err).into_response())?;
Ok((
[

View file

@ -4,10 +4,13 @@ use std::sync::Arc;
use axum::extract::Query;
use axum::http::StatusCode;
use axum::response::Json;
use axum::response::{IntoResponse, Json};
use axum::Extension;
use serde::{Deserialize, Serialize};
use tracing::{info, warn};
use crate::auth::OptionalUser;
use crate::licensing::check_license_bounds;
use crate::parsing::{
cell_for_row, h3_cell_bounds, needs_parent, parse_field_set, parse_filters, row_passes_filters,
validate_h3_resolution,
@ -70,19 +73,25 @@ pub struct HexagonStatsParams {
pub async fn get_hexagon_stats(
state: Arc<AppState>,
Extension(user): Extension<OptionalUser>,
Query(params): Query<HexagonStatsParams>,
) -> Result<Json<HexagonStatsResponse>, (StatusCode, String)> {
) -> Result<Json<HexagonStatsResponse>, axum::response::Response> {
let cell = h3o::CellIndex::from_str(&params.h3).map_err(|error| {
warn!(h3 = %params.h3, error = %error, "Invalid H3 cell index");
(
StatusCode::BAD_REQUEST,
format!("Invalid H3 cell: {}", error),
)
.into_response()
})?;
let cell_u64: u64 = cell.into();
let resolution = params.resolution;
validate_h3_resolution(resolution)?;
validate_h3_resolution(resolution).map_err(IntoResponse::into_response)?;
// License check using H3 cell bounds
let h3_bounds = h3_cell_bounds(cell, 0.0);
check_license_bounds(&user.0, h3_bounds).map_err(|(_, resp)| resp)?;
let h3_str = params.h3.clone();
let filters_str = params.filters.clone();
@ -91,7 +100,7 @@ pub async fn get_hexagon_stats(
&state.feature_name_to_index,
&state.data.enum_values,
)
.map_err(|err| (StatusCode::BAD_REQUEST, err))?;
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
let (fields_specified, field_set) = parse_field_set(params.fields.as_deref());
@ -164,8 +173,8 @@ pub async fn get_hexagon_stats(
})
})
.await
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()))?
.map_err(|error: String| (StatusCode::INTERNAL_SERVER_ERROR, error))?;
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()).into_response())?
.map_err(|error: String| (StatusCode::INTERNAL_SERVER_ERROR, error).into_response())?;
Ok(Json(response))
}

View file

@ -2,19 +2,23 @@ use std::sync::Arc;
use axum::extract::Query;
use axum::http::StatusCode;
use axum::response::Json;
use axum::response::{IntoResponse, Json};
use axum::Extension;
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use tracing::info;
use crate::aggregation::Aggregator;
use crate::auth::OptionalUser;
use crate::consts::MAX_CELLS_PER_REQUEST;
use crate::data::travel_time::TravelData;
use crate::licensing::check_license_bounds;
use crate::parsing::{
bounds_intersect, cell_for_row, h3_cell_bounds, needs_parent, parse_field_indices,
parse_filters, require_bounds, row_passes_filters, validate_h3_resolution,
};
use crate::routes::travel_time::fetch_travel_times;
use crate::routes::travel_time::TravelTimeAgg;
use crate::state::AppState;
#[derive(Serialize)]
@ -27,64 +31,69 @@ pub struct HexagonParams {
resolution: u8,
bounds: Option<String>,
/// Comma-separated filters: `name:min:max,...`
/// Rows must have non-NaN values within [min,max] for each filter.
filters: Option<String>,
/// Comma-separated feature names to include in min/max aggregation.
/// When present (even if empty), only listed features are aggregated and written.
/// When absent, all features are included (backward compatible).
fields: Option<String>,
/// Pipe-separated travel time entries: `lat,lon,mode|lat,lon,mode`
/// Each entry requests travel time from hex centroids to that destination via the given mode.
/// Pipe-separated travel time entries: `mode:slug|mode:slug:min:max`
/// Each entry requests travel time aggregation for that mode+destination.
/// Optional min:max applies as a filter (exclude properties outside range).
travel: Option<String>,
}
struct TravelEntry {
lat: f64,
lon: f64,
mode: String,
slug: String,
filter_min: Option<f32>,
filter_max: Option<f32>,
}
const VALID_MODES: &[&str] = &["car", "bicycle", "walking", "transit"];
/// Parse `travel` param into a list of travel entries.
/// Format: `lat,lon,mode|lat,lon,mode`
fn parse_travel_entries(s: &str) -> Result<Vec<TravelEntry>, String> {
/// Format: `mode:slug` or `mode:slug:min:max`
fn parse_travel_entries(travel_str: &str) -> Result<Vec<TravelEntry>, String> {
let mut entries = Vec::new();
let mut seen_modes = Vec::new();
for segment in s.split('|') {
let parts: Vec<&str> = segment.split(',').collect();
if parts.len() != 3 {
let mut seen_keys = Vec::new();
for segment in travel_str.split('|') {
let parts: Vec<&str> = segment.split(':').collect();
if parts.len() < 2 {
return Err(format!(
"each travel entry must be 'lat,lon,mode', got '{}'",
"each travel entry must be 'mode:slug' or 'mode:slug:min:max', got '{}'",
segment
));
}
let lat: f64 = parts[0]
.trim()
.parse()
.map_err(|_| format!("invalid travel latitude in '{}'", segment))?;
let lon: f64 = parts[1]
.trim()
.parse()
.map_err(|_| format!("invalid travel longitude in '{}'", segment))?;
let mode = parts[2].trim().to_string();
if !VALID_MODES.contains(&mode.as_str()) {
return Err(format!(
"invalid travel mode '{}', must be one of: {}",
mode,
VALID_MODES.join(", ")
));
let mode = parts[0].trim().to_string();
let slug = parts[1].trim().to_string();
let (filter_min, filter_max) = if parts.len() >= 4 {
let min: f32 = parts[2]
.trim()
.parse()
.map_err(|_| format!("invalid travel filter min in '{}'", segment))?;
let max: f32 = parts[3]
.trim()
.parse()
.map_err(|_| format!("invalid travel filter max in '{}'", segment))?;
(Some(min), Some(max))
} else {
(None, None)
};
let key = format!("{}:{}", mode, slug);
if seen_keys.contains(&key) {
return Err(format!("duplicate travel entry '{}'", key));
}
if seen_modes.contains(&mode) {
return Err(format!("duplicate travel mode '{}'", mode));
}
seen_modes.push(mode.clone());
entries.push(TravelEntry { lat, lon, mode });
seen_keys.push(key);
entries.push(TravelEntry {
mode,
slug,
filter_min,
filter_max,
});
}
Ok(entries)
}
/// Build feature maps from aggregated cell data, filtering to only cells that intersect the query bounds.
#[allow(clippy::too_many_arguments)]
fn build_feature_maps(
groups: &FxHashMap<u64, Aggregator>,
min_keys: &[String],
@ -92,7 +101,9 @@ fn build_feature_maps(
avg_keys: &[String],
num_features: usize,
indices: Option<&[usize]>,
query_bounds: (f64, f64, f64, f64), // (south, west, north, east)
query_bounds: (f64, f64, f64, f64),
travel_aggs: &[FxHashMap<u64, TravelTimeAgg>],
travel_field_keys: &[String],
) -> Vec<Map<String, Value>> {
let mut features = Vec::with_capacity(groups.len());
let (q_south, q_west, q_north, q_east) = query_bounds;
@ -143,6 +154,25 @@ fn build_feature_maps(
}
}
// Add travel time aggregation fields
for (ti, agg_map) in travel_aggs.iter().enumerate() {
if let Some(agg) = agg_map.get(&cell_id) {
if agg.count > 0 {
let key = &travel_field_keys[ti];
let avg = agg.sum / agg.count as f64;
if let Some(nm) = serde_json::Number::from_f64(agg.min as f64) {
map.insert(format!("min_{key}"), Value::Number(nm));
}
if let Some(nm) = serde_json::Number::from_f64(agg.max as f64) {
map.insert(format!("max_{key}"), Value::Number(nm));
}
if let Some(nm) = serde_json::Number::from_f64(avg) {
map.insert(format!("avg_{key}"), Value::Number(nm));
}
}
}
}
features.push(map);
}
@ -151,12 +181,21 @@ fn build_feature_maps(
pub async fn get_hexagons(
state: Arc<AppState>,
Extension(user): Extension<OptionalUser>,
Query(params): Query<HexagonParams>,
) -> Result<Json<HexagonsResponse>, (StatusCode, String)> {
) -> Result<Json<HexagonsResponse>, axum::response::Response> {
let resolution = params.resolution;
validate_h3_resolution(resolution)?;
validate_h3_resolution(resolution).map_err(IntoResponse::into_response)?;
let (south, west, north, east) = require_bounds(params.bounds)?;
let (south, west, north, east) =
require_bounds(params.bounds).map_err(IntoResponse::into_response)?;
// Skip license check at low resolutions (≤5) — data is too aggregated to be
// commercially useful, and the homepage demo needs country-wide access.
if resolution > 5 {
check_license_bounds(&user.0, (south, west, north, east))
.map_err(|(_, resp)| resp)?;
}
let filters_str = params.filters.clone();
let (parsed_filters, parsed_enum_filters) = parse_filters(
@ -164,30 +203,49 @@ pub async fn get_hexagons(
&state.feature_name_to_index,
&state.data.enum_values,
)
.map_err(|err| (StatusCode::BAD_REQUEST, err))?;
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
let field_indices = parse_field_indices(params.fields.as_deref(), &state.feature_name_to_index);
let field_indices = parse_field_indices(params.fields.as_deref(), &state.feature_name_to_index)
.map_err(|err| (err.0, err.1).into_response())?;
// Parse travel entries
let travel_entries = params
.travel
.as_deref()
.filter(|s| !s.is_empty())
.filter(|val| !val.is_empty())
.map(parse_travel_entries)
.transpose()
.map_err(|e| (StatusCode::BAD_REQUEST, e))?
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?
.unwrap_or_default();
// Capture what we need for the R5 calls before moving state into spawn_blocking
let r5_url = state.r5_url.clone();
let http_client = state.http_client.clone();
let mut response = tokio::task::spawn_blocking(move || -> Result<HexagonsResponse, String> {
let response = tokio::task::spawn_blocking(move || -> Result<HexagonsResponse, String> {
let t0 = std::time::Instant::now();
// Load travel time data from precomputed parquet files
let travel_data: Vec<TravelData> = if !travel_entries.is_empty() {
let store = &state.travel_time_store;
travel_entries
.iter()
.map(|entry| {
store
.get(&entry.mode, &entry.slug)
.map_err(|err| format!("Failed to load travel data: {}", err))
})
.collect::<Result<Vec<_>, _>>()?
} else {
Vec::new()
};
let has_travel = !travel_entries.is_empty();
let travel_field_keys: Vec<String> = travel_entries
.iter()
.map(|te| format!("tt_{}_{}", te.mode, te.slug))
.collect();
let num_features = state.data.num_features;
let feature_data = &state.data.feature_data;
let (pc_interner, pc_keys) = state.data.postcode_parts();
let min_keys = &state.min_keys;
let max_keys = &state.max_keys;
let avg_keys = &state.avg_keys;
@ -198,49 +256,70 @@ pub async fn get_hexagons(
let need_parent = needs_parent(resolution);
let mut groups: FxHashMap<u64, Aggregator> = FxHashMap::default();
let mut travel_aggs: Vec<FxHashMap<u64, TravelTimeAgg>> =
(0..travel_entries.len()).map(|_| FxHashMap::default()).collect();
// Hoist has_selective branch outside the hot loop to avoid per-row branching
if let Some(sel_indices) = field_indices.as_deref() {
state
.grid
.for_each_in_bounds(south, west, north, east, |row_idx| {
let row = row_idx as usize;
if !row_passes_filters(
row,
&parsed_filters,
&parsed_enum_filters,
feature_data,
num_features,
) {
return;
// Main aggregation loop
let aggregate_row =
|row: usize,
groups: &mut FxHashMap<u64, Aggregator>,
travel_aggs: &mut [FxHashMap<u64, TravelTimeAgg>]| {
// Regular filters
if !row_passes_filters(
row,
&parsed_filters,
&parsed_enum_filters,
feature_data,
num_features,
) {
return;
}
// Travel time filter: check each entry with a range
let mut travel_minutes: Vec<Option<i16>> = Vec::new();
if has_travel {
let postcode = pc_interner.resolve(&pc_keys[row]);
travel_minutes.reserve(travel_entries.len());
for (ti, entry) in travel_entries.iter().enumerate() {
let minutes = travel_data[ti].get(postcode).copied();
travel_minutes.push(minutes);
if let (Some(fmin), Some(fmax)) = (entry.filter_min, entry.filter_max) {
match minutes {
Some(mins) if (mins as f32) >= fmin && (mins as f32) <= fmax => {}
_ => return, // Filtered out
}
}
}
let cell_id = cell_for_row(row, precomputed, h3_res, need_parent);
let aggregation = groups
.entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features));
}
let cell_id = cell_for_row(row, precomputed, h3_res, need_parent);
// Aggregate regular features
let aggregation = groups
.entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features));
if let Some(sel_indices) = field_indices.as_deref() {
aggregation.add_row_selective(feature_data, row, num_features, sel_indices);
});
} else {
state
.grid
.for_each_in_bounds(south, west, north, east, |row_idx| {
let row = row_idx as usize;
if !row_passes_filters(
row,
&parsed_filters,
&parsed_enum_filters,
feature_data,
num_features,
) {
return;
}
let cell_id = cell_for_row(row, precomputed, h3_res, need_parent);
let aggregation = groups
.entry(cell_id)
.or_insert_with(|| Aggregator::new(num_features));
} else {
aggregation.add_row(feature_data, row, num_features);
});
}
}
// Aggregate travel time
for (ti, minutes) in travel_minutes.iter().enumerate() {
if let Some(mins) = minutes {
let agg = travel_aggs[ti]
.entry(cell_id)
.or_insert_with(TravelTimeAgg::new);
agg.add(*mins as f32);
}
}
};
state
.grid
.for_each_in_bounds(south, west, north, east, |row_idx| {
aggregate_row(row_idx as usize, &mut groups, &mut travel_aggs);
});
let t_agg = t0.elapsed();
@ -252,6 +331,8 @@ pub async fn get_hexagons(
num_features,
field_indices.as_deref(),
(south, west, north, east),
&travel_aggs,
&travel_field_keys,
);
let truncated = features.len() > MAX_CELLS_PER_REQUEST;
@ -268,6 +349,7 @@ pub async fn get_hexagons(
bounds = format_args!("{:.4},{:.4},{:.4},{:.4}", south, west, north, east),
filters = num_filters,
filters_raw = filters_str.as_deref().unwrap_or("-"),
travel_entries = travel_entries.len(),
agg_ms = format_args!("{:.1}", t_agg.as_secs_f64() * 1000.0),
total_ms = format_args!("{:.1}", t_total.as_secs_f64() * 1000.0),
"GET /api/hexagons"
@ -276,76 +358,8 @@ pub async fn get_hexagons(
Ok(HexagonsResponse { features })
})
.await
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()))?
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error))?;
// If travel entries were requested and R5 is configured, fetch travel times concurrently.
if !travel_entries.is_empty() {
let url = r5_url.as_deref().ok_or((
StatusCode::SERVICE_UNAVAILABLE,
"Travel time queries require routing service (R5_URL not configured)".into(),
))?;
// Collect hex centroids
let origins: Vec<[f64; 2]> = response
.features
.iter()
.map(|f| {
let lat = f
.get("lat")
.and_then(|v| v.as_f64())
.expect("lat must be present in feature map");
let lon = f
.get("lon")
.and_then(|v| v.as_f64())
.expect("lon must be present in feature map");
[lat, lon]
})
.collect();
// Fire concurrent R5 calls for each travel entry
let mut handles = Vec::with_capacity(travel_entries.len());
for entry in &travel_entries {
let client = http_client.clone();
let url = url.to_string();
let origins = origins.clone();
let dest = [entry.lat, entry.lon];
let mode = entry.mode.clone();
handles.push(tokio::spawn(async move {
fetch_travel_times(&client, &url, origins, dest, &mode).await
}));
}
let mut results = Vec::with_capacity(handles.len());
for handle in handles {
results.push(handle.await);
}
for (entry, result) in travel_entries.iter().zip(results) {
let travel_times = result
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.map_err(|err| (StatusCode::BAD_GATEWAY, err))?;
let field_name = format!("travel_time_{}", entry.mode);
for (feature, tt) in response.features.iter_mut().zip(&travel_times) {
match tt {
Some(minutes) => {
if let Some(num) = serde_json::Number::from_f64(*minutes) {
feature.insert(field_name.clone(), Value::Number(num));
}
}
None => {
feature.insert(field_name.clone(), Value::Null);
}
}
}
info!(
hexagons = response.features.len(),
destination = format_args!("{},{}", entry.lat, entry.lon),
mode = entry.mode,
"Travel times merged"
);
}
}
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()).into_response())?
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error).into_response())?;
Ok(Json(response))
}

View file

@ -0,0 +1,374 @@
use std::sync::Arc;
use axum::extract::Path;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::{Extension, Json};
use serde::{Deserialize, Serialize};
use tracing::{info, warn};
use crate::auth::OptionalUser;
use crate::pocketbase::auth_superuser;
use crate::state::AppState;
#[derive(Serialize)]
struct InviteResponse {
code: String,
url: String,
invite_type: String,
}
#[derive(Serialize)]
struct InviteValidation {
valid: bool,
invite_type: String,
used: bool,
}
#[derive(Deserialize)]
pub struct RedeemRequest {
code: String,
}
#[derive(Serialize)]
struct RedeemResponse {
/// "licensed" if admin invite was redeemed directly, or a checkout URL for referral
result: String,
/// For referral invites: the Stripe checkout URL with coupon
checkout_url: Option<String>,
}
/// Validate that an invite code contains only safe characters (alphanumeric, lowercase).
/// Rejects any code that could be used for PocketBase filter injection.
fn validate_invite_code(code: &str) -> Result<(), &'static str> {
if code.is_empty() || code.len() > 20 {
return Err("Invalid invite code length");
}
if !code.bytes().all(|b| b.is_ascii_alphanumeric()) {
return Err("Invalid invite code characters");
}
Ok(())
}
fn generate_invite_code() -> String {
use rand::Rng;
let mut rng = rand::rng();
let chars: Vec<char> = (0..12)
.map(|_| {
let idx: u8 = rng.random_range(0..36);
if idx < 10 {
(b'0' + idx) as char
} else {
(b'a' + idx - 10) as char
}
})
.collect();
chars.into_iter().collect()
}
/// Create an invite. Admins create "admin" invites (free license).
/// Licensed non-admin users create "referral" invites (30% off).
pub async fn post_invites(
state: Arc<AppState>,
Extension(user): Extension<OptionalUser>,
_body: Json<serde_json::Value>,
) -> Response {
let user = match user.0 {
Some(u) => u,
None => return StatusCode::UNAUTHORIZED.into_response(),
};
let invite_type = if user.is_admin {
"admin"
} else if user.subscription == "licensed" {
"referral"
} else {
return (StatusCode::FORBIDDEN, "Only licensed users can create invites").into_response();
};
let code = generate_invite_code();
let pb_url = state.pocketbase_url.trim_end_matches('/');
let token = match auth_superuser(&state.http_client, pb_url, &state.pocketbase_admin_email, &state.pocketbase_admin_password).await
{
Ok(t) => t,
Err(err) => {
warn!("Failed to auth as PocketBase superuser: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
};
let create_url = format!("{pb_url}/api/collections/invites/records");
let res = state
.http_client
.post(&create_url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"code": code,
"created_by": user.id,
"invite_type": invite_type,
"used_by_id": "",
"used_at": "",
}))
.send()
.await;
match res {
Ok(resp) if resp.status().is_success() => {
let public_url = &state.public_url;
let url = format!("{public_url}/invite/{code}");
info!(code = %code, invite_type, user_id = %user.id, "Created invite");
Json(InviteResponse {
code,
url,
invite_type: invite_type.to_string(),
})
.into_response()
}
Ok(resp) => {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
warn!("Failed to create invite ({status}): {text}");
StatusCode::BAD_GATEWAY.into_response()
}
Err(err) => {
warn!("PocketBase request error: {err}");
StatusCode::BAD_GATEWAY.into_response()
}
}
}
/// Validate an invite code. Requires authentication to prevent enumeration.
pub async fn get_invite(
state: Arc<AppState>,
Extension(user): Extension<OptionalUser>,
Path(code): Path<String>,
) -> Response {
if user.0.is_none() {
return StatusCode::UNAUTHORIZED.into_response();
}
if let Err(msg) = validate_invite_code(&code) {
return (StatusCode::BAD_REQUEST, msg).into_response();
}
let pb_url = state.pocketbase_url.trim_end_matches('/');
let filter = format!("code=\"{}\"", code);
let url = format!(
"{pb_url}/api/collections/invites/records?filter={}&perPage=1",
urlencoding::encode(&filter)
);
let res = match state.http_client.get(&url).send().await {
Ok(r) => r,
Err(err) => {
warn!("Failed to look up invite: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
};
if !res.status().is_success() {
return StatusCode::BAD_GATEWAY.into_response();
}
let body: serde_json::Value = match res.json().await {
Ok(v) => v,
Err(_) => return StatusCode::BAD_GATEWAY.into_response(),
};
let items = body["items"].as_array();
match items.and_then(|arr| arr.first()) {
Some(invite) => {
let invite_type = invite["invite_type"].as_str().unwrap_or("").to_string();
let used_by = invite["used_by_id"].as_str().unwrap_or("");
let used = !used_by.is_empty();
Json(InviteValidation {
valid: true,
invite_type,
used,
})
.into_response()
}
None => Json(InviteValidation {
valid: false,
invite_type: String::new(),
used: false,
})
.into_response(),
}
}
/// Redeem an invite code. Requires authentication.
/// Admin invite: sets subscription to "licensed" directly.
/// Referral invite: returns a discounted Stripe checkout URL.
pub async fn post_redeem_invite(
state: Arc<AppState>,
Extension(user): Extension<OptionalUser>,
Json(req): Json<RedeemRequest>,
) -> Response {
let user = match user.0 {
Some(u) => u,
None => return StatusCode::UNAUTHORIZED.into_response(),
};
if let Err(msg) = validate_invite_code(&req.code) {
return (StatusCode::BAD_REQUEST, msg).into_response();
}
let pb_url = state.pocketbase_url.trim_end_matches('/');
let token = match auth_superuser(&state.http_client, pb_url, &state.pocketbase_admin_email, &state.pocketbase_admin_password).await
{
Ok(t) => t,
Err(err) => {
warn!("Failed to auth as PocketBase superuser: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
};
// Look up invite
let filter = format!(
"code=\"{}\" && used_by_id=\"\"",
req.code
);
let lookup_url = format!(
"{pb_url}/api/collections/invites/records?filter={}&perPage=1",
urlencoding::encode(&filter)
);
let res = match state.http_client.get(&lookup_url)
.header("Authorization", format!("Bearer {token}"))
.send().await
{
Ok(r) => r,
Err(err) => {
warn!("Failed to look up invite: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
};
let body: serde_json::Value = match res.json().await {
Ok(v) => v,
Err(_) => return StatusCode::BAD_GATEWAY.into_response(),
};
let invite = match body["items"].as_array().and_then(|arr| arr.first()) {
Some(inv) => inv.clone(),
None => {
return (StatusCode::NOT_FOUND, "Invalid or already used invite code").into_response()
}
};
let invite_id = invite["id"].as_str().unwrap_or("");
let invite_type = invite["invite_type"].as_str().unwrap_or("");
// Mark invite as used
let now = {
let dur = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default();
dur.as_secs().to_string()
};
let _ = state
.http_client
.patch(&format!(
"{pb_url}/api/collections/invites/records/{invite_id}"
))
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({
"used_by_id": user.id,
"used_at": now,
}))
.send()
.await;
if invite_type == "admin" {
// Grant license directly
let update_url = format!("{pb_url}/api/collections/users/records/{}", user.id);
let res = state
.http_client
.patch(&update_url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({ "subscription": "licensed" }))
.send()
.await;
match res {
Ok(resp) if resp.status().is_success() => {
state.token_cache.invalidate_by_user_id(&user.id);
info!(user_id = %user.id, code = %req.code, "Admin invite redeemed — user licensed");
Json(RedeemResponse {
result: "licensed".to_string(),
checkout_url: None,
})
.into_response()
}
_ => {
warn!("Failed to update user subscription for admin invite");
StatusCode::BAD_GATEWAY.into_response()
}
}
} else {
// Referral invite — create discounted checkout with dynamic pricing
let count = match super::pricing::count_licensed_users(&state).await {
Ok(c) => c,
Err(err) => {
warn!("Failed to count licensed users for invite checkout: {err}");
return StatusCode::SERVICE_UNAVAILABLE.into_response();
}
};
let price_pence = super::pricing::price_for_count(count);
let secret_key = &state.stripe_secret_key;
let public_url = &state.public_url;
let success_url = format!("{public_url}/pricing?license_success=1");
let cancel_url = format!("{public_url}/pricing");
let form_params = vec![
("mode", "payment".to_string()),
(
"line_items[0][price_data][unit_amount]",
price_pence.to_string(),
),
("line_items[0][price_data][currency]", "gbp".to_string()),
(
"line_items[0][price_data][product_data][name]",
"Perfect Postcodes Lifetime License".to_string(),
),
("line_items[0][quantity]", "1".to_string()),
("success_url", success_url),
("cancel_url", cancel_url),
("client_reference_id", user.id.clone()),
("customer_email", user.email.clone()),
("discounts[0][coupon]", state.stripe_referral_coupon_id.clone()),
];
let stripe_res = state
.http_client
.post("https://api.stripe.com/v1/checkout/sessions")
.basic_auth(secret_key, None::<&str>)
.form(&form_params)
.send()
.await;
match stripe_res {
Ok(resp) if resp.status().is_success() => {
let stripe_body: serde_json::Value = resp.json().await.unwrap_or_default();
let checkout_url = stripe_body["url"]
.as_str()
.unwrap_or_default()
.to_string();
info!(user_id = %user.id, code = %req.code, "Referral invite redeemed — checkout created");
Json(RedeemResponse {
result: "checkout".to_string(),
checkout_url: Some(checkout_url),
})
.into_response()
}
_ => {
warn!("Failed to create Stripe checkout for referral invite");
StatusCode::BAD_GATEWAY.into_response()
}
}
}
}

View file

@ -0,0 +1,64 @@
use std::sync::Arc;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::{Extension, Json};
use serde::Deserialize;
use tracing::warn;
use crate::auth::OptionalUser;
use crate::pocketbase::auth_superuser;
use crate::state::AppState;
#[derive(Deserialize)]
pub struct UpdateNewsletterRequest {
newsletter: bool,
}
pub async fn patch_newsletter(
state: Arc<AppState>,
Extension(user): Extension<OptionalUser>,
Json(req): Json<UpdateNewsletterRequest>,
) -> Response {
let user = match user.0 {
Some(u) => u,
None => return StatusCode::UNAUTHORIZED.into_response(),
};
let pb_url = state.pocketbase_url.trim_end_matches('/');
let token = match auth_superuser(&state.http_client, pb_url, &state.pocketbase_admin_email, &state.pocketbase_admin_password).await
{
Ok(t) => t,
Err(err) => {
warn!("Failed to authenticate as PocketBase superuser: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
};
let url = format!("{pb_url}/api/collections/users/records/{}", user.id);
let res = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({ "newsletter": req.newsletter }))
.send()
.await;
match res {
Ok(resp) if resp.status().is_success() => {
state.token_cache.invalidate_by_user_id(&user.id);
StatusCode::OK.into_response()
}
Ok(resp) => {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
warn!("PocketBase user update failed ({status}): {text}");
StatusCode::BAD_GATEWAY.into_response()
}
Err(err) => {
warn!("PocketBase request error: {err}");
StatusCode::BAD_GATEWAY.into_response()
}
}
}

View file

@ -23,9 +23,16 @@ pub async fn proxy_to_pocketbase(state: Arc<AppState>, req: Request) -> impl Int
let method = req.method().clone();
let mut builder = state.http_client.request(method, &url);
// Forward headers except host
// Forward only safe headers (allowlist)
const ALLOWED_HEADERS: &[&str] = &[
"content-type",
"accept",
"authorization",
"cookie",
"accept-language",
];
for (name, value) in req.headers() {
if name != "host" {
if ALLOWED_HEADERS.contains(&name.as_str()) {
builder = builder.header(name.clone(), value.clone());
}
}

View file

@ -6,11 +6,13 @@ use axum::response::Json;
use serde::{Deserialize, Serialize};
use tracing::info;
use crate::data::slugify;
use crate::state::AppState;
#[derive(Serialize)]
pub struct PlaceResult {
name: String,
slug: String,
place_type: String,
lat: f32,
lon: f32,
@ -28,6 +30,8 @@ pub struct PlacesResponse {
pub struct PlacesParams {
q: String,
limit: Option<usize>,
/// If set, only return places that have travel time data for this mode.
mode: Option<String>,
}
pub async fn get_places(
@ -41,33 +45,44 @@ pub async fn get_places(
};
let limit = params.limit.unwrap_or(7).min(20);
let mode_filter = params.mode;
let places = tokio::task::spawn_blocking(move || {
let t0 = std::time::Instant::now();
let query_lower = query.to_lowercase();
let pd = &state.place_data;
let tt_store = &state.travel_time_store;
// Linear scan — ~50-100k rows, <1ms
// Tuple: (row_idx, is_exact, is_prefix, type_rank, population, name_len)
let mut matches: Vec<(usize, bool, bool, u8, u32, usize)> = pd
// Tuple: (row_idx, is_exact, is_prefix, type_rank, population, name_len, slug)
let mut matches: Vec<(usize, bool, bool, u8, u32, usize, String)> = pd
.name_lower
.iter()
.enumerate()
.filter_map(|(idx, name)| {
if name.contains(&query_lower) {
let is_exact = name.len() == query_lower.len();
let is_prefix = name.starts_with(&query_lower);
Some((
idx,
is_exact,
is_prefix,
pd.type_rank[idx],
pd.population[idx],
pd.name[idx].len(),
))
} else {
None
if !name.contains(&query_lower) {
return None;
}
let slug = slugify(&pd.name[idx]);
// If mode filter is set, only include places with travel data
if let Some(ref mode) = mode_filter {
if !tt_store.has_destination(mode, &slug) {
return None;
}
}
let is_exact = name.len() == query_lower.len();
let is_prefix = name.starts_with(&query_lower);
Some((
idx,
is_exact,
is_prefix,
pd.type_rank[idx],
pd.population[idx],
pd.name[idx].len(),
slug,
))
})
.collect();
@ -85,12 +100,13 @@ pub async fn get_places(
let results: Vec<PlaceResult> = matches
.iter()
.map(|&(idx, ..)| PlaceResult {
name: pd.name[idx].clone(),
place_type: pd.place_type.get(idx).to_string(),
lat: pd.lat[idx],
lon: pd.lon[idx],
city: pd.city[idx].clone(),
.map(|(idx, .., slug)| PlaceResult {
name: pd.name[*idx].clone(),
slug: slug.clone(),
place_type: pd.place_type.get(*idx).to_string(),
lat: pd.lat[*idx],
lon: pd.lon[*idx],
city: pd.city[*idx].clone(),
})
.collect();
@ -99,6 +115,7 @@ pub async fn get_places(
query = query.as_str(),
results = results.len(),
scanned = pd.name_lower.len(),
mode = mode_filter.as_deref().unwrap_or("-"),
ms = format_args!("{:.1}", elapsed.as_secs_f64() * 1000.0),
"GET /api/places"
);

View file

@ -2,11 +2,14 @@ use std::sync::Arc;
use axum::extract::Query;
use axum::http::StatusCode;
use axum::response::Json;
use axum::response::{IntoResponse, Json};
use axum::Extension;
use serde::Deserialize;
use tracing::{info, warn};
use crate::auth::OptionalUser;
use crate::consts::POSTCODE_SEARCH_OFFSET;
use crate::licensing::check_license_point;
use crate::parsing::{parse_field_set, parse_filters, row_passes_filters};
use crate::state::AppState;
@ -24,8 +27,9 @@ pub struct PostcodeStatsParams {
pub async fn get_postcode_stats(
state: Arc<AppState>,
Extension(user): Extension<OptionalUser>,
Query(params): Query<PostcodeStatsParams>,
) -> Result<Json<HexagonStatsResponse>, (StatusCode, String)> {
) -> Result<Json<HexagonStatsResponse>, axum::response::Response> {
// Normalize postcode: uppercase, collapse whitespace
let normalized = params
.postcode
@ -42,18 +46,23 @@ pub async fn get_postcode_stats(
return Err((
StatusCode::NOT_FOUND,
format!("Postcode not found: {}", normalized),
));
)
.into_response());
}
};
let (centroid_lat, centroid_lon) = state.postcode_data.centroids[pc_idx];
// License check using postcode centroid
check_license_point(&user.0, centroid_lat as f64, centroid_lon as f64)
.map_err(|(_, resp)| resp)?;
let filters_str = params.filters.clone();
let (parsed_filters, parsed_enum_filters) = parse_filters(
params.filters.as_deref(),
&state.feature_name_to_index,
&state.data.enum_values,
)
.map_err(|err| (StatusCode::BAD_REQUEST, err))?;
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
let (fields_specified, field_set) = parse_field_set(params.fields.as_deref());
@ -129,8 +138,8 @@ pub async fn get_postcode_stats(
})
})
.await
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()))?
.map_err(|error: String| (StatusCode::INTERNAL_SERVER_ERROR, error))?;
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()).into_response())?
.map_err(|error: String| (StatusCode::INTERNAL_SERVER_ERROR, error).into_response())?;
Ok(Json(response))
}

View file

@ -2,14 +2,17 @@ use std::sync::Arc;
use axum::extract::{Path, Query};
use axum::http::StatusCode;
use axum::response::Json;
use axum::response::{IntoResponse, Json};
use axum::Extension;
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use tracing::info;
use crate::aggregation::Aggregator;
use crate::auth::OptionalUser;
use crate::consts::MAX_CELLS_PER_REQUEST;
use crate::licensing::check_license_bounds;
use crate::parsing::{
bounds_intersect, parse_field_indices, parse_filters, require_bounds, row_passes_filters,
};
@ -60,9 +63,14 @@ fn build_postcode_geometry(rings: &[Vec<[f32; 2]>]) -> Value {
pub async fn get_postcodes(
state: Arc<AppState>,
Extension(user): Extension<OptionalUser>,
Query(params): Query<PostcodeParams>,
) -> Result<Json<PostcodesResponse>, (StatusCode, String)> {
let (south, west, north, east) = require_bounds(params.bounds)?;
) -> Result<Json<PostcodesResponse>, axum::response::Response> {
let (south, west, north, east) =
require_bounds(params.bounds).map_err(IntoResponse::into_response)?;
check_license_bounds(&user.0, (south, west, north, east))
.map_err(|(_, resp)| resp)?;
let filters_str = params.filters.clone();
let (parsed_filters, parsed_enum_filters) = parse_filters(
@ -70,10 +78,11 @@ pub async fn get_postcodes(
&state.feature_name_to_index,
&state.data.enum_values,
)
.map_err(|err| (StatusCode::BAD_REQUEST, err))?;
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
let field_indices = parse_field_indices(params.fields.as_deref(), &state.feature_name_to_index);
let field_indices = parse_field_indices(params.fields.as_deref(), &state.feature_name_to_index)
.map_err(|err| (err.0, err.1).into_response())?;
let response = tokio::task::spawn_blocking(move || -> Result<PostcodesResponse, String> {
let postcode_data = &state.postcode_data;
@ -222,7 +231,7 @@ pub async fn get_postcodes(
}
}
let truncated = features.len() >= MAX_CELLS_PER_REQUEST;
let truncated = features.len() > MAX_CELLS_PER_REQUEST;
let t_total = t0.elapsed();
info!(
postcodes_before_filter,
@ -242,8 +251,8 @@ pub async fn get_postcodes(
})
})
.await
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()))?
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error))?;
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()).into_response())?
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error).into_response())?;
Ok(Json(response))
}

View file

@ -0,0 +1,105 @@
use std::sync::Arc;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::Json;
use serde::Serialize;
use tracing::warn;
use crate::pocketbase::auth_superuser;
use crate::state::AppState;
/// Pricing tiers: (cumulative user cap, price in pence).
const TIERS: &[(u64, u64)] = &[
(10, 0), // First 10 users: free
(20, 1000), // Next 10: £10
(45, 2500), // Next 25: £25
(95, 5000), // Next 50: £50
];
const FINAL_PRICE_PENCE: u64 = 10000; // £100 after 95
#[derive(Serialize)]
pub struct Tier {
up_to: Option<u64>,
price_pence: u64,
slots: u64,
}
#[derive(Serialize)]
pub struct PricingResponse {
licensed_count: u64,
current_price_pence: u64,
tiers: Vec<Tier>,
}
/// Determine the price (in pence) for the next user given `count` existing licensed users.
pub fn price_for_count(count: u64) -> u64 {
for &(cap, price) in TIERS {
if count < cap {
return price;
}
}
FINAL_PRICE_PENCE
}
/// Count users with subscription="licensed" in PocketBase.
pub async fn count_licensed_users(state: &AppState) -> anyhow::Result<u64> {
let pb_url = state.pocketbase_url.trim_end_matches('/');
let token = auth_superuser(&state.http_client, pb_url, &state.pocketbase_admin_email, &state.pocketbase_admin_password).await?;
let filter = "subscription=\"licensed\"";
let url = format!(
"{pb_url}/api/collections/users/records?filter={}&perPage=1",
urlencoding::encode(filter)
);
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
if !resp.status().is_success() {
anyhow::bail!("PocketBase returned {}", resp.status());
}
let body: serde_json::Value = resp.json().await?;
let total = body["totalItems"].as_u64().unwrap_or(0);
Ok(total)
}
pub async fn get_pricing(state: Arc<AppState>) -> Response {
let count = match count_licensed_users(&state).await {
Ok(c) => c,
Err(err) => {
warn!("Failed to count licensed users: {err}");
return StatusCode::SERVICE_UNAVAILABLE.into_response();
}
};
let current_price = price_for_count(count);
let mut tiers = Vec::new();
let mut prev_cap = 0u64;
for &(cap, price) in TIERS {
tiers.push(Tier {
up_to: Some(cap),
price_pence: price,
slots: cap - prev_cap,
});
prev_cap = cap;
}
tiers.push(Tier {
up_to: None,
price_pence: FINAL_PRICE_PENCE,
slots: 0,
});
Json(PricingResponse {
licensed_count: count,
current_price_pence: current_price,
tiers,
})
.into_response()
}

View file

@ -3,12 +3,15 @@ use std::sync::Arc;
use axum::extract::Query;
use axum::http::StatusCode;
use axum::response::Json;
use axum::response::{IntoResponse, Json};
use axum::Extension;
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use tracing::{info, warn};
use crate::auth::OptionalUser;
use crate::consts::{DEFAULT_PROPERTIES_LIMIT, MAX_PROPERTIES_LIMIT};
use crate::licensing::check_license_bounds;
use crate::parsing::{
cell_for_row, h3_cell_bounds, needs_parent, parse_filters, row_passes_filters,
validate_h3_resolution,
@ -90,19 +93,25 @@ fn lookup_enum_value(
pub async fn get_hexagon_properties(
state: Arc<AppState>,
Extension(user): Extension<OptionalUser>,
Query(params): Query<HexagonPropertiesParams>,
) -> Result<Json<HexagonPropertiesResponse>, (StatusCode, String)> {
) -> Result<Json<HexagonPropertiesResponse>, axum::response::Response> {
let cell = h3o::CellIndex::from_str(&params.h3).map_err(|error| {
warn!(h3 = %params.h3, error = %error, "Invalid H3 cell index");
(
StatusCode::BAD_REQUEST,
format!("Invalid H3 cell: {}", error),
)
.into_response()
})?;
let cell_u64: u64 = cell.into();
let resolution = params.resolution;
validate_h3_resolution(resolution)?;
validate_h3_resolution(resolution).map_err(IntoResponse::into_response)?;
// License check using H3 cell bounds
let h3_bounds = h3_cell_bounds(cell, 0.0);
check_license_bounds(&user.0, h3_bounds).map_err(|(_, resp)| resp)?;
let h3_str = params.h3.clone();
let filters_str = params.filters.clone();
@ -111,7 +120,7 @@ pub async fn get_hexagon_properties(
&state.feature_name_to_index,
&state.data.enum_values,
)
.map_err(|err| (StatusCode::BAD_REQUEST, err))?;
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
let result = tokio::task::spawn_blocking(move || {
@ -249,8 +258,8 @@ pub async fn get_hexagon_properties(
})
})
.await
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()))?
.map_err(|error: String| (StatusCode::INTERNAL_SERVER_ERROR, error))?;
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()).into_response())?
.map_err(|error: String| (StatusCode::INTERNAL_SERVER_ERROR, error).into_response())?;
Ok(Json(result))
}

View file

@ -1,6 +1,7 @@
use std::collections::{HashMap, HashSet};
use rustc_hash::FxHashMap;
use tracing::warn;
use crate::consts::MAX_PRICE_HISTORY_POINTS;
use crate::data::FeatureStats;
@ -78,6 +79,13 @@ pub fn compute_feature_stats(
let idx = value as usize;
if idx < value_counts.len() {
value_counts[idx] += 1;
} else {
warn!(
feature = feature_name.as_str(),
idx,
max = value_counts.len(),
"Enum index out of bounds — possible data/schema mismatch"
);
}
}
}

View file

@ -0,0 +1,129 @@
use std::sync::Arc;
use axum::body::Bytes;
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use hmac::{Hmac, Mac};
use sha2::Sha256;
use tracing::{info, warn};
use crate::pocketbase::auth_superuser;
use crate::state::AppState;
type HmacSha256 = Hmac<Sha256>;
/// Verify Stripe webhook signature (v1 scheme).
fn verify_signature(payload: &[u8], sig_header: &str, secret: &str) -> bool {
// Parse timestamp and signature from header: "t=TIMESTAMP,v1=SIGNATURE"
let mut timestamp = None;
let mut signature = None;
for part in sig_header.split(',') {
if let Some(ts) = part.strip_prefix("t=") {
timestamp = Some(ts);
} else if let Some(sig) = part.strip_prefix("v1=") {
signature = Some(sig);
}
}
let (ts, sig_hex) = match (timestamp, signature) {
(Some(t), Some(s)) => (t, s),
_ => return false,
};
// Compute expected signature: HMAC-SHA256(secret, "TIMESTAMP.PAYLOAD")
let signed_payload = format!("{ts}.{}", String::from_utf8_lossy(payload));
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) {
Ok(m) => m,
Err(_) => return false,
};
mac.update(signed_payload.as_bytes());
// Decode the provided hex signature and verify with constant-time comparison
let sig_bytes = match hex::decode(sig_hex) {
Ok(bytes) => bytes,
Err(_) => return false,
};
mac.verify_slice(&sig_bytes).is_ok()
}
/// Handle Stripe webhook events.
/// On `checkout.session.completed`, updates the user's subscription to "licensed".
pub async fn post_stripe_webhook(
state: Arc<AppState>,
headers: HeaderMap,
body: Bytes,
) -> Response {
let webhook_secret = &state.stripe_webhook_secret;
let sig_header = match headers.get("stripe-signature").and_then(|h| h.to_str().ok()) {
Some(s) => s,
None => {
warn!("Missing Stripe-Signature header");
return StatusCode::BAD_REQUEST.into_response();
}
};
if !verify_signature(&body, sig_header, webhook_secret) {
warn!("Invalid Stripe webhook signature");
return StatusCode::BAD_REQUEST.into_response();
}
let event: serde_json::Value = match serde_json::from_slice(&body) {
Ok(v) => v,
Err(err) => {
warn!("Failed to parse webhook body: {err}");
return StatusCode::BAD_REQUEST.into_response();
}
};
let event_type = event["type"].as_str().unwrap_or("");
info!(event_type, "Received Stripe webhook");
if event_type == "checkout.session.completed" {
let user_id = event["data"]["object"]["client_reference_id"]
.as_str()
.unwrap_or("");
if user_id.is_empty() {
warn!("checkout.session.completed missing client_reference_id");
return StatusCode::OK.into_response();
}
// Update user subscription to "licensed" via PocketBase superuser auth
let pb_url = state.pocketbase_url.trim_end_matches('/');
let token = match auth_superuser(&state.http_client, pb_url, &state.pocketbase_admin_email, &state.pocketbase_admin_password)
.await
{
Ok(t) => t,
Err(err) => {
warn!("Failed to auth as PocketBase superuser in webhook: {err}");
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
let res = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({ "subscription": "licensed" }))
.send()
.await;
match res {
Ok(resp) if resp.status().is_success() => {
state.token_cache.invalidate_by_user_id(user_id);
info!(user_id, "User subscription updated to licensed via Stripe webhook");
}
Ok(resp) => {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
warn!(user_id, "Failed to update user subscription ({status}): {text}");
}
Err(err) => {
warn!(user_id, "PocketBase request error in webhook: {err}");
}
}
}
StatusCode::OK.into_response()
}

View file

@ -0,0 +1,78 @@
use std::sync::Arc;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::{Extension, Json};
use serde::Deserialize;
use tracing::warn;
use crate::auth::OptionalUser;
use crate::pocketbase::auth_superuser;
use crate::state::AppState;
const VALID_SUBSCRIPTIONS: &[&str] = &["free", "licensed"];
#[derive(Deserialize)]
pub struct UpdateSubscriptionRequest {
subscription: String,
}
pub async fn patch_subscription(
state: Arc<AppState>,
Extension(user): Extension<OptionalUser>,
Json(req): Json<UpdateSubscriptionRequest>,
) -> Response {
let user = match user.0 {
Some(u) => u,
None => return StatusCode::UNAUTHORIZED.into_response(),
};
if !user.is_admin {
return StatusCode::FORBIDDEN.into_response();
}
if !VALID_SUBSCRIPTIONS.contains(&req.subscription.as_str()) {
return (
StatusCode::BAD_REQUEST,
format!("Invalid subscription: {}", req.subscription),
)
.into_response();
}
let pb_url = state.pocketbase_url.trim_end_matches('/');
let token = match auth_superuser(&state.http_client, pb_url, &state.pocketbase_admin_email, &state.pocketbase_admin_password).await
{
Ok(t) => t,
Err(err) => {
warn!("Failed to authenticate as PocketBase superuser: {err}");
return StatusCode::BAD_GATEWAY.into_response();
}
};
let url = format!("{pb_url}/api/collections/users/records/{}", user.id);
let res = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&serde_json::json!({ "subscription": req.subscription }))
.send()
.await;
match res {
Ok(resp) if resp.status().is_success() => {
state.token_cache.invalidate_by_user_id(&user.id);
StatusCode::OK.into_response()
}
Ok(resp) => {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
warn!("PocketBase user update failed ({status}): {text}");
StatusCode::BAD_GATEWAY.into_response()
}
Err(err) => {
warn!("PocketBase request error: {err}");
StatusCode::BAD_GATEWAY.into_response()
}
}
}

View file

@ -1,7 +1,7 @@
use std::sync::Arc;
use axum::extract::{Path, Query, State};
use axum::http::{header, HeaderMap, StatusCode};
use axum::http::{header, StatusCode};
use axum::response::{IntoResponse, Response};
use pmtiles::async_reader::AsyncPmTilesReader;
use pmtiles::MmapBackend;
@ -40,7 +40,7 @@ pub struct StyleParams {
pub async fn get_style(
State(reader): State<Arc<TileReader>>,
headers: HeaderMap,
public_url: String,
Query(params): Query<StyleParams>,
) -> Result<Response, (StatusCode, String)> {
let is_dark = params.theme.as_deref() == Some("dark");
@ -50,7 +50,7 @@ pub async fn get_style(
warn!(error = %err, "Failed to get PMTiles metadata");
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to get PMTiles metadata: {err}"),
"Failed to read tile metadata".to_string(),
)
})?;
@ -59,7 +59,7 @@ pub async fn get_style(
warn!(error = %err, "Failed to parse PMTiles metadata JSON");
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to parse PMTiles metadata: {err}"),
"Failed to parse tile metadata".to_string(),
)
})?;
@ -70,15 +70,8 @@ pub async fn get_style(
.cloned()
.unwrap_or_default();
// Build absolute tile URL using the request host
let host = headers
.get(header::HOST)
.and_then(|hv| hv.to_str().ok())
.ok_or((
StatusCode::BAD_REQUEST,
"Missing Host header".into(),
))?;
let tile_url = format!("http://{}/api/tiles/{{z}}/{{x}}/{{y}}", host);
// Build absolute tile URL using the configured public URL (not the Host header)
let tile_url = format!("{}/api/tiles/{{z}}/{{x}}/{{y}}", public_url.trim_end_matches('/'));
let style = build_style(is_dark, &layers, &tile_url);
Ok((

View file

@ -0,0 +1,38 @@
use std::sync::Arc;
use axum::http::StatusCode;
use axum::response::Json;
use serde::Serialize;
use crate::state::AppState;
#[derive(Serialize)]
pub struct TravelModeInfo {
mode: String,
destinations: usize,
}
#[derive(Serialize)]
pub struct TravelModesResponse {
modes: Vec<TravelModeInfo>,
}
pub async fn get_travel_modes(
state: Arc<AppState>,
) -> Result<Json<TravelModesResponse>, (StatusCode, String)> {
let store = &state.travel_time_store;
let modes = store
.available_modes
.iter()
.map(|mode| TravelModeInfo {
mode: mode.clone(),
destinations: store
.destinations
.get(mode)
.map(|slugs| slugs.len())
.unwrap_or(0),
})
.collect();
Ok(Json(TravelModesResponse { modes }))
}

View file

@ -1,72 +1,30 @@
use serde::{Deserialize, Serialize};
use tracing::warn;
#[derive(Serialize)]
struct R5Request {
origin: [f64; 2],
destinations: Vec<[f64; 2]>,
mode: String,
/// Per-hex-cell travel time aggregation.
pub struct TravelTimeAgg {
pub min: f32,
pub max: f32,
pub sum: f64,
pub count: u32,
}
#[derive(Deserialize)]
struct R5Response {
travel_times: Vec<f64>,
}
/// Call the R5 Java service to compute one-to-many travel times.
///
/// `origins` are hex centroids as `[lat, lon]`.
/// `destination` is the user-chosen point as `[lat, lon]`.
/// `mode` is one of "car", "bicycle", "walking", "transit".
///
/// R5 computes from destination to all origins (one-to-many from the user's chosen point).
/// Returns a Vec of travel times in minutes (one per origin), with None for unreachable.
pub async fn fetch_travel_times(
client: &reqwest::Client,
r5_url: &str,
origins: Vec<[f64; 2]>,
destination: [f64; 2],
mode: &str,
) -> Result<Vec<Option<f64>>, String> {
if origins.is_empty() {
return Ok(vec![]);
impl TravelTimeAgg {
pub fn new() -> Self {
TravelTimeAgg {
min: f32::INFINITY,
max: f32::NEG_INFINITY,
sum: 0.0,
count: 0,
}
}
let body = R5Request {
origin: destination,
destinations: origins,
mode: mode.to_string(),
};
let resp = client
.post(format!("{}/travel-times", r5_url))
.json(&body)
.timeout(std::time::Duration::from_secs(30))
.send()
.await
.map_err(|e| {
warn!("R5 request failed: {}", e);
format!("R5 routing error: {}", e)
})?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
warn!("R5 returned {}: {}", status, body);
return Err(format!("R5 returned {}: {}", status, body));
#[inline]
pub fn add(&mut self, value: f32) {
if value < self.min {
self.min = value;
}
if value > self.max {
self.max = value;
}
self.sum += value as f64;
self.count += 1;
}
let r5_resp: R5Response = resp.json().await.map_err(|e| {
warn!("Failed to parse R5 response: {}", e);
format!("Failed to parse R5 response: {}", e)
})?;
// R5 returns -1 for unreachable destinations
let travel_times: Vec<Option<f64>> = r5_resp
.travel_times
.into_iter()
.map(|t| if t < 0.0 { None } else { Some(t) })
.collect();
Ok(travel_times)
}