Shared state

This commit is contained in:
Andras Schmelczer 2026-03-17 21:08:32 +00:00
parent 53fff3efaa
commit 15fa09430b
25 changed files with 174 additions and 215 deletions

View file

@ -422,157 +422,78 @@ async fn main() -> anyhow::Result<()> {
.allow_headers(AllowHeaders::mirror_request())
.allow_credentials(true);
// Each route closure captures a clone of `shared` and calls `load_state()`
// at request time to get the latest `Arc<AppState>`. This enables hot-reload:
// Handlers use Axum's State extractor to get Arc<SharedState>, then call
// load_state() to get the current Arc<AppState>. This enables hot-reload:
// the reload endpoint swaps in a new AppState, and subsequent requests pick it up.
macro_rules! s {
() => {
shared.clone()
};
}
let (s1, s2, s3, s4, s5, s6) = (s!(), s!(), s!(), s!(), s!(), s!());
let (s7, s8, s9, s10, s11, s12) = (s!(), s!(), s!(), s!(), s!(), s!());
let (s13, s14, s15, s16, s17, s18) = (s!(), s!(), s!(), s!(), s!(), s!());
let (s19, s20, s21, s22, s23, s24) = (s!(), s!(), s!(), s!(), s!(), s!());
let (s25, s26, s27, s28, s29) = (s!(), s!(), s!(), s!(), s!());
let s_crawler = shared.clone();
let s_pb = shared.clone();
let s_reload = shared.clone();
let api = Router::new()
.route(
"/api/features",
get(move || routes::get_features(s1.load_state())),
)
.route(
"/api/hexagons",
get(move |ext, query| routes::get_hexagons(s2.load_state(), ext, query)),
)
.route(
"/api/postcodes",
get(move |ext, query| routes::get_postcodes(s3.load_state(), ext, query)),
)
.route(
"/api/postcode/{postcode}",
get(move |path| routes::get_postcode_lookup(s4.load_state(), path)),
)
.route(
"/api/pois",
get(move |query| routes::get_pois(s5.load_state(), query)),
)
.route(
"/api/poi-categories",
get(move || routes::get_poi_categories(s6.load_state())),
)
.route(
"/api/places",
get(move |query| routes::get_places(s7.load_state(), query)),
)
.route(
"/api/travel-modes",
get(move || routes::get_travel_modes(s8.load_state())),
)
.route(
"/api/travel-destinations",
get(move |query| routes::get_travel_destinations(s9.load_state(), query)),
)
.route(
"/api/journey",
get(move |query| routes::get_journey(s10.load_state(), query)),
)
.route(
"/api/hexagon-properties",
get(move |ext, query| routes::get_hexagon_properties(s11.load_state(), ext, query)),
)
.route(
"/api/hexagon-stats",
get(move |ext, query| routes::get_hexagon_stats(s12.load_state(), ext, query)),
)
.route(
"/api/postcode-stats",
get(move |ext, query| routes::get_postcode_stats(s13.load_state(), ext, query)),
)
.route(
"/api/postcode-properties",
get(move |ext, query| routes::get_postcode_properties(s14.load_state(), ext, query)),
)
.route(
"/api/screenshot",
get(move |headers, query| routes::get_screenshot(s15.load_state(), headers, query)),
)
.route(
"/api/export",
get(move |headers, ext, query| {
routes::get_export(s16.load_state(), headers, ext, query)
})
.layer(ConcurrencyLimitLayer::new(3)),
)
.route("/api/me", get(routes::get_me))
.route(
"/api/shorten",
post(move |body| routes::post_shorten(s17.load_state(), body)),
)
.route(
"/api/ai-filters",
post(move |ext, body| routes::post_ai_filters(s18.load_state(), ext, body))
.layer(ConcurrencyLimitLayer::new(5)),
)
.route(
"/api/streetview",
get(move |query| routes::get_streetview(s19.load_state(), query)),
)
.route(
"/api/newsletter",
patch(move |ext, body| routes::patch_newsletter(s20.load_state(), ext, body)),
)
.route(
"/api/pricing",
get(move || routes::get_pricing(s21.load_state())),
)
.route(
"/api/checkout",
post(move |ext, body| routes::post_checkout(s22.load_state(), ext, body))
.layer(ConcurrencyLimitLayer::new(10)),
)
.route(
"/api/stripe-webhook",
post(move |headers, body| routes::post_stripe_webhook(s23.load_state(), headers, body)),
)
.route(
"/api/invites",
get(move |ext| routes::get_invites(s24.load_state(), ext))
.post(move |ext, body| routes::post_invites(s25.load_state(), ext, body)),
)
.route(
"/api/invite/{code}",
get(move |ext, path| routes::get_invite(s26.load_state(), ext, path)),
)
.route(
"/api/redeem-invite",
post(move |ext, body| routes::post_redeem_invite(s27.load_state(), ext, body)),
)
.route(
"/s/{code}",
get(move |path| routes::get_short_url(s28.load_state(), path)),
)
.route(
"/api/telemetry",
post(move |ext, headers, body| {
let _ = s29.load_state();
routes::post_telemetry(ext, headers, body)
}),
)
.route(
"/api/reload",
post(move || routes::post_reload(s_reload.clone())),
);
// Add tile routes
let reader_tile = tile_reader.clone();
let reader_style = tile_reader.clone();
let public_url_tiles = initial_state.public_url.clone();
let api = api
let api = Router::new()
.route("/api/features", get(routes::get_features))
.route("/api/hexagons", get(routes::get_hexagons))
.route("/api/postcodes", get(routes::get_postcodes))
.route(
"/api/postcode/{postcode}",
get(routes::get_postcode_lookup),
)
.route("/api/pois", get(routes::get_pois))
.route("/api/poi-categories", get(routes::get_poi_categories))
.route("/api/places", get(routes::get_places))
.route("/api/travel-modes", get(routes::get_travel_modes))
.route(
"/api/travel-destinations",
get(routes::get_travel_destinations),
)
.route("/api/journey", get(routes::get_journey))
.route(
"/api/hexagon-properties",
get(routes::get_hexagon_properties),
)
.route("/api/hexagon-stats", get(routes::get_hexagon_stats))
.route("/api/postcode-stats", get(routes::get_postcode_stats))
.route(
"/api/postcode-properties",
get(routes::get_postcode_properties),
)
.route("/api/screenshot", get(routes::get_screenshot))
.route(
"/api/export",
get(routes::get_export).layer(ConcurrencyLimitLayer::new(3)),
)
.route("/api/me", get(routes::get_me))
.route("/api/shorten", post(routes::post_shorten))
.route(
"/api/ai-filters",
post(routes::post_ai_filters).layer(ConcurrencyLimitLayer::new(5)),
)
.route("/api/streetview", get(routes::get_streetview))
.route("/api/newsletter", patch(routes::patch_newsletter))
.route("/api/pricing", get(routes::get_pricing))
.route(
"/api/checkout",
post(routes::post_checkout).layer(ConcurrencyLimitLayer::new(10)),
)
.route(
"/api/stripe-webhook",
post(routes::post_stripe_webhook),
)
.route(
"/api/invites",
get(routes::get_invites).post(routes::post_invites),
)
.route("/api/invite/{code}", get(routes::get_invite))
.route("/api/redeem-invite", post(routes::post_redeem_invite))
.route("/s/{code}", get(routes::get_short_url))
.route("/api/telemetry", post(routes::post_telemetry))
.route("/api/reload", post(routes::post_reload))
.route(
"/pb/{*rest}",
any(routes::proxy_to_pocketbase),
)
// Tile routes use a different state type — kept as closures
.route(
"/api/tiles/{z}/{x}/{y}",
get(move |path| routes::get_tile(axum::extract::State(reader_tile.clone()), path)),
@ -589,10 +510,7 @@ async fn main() -> anyhow::Result<()> {
"/metrics",
get(move || metrics::metrics_handler(metrics_handle.clone())),
)
.route(
"/pb/{*rest}",
any(move |req| routes::proxy_to_pocketbase(s_pb.load_state(), req)),
);
.with_state(shared.clone());
let app = if let Some(ref dist) = cli.dist {
api.fallback_service(ServeDir::new(dist).fallback(ServeFile::new(dist.join("index.html"))))

View file

@ -1,5 +1,6 @@
use std::sync::Arc;
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::Json;
use axum::Extension;
@ -13,7 +14,7 @@ use crate::consts::{AI_FILTERS_MAX_TOKENS, AI_FILTERS_TEMPERATURE, AI_FILTERS_WE
use crate::data::slugify;
use crate::pocketbase::auth_superuser;
use crate::routes::{FeatureInfo, FeaturesResponse};
use crate::state::AppState;
use crate::state::{AppState, SharedState};
use crate::utils::gemini_chat;
#[derive(Deserialize)]
@ -501,10 +502,11 @@ async fn update_ai_usage(state: &AppState, user_id: &str, tokens_used: u64, week
const MAX_TOOL_ROUNDS: usize = 5;
pub async fn post_ai_filters(
state: Arc<AppState>,
State(shared): State<Arc<SharedState>>,
Extension(user): Extension<OptionalUser>,
Json(req): Json<AiFiltersRequest>,
) -> Result<Json<AiFiltersResponse>, (StatusCode, String)> {
let state = shared.load_state();
// Auth check
let user = user
.0

View file

@ -1,5 +1,6 @@
use std::sync::Arc;
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::{Extension, Json};
@ -8,7 +9,7 @@ use tracing::{info, warn};
use crate::auth::OptionalUser;
use crate::pocketbase::auth_superuser;
use crate::state::AppState;
use crate::state::{AppState, SharedState};
use super::pricing::{count_licensed_users, price_for_count};
@ -25,10 +26,11 @@ struct CheckoutResponse {
/// Create a Stripe Checkout session for the lifetime license (or grant for free if in free tier).
/// Requires authentication. Optionally accepts a referral code to apply a coupon.
pub async fn post_checkout(
state: Arc<AppState>,
State(shared): State<Arc<SharedState>>,
Extension(user): Extension<OptionalUser>,
Json(req): Json<CheckoutRequest>,
) -> Response {
let state = shared.load_state();
let user = match user.0 {
Some(u) => u,
None => return StatusCode::UNAUTHORIZED.into_response(),

View file

@ -2,7 +2,7 @@ use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use axum::extract::Query;
use axum::extract::{Query, State};
use axum::http::{header, HeaderMap, StatusCode};
use axum::response::IntoResponse;
use axum::Extension;
@ -17,7 +17,7 @@ use crate::data::QuantRef;
use crate::licensing::check_license_bounds;
use crate::parsing::{parse_field_indices, parse_filters, require_bounds, row_passes_filters};
use crate::routes::{fetch_screenshot_bytes, FeatureInfo};
use crate::state::AppState;
use crate::state::SharedState;
const MAX_EXPORT_POSTCODES: usize = 250;
/// Height (in pixels) reserved for the screenshot row
@ -125,11 +125,12 @@ fn build_frontend_params(
}
pub async fn get_export(
state: Arc<AppState>,
State(shared): State<Arc<SharedState>>,
headers: HeaderMap,
Extension(user): Extension<OptionalUser>,
Query(params): Query<ExportParams>,
) -> Result<impl IntoResponse, axum::response::Response> {
let state = shared.load_state();
let (south, west, north, east) =
require_bounds(params.bounds).map_err(IntoResponse::into_response)?;

View file

@ -1,12 +1,13 @@
use std::sync::Arc;
use axum::extract::State;
use axum::response::Json;
use serde::Serialize;
use tracing::info;
use crate::data::{Histogram, PropertyData};
use crate::features::{ENUM_FEATURE_GROUPS, FEATURE_GROUPS};
use crate::state::AppState;
use crate::state::SharedState;
fn is_empty(val: &str) -> bool {
val.is_empty()
@ -154,7 +155,8 @@ pub fn build_features_response(data: &PropertyData) -> FeaturesResponse {
FeaturesResponse { groups }
}
pub async fn get_features(state: Arc<AppState>) -> Json<FeaturesResponse> {
pub async fn get_features(State(shared): State<Arc<SharedState>>) -> Json<FeaturesResponse> {
let state = shared.load_state();
info!("GET /api/features");
Json(state.features_response.clone())
}

View file

@ -2,7 +2,7 @@ use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;
use axum::extract::Query;
use axum::extract::{Query, State};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Json};
use axum::Extension;
@ -16,7 +16,7 @@ use crate::parsing::{
cell_for_row_cached, h3_cell_bounds, needs_parent, parse_field_set, parse_filters,
row_passes_filters, validate_h3_resolution,
};
use crate::state::AppState;
use crate::state::SharedState;
use super::stats;
@ -79,10 +79,11 @@ pub struct HexagonStatsParams {
}
pub async fn get_hexagon_stats(
state: Arc<AppState>,
State(shared): State<Arc<SharedState>>,
Extension(user): Extension<OptionalUser>,
Query(params): Query<HexagonStatsParams>,
) -> Result<Json<HexagonStatsResponse>, axum::response::Response> {
let state = shared.load_state();
let cell = h3o::CellIndex::from_str(&params.h3).map_err(|error| {
warn!(h3 = %params.h3, error = %error, "Invalid H3 cell index");
(

View file

@ -1,6 +1,6 @@
use std::sync::Arc;
use axum::extract::Query;
use axum::extract::{Query, State};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Json};
use axum::Extension;
@ -21,7 +21,7 @@ use crate::parsing::{
row_passes_filters, validate_h3_resolution,
};
use crate::routes::travel_time::{parse_optional_travel, TravelTimeAgg};
use crate::state::AppState;
use crate::state::SharedState;
/// Row count threshold above which we use rayon parallel aggregation.
const PARALLEL_THRESHOLD: usize = 50_000;
@ -182,10 +182,11 @@ fn build_feature_maps(
}
pub async fn get_hexagons(
state: Arc<AppState>,
State(shared): State<Arc<SharedState>>,
Extension(user): Extension<OptionalUser>,
Query(params): Query<HexagonParams>,
) -> Result<Json<HexagonsResponse>, axum::response::Response> {
let state = shared.load_state();
let resolution = params.resolution;
validate_h3_resolution(resolution).map_err(IntoResponse::into_response)?;

View file

@ -1,6 +1,6 @@
use std::sync::Arc;
use axum::extract::Path;
use axum::extract::{Path, State};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::{Extension, Json};
@ -9,7 +9,7 @@ use tracing::{info, warn};
use crate::auth::OptionalUser;
use crate::pocketbase::auth_superuser;
use crate::state::AppState;
use crate::state::SharedState;
#[derive(Serialize)]
struct InviteResponse {
@ -90,10 +90,11 @@ fn generate_invite_code() -> String {
/// Create an invite. Admins create "admin" invites (free license) by default,
/// but can explicitly request "referral" type. Licensed non-admin users always create "referral" invites (30% off).
pub async fn post_invites(
state: Arc<AppState>,
State(shared): State<Arc<SharedState>>,
Extension(user): Extension<OptionalUser>,
Json(body): Json<CreateInviteRequest>,
) -> Response {
let state = shared.load_state();
let user = match user.0 {
Some(u) => u,
None => return StatusCode::UNAUTHORIZED.into_response(),
@ -179,10 +180,11 @@ const DEV_INVITE_CODE: &str = "devdevdevdev";
/// Validate an invite code. Public endpoint — codes are 12-char random alphanumeric
/// so enumeration is impractical, and the response only reveals valid/invalid + type.
pub async fn get_invite(
state: Arc<AppState>,
State(shared): State<Arc<SharedState>>,
Extension(_user): Extension<OptionalUser>,
Path(code): Path<String>,
) -> Response {
let state = shared.load_state();
if let Err(msg) = validate_invite_code(&code) {
return (StatusCode::BAD_REQUEST, msg).into_response();
}
@ -297,10 +299,11 @@ pub async fn get_invite(
/// Admin invite: sets subscription to "licensed" directly.
/// Referral invite: returns a discounted Stripe checkout URL.
pub async fn post_redeem_invite(
state: Arc<AppState>,
State(shared): State<Arc<SharedState>>,
Extension(user): Extension<OptionalUser>,
Json(req): Json<RedeemRequest>,
) -> Response {
let state = shared.load_state();
let user = match user.0 {
Some(u) => u,
None => return StatusCode::UNAUTHORIZED.into_response(),
@ -486,9 +489,10 @@ pub async fn post_redeem_invite(
/// List invites. Admins see all invites; licensed users see only their own.
pub async fn get_invites(
state: Arc<AppState>,
State(shared): State<Arc<SharedState>>,
Extension(user): Extension<OptionalUser>,
) -> Response {
let state = shared.load_state();
let user = match user.0 {
Some(u) => u,
None => return StatusCode::UNAUTHORIZED.into_response(),

View file

@ -1,10 +1,11 @@
use std::sync::Arc;
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::Json;
use serde::{Deserialize, Serialize};
use crate::state::AppState;
use crate::state::SharedState;
#[derive(Deserialize)]
pub struct JourneyQuery {
@ -24,9 +25,10 @@ pub struct JourneyResponse {
}
pub async fn get_journey(
state: Arc<AppState>,
State(shared): State<Arc<SharedState>>,
query: axum::extract::Query<JourneyQuery>,
) -> Result<Json<JourneyResponse>, (StatusCode, String)> {
let state = shared.load_state();
let store = &state.travel_time_store;
if !store.has_destination(&query.mode, &query.slug) {

View file

@ -1,5 +1,6 @@
use std::sync::Arc;
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::{Extension, Json};
@ -8,7 +9,7 @@ use tracing::warn;
use crate::auth::OptionalUser;
use crate::pocketbase::auth_superuser;
use crate::state::AppState;
use crate::state::SharedState;
#[derive(Deserialize)]
pub struct UpdateNewsletterRequest {
@ -16,10 +17,11 @@ pub struct UpdateNewsletterRequest {
}
pub async fn patch_newsletter(
state: Arc<AppState>,
State(shared): State<Arc<SharedState>>,
Extension(user): Extension<OptionalUser>,
Json(req): Json<UpdateNewsletterRequest>,
) -> Response {
let state = shared.load_state();
let user = match user.0 {
Some(u) => u,
None => return StatusCode::UNAUTHORIZED.into_response(),

View file

@ -2,12 +2,12 @@ use std::sync::{Arc, LazyLock};
use std::time::Duration;
use axum::body::Body;
use axum::extract::Request;
use axum::extract::{Request, State};
use axum::http::{HeaderName, StatusCode};
use axum::response::{IntoResponse, Response};
use tracing::warn;
use crate::state::AppState;
use crate::state::SharedState;
/// Dedicated HTTP client for proxying — does not follow redirects so 3xx
/// responses are passed through to the browser (needed for OAuth flows).
@ -22,7 +22,8 @@ static PROXY_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
.expect("Failed to build proxy HTTP client")
});
pub async fn proxy_to_pocketbase(state: Arc<AppState>, req: Request) -> impl IntoResponse {
pub async fn proxy_to_pocketbase(State(shared): State<Arc<SharedState>>, req: Request) -> impl IntoResponse {
let state = shared.load_state();
let pb_url = state.pocketbase_url.trim_end_matches('/');
let path = req.uri().path();

View file

@ -1,13 +1,13 @@
use std::sync::Arc;
use axum::extract::Query;
use axum::extract::{Query, State};
use axum::http::StatusCode;
use axum::response::Json;
use serde::{Deserialize, Serialize};
use tracing::info;
use crate::data::slugify;
use crate::state::AppState;
use crate::state::SharedState;
#[derive(Serialize)]
pub struct PlaceResult {
@ -35,9 +35,10 @@ pub struct PlacesParams {
}
pub async fn get_places(
state: Arc<AppState>,
State(shared): State<Arc<SharedState>>,
Query(params): Query<PlacesParams>,
) -> Result<Json<PlacesResponse>, (StatusCode, String)> {
let state = shared.load_state();
let query = if params.q.is_empty() {
return Err((StatusCode::BAD_REQUEST, "'q' must not be empty".into()));
} else {

View file

@ -1,6 +1,6 @@
use std::sync::Arc;
use axum::extract::Query;
use axum::extract::{Query, State};
use axum::http::StatusCode;
use axum::response::Json;
use serde::{Deserialize, Serialize};
@ -9,7 +9,7 @@ use tracing::info;
use crate::consts::MAX_POIS_PER_REQUEST;
use crate::data::POICategoryGroup;
use crate::parsing::require_bounds;
use crate::state::AppState;
use crate::state::SharedState;
#[derive(Serialize)]
#[allow(clippy::upper_case_acronyms)]
@ -36,9 +36,10 @@ pub struct POIParams {
}
pub async fn get_pois(
state: Arc<AppState>,
State(shared): State<Arc<SharedState>>,
Query(params): Query<POIParams>,
) -> Result<Json<POIsResponse>, (StatusCode, String)> {
let state = shared.load_state();
let (south, west, north, east) = require_bounds(params.bounds)?;
let category_filter: Option<rustc_hash::FxHashSet<u16>> = params
@ -127,7 +128,8 @@ pub struct POICategoriesResponse {
groups: Vec<POICategoryGroup>,
}
pub async fn get_poi_categories(state: Arc<AppState>) -> Json<POICategoriesResponse> {
pub async fn get_poi_categories(State(shared): State<Arc<SharedState>>) -> Json<POICategoriesResponse> {
let state = shared.load_state();
let groups: Vec<POICategoryGroup> = state.poi_category_groups.to_vec();
let total: usize = groups.iter().map(|group| group.categories.len()).sum();

View file

@ -1,6 +1,6 @@
use std::sync::Arc;
use axum::extract::Query;
use axum::extract::{Query, State};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Json};
use axum::Extension;
@ -11,7 +11,7 @@ use crate::auth::OptionalUser;
use crate::consts::{DEFAULT_PROPERTIES_LIMIT, MAX_PROPERTIES_LIMIT, POSTCODE_SEARCH_OFFSET};
use crate::licensing::check_license_point;
use crate::parsing::{parse_filters, row_passes_filters};
use crate::state::AppState;
use crate::state::SharedState;
use crate::utils::normalize_postcode;
use super::properties::{HexagonPropertiesResponse, Property};
@ -25,10 +25,11 @@ pub struct PostcodePropertiesParams {
}
pub async fn get_postcode_properties(
state: Arc<AppState>,
State(shared): State<Arc<SharedState>>,
Extension(user): Extension<OptionalUser>,
Query(params): Query<PostcodePropertiesParams>,
) -> Result<Json<HexagonPropertiesResponse>, axum::response::Response> {
let state = shared.load_state();
let normalized = normalize_postcode(&params.postcode);
let pc_idx = match state.postcode_data.postcode_to_idx.get(&normalized) {

View file

@ -1,6 +1,6 @@
use std::sync::Arc;
use axum::extract::Query;
use axum::extract::{Query, State};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Json};
use axum::Extension;
@ -11,7 +11,7 @@ use crate::auth::OptionalUser;
use crate::consts::POSTCODE_SEARCH_OFFSET;
use crate::licensing::check_license_point;
use crate::parsing::{parse_field_set, parse_filters, row_passes_filters};
use crate::state::AppState;
use crate::state::SharedState;
use crate::utils::normalize_postcode;
use super::hexagon_stats::HexagonStatsResponse;
@ -27,10 +27,11 @@ pub struct PostcodeStatsParams {
}
pub async fn get_postcode_stats(
state: Arc<AppState>,
State(shared): State<Arc<SharedState>>,
Extension(user): Extension<OptionalUser>,
Query(params): Query<PostcodeStatsParams>,
) -> Result<Json<HexagonStatsResponse>, axum::response::Response> {
let state = shared.load_state();
let normalized = normalize_postcode(&params.postcode);
// Look up postcode centroid for spatial search

View file

@ -1,6 +1,6 @@
use std::sync::Arc;
use axum::extract::{Path, Query};
use axum::extract::{Path, Query, State};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Json};
use axum::Extension;
@ -19,7 +19,7 @@ use crate::parsing::{
bounds_intersect, parse_field_indices, parse_filters, require_bounds, row_passes_filters,
};
use crate::routes::travel_time::{parse_optional_travel, TravelTimeAgg};
use crate::state::AppState;
use crate::state::SharedState;
use crate::utils::normalize_postcode;
#[derive(Serialize)]
@ -40,10 +40,11 @@ pub struct PostcodeParams {
}
pub async fn get_postcodes(
state: Arc<AppState>,
State(shared): State<Arc<SharedState>>,
Extension(user): Extension<OptionalUser>,
Query(params): Query<PostcodeParams>,
) -> Result<Json<PostcodesResponse>, axum::response::Response> {
let state = shared.load_state();
let (south, west, north, east) =
require_bounds(params.bounds).map_err(IntoResponse::into_response)?;
@ -312,9 +313,10 @@ pub async fn get_postcodes(
/// Look up a single postcode and return its centroid coordinates and geometry.
pub async fn get_postcode_lookup(
state: Arc<AppState>,
State(shared): State<Arc<SharedState>>,
Path(postcode): Path<String>,
) -> Result<Json<Value>, StatusCode> {
let state = shared.load_state();
let normalized = normalize_postcode(&postcode);
let postcode_data = &state.postcode_data;

View file

@ -1,5 +1,6 @@
use std::sync::Arc;
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::Json;
@ -7,7 +8,7 @@ use serde::Serialize;
use tracing::warn;
use crate::pocketbase::auth_superuser;
use crate::state::AppState;
use crate::state::{AppState, SharedState};
/// Pricing tiers: (cumulative user cap, price in pence).
const TIERS: &[(u64, u64)] = &[
@ -75,7 +76,8 @@ pub async fn count_licensed_users(state: &AppState) -> anyhow::Result<u64> {
Ok(total)
}
pub async fn get_pricing(state: Arc<AppState>) -> Response {
pub async fn get_pricing(State(shared): State<Arc<SharedState>>) -> Response {
let state = shared.load_state();
let count = match count_licensed_users(&state).await {
Ok(c) => c,
Err(err) => {

View file

@ -1,7 +1,7 @@
use std::str::FromStr;
use std::sync::Arc;
use axum::extract::Query;
use axum::extract::{Query, State};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Json};
use axum::Extension;
@ -17,7 +17,7 @@ use crate::parsing::{
cell_for_row_cached, h3_cell_bounds, needs_parent, parse_filters, row_passes_filters,
validate_h3_resolution,
};
use crate::state::AppState;
use crate::state::{AppState, SharedState};
#[derive(Deserialize)]
pub struct HexagonPropertiesParams {
@ -173,10 +173,11 @@ pub fn build_property(
}
pub async fn get_hexagon_properties(
state: Arc<AppState>,
State(shared): State<Arc<SharedState>>,
Extension(user): Extension<OptionalUser>,
Query(params): Query<HexagonPropertiesParams>,
) -> Result<Json<HexagonPropertiesResponse>, axum::response::Response> {
let state = shared.load_state();
let cell = h3o::CellIndex::from_str(&params.h3).map_err(|error| {
warn!(h3 = %params.h3, error = %error, "Invalid H3 cell index");
(

View file

@ -1,6 +1,7 @@
use std::sync::Arc;
use std::time::Instant;
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Json, Response};
use serde_json::json;
@ -13,7 +14,7 @@ use crate::routes::{build_features_response, build_system_prompt};
use crate::state::{AppState, SharedState};
use crate::utils::GridIndex;
pub async fn post_reload(shared: Arc<SharedState>) -> Response {
pub async fn post_reload(State(shared): State<Arc<SharedState>>) -> Response {
if !shared.try_start_reload() {
return (StatusCode::CONFLICT, "Reload already in progress").into_response();
}

View file

@ -1,12 +1,13 @@
use std::sync::Arc;
use axum::extract::State;
use axum::http::header::HeaderValue;
use axum::http::{header, HeaderMap, StatusCode, Uri};
use axum::response::IntoResponse;
use metrics::histogram;
use tracing::{info, warn};
use crate::state::AppState;
use crate::state::{AppState, SharedState};
/// Fetch a JPEG screenshot from the screenshot service.
/// Used by both the `/api/screenshot` proxy and the xlsx export.
@ -39,10 +40,11 @@ pub async fn fetch_screenshot_bytes(
}
pub async fn get_screenshot(
state: Arc<AppState>,
State(shared): State<Arc<SharedState>>,
headers: HeaderMap,
uri: Uri,
) -> impl IntoResponse {
let state = shared.load_state();
let qs = uri.query().unwrap_or_default();
let auth = headers.get(header::AUTHORIZATION);
let is_og = qs.contains("og=1");

View file

@ -1,6 +1,6 @@
use std::sync::Arc;
use axum::extract::Path;
use axum::extract::{Path, State};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Redirect, Response};
use axum::Json;
@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize};
use tracing::warn;
use crate::pocketbase::auth_superuser;
use crate::state::AppState;
use crate::state::SharedState;
const CODE_LEN: usize = 8;
const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyz0123456789";
@ -38,7 +38,8 @@ struct PbRecord {
params: String,
}
pub async fn post_shorten(state: Arc<AppState>, Json(req): Json<ShortenRequest>) -> Response {
pub async fn post_shorten(State(shared): State<Arc<SharedState>>, Json(req): Json<ShortenRequest>) -> Response {
let state = shared.load_state();
let pb_url = state.pocketbase_url.trim_end_matches('/');
let token = match auth_superuser(
@ -92,7 +93,9 @@ pub async fn post_shorten(state: Arc<AppState>, Json(req): Json<ShortenRequest>)
}
}
pub async fn get_short_url(state: Arc<AppState>, Path(code): Path<String>) -> Response {
pub async fn get_short_url(State(shared): State<Arc<SharedState>>, Path(code): Path<String>) -> Response {
let state = shared.load_state();
if code.is_empty() || code.len() > 20 || !code.bytes().all(|b| b.is_ascii_alphanumeric()) {
return StatusCode::BAD_REQUEST.into_response();
}

View file

@ -1,11 +1,12 @@
use std::sync::Arc;
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Json};
use serde::{Deserialize, Serialize};
use tracing::warn;
use crate::state::AppState;
use crate::state::SharedState;
#[derive(Deserialize)]
pub struct StreetViewQuery {
@ -28,9 +29,10 @@ struct StreetViewResponse {
}
pub async fn get_streetview(
state: Arc<AppState>,
State(shared): State<Arc<SharedState>>,
query: axum::extract::Query<StreetViewQuery>,
) -> impl IntoResponse {
let state = shared.load_state();
let url = format!(
"https://maps.googleapis.com/maps/api/streetview/metadata?location={},{}&radius=1000&source=outdoor&key={}",
query.lat, query.lon, state.google_maps_api_key

View file

@ -1,6 +1,7 @@
use std::sync::Arc;
use axum::body::Bytes;
use axum::extract::State;
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use hmac::{Hmac, Mac};
@ -8,7 +9,7 @@ use sha2::Sha256;
use tracing::{info, warn};
use crate::pocketbase::auth_superuser;
use crate::state::AppState;
use crate::state::SharedState;
type HmacSha256 = Hmac<Sha256>;
@ -49,10 +50,11 @@ fn verify_signature(payload: &[u8], sig_header: &str, secret: &str) -> bool {
/// Handle Stripe webhook events.
/// On `checkout.session.completed`, updates the user's subscription to "licensed".
pub async fn post_stripe_webhook(
state: Arc<AppState>,
State(shared): State<Arc<SharedState>>,
headers: HeaderMap,
body: Bytes,
) -> Response {
let state = shared.load_state();
let webhook_secret = &state.stripe_webhook_secret;
let sig_header = match headers

View file

@ -1,6 +1,6 @@
use std::sync::Arc;
use axum::extract::Query;
use axum::extract::{Query, State};
use axum::http::StatusCode;
use axum::response::Json;
use rustc_hash::FxHashSet;
@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize};
use tracing::info;
use crate::data::slugify;
use crate::state::AppState;
use crate::state::SharedState;
#[derive(Serialize)]
pub struct DestinationResult {
@ -30,9 +30,10 @@ pub struct DestinationsParams {
}
pub async fn get_travel_destinations(
state: Arc<AppState>,
State(shared): State<Arc<SharedState>>,
Query(params): Query<DestinationsParams>,
) -> Result<Json<DestinationsResponse>, (StatusCode, String)> {
let state = shared.load_state();
let mode = params.mode;
let destinations = tokio::task::spawn_blocking(move || {

View file

@ -1,10 +1,11 @@
use std::sync::Arc;
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::Json;
use serde::Serialize;
use crate::state::AppState;
use crate::state::SharedState;
#[derive(Serialize)]
pub struct TravelModeInfo {
@ -18,8 +19,9 @@ pub struct TravelModesResponse {
}
pub async fn get_travel_modes(
state: Arc<AppState>,
State(shared): State<Arc<SharedState>>,
) -> Result<Json<TravelModesResponse>, (StatusCode, String)> {
let state = shared.load_state();
let store = &state.travel_time_store;
let modes = store
.available_modes