lmao
This commit is contained in:
parent
03445188ea
commit
524580eb25
102 changed files with 36625 additions and 1295 deletions
178
server-rs/src/routes/checkout.rs
Normal file
178
server-rs/src/routes/checkout.rs
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::{Extension, Json};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::auth::OptionalUser;
|
||||
use crate::pocketbase::auth_superuser;
|
||||
use crate::state::AppState;
|
||||
|
||||
use super::pricing::{count_licensed_users, price_for_count};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct CheckoutRequest {
|
||||
referral_code: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct CheckoutResponse {
|
||||
url: String,
|
||||
}
|
||||
|
||||
/// Create a Stripe Checkout session for the lifetime license (or grant for free if in free tier).
|
||||
/// Requires authentication. Optionally accepts a referral code to apply a coupon.
|
||||
pub async fn post_checkout(
|
||||
state: Arc<AppState>,
|
||||
Extension(user): Extension<OptionalUser>,
|
||||
Json(req): Json<CheckoutRequest>,
|
||||
) -> Response {
|
||||
let user = match user.0 {
|
||||
Some(u) => u,
|
||||
None => return StatusCode::UNAUTHORIZED.into_response(),
|
||||
};
|
||||
|
||||
let count = match count_licensed_users(&state).await {
|
||||
Ok(c) => c,
|
||||
Err(err) => {
|
||||
warn!("Failed to count licensed users at checkout: {err}");
|
||||
return StatusCode::SERVICE_UNAVAILABLE.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let price_pence = price_for_count(count);
|
||||
let public_url = &state.public_url;
|
||||
let success_url = format!("{public_url}/pricing?license_success=1");
|
||||
|
||||
// Free tier — grant license directly without Stripe
|
||||
if price_pence == 0 {
|
||||
if let Err(err) = grant_license(&state, &user.id).await {
|
||||
warn!(user_id = %user.id, "Failed to grant free license: {err}");
|
||||
return StatusCode::BAD_GATEWAY.into_response();
|
||||
}
|
||||
info!(user_id = %user.id, "Granted free early-bird license");
|
||||
return Json(CheckoutResponse { url: success_url }).into_response();
|
||||
}
|
||||
|
||||
// Paid tier — create Stripe checkout with dynamic price
|
||||
let secret_key = &state.stripe_secret_key;
|
||||
let cancel_url = format!("{public_url}/pricing");
|
||||
|
||||
let mut form_params = vec![
|
||||
("mode", "payment".to_string()),
|
||||
(
|
||||
"line_items[0][price_data][unit_amount]",
|
||||
price_pence.to_string(),
|
||||
),
|
||||
("line_items[0][price_data][currency]", "gbp".to_string()),
|
||||
(
|
||||
"line_items[0][price_data][product_data][name]",
|
||||
"Perfect Postcodes Lifetime License".to_string(),
|
||||
),
|
||||
("line_items[0][quantity]", "1".to_string()),
|
||||
("success_url", success_url),
|
||||
("cancel_url", cancel_url),
|
||||
("client_reference_id", user.id.clone()),
|
||||
("customer_email", user.email.clone()),
|
||||
];
|
||||
|
||||
// If a referral code is provided and valid, look it up and apply the coupon
|
||||
if let Some(ref code) = req.referral_code {
|
||||
if validate_referral_invite(&state, code).await {
|
||||
form_params.push(("discounts[0][coupon]", state.stripe_referral_coupon_id.clone()));
|
||||
info!(code = %code, "Applying referral coupon to checkout");
|
||||
}
|
||||
}
|
||||
|
||||
let res = state
|
||||
.http_client
|
||||
.post("https://api.stripe.com/v1/checkout/sessions")
|
||||
.basic_auth(secret_key, None::<&str>)
|
||||
.form(&form_params)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let body: serde_json::Value = match resp.json().await {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
warn!("Failed to parse Stripe response: {err}");
|
||||
return StatusCode::BAD_GATEWAY.into_response();
|
||||
}
|
||||
};
|
||||
let url = body["url"].as_str().unwrap_or_default().to_string();
|
||||
if url.is_empty() {
|
||||
warn!("Stripe session missing URL");
|
||||
return StatusCode::BAD_GATEWAY.into_response();
|
||||
}
|
||||
info!(user_id = %user.id, price_pence, "Created Stripe checkout session");
|
||||
Json(CheckoutResponse { url }).into_response()
|
||||
}
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
warn!("Stripe checkout failed ({status}): {text}");
|
||||
StatusCode::BAD_GATEWAY.into_response()
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Stripe request error: {err}");
|
||||
StatusCode::BAD_GATEWAY.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Grant a license by updating the user's subscription to "licensed" in PocketBase.
|
||||
async fn grant_license(state: &AppState, user_id: &str) -> anyhow::Result<()> {
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let token = auth_superuser(&state.http_client, pb_url, &state.pocketbase_admin_email, &state.pocketbase_admin_password).await?;
|
||||
|
||||
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({ "subscription": "licensed" }))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("PocketBase update failed ({status}): {text}");
|
||||
}
|
||||
|
||||
state.token_cache.invalidate_by_user_id(user_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a referral invite code exists and is unused.
|
||||
async fn validate_referral_invite(state: &AppState, code: &str) -> bool {
|
||||
// Only allow alphanumeric codes to prevent PocketBase filter injection
|
||||
if code.is_empty()
|
||||
|| code.len() > 20
|
||||
|| !code.bytes().all(|b| b.is_ascii_alphanumeric())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let filter = format!(
|
||||
"code=\"{}\" && invite_type=\"referral\" && used_by_id=\"\"",
|
||||
code
|
||||
);
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/invites/records?filter={}&perPage=1",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
|
||||
match state.http_client.get(&url).send().await {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let body: serde_json::Value = resp.json().await.unwrap_or_default();
|
||||
body["totalItems"].as_u64().unwrap_or(0) > 0
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
|
@ -5,11 +5,14 @@ use std::sync::Arc;
|
|||
use axum::extract::Query;
|
||||
use axum::http::{header, StatusCode};
|
||||
use axum::response::IntoResponse;
|
||||
use axum::Extension;
|
||||
use rust_xlsxwriter::{Format, FormatAlign, FormatBorder, Image, Url, Workbook};
|
||||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
use serde::Deserialize;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::auth::OptionalUser;
|
||||
use crate::licensing::check_license_bounds;
|
||||
use crate::parsing::{parse_field_indices, parse_filters, require_bounds, row_passes_filters};
|
||||
use crate::routes::FeatureInfo;
|
||||
use crate::state::AppState;
|
||||
|
|
@ -150,9 +153,14 @@ async fn fetch_screenshot(
|
|||
|
||||
pub async fn get_export(
|
||||
state: Arc<AppState>,
|
||||
Extension(user): Extension<OptionalUser>,
|
||||
Query(params): Query<ExportParams>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
let (south, west, north, east) = require_bounds(params.bounds)?;
|
||||
) -> Result<impl IntoResponse, axum::response::Response> {
|
||||
let (south, west, north, east) =
|
||||
require_bounds(params.bounds).map_err(IntoResponse::into_response)?;
|
||||
|
||||
check_license_bounds(&user.0, (south, west, north, east))
|
||||
.map_err(|(_, resp)| resp)?;
|
||||
|
||||
let filters_str = params.filters.clone();
|
||||
let fields_str = params.fields.clone();
|
||||
|
|
@ -161,7 +169,7 @@ pub async fn get_export(
|
|||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
)
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err))?;
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
|
||||
let public_url = state.public_url.clone();
|
||||
|
||||
|
|
@ -269,7 +277,8 @@ pub async fn get_export(
|
|||
let filter_feature_names = extract_filter_feature_names(filters_str.as_deref());
|
||||
|
||||
let field_indices =
|
||||
parse_field_indices(fields_str.as_deref(), &state.feature_name_to_index);
|
||||
parse_field_indices(fields_str.as_deref(), &state.feature_name_to_index)
|
||||
.map_err(|err| err.1)?;
|
||||
|
||||
let all_feature_indices: Vec<usize> = if let Some(ref indices) = field_indices {
|
||||
indices.clone()
|
||||
|
|
@ -564,8 +573,8 @@ pub async fn get_export(
|
|||
Ok(buf)
|
||||
})
|
||||
.await
|
||||
.map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?
|
||||
.map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err))?;
|
||||
.map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?
|
||||
.map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err).into_response())?;
|
||||
|
||||
Ok((
|
||||
[
|
||||
|
|
|
|||
|
|
@ -4,10 +4,13 @@ use std::sync::Arc;
|
|||
|
||||
use axum::extract::Query;
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::Json;
|
||||
use axum::response::{IntoResponse, Json};
|
||||
use axum::Extension;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::auth::OptionalUser;
|
||||
use crate::licensing::check_license_bounds;
|
||||
use crate::parsing::{
|
||||
cell_for_row, h3_cell_bounds, needs_parent, parse_field_set, parse_filters, row_passes_filters,
|
||||
validate_h3_resolution,
|
||||
|
|
@ -70,19 +73,25 @@ pub struct HexagonStatsParams {
|
|||
|
||||
pub async fn get_hexagon_stats(
|
||||
state: Arc<AppState>,
|
||||
Extension(user): Extension<OptionalUser>,
|
||||
Query(params): Query<HexagonStatsParams>,
|
||||
) -> Result<Json<HexagonStatsResponse>, (StatusCode, String)> {
|
||||
) -> Result<Json<HexagonStatsResponse>, axum::response::Response> {
|
||||
let cell = h3o::CellIndex::from_str(¶ms.h3).map_err(|error| {
|
||||
warn!(h3 = %params.h3, error = %error, "Invalid H3 cell index");
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Invalid H3 cell: {}", error),
|
||||
)
|
||||
.into_response()
|
||||
})?;
|
||||
let cell_u64: u64 = cell.into();
|
||||
|
||||
let resolution = params.resolution;
|
||||
validate_h3_resolution(resolution)?;
|
||||
validate_h3_resolution(resolution).map_err(IntoResponse::into_response)?;
|
||||
|
||||
// License check using H3 cell bounds
|
||||
let h3_bounds = h3_cell_bounds(cell, 0.0);
|
||||
check_license_bounds(&user.0, h3_bounds).map_err(|(_, resp)| resp)?;
|
||||
|
||||
let h3_str = params.h3.clone();
|
||||
let filters_str = params.filters.clone();
|
||||
|
|
@ -91,7 +100,7 @@ pub async fn get_hexagon_stats(
|
|||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
)
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err))?;
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
|
||||
|
||||
let (fields_specified, field_set) = parse_field_set(params.fields.as_deref());
|
||||
|
|
@ -164,8 +173,8 @@ pub async fn get_hexagon_stats(
|
|||
})
|
||||
})
|
||||
.await
|
||||
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()))?
|
||||
.map_err(|error: String| (StatusCode::INTERNAL_SERVER_ERROR, error))?;
|
||||
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()).into_response())?
|
||||
.map_err(|error: String| (StatusCode::INTERNAL_SERVER_ERROR, error).into_response())?;
|
||||
|
||||
Ok(Json(response))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,19 +2,23 @@ use std::sync::Arc;
|
|||
|
||||
use axum::extract::Query;
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::Json;
|
||||
use axum::response::{IntoResponse, Json};
|
||||
use axum::Extension;
|
||||
use rustc_hash::FxHashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Map, Value};
|
||||
use tracing::info;
|
||||
|
||||
use crate::aggregation::Aggregator;
|
||||
use crate::auth::OptionalUser;
|
||||
use crate::consts::MAX_CELLS_PER_REQUEST;
|
||||
use crate::data::travel_time::TravelData;
|
||||
use crate::licensing::check_license_bounds;
|
||||
use crate::parsing::{
|
||||
bounds_intersect, cell_for_row, h3_cell_bounds, needs_parent, parse_field_indices,
|
||||
parse_filters, require_bounds, row_passes_filters, validate_h3_resolution,
|
||||
};
|
||||
use crate::routes::travel_time::fetch_travel_times;
|
||||
use crate::routes::travel_time::TravelTimeAgg;
|
||||
use crate::state::AppState;
|
||||
|
||||
#[derive(Serialize)]
|
||||
|
|
@ -27,64 +31,69 @@ pub struct HexagonParams {
|
|||
resolution: u8,
|
||||
bounds: Option<String>,
|
||||
/// Comma-separated filters: `name:min:max,...`
|
||||
/// Rows must have non-NaN values within [min,max] for each filter.
|
||||
filters: Option<String>,
|
||||
/// Comma-separated feature names to include in min/max aggregation.
|
||||
/// When present (even if empty), only listed features are aggregated and written.
|
||||
/// When absent, all features are included (backward compatible).
|
||||
fields: Option<String>,
|
||||
/// Pipe-separated travel time entries: `lat,lon,mode|lat,lon,mode`
|
||||
/// Each entry requests travel time from hex centroids to that destination via the given mode.
|
||||
/// Pipe-separated travel time entries: `mode:slug|mode:slug:min:max`
|
||||
/// Each entry requests travel time aggregation for that mode+destination.
|
||||
/// Optional min:max applies as a filter (exclude properties outside range).
|
||||
travel: Option<String>,
|
||||
}
|
||||
|
||||
struct TravelEntry {
|
||||
lat: f64,
|
||||
lon: f64,
|
||||
mode: String,
|
||||
slug: String,
|
||||
filter_min: Option<f32>,
|
||||
filter_max: Option<f32>,
|
||||
}
|
||||
|
||||
const VALID_MODES: &[&str] = &["car", "bicycle", "walking", "transit"];
|
||||
|
||||
/// Parse `travel` param into a list of travel entries.
|
||||
/// Format: `lat,lon,mode|lat,lon,mode`
|
||||
fn parse_travel_entries(s: &str) -> Result<Vec<TravelEntry>, String> {
|
||||
/// Format: `mode:slug` or `mode:slug:min:max`
|
||||
fn parse_travel_entries(travel_str: &str) -> Result<Vec<TravelEntry>, String> {
|
||||
let mut entries = Vec::new();
|
||||
let mut seen_modes = Vec::new();
|
||||
for segment in s.split('|') {
|
||||
let parts: Vec<&str> = segment.split(',').collect();
|
||||
if parts.len() != 3 {
|
||||
let mut seen_keys = Vec::new();
|
||||
for segment in travel_str.split('|') {
|
||||
let parts: Vec<&str> = segment.split(':').collect();
|
||||
if parts.len() < 2 {
|
||||
return Err(format!(
|
||||
"each travel entry must be 'lat,lon,mode', got '{}'",
|
||||
"each travel entry must be 'mode:slug' or 'mode:slug:min:max', got '{}'",
|
||||
segment
|
||||
));
|
||||
}
|
||||
let lat: f64 = parts[0]
|
||||
.trim()
|
||||
.parse()
|
||||
.map_err(|_| format!("invalid travel latitude in '{}'", segment))?;
|
||||
let lon: f64 = parts[1]
|
||||
.trim()
|
||||
.parse()
|
||||
.map_err(|_| format!("invalid travel longitude in '{}'", segment))?;
|
||||
let mode = parts[2].trim().to_string();
|
||||
if !VALID_MODES.contains(&mode.as_str()) {
|
||||
return Err(format!(
|
||||
"invalid travel mode '{}', must be one of: {}",
|
||||
mode,
|
||||
VALID_MODES.join(", ")
|
||||
));
|
||||
let mode = parts[0].trim().to_string();
|
||||
let slug = parts[1].trim().to_string();
|
||||
|
||||
let (filter_min, filter_max) = if parts.len() >= 4 {
|
||||
let min: f32 = parts[2]
|
||||
.trim()
|
||||
.parse()
|
||||
.map_err(|_| format!("invalid travel filter min in '{}'", segment))?;
|
||||
let max: f32 = parts[3]
|
||||
.trim()
|
||||
.parse()
|
||||
.map_err(|_| format!("invalid travel filter max in '{}'", segment))?;
|
||||
(Some(min), Some(max))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
let key = format!("{}:{}", mode, slug);
|
||||
if seen_keys.contains(&key) {
|
||||
return Err(format!("duplicate travel entry '{}'", key));
|
||||
}
|
||||
if seen_modes.contains(&mode) {
|
||||
return Err(format!("duplicate travel mode '{}'", mode));
|
||||
}
|
||||
seen_modes.push(mode.clone());
|
||||
entries.push(TravelEntry { lat, lon, mode });
|
||||
seen_keys.push(key);
|
||||
entries.push(TravelEntry {
|
||||
mode,
|
||||
slug,
|
||||
filter_min,
|
||||
filter_max,
|
||||
});
|
||||
}
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
/// Build feature maps from aggregated cell data, filtering to only cells that intersect the query bounds.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn build_feature_maps(
|
||||
groups: &FxHashMap<u64, Aggregator>,
|
||||
min_keys: &[String],
|
||||
|
|
@ -92,7 +101,9 @@ fn build_feature_maps(
|
|||
avg_keys: &[String],
|
||||
num_features: usize,
|
||||
indices: Option<&[usize]>,
|
||||
query_bounds: (f64, f64, f64, f64), // (south, west, north, east)
|
||||
query_bounds: (f64, f64, f64, f64),
|
||||
travel_aggs: &[FxHashMap<u64, TravelTimeAgg>],
|
||||
travel_field_keys: &[String],
|
||||
) -> Vec<Map<String, Value>> {
|
||||
let mut features = Vec::with_capacity(groups.len());
|
||||
let (q_south, q_west, q_north, q_east) = query_bounds;
|
||||
|
|
@ -143,6 +154,25 @@ fn build_feature_maps(
|
|||
}
|
||||
}
|
||||
|
||||
// Add travel time aggregation fields
|
||||
for (ti, agg_map) in travel_aggs.iter().enumerate() {
|
||||
if let Some(agg) = agg_map.get(&cell_id) {
|
||||
if agg.count > 0 {
|
||||
let key = &travel_field_keys[ti];
|
||||
let avg = agg.sum / agg.count as f64;
|
||||
if let Some(nm) = serde_json::Number::from_f64(agg.min as f64) {
|
||||
map.insert(format!("min_{key}"), Value::Number(nm));
|
||||
}
|
||||
if let Some(nm) = serde_json::Number::from_f64(agg.max as f64) {
|
||||
map.insert(format!("max_{key}"), Value::Number(nm));
|
||||
}
|
||||
if let Some(nm) = serde_json::Number::from_f64(avg) {
|
||||
map.insert(format!("avg_{key}"), Value::Number(nm));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
features.push(map);
|
||||
}
|
||||
|
||||
|
|
@ -151,12 +181,21 @@ fn build_feature_maps(
|
|||
|
||||
pub async fn get_hexagons(
|
||||
state: Arc<AppState>,
|
||||
Extension(user): Extension<OptionalUser>,
|
||||
Query(params): Query<HexagonParams>,
|
||||
) -> Result<Json<HexagonsResponse>, (StatusCode, String)> {
|
||||
) -> Result<Json<HexagonsResponse>, axum::response::Response> {
|
||||
let resolution = params.resolution;
|
||||
validate_h3_resolution(resolution)?;
|
||||
validate_h3_resolution(resolution).map_err(IntoResponse::into_response)?;
|
||||
|
||||
let (south, west, north, east) = require_bounds(params.bounds)?;
|
||||
let (south, west, north, east) =
|
||||
require_bounds(params.bounds).map_err(IntoResponse::into_response)?;
|
||||
|
||||
// Skip license check at low resolutions (≤5) — data is too aggregated to be
|
||||
// commercially useful, and the homepage demo needs country-wide access.
|
||||
if resolution > 5 {
|
||||
check_license_bounds(&user.0, (south, west, north, east))
|
||||
.map_err(|(_, resp)| resp)?;
|
||||
}
|
||||
|
||||
let filters_str = params.filters.clone();
|
||||
let (parsed_filters, parsed_enum_filters) = parse_filters(
|
||||
|
|
@ -164,30 +203,49 @@ pub async fn get_hexagons(
|
|||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
)
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err))?;
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
|
||||
|
||||
let field_indices = parse_field_indices(params.fields.as_deref(), &state.feature_name_to_index);
|
||||
let field_indices = parse_field_indices(params.fields.as_deref(), &state.feature_name_to_index)
|
||||
.map_err(|err| (err.0, err.1).into_response())?;
|
||||
|
||||
// Parse travel entries
|
||||
let travel_entries = params
|
||||
.travel
|
||||
.as_deref()
|
||||
.filter(|s| !s.is_empty())
|
||||
.filter(|val| !val.is_empty())
|
||||
.map(parse_travel_entries)
|
||||
.transpose()
|
||||
.map_err(|e| (StatusCode::BAD_REQUEST, e))?
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?
|
||||
.unwrap_or_default();
|
||||
|
||||
// Capture what we need for the R5 calls before moving state into spawn_blocking
|
||||
let r5_url = state.r5_url.clone();
|
||||
let http_client = state.http_client.clone();
|
||||
|
||||
let mut response = tokio::task::spawn_blocking(move || -> Result<HexagonsResponse, String> {
|
||||
let response = tokio::task::spawn_blocking(move || -> Result<HexagonsResponse, String> {
|
||||
let t0 = std::time::Instant::now();
|
||||
|
||||
// Load travel time data from precomputed parquet files
|
||||
let travel_data: Vec<TravelData> = if !travel_entries.is_empty() {
|
||||
let store = &state.travel_time_store;
|
||||
travel_entries
|
||||
.iter()
|
||||
.map(|entry| {
|
||||
store
|
||||
.get(&entry.mode, &entry.slug)
|
||||
.map_err(|err| format!("Failed to load travel data: {}", err))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
let has_travel = !travel_entries.is_empty();
|
||||
let travel_field_keys: Vec<String> = travel_entries
|
||||
.iter()
|
||||
.map(|te| format!("tt_{}_{}", te.mode, te.slug))
|
||||
.collect();
|
||||
|
||||
let num_features = state.data.num_features;
|
||||
let feature_data = &state.data.feature_data;
|
||||
let (pc_interner, pc_keys) = state.data.postcode_parts();
|
||||
let min_keys = &state.min_keys;
|
||||
let max_keys = &state.max_keys;
|
||||
let avg_keys = &state.avg_keys;
|
||||
|
|
@ -198,49 +256,70 @@ pub async fn get_hexagons(
|
|||
let need_parent = needs_parent(resolution);
|
||||
|
||||
let mut groups: FxHashMap<u64, Aggregator> = FxHashMap::default();
|
||||
let mut travel_aggs: Vec<FxHashMap<u64, TravelTimeAgg>> =
|
||||
(0..travel_entries.len()).map(|_| FxHashMap::default()).collect();
|
||||
|
||||
// Hoist has_selective branch outside the hot loop to avoid per-row branching
|
||||
if let Some(sel_indices) = field_indices.as_deref() {
|
||||
state
|
||||
.grid
|
||||
.for_each_in_bounds(south, west, north, east, |row_idx| {
|
||||
let row = row_idx as usize;
|
||||
if !row_passes_filters(
|
||||
row,
|
||||
&parsed_filters,
|
||||
&parsed_enum_filters,
|
||||
feature_data,
|
||||
num_features,
|
||||
) {
|
||||
return;
|
||||
// Main aggregation loop
|
||||
let aggregate_row =
|
||||
|row: usize,
|
||||
groups: &mut FxHashMap<u64, Aggregator>,
|
||||
travel_aggs: &mut [FxHashMap<u64, TravelTimeAgg>]| {
|
||||
// Regular filters
|
||||
if !row_passes_filters(
|
||||
row,
|
||||
&parsed_filters,
|
||||
&parsed_enum_filters,
|
||||
feature_data,
|
||||
num_features,
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Travel time filter: check each entry with a range
|
||||
let mut travel_minutes: Vec<Option<i16>> = Vec::new();
|
||||
if has_travel {
|
||||
let postcode = pc_interner.resolve(&pc_keys[row]);
|
||||
travel_minutes.reserve(travel_entries.len());
|
||||
for (ti, entry) in travel_entries.iter().enumerate() {
|
||||
let minutes = travel_data[ti].get(postcode).copied();
|
||||
travel_minutes.push(minutes);
|
||||
if let (Some(fmin), Some(fmax)) = (entry.filter_min, entry.filter_max) {
|
||||
match minutes {
|
||||
Some(mins) if (mins as f32) >= fmin && (mins as f32) <= fmax => {}
|
||||
_ => return, // Filtered out
|
||||
}
|
||||
}
|
||||
}
|
||||
let cell_id = cell_for_row(row, precomputed, h3_res, need_parent);
|
||||
let aggregation = groups
|
||||
.entry(cell_id)
|
||||
.or_insert_with(|| Aggregator::new(num_features));
|
||||
}
|
||||
|
||||
let cell_id = cell_for_row(row, precomputed, h3_res, need_parent);
|
||||
|
||||
// Aggregate regular features
|
||||
let aggregation = groups
|
||||
.entry(cell_id)
|
||||
.or_insert_with(|| Aggregator::new(num_features));
|
||||
if let Some(sel_indices) = field_indices.as_deref() {
|
||||
aggregation.add_row_selective(feature_data, row, num_features, sel_indices);
|
||||
});
|
||||
} else {
|
||||
state
|
||||
.grid
|
||||
.for_each_in_bounds(south, west, north, east, |row_idx| {
|
||||
let row = row_idx as usize;
|
||||
if !row_passes_filters(
|
||||
row,
|
||||
&parsed_filters,
|
||||
&parsed_enum_filters,
|
||||
feature_data,
|
||||
num_features,
|
||||
) {
|
||||
return;
|
||||
}
|
||||
let cell_id = cell_for_row(row, precomputed, h3_res, need_parent);
|
||||
let aggregation = groups
|
||||
.entry(cell_id)
|
||||
.or_insert_with(|| Aggregator::new(num_features));
|
||||
} else {
|
||||
aggregation.add_row(feature_data, row, num_features);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Aggregate travel time
|
||||
for (ti, minutes) in travel_minutes.iter().enumerate() {
|
||||
if let Some(mins) = minutes {
|
||||
let agg = travel_aggs[ti]
|
||||
.entry(cell_id)
|
||||
.or_insert_with(TravelTimeAgg::new);
|
||||
agg.add(*mins as f32);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
state
|
||||
.grid
|
||||
.for_each_in_bounds(south, west, north, east, |row_idx| {
|
||||
aggregate_row(row_idx as usize, &mut groups, &mut travel_aggs);
|
||||
});
|
||||
|
||||
let t_agg = t0.elapsed();
|
||||
|
||||
|
|
@ -252,6 +331,8 @@ pub async fn get_hexagons(
|
|||
num_features,
|
||||
field_indices.as_deref(),
|
||||
(south, west, north, east),
|
||||
&travel_aggs,
|
||||
&travel_field_keys,
|
||||
);
|
||||
|
||||
let truncated = features.len() > MAX_CELLS_PER_REQUEST;
|
||||
|
|
@ -268,6 +349,7 @@ pub async fn get_hexagons(
|
|||
bounds = format_args!("{:.4},{:.4},{:.4},{:.4}", south, west, north, east),
|
||||
filters = num_filters,
|
||||
filters_raw = filters_str.as_deref().unwrap_or("-"),
|
||||
travel_entries = travel_entries.len(),
|
||||
agg_ms = format_args!("{:.1}", t_agg.as_secs_f64() * 1000.0),
|
||||
total_ms = format_args!("{:.1}", t_total.as_secs_f64() * 1000.0),
|
||||
"GET /api/hexagons"
|
||||
|
|
@ -276,76 +358,8 @@ pub async fn get_hexagons(
|
|||
Ok(HexagonsResponse { features })
|
||||
})
|
||||
.await
|
||||
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()))?
|
||||
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error))?;
|
||||
|
||||
// If travel entries were requested and R5 is configured, fetch travel times concurrently.
|
||||
if !travel_entries.is_empty() {
|
||||
let url = r5_url.as_deref().ok_or((
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
"Travel time queries require routing service (R5_URL not configured)".into(),
|
||||
))?;
|
||||
|
||||
// Collect hex centroids
|
||||
let origins: Vec<[f64; 2]> = response
|
||||
.features
|
||||
.iter()
|
||||
.map(|f| {
|
||||
let lat = f
|
||||
.get("lat")
|
||||
.and_then(|v| v.as_f64())
|
||||
.expect("lat must be present in feature map");
|
||||
let lon = f
|
||||
.get("lon")
|
||||
.and_then(|v| v.as_f64())
|
||||
.expect("lon must be present in feature map");
|
||||
[lat, lon]
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Fire concurrent R5 calls for each travel entry
|
||||
let mut handles = Vec::with_capacity(travel_entries.len());
|
||||
for entry in &travel_entries {
|
||||
let client = http_client.clone();
|
||||
let url = url.to_string();
|
||||
let origins = origins.clone();
|
||||
let dest = [entry.lat, entry.lon];
|
||||
let mode = entry.mode.clone();
|
||||
handles.push(tokio::spawn(async move {
|
||||
fetch_travel_times(&client, &url, origins, dest, &mode).await
|
||||
}));
|
||||
}
|
||||
|
||||
let mut results = Vec::with_capacity(handles.len());
|
||||
for handle in handles {
|
||||
results.push(handle.await);
|
||||
}
|
||||
for (entry, result) in travel_entries.iter().zip(results) {
|
||||
let travel_times = result
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
|
||||
.map_err(|err| (StatusCode::BAD_GATEWAY, err))?;
|
||||
|
||||
let field_name = format!("travel_time_{}", entry.mode);
|
||||
for (feature, tt) in response.features.iter_mut().zip(&travel_times) {
|
||||
match tt {
|
||||
Some(minutes) => {
|
||||
if let Some(num) = serde_json::Number::from_f64(*minutes) {
|
||||
feature.insert(field_name.clone(), Value::Number(num));
|
||||
}
|
||||
}
|
||||
None => {
|
||||
feature.insert(field_name.clone(), Value::Null);
|
||||
}
|
||||
}
|
||||
}
|
||||
info!(
|
||||
hexagons = response.features.len(),
|
||||
destination = format_args!("{},{}", entry.lat, entry.lon),
|
||||
mode = entry.mode,
|
||||
"Travel times merged"
|
||||
);
|
||||
}
|
||||
}
|
||||
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()).into_response())?
|
||||
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error).into_response())?;
|
||||
|
||||
Ok(Json(response))
|
||||
}
|
||||
|
|
|
|||
374
server-rs/src/routes/invites.rs
Normal file
374
server-rs/src/routes/invites.rs
Normal file
|
|
@ -0,0 +1,374 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use axum::extract::Path;
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::{Extension, Json};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::auth::OptionalUser;
|
||||
use crate::pocketbase::auth_superuser;
|
||||
use crate::state::AppState;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct InviteResponse {
|
||||
code: String,
|
||||
url: String,
|
||||
invite_type: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct InviteValidation {
|
||||
valid: bool,
|
||||
invite_type: String,
|
||||
used: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct RedeemRequest {
|
||||
code: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct RedeemResponse {
|
||||
/// "licensed" if admin invite was redeemed directly, or a checkout URL for referral
|
||||
result: String,
|
||||
/// For referral invites: the Stripe checkout URL with coupon
|
||||
checkout_url: Option<String>,
|
||||
}
|
||||
|
||||
/// Validate that an invite code contains only safe characters (alphanumeric, lowercase).
|
||||
/// Rejects any code that could be used for PocketBase filter injection.
|
||||
fn validate_invite_code(code: &str) -> Result<(), &'static str> {
|
||||
if code.is_empty() || code.len() > 20 {
|
||||
return Err("Invalid invite code length");
|
||||
}
|
||||
if !code.bytes().all(|b| b.is_ascii_alphanumeric()) {
|
||||
return Err("Invalid invite code characters");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn generate_invite_code() -> String {
|
||||
use rand::Rng;
|
||||
let mut rng = rand::rng();
|
||||
let chars: Vec<char> = (0..12)
|
||||
.map(|_| {
|
||||
let idx: u8 = rng.random_range(0..36);
|
||||
if idx < 10 {
|
||||
(b'0' + idx) as char
|
||||
} else {
|
||||
(b'a' + idx - 10) as char
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
chars.into_iter().collect()
|
||||
}
|
||||
|
||||
/// Create an invite. Admins create "admin" invites (free license).
|
||||
/// Licensed non-admin users create "referral" invites (30% off).
|
||||
pub async fn post_invites(
|
||||
state: Arc<AppState>,
|
||||
Extension(user): Extension<OptionalUser>,
|
||||
_body: Json<serde_json::Value>,
|
||||
) -> Response {
|
||||
let user = match user.0 {
|
||||
Some(u) => u,
|
||||
None => return StatusCode::UNAUTHORIZED.into_response(),
|
||||
};
|
||||
|
||||
let invite_type = if user.is_admin {
|
||||
"admin"
|
||||
} else if user.subscription == "licensed" {
|
||||
"referral"
|
||||
} else {
|
||||
return (StatusCode::FORBIDDEN, "Only licensed users can create invites").into_response();
|
||||
};
|
||||
|
||||
let code = generate_invite_code();
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
|
||||
let token = match auth_superuser(&state.http_client, pb_url, &state.pocketbase_admin_email, &state.pocketbase_admin_password).await
|
||||
{
|
||||
Ok(t) => t,
|
||||
Err(err) => {
|
||||
warn!("Failed to auth as PocketBase superuser: {err}");
|
||||
return StatusCode::BAD_GATEWAY.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let create_url = format!("{pb_url}/api/collections/invites/records");
|
||||
let res = state
|
||||
.http_client
|
||||
.post(&create_url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"code": code,
|
||||
"created_by": user.id,
|
||||
"invite_type": invite_type,
|
||||
"used_by_id": "",
|
||||
"used_at": "",
|
||||
}))
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let public_url = &state.public_url;
|
||||
let url = format!("{public_url}/invite/{code}");
|
||||
info!(code = %code, invite_type, user_id = %user.id, "Created invite");
|
||||
Json(InviteResponse {
|
||||
code,
|
||||
url,
|
||||
invite_type: invite_type.to_string(),
|
||||
})
|
||||
.into_response()
|
||||
}
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
warn!("Failed to create invite ({status}): {text}");
|
||||
StatusCode::BAD_GATEWAY.into_response()
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("PocketBase request error: {err}");
|
||||
StatusCode::BAD_GATEWAY.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate an invite code. Requires authentication to prevent enumeration.
|
||||
pub async fn get_invite(
|
||||
state: Arc<AppState>,
|
||||
Extension(user): Extension<OptionalUser>,
|
||||
Path(code): Path<String>,
|
||||
) -> Response {
|
||||
if user.0.is_none() {
|
||||
return StatusCode::UNAUTHORIZED.into_response();
|
||||
}
|
||||
|
||||
if let Err(msg) = validate_invite_code(&code) {
|
||||
return (StatusCode::BAD_REQUEST, msg).into_response();
|
||||
}
|
||||
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let filter = format!("code=\"{}\"", code);
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/invites/records?filter={}&perPage=1",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
|
||||
let res = match state.http_client.get(&url).send().await {
|
||||
Ok(r) => r,
|
||||
Err(err) => {
|
||||
warn!("Failed to look up invite: {err}");
|
||||
return StatusCode::BAD_GATEWAY.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if !res.status().is_success() {
|
||||
return StatusCode::BAD_GATEWAY.into_response();
|
||||
}
|
||||
|
||||
let body: serde_json::Value = match res.json().await {
|
||||
Ok(v) => v,
|
||||
Err(_) => return StatusCode::BAD_GATEWAY.into_response(),
|
||||
};
|
||||
|
||||
let items = body["items"].as_array();
|
||||
match items.and_then(|arr| arr.first()) {
|
||||
Some(invite) => {
|
||||
let invite_type = invite["invite_type"].as_str().unwrap_or("").to_string();
|
||||
let used_by = invite["used_by_id"].as_str().unwrap_or("");
|
||||
let used = !used_by.is_empty();
|
||||
Json(InviteValidation {
|
||||
valid: true,
|
||||
invite_type,
|
||||
used,
|
||||
})
|
||||
.into_response()
|
||||
}
|
||||
None => Json(InviteValidation {
|
||||
valid: false,
|
||||
invite_type: String::new(),
|
||||
used: false,
|
||||
})
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Redeem an invite code. Requires authentication.
|
||||
/// Admin invite: sets subscription to "licensed" directly.
|
||||
/// Referral invite: returns a discounted Stripe checkout URL.
|
||||
pub async fn post_redeem_invite(
|
||||
state: Arc<AppState>,
|
||||
Extension(user): Extension<OptionalUser>,
|
||||
Json(req): Json<RedeemRequest>,
|
||||
) -> Response {
|
||||
let user = match user.0 {
|
||||
Some(u) => u,
|
||||
None => return StatusCode::UNAUTHORIZED.into_response(),
|
||||
};
|
||||
|
||||
if let Err(msg) = validate_invite_code(&req.code) {
|
||||
return (StatusCode::BAD_REQUEST, msg).into_response();
|
||||
}
|
||||
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
|
||||
let token = match auth_superuser(&state.http_client, pb_url, &state.pocketbase_admin_email, &state.pocketbase_admin_password).await
|
||||
{
|
||||
Ok(t) => t,
|
||||
Err(err) => {
|
||||
warn!("Failed to auth as PocketBase superuser: {err}");
|
||||
return StatusCode::BAD_GATEWAY.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
// Look up invite
|
||||
let filter = format!(
|
||||
"code=\"{}\" && used_by_id=\"\"",
|
||||
req.code
|
||||
);
|
||||
let lookup_url = format!(
|
||||
"{pb_url}/api/collections/invites/records?filter={}&perPage=1",
|
||||
urlencoding::encode(&filter)
|
||||
);
|
||||
|
||||
let res = match state.http_client.get(&lookup_url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send().await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(err) => {
|
||||
warn!("Failed to look up invite: {err}");
|
||||
return StatusCode::BAD_GATEWAY.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let body: serde_json::Value = match res.json().await {
|
||||
Ok(v) => v,
|
||||
Err(_) => return StatusCode::BAD_GATEWAY.into_response(),
|
||||
};
|
||||
|
||||
let invite = match body["items"].as_array().and_then(|arr| arr.first()) {
|
||||
Some(inv) => inv.clone(),
|
||||
None => {
|
||||
return (StatusCode::NOT_FOUND, "Invalid or already used invite code").into_response()
|
||||
}
|
||||
};
|
||||
|
||||
let invite_id = invite["id"].as_str().unwrap_or("");
|
||||
let invite_type = invite["invite_type"].as_str().unwrap_or("");
|
||||
|
||||
// Mark invite as used
|
||||
let now = {
|
||||
let dur = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default();
|
||||
dur.as_secs().to_string()
|
||||
};
|
||||
let _ = state
|
||||
.http_client
|
||||
.patch(&format!(
|
||||
"{pb_url}/api/collections/invites/records/{invite_id}"
|
||||
))
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({
|
||||
"used_by_id": user.id,
|
||||
"used_at": now,
|
||||
}))
|
||||
.send()
|
||||
.await;
|
||||
|
||||
if invite_type == "admin" {
|
||||
// Grant license directly
|
||||
let update_url = format!("{pb_url}/api/collections/users/records/{}", user.id);
|
||||
let res = state
|
||||
.http_client
|
||||
.patch(&update_url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({ "subscription": "licensed" }))
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
state.token_cache.invalidate_by_user_id(&user.id);
|
||||
info!(user_id = %user.id, code = %req.code, "Admin invite redeemed — user licensed");
|
||||
Json(RedeemResponse {
|
||||
result: "licensed".to_string(),
|
||||
checkout_url: None,
|
||||
})
|
||||
.into_response()
|
||||
}
|
||||
_ => {
|
||||
warn!("Failed to update user subscription for admin invite");
|
||||
StatusCode::BAD_GATEWAY.into_response()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Referral invite — create discounted checkout with dynamic pricing
|
||||
let count = match super::pricing::count_licensed_users(&state).await {
|
||||
Ok(c) => c,
|
||||
Err(err) => {
|
||||
warn!("Failed to count licensed users for invite checkout: {err}");
|
||||
return StatusCode::SERVICE_UNAVAILABLE.into_response();
|
||||
}
|
||||
};
|
||||
let price_pence = super::pricing::price_for_count(count);
|
||||
|
||||
let secret_key = &state.stripe_secret_key;
|
||||
let public_url = &state.public_url;
|
||||
let success_url = format!("{public_url}/pricing?license_success=1");
|
||||
let cancel_url = format!("{public_url}/pricing");
|
||||
|
||||
let form_params = vec![
|
||||
("mode", "payment".to_string()),
|
||||
(
|
||||
"line_items[0][price_data][unit_amount]",
|
||||
price_pence.to_string(),
|
||||
),
|
||||
("line_items[0][price_data][currency]", "gbp".to_string()),
|
||||
(
|
||||
"line_items[0][price_data][product_data][name]",
|
||||
"Perfect Postcodes Lifetime License".to_string(),
|
||||
),
|
||||
("line_items[0][quantity]", "1".to_string()),
|
||||
("success_url", success_url),
|
||||
("cancel_url", cancel_url),
|
||||
("client_reference_id", user.id.clone()),
|
||||
("customer_email", user.email.clone()),
|
||||
("discounts[0][coupon]", state.stripe_referral_coupon_id.clone()),
|
||||
];
|
||||
|
||||
let stripe_res = state
|
||||
.http_client
|
||||
.post("https://api.stripe.com/v1/checkout/sessions")
|
||||
.basic_auth(secret_key, None::<&str>)
|
||||
.form(&form_params)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match stripe_res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let stripe_body: serde_json::Value = resp.json().await.unwrap_or_default();
|
||||
let checkout_url = stripe_body["url"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
info!(user_id = %user.id, code = %req.code, "Referral invite redeemed — checkout created");
|
||||
Json(RedeemResponse {
|
||||
result: "checkout".to_string(),
|
||||
checkout_url: Some(checkout_url),
|
||||
})
|
||||
.into_response()
|
||||
}
|
||||
_ => {
|
||||
warn!("Failed to create Stripe checkout for referral invite");
|
||||
StatusCode::BAD_GATEWAY.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
64
server-rs/src/routes/newsletter.rs
Normal file
64
server-rs/src/routes/newsletter.rs
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::{Extension, Json};
|
||||
use serde::Deserialize;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::auth::OptionalUser;
|
||||
use crate::pocketbase::auth_superuser;
|
||||
use crate::state::AppState;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct UpdateNewsletterRequest {
|
||||
newsletter: bool,
|
||||
}
|
||||
|
||||
pub async fn patch_newsletter(
|
||||
state: Arc<AppState>,
|
||||
Extension(user): Extension<OptionalUser>,
|
||||
Json(req): Json<UpdateNewsletterRequest>,
|
||||
) -> Response {
|
||||
let user = match user.0 {
|
||||
Some(u) => u,
|
||||
None => return StatusCode::UNAUTHORIZED.into_response(),
|
||||
};
|
||||
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
|
||||
let token = match auth_superuser(&state.http_client, pb_url, &state.pocketbase_admin_email, &state.pocketbase_admin_password).await
|
||||
{
|
||||
Ok(t) => t,
|
||||
Err(err) => {
|
||||
warn!("Failed to authenticate as PocketBase superuser: {err}");
|
||||
return StatusCode::BAD_GATEWAY.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let url = format!("{pb_url}/api/collections/users/records/{}", user.id);
|
||||
let res = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({ "newsletter": req.newsletter }))
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
state.token_cache.invalidate_by_user_id(&user.id);
|
||||
StatusCode::OK.into_response()
|
||||
}
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
warn!("PocketBase user update failed ({status}): {text}");
|
||||
StatusCode::BAD_GATEWAY.into_response()
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("PocketBase request error: {err}");
|
||||
StatusCode::BAD_GATEWAY.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -23,9 +23,16 @@ pub async fn proxy_to_pocketbase(state: Arc<AppState>, req: Request) -> impl Int
|
|||
let method = req.method().clone();
|
||||
let mut builder = state.http_client.request(method, &url);
|
||||
|
||||
// Forward headers except host
|
||||
// Forward only safe headers (allowlist)
|
||||
const ALLOWED_HEADERS: &[&str] = &[
|
||||
"content-type",
|
||||
"accept",
|
||||
"authorization",
|
||||
"cookie",
|
||||
"accept-language",
|
||||
];
|
||||
for (name, value) in req.headers() {
|
||||
if name != "host" {
|
||||
if ALLOWED_HEADERS.contains(&name.as_str()) {
|
||||
builder = builder.header(name.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,11 +6,13 @@ use axum::response::Json;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use tracing::info;
|
||||
|
||||
use crate::data::slugify;
|
||||
use crate::state::AppState;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct PlaceResult {
|
||||
name: String,
|
||||
slug: String,
|
||||
place_type: String,
|
||||
lat: f32,
|
||||
lon: f32,
|
||||
|
|
@ -28,6 +30,8 @@ pub struct PlacesResponse {
|
|||
pub struct PlacesParams {
|
||||
q: String,
|
||||
limit: Option<usize>,
|
||||
/// If set, only return places that have travel time data for this mode.
|
||||
mode: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn get_places(
|
||||
|
|
@ -41,33 +45,44 @@ pub async fn get_places(
|
|||
};
|
||||
|
||||
let limit = params.limit.unwrap_or(7).min(20);
|
||||
let mode_filter = params.mode;
|
||||
|
||||
let places = tokio::task::spawn_blocking(move || {
|
||||
let t0 = std::time::Instant::now();
|
||||
let query_lower = query.to_lowercase();
|
||||
let pd = &state.place_data;
|
||||
let tt_store = &state.travel_time_store;
|
||||
|
||||
// Linear scan — ~50-100k rows, <1ms
|
||||
// Tuple: (row_idx, is_exact, is_prefix, type_rank, population, name_len)
|
||||
let mut matches: Vec<(usize, bool, bool, u8, u32, usize)> = pd
|
||||
// Tuple: (row_idx, is_exact, is_prefix, type_rank, population, name_len, slug)
|
||||
let mut matches: Vec<(usize, bool, bool, u8, u32, usize, String)> = pd
|
||||
.name_lower
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, name)| {
|
||||
if name.contains(&query_lower) {
|
||||
let is_exact = name.len() == query_lower.len();
|
||||
let is_prefix = name.starts_with(&query_lower);
|
||||
Some((
|
||||
idx,
|
||||
is_exact,
|
||||
is_prefix,
|
||||
pd.type_rank[idx],
|
||||
pd.population[idx],
|
||||
pd.name[idx].len(),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
if !name.contains(&query_lower) {
|
||||
return None;
|
||||
}
|
||||
let slug = slugify(&pd.name[idx]);
|
||||
|
||||
// If mode filter is set, only include places with travel data
|
||||
if let Some(ref mode) = mode_filter {
|
||||
if !tt_store.has_destination(mode, &slug) {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
let is_exact = name.len() == query_lower.len();
|
||||
let is_prefix = name.starts_with(&query_lower);
|
||||
Some((
|
||||
idx,
|
||||
is_exact,
|
||||
is_prefix,
|
||||
pd.type_rank[idx],
|
||||
pd.population[idx],
|
||||
pd.name[idx].len(),
|
||||
slug,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
|
@ -85,12 +100,13 @@ pub async fn get_places(
|
|||
|
||||
let results: Vec<PlaceResult> = matches
|
||||
.iter()
|
||||
.map(|&(idx, ..)| PlaceResult {
|
||||
name: pd.name[idx].clone(),
|
||||
place_type: pd.place_type.get(idx).to_string(),
|
||||
lat: pd.lat[idx],
|
||||
lon: pd.lon[idx],
|
||||
city: pd.city[idx].clone(),
|
||||
.map(|(idx, .., slug)| PlaceResult {
|
||||
name: pd.name[*idx].clone(),
|
||||
slug: slug.clone(),
|
||||
place_type: pd.place_type.get(*idx).to_string(),
|
||||
lat: pd.lat[*idx],
|
||||
lon: pd.lon[*idx],
|
||||
city: pd.city[*idx].clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
|
@ -99,6 +115,7 @@ pub async fn get_places(
|
|||
query = query.as_str(),
|
||||
results = results.len(),
|
||||
scanned = pd.name_lower.len(),
|
||||
mode = mode_filter.as_deref().unwrap_or("-"),
|
||||
ms = format_args!("{:.1}", elapsed.as_secs_f64() * 1000.0),
|
||||
"GET /api/places"
|
||||
);
|
||||
|
|
|
|||
|
|
@ -2,11 +2,14 @@ use std::sync::Arc;
|
|||
|
||||
use axum::extract::Query;
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::Json;
|
||||
use axum::response::{IntoResponse, Json};
|
||||
use axum::Extension;
|
||||
use serde::Deserialize;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::auth::OptionalUser;
|
||||
use crate::consts::POSTCODE_SEARCH_OFFSET;
|
||||
use crate::licensing::check_license_point;
|
||||
use crate::parsing::{parse_field_set, parse_filters, row_passes_filters};
|
||||
use crate::state::AppState;
|
||||
|
||||
|
|
@ -24,8 +27,9 @@ pub struct PostcodeStatsParams {
|
|||
|
||||
pub async fn get_postcode_stats(
|
||||
state: Arc<AppState>,
|
||||
Extension(user): Extension<OptionalUser>,
|
||||
Query(params): Query<PostcodeStatsParams>,
|
||||
) -> Result<Json<HexagonStatsResponse>, (StatusCode, String)> {
|
||||
) -> Result<Json<HexagonStatsResponse>, axum::response::Response> {
|
||||
// Normalize postcode: uppercase, collapse whitespace
|
||||
let normalized = params
|
||||
.postcode
|
||||
|
|
@ -42,18 +46,23 @@ pub async fn get_postcode_stats(
|
|||
return Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
format!("Postcode not found: {}", normalized),
|
||||
));
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
let (centroid_lat, centroid_lon) = state.postcode_data.centroids[pc_idx];
|
||||
|
||||
// License check using postcode centroid
|
||||
check_license_point(&user.0, centroid_lat as f64, centroid_lon as f64)
|
||||
.map_err(|(_, resp)| resp)?;
|
||||
|
||||
let filters_str = params.filters.clone();
|
||||
let (parsed_filters, parsed_enum_filters) = parse_filters(
|
||||
params.filters.as_deref(),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
)
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err))?;
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
|
||||
|
||||
let (fields_specified, field_set) = parse_field_set(params.fields.as_deref());
|
||||
|
|
@ -129,8 +138,8 @@ pub async fn get_postcode_stats(
|
|||
})
|
||||
})
|
||||
.await
|
||||
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()))?
|
||||
.map_err(|error: String| (StatusCode::INTERNAL_SERVER_ERROR, error))?;
|
||||
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()).into_response())?
|
||||
.map_err(|error: String| (StatusCode::INTERNAL_SERVER_ERROR, error).into_response())?;
|
||||
|
||||
Ok(Json(response))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,14 +2,17 @@ use std::sync::Arc;
|
|||
|
||||
use axum::extract::{Path, Query};
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::Json;
|
||||
use axum::response::{IntoResponse, Json};
|
||||
use axum::Extension;
|
||||
use rustc_hash::FxHashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Map, Value};
|
||||
use tracing::info;
|
||||
|
||||
use crate::aggregation::Aggregator;
|
||||
use crate::auth::OptionalUser;
|
||||
use crate::consts::MAX_CELLS_PER_REQUEST;
|
||||
use crate::licensing::check_license_bounds;
|
||||
use crate::parsing::{
|
||||
bounds_intersect, parse_field_indices, parse_filters, require_bounds, row_passes_filters,
|
||||
};
|
||||
|
|
@ -60,9 +63,14 @@ fn build_postcode_geometry(rings: &[Vec<[f32; 2]>]) -> Value {
|
|||
|
||||
pub async fn get_postcodes(
|
||||
state: Arc<AppState>,
|
||||
Extension(user): Extension<OptionalUser>,
|
||||
Query(params): Query<PostcodeParams>,
|
||||
) -> Result<Json<PostcodesResponse>, (StatusCode, String)> {
|
||||
let (south, west, north, east) = require_bounds(params.bounds)?;
|
||||
) -> Result<Json<PostcodesResponse>, axum::response::Response> {
|
||||
let (south, west, north, east) =
|
||||
require_bounds(params.bounds).map_err(IntoResponse::into_response)?;
|
||||
|
||||
check_license_bounds(&user.0, (south, west, north, east))
|
||||
.map_err(|(_, resp)| resp)?;
|
||||
|
||||
let filters_str = params.filters.clone();
|
||||
let (parsed_filters, parsed_enum_filters) = parse_filters(
|
||||
|
|
@ -70,10 +78,11 @@ pub async fn get_postcodes(
|
|||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
)
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err))?;
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
|
||||
|
||||
let field_indices = parse_field_indices(params.fields.as_deref(), &state.feature_name_to_index);
|
||||
let field_indices = parse_field_indices(params.fields.as_deref(), &state.feature_name_to_index)
|
||||
.map_err(|err| (err.0, err.1).into_response())?;
|
||||
|
||||
let response = tokio::task::spawn_blocking(move || -> Result<PostcodesResponse, String> {
|
||||
let postcode_data = &state.postcode_data;
|
||||
|
|
@ -222,7 +231,7 @@ pub async fn get_postcodes(
|
|||
}
|
||||
}
|
||||
|
||||
let truncated = features.len() >= MAX_CELLS_PER_REQUEST;
|
||||
let truncated = features.len() > MAX_CELLS_PER_REQUEST;
|
||||
let t_total = t0.elapsed();
|
||||
info!(
|
||||
postcodes_before_filter,
|
||||
|
|
@ -242,8 +251,8 @@ pub async fn get_postcodes(
|
|||
})
|
||||
})
|
||||
.await
|
||||
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()))?
|
||||
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error))?;
|
||||
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()).into_response())?
|
||||
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error).into_response())?;
|
||||
|
||||
Ok(Json(response))
|
||||
}
|
||||
|
|
|
|||
105
server-rs/src/routes/pricing.rs
Normal file
105
server-rs/src/routes/pricing.rs
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::Json;
|
||||
use serde::Serialize;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::pocketbase::auth_superuser;
|
||||
use crate::state::AppState;
|
||||
|
||||
/// Pricing tiers: (cumulative user cap, price in pence).
|
||||
const TIERS: &[(u64, u64)] = &[
|
||||
(10, 0), // First 10 users: free
|
||||
(20, 1000), // Next 10: £10
|
||||
(45, 2500), // Next 25: £25
|
||||
(95, 5000), // Next 50: £50
|
||||
];
|
||||
const FINAL_PRICE_PENCE: u64 = 10000; // £100 after 95
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct Tier {
|
||||
up_to: Option<u64>,
|
||||
price_pence: u64,
|
||||
slots: u64,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct PricingResponse {
|
||||
licensed_count: u64,
|
||||
current_price_pence: u64,
|
||||
tiers: Vec<Tier>,
|
||||
}
|
||||
|
||||
/// Determine the price (in pence) for the next user given `count` existing licensed users.
|
||||
pub fn price_for_count(count: u64) -> u64 {
|
||||
for &(cap, price) in TIERS {
|
||||
if count < cap {
|
||||
return price;
|
||||
}
|
||||
}
|
||||
FINAL_PRICE_PENCE
|
||||
}
|
||||
|
||||
/// Count users with subscription="licensed" in PocketBase.
|
||||
pub async fn count_licensed_users(state: &AppState) -> anyhow::Result<u64> {
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let token = auth_superuser(&state.http_client, pb_url, &state.pocketbase_admin_email, &state.pocketbase_admin_password).await?;
|
||||
|
||||
let filter = "subscription=\"licensed\"";
|
||||
let url = format!(
|
||||
"{pb_url}/api/collections/users/records?filter={}&perPage=1",
|
||||
urlencoding::encode(filter)
|
||||
);
|
||||
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
anyhow::bail!("PocketBase returned {}", resp.status());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let total = body["totalItems"].as_u64().unwrap_or(0);
|
||||
Ok(total)
|
||||
}
|
||||
|
||||
pub async fn get_pricing(state: Arc<AppState>) -> Response {
|
||||
let count = match count_licensed_users(&state).await {
|
||||
Ok(c) => c,
|
||||
Err(err) => {
|
||||
warn!("Failed to count licensed users: {err}");
|
||||
return StatusCode::SERVICE_UNAVAILABLE.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let current_price = price_for_count(count);
|
||||
|
||||
let mut tiers = Vec::new();
|
||||
let mut prev_cap = 0u64;
|
||||
for &(cap, price) in TIERS {
|
||||
tiers.push(Tier {
|
||||
up_to: Some(cap),
|
||||
price_pence: price,
|
||||
slots: cap - prev_cap,
|
||||
});
|
||||
prev_cap = cap;
|
||||
}
|
||||
tiers.push(Tier {
|
||||
up_to: None,
|
||||
price_pence: FINAL_PRICE_PENCE,
|
||||
slots: 0,
|
||||
});
|
||||
|
||||
Json(PricingResponse {
|
||||
licensed_count: count,
|
||||
current_price_pence: current_price,
|
||||
tiers,
|
||||
})
|
||||
.into_response()
|
||||
}
|
||||
|
|
@ -3,12 +3,15 @@ use std::sync::Arc;
|
|||
|
||||
use axum::extract::Query;
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::Json;
|
||||
use axum::response::{IntoResponse, Json};
|
||||
use axum::Extension;
|
||||
use rustc_hash::FxHashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::auth::OptionalUser;
|
||||
use crate::consts::{DEFAULT_PROPERTIES_LIMIT, MAX_PROPERTIES_LIMIT};
|
||||
use crate::licensing::check_license_bounds;
|
||||
use crate::parsing::{
|
||||
cell_for_row, h3_cell_bounds, needs_parent, parse_filters, row_passes_filters,
|
||||
validate_h3_resolution,
|
||||
|
|
@ -90,19 +93,25 @@ fn lookup_enum_value(
|
|||
|
||||
pub async fn get_hexagon_properties(
|
||||
state: Arc<AppState>,
|
||||
Extension(user): Extension<OptionalUser>,
|
||||
Query(params): Query<HexagonPropertiesParams>,
|
||||
) -> Result<Json<HexagonPropertiesResponse>, (StatusCode, String)> {
|
||||
) -> Result<Json<HexagonPropertiesResponse>, axum::response::Response> {
|
||||
let cell = h3o::CellIndex::from_str(¶ms.h3).map_err(|error| {
|
||||
warn!(h3 = %params.h3, error = %error, "Invalid H3 cell index");
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Invalid H3 cell: {}", error),
|
||||
)
|
||||
.into_response()
|
||||
})?;
|
||||
let cell_u64: u64 = cell.into();
|
||||
|
||||
let resolution = params.resolution;
|
||||
validate_h3_resolution(resolution)?;
|
||||
validate_h3_resolution(resolution).map_err(IntoResponse::into_response)?;
|
||||
|
||||
// License check using H3 cell bounds
|
||||
let h3_bounds = h3_cell_bounds(cell, 0.0);
|
||||
check_license_bounds(&user.0, h3_bounds).map_err(|(_, resp)| resp)?;
|
||||
|
||||
let h3_str = params.h3.clone();
|
||||
let filters_str = params.filters.clone();
|
||||
|
|
@ -111,7 +120,7 @@ pub async fn get_hexagon_properties(
|
|||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
)
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err))?;
|
||||
.map_err(|err| (StatusCode::BAD_REQUEST, err).into_response())?;
|
||||
let num_filters = parsed_filters.len() + parsed_enum_filters.len();
|
||||
|
||||
let result = tokio::task::spawn_blocking(move || {
|
||||
|
|
@ -249,8 +258,8 @@ pub async fn get_hexagon_properties(
|
|||
})
|
||||
})
|
||||
.await
|
||||
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()))?
|
||||
.map_err(|error: String| (StatusCode::INTERNAL_SERVER_ERROR, error))?;
|
||||
.map_err(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()).into_response())?
|
||||
.map_err(|error: String| (StatusCode::INTERNAL_SERVER_ERROR, error).into_response())?;
|
||||
|
||||
Ok(Json(result))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use rustc_hash::FxHashMap;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::consts::MAX_PRICE_HISTORY_POINTS;
|
||||
use crate::data::FeatureStats;
|
||||
|
|
@ -78,6 +79,13 @@ pub fn compute_feature_stats(
|
|||
let idx = value as usize;
|
||||
if idx < value_counts.len() {
|
||||
value_counts[idx] += 1;
|
||||
} else {
|
||||
warn!(
|
||||
feature = feature_name.as_str(),
|
||||
idx,
|
||||
max = value_counts.len(),
|
||||
"Enum index out of bounds — possible data/schema mismatch"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
129
server-rs/src/routes/stripe_webhook.rs
Normal file
129
server-rs/src/routes/stripe_webhook.rs
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use axum::body::Bytes;
|
||||
use axum::http::{HeaderMap, StatusCode};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::pocketbase::auth_superuser;
|
||||
use crate::state::AppState;
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
/// Verify Stripe webhook signature (v1 scheme).
|
||||
fn verify_signature(payload: &[u8], sig_header: &str, secret: &str) -> bool {
|
||||
// Parse timestamp and signature from header: "t=TIMESTAMP,v1=SIGNATURE"
|
||||
let mut timestamp = None;
|
||||
let mut signature = None;
|
||||
for part in sig_header.split(',') {
|
||||
if let Some(ts) = part.strip_prefix("t=") {
|
||||
timestamp = Some(ts);
|
||||
} else if let Some(sig) = part.strip_prefix("v1=") {
|
||||
signature = Some(sig);
|
||||
}
|
||||
}
|
||||
|
||||
let (ts, sig_hex) = match (timestamp, signature) {
|
||||
(Some(t), Some(s)) => (t, s),
|
||||
_ => return false,
|
||||
};
|
||||
|
||||
// Compute expected signature: HMAC-SHA256(secret, "TIMESTAMP.PAYLOAD")
|
||||
let signed_payload = format!("{ts}.{}", String::from_utf8_lossy(payload));
|
||||
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) {
|
||||
Ok(m) => m,
|
||||
Err(_) => return false,
|
||||
};
|
||||
mac.update(signed_payload.as_bytes());
|
||||
|
||||
// Decode the provided hex signature and verify with constant-time comparison
|
||||
let sig_bytes = match hex::decode(sig_hex) {
|
||||
Ok(bytes) => bytes,
|
||||
Err(_) => return false,
|
||||
};
|
||||
mac.verify_slice(&sig_bytes).is_ok()
|
||||
}
|
||||
|
||||
/// Handle Stripe webhook events.
|
||||
/// On `checkout.session.completed`, updates the user's subscription to "licensed".
|
||||
pub async fn post_stripe_webhook(
|
||||
state: Arc<AppState>,
|
||||
headers: HeaderMap,
|
||||
body: Bytes,
|
||||
) -> Response {
|
||||
let webhook_secret = &state.stripe_webhook_secret;
|
||||
|
||||
let sig_header = match headers.get("stripe-signature").and_then(|h| h.to_str().ok()) {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
warn!("Missing Stripe-Signature header");
|
||||
return StatusCode::BAD_REQUEST.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if !verify_signature(&body, sig_header, webhook_secret) {
|
||||
warn!("Invalid Stripe webhook signature");
|
||||
return StatusCode::BAD_REQUEST.into_response();
|
||||
}
|
||||
|
||||
let event: serde_json::Value = match serde_json::from_slice(&body) {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
warn!("Failed to parse webhook body: {err}");
|
||||
return StatusCode::BAD_REQUEST.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let event_type = event["type"].as_str().unwrap_or("");
|
||||
info!(event_type, "Received Stripe webhook");
|
||||
|
||||
if event_type == "checkout.session.completed" {
|
||||
let user_id = event["data"]["object"]["client_reference_id"]
|
||||
.as_str()
|
||||
.unwrap_or("");
|
||||
if user_id.is_empty() {
|
||||
warn!("checkout.session.completed missing client_reference_id");
|
||||
return StatusCode::OK.into_response();
|
||||
}
|
||||
|
||||
// Update user subscription to "licensed" via PocketBase superuser auth
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let token = match auth_superuser(&state.http_client, pb_url, &state.pocketbase_admin_email, &state.pocketbase_admin_password)
|
||||
.await
|
||||
{
|
||||
Ok(t) => t,
|
||||
Err(err) => {
|
||||
warn!("Failed to auth as PocketBase superuser in webhook: {err}");
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
|
||||
let res = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({ "subscription": "licensed" }))
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
state.token_cache.invalidate_by_user_id(user_id);
|
||||
info!(user_id, "User subscription updated to licensed via Stripe webhook");
|
||||
}
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
warn!(user_id, "Failed to update user subscription ({status}): {text}");
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(user_id, "PocketBase request error in webhook: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
StatusCode::OK.into_response()
|
||||
}
|
||||
78
server-rs/src/routes/subscription.rs
Normal file
78
server-rs/src/routes/subscription.rs
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::{Extension, Json};
|
||||
use serde::Deserialize;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::auth::OptionalUser;
|
||||
use crate::pocketbase::auth_superuser;
|
||||
use crate::state::AppState;
|
||||
|
||||
const VALID_SUBSCRIPTIONS: &[&str] = &["free", "licensed"];
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct UpdateSubscriptionRequest {
|
||||
subscription: String,
|
||||
}
|
||||
|
||||
pub async fn patch_subscription(
|
||||
state: Arc<AppState>,
|
||||
Extension(user): Extension<OptionalUser>,
|
||||
Json(req): Json<UpdateSubscriptionRequest>,
|
||||
) -> Response {
|
||||
let user = match user.0 {
|
||||
Some(u) => u,
|
||||
None => return StatusCode::UNAUTHORIZED.into_response(),
|
||||
};
|
||||
|
||||
if !user.is_admin {
|
||||
return StatusCode::FORBIDDEN.into_response();
|
||||
}
|
||||
|
||||
if !VALID_SUBSCRIPTIONS.contains(&req.subscription.as_str()) {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Invalid subscription: {}", req.subscription),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
|
||||
let token = match auth_superuser(&state.http_client, pb_url, &state.pocketbase_admin_email, &state.pocketbase_admin_password).await
|
||||
{
|
||||
Ok(t) => t,
|
||||
Err(err) => {
|
||||
warn!("Failed to authenticate as PocketBase superuser: {err}");
|
||||
return StatusCode::BAD_GATEWAY.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let url = format!("{pb_url}/api/collections/users/records/{}", user.id);
|
||||
let res = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&serde_json::json!({ "subscription": req.subscription }))
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
state.token_cache.invalidate_by_user_id(&user.id);
|
||||
StatusCode::OK.into_response()
|
||||
}
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
warn!("PocketBase user update failed ({status}): {text}");
|
||||
StatusCode::BAD_GATEWAY.into_response()
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("PocketBase request error: {err}");
|
||||
StatusCode::BAD_GATEWAY.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use axum::extract::{Path, Query, State};
|
||||
use axum::http::{header, HeaderMap, StatusCode};
|
||||
use axum::http::{header, StatusCode};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use pmtiles::async_reader::AsyncPmTilesReader;
|
||||
use pmtiles::MmapBackend;
|
||||
|
|
@ -40,7 +40,7 @@ pub struct StyleParams {
|
|||
|
||||
pub async fn get_style(
|
||||
State(reader): State<Arc<TileReader>>,
|
||||
headers: HeaderMap,
|
||||
public_url: String,
|
||||
Query(params): Query<StyleParams>,
|
||||
) -> Result<Response, (StatusCode, String)> {
|
||||
let is_dark = params.theme.as_deref() == Some("dark");
|
||||
|
|
@ -50,7 +50,7 @@ pub async fn get_style(
|
|||
warn!(error = %err, "Failed to get PMTiles metadata");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to get PMTiles metadata: {err}"),
|
||||
"Failed to read tile metadata".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
|
|
@ -59,7 +59,7 @@ pub async fn get_style(
|
|||
warn!(error = %err, "Failed to parse PMTiles metadata JSON");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to parse PMTiles metadata: {err}"),
|
||||
"Failed to parse tile metadata".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
|
|
@ -70,15 +70,8 @@ pub async fn get_style(
|
|||
.cloned()
|
||||
.unwrap_or_default();
|
||||
|
||||
// Build absolute tile URL using the request host
|
||||
let host = headers
|
||||
.get(header::HOST)
|
||||
.and_then(|hv| hv.to_str().ok())
|
||||
.ok_or((
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Missing Host header".into(),
|
||||
))?;
|
||||
let tile_url = format!("http://{}/api/tiles/{{z}}/{{x}}/{{y}}", host);
|
||||
// Build absolute tile URL using the configured public URL (not the Host header)
|
||||
let tile_url = format!("{}/api/tiles/{{z}}/{{x}}/{{y}}", public_url.trim_end_matches('/'));
|
||||
let style = build_style(is_dark, &layers, &tile_url);
|
||||
|
||||
Ok((
|
||||
|
|
|
|||
38
server-rs/src/routes/travel_modes.rs
Normal file
38
server-rs/src/routes/travel_modes.rs
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::Json;
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct TravelModeInfo {
|
||||
mode: String,
|
||||
destinations: usize,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct TravelModesResponse {
|
||||
modes: Vec<TravelModeInfo>,
|
||||
}
|
||||
|
||||
pub async fn get_travel_modes(
|
||||
state: Arc<AppState>,
|
||||
) -> Result<Json<TravelModesResponse>, (StatusCode, String)> {
|
||||
let store = &state.travel_time_store;
|
||||
let modes = store
|
||||
.available_modes
|
||||
.iter()
|
||||
.map(|mode| TravelModeInfo {
|
||||
mode: mode.clone(),
|
||||
destinations: store
|
||||
.destinations
|
||||
.get(mode)
|
||||
.map(|slugs| slugs.len())
|
||||
.unwrap_or(0),
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(Json(TravelModesResponse { modes }))
|
||||
}
|
||||
|
|
@ -1,72 +1,30 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use tracing::warn;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct R5Request {
|
||||
origin: [f64; 2],
|
||||
destinations: Vec<[f64; 2]>,
|
||||
mode: String,
|
||||
/// Per-hex-cell travel time aggregation.
|
||||
pub struct TravelTimeAgg {
|
||||
pub min: f32,
|
||||
pub max: f32,
|
||||
pub sum: f64,
|
||||
pub count: u32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct R5Response {
|
||||
travel_times: Vec<f64>,
|
||||
}
|
||||
|
||||
/// Call the R5 Java service to compute one-to-many travel times.
|
||||
///
|
||||
/// `origins` are hex centroids as `[lat, lon]`.
|
||||
/// `destination` is the user-chosen point as `[lat, lon]`.
|
||||
/// `mode` is one of "car", "bicycle", "walking", "transit".
|
||||
///
|
||||
/// R5 computes from destination to all origins (one-to-many from the user's chosen point).
|
||||
/// Returns a Vec of travel times in minutes (one per origin), with None for unreachable.
|
||||
pub async fn fetch_travel_times(
|
||||
client: &reqwest::Client,
|
||||
r5_url: &str,
|
||||
origins: Vec<[f64; 2]>,
|
||||
destination: [f64; 2],
|
||||
mode: &str,
|
||||
) -> Result<Vec<Option<f64>>, String> {
|
||||
if origins.is_empty() {
|
||||
return Ok(vec![]);
|
||||
impl TravelTimeAgg {
|
||||
pub fn new() -> Self {
|
||||
TravelTimeAgg {
|
||||
min: f32::INFINITY,
|
||||
max: f32::NEG_INFINITY,
|
||||
sum: 0.0,
|
||||
count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
let body = R5Request {
|
||||
origin: destination,
|
||||
destinations: origins,
|
||||
mode: mode.to_string(),
|
||||
};
|
||||
|
||||
let resp = client
|
||||
.post(format!("{}/travel-times", r5_url))
|
||||
.json(&body)
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
warn!("R5 request failed: {}", e);
|
||||
format!("R5 routing error: {}", e)
|
||||
})?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
warn!("R5 returned {}: {}", status, body);
|
||||
return Err(format!("R5 returned {}: {}", status, body));
|
||||
#[inline]
|
||||
pub fn add(&mut self, value: f32) {
|
||||
if value < self.min {
|
||||
self.min = value;
|
||||
}
|
||||
if value > self.max {
|
||||
self.max = value;
|
||||
}
|
||||
self.sum += value as f64;
|
||||
self.count += 1;
|
||||
}
|
||||
|
||||
let r5_resp: R5Response = resp.json().await.map_err(|e| {
|
||||
warn!("Failed to parse R5 response: {}", e);
|
||||
format!("Failed to parse R5 response: {}", e)
|
||||
})?;
|
||||
|
||||
// R5 returns -1 for unreachable destinations
|
||||
let travel_times: Vec<Option<f64>> = r5_resp
|
||||
.travel_times
|
||||
.into_iter()
|
||||
.map(|t| if t < 0.0 { None } else { Some(t) })
|
||||
.collect();
|
||||
|
||||
Ok(travel_times)
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue