diff --git a/server-rs/src/main.rs b/server-rs/src/main.rs index 881a4ef..a8a3373 100644 --- a/server-rs/src/main.rs +++ b/server-rs/src/main.rs @@ -422,157 +422,78 @@ async fn main() -> anyhow::Result<()> { .allow_headers(AllowHeaders::mirror_request()) .allow_credentials(true); - // Each route closure captures a clone of `shared` and calls `load_state()` - // at request time to get the latest `Arc`. This enables hot-reload: + // Handlers use Axum's State extractor to get Arc, then call + // load_state() to get the current Arc. This enables hot-reload: // the reload endpoint swaps in a new AppState, and subsequent requests pick it up. - macro_rules! s { - () => { - shared.clone() - }; - } - - let (s1, s2, s3, s4, s5, s6) = (s!(), s!(), s!(), s!(), s!(), s!()); - let (s7, s8, s9, s10, s11, s12) = (s!(), s!(), s!(), s!(), s!(), s!()); - let (s13, s14, s15, s16, s17, s18) = (s!(), s!(), s!(), s!(), s!(), s!()); - let (s19, s20, s21, s22, s23, s24) = (s!(), s!(), s!(), s!(), s!(), s!()); - let (s25, s26, s27, s28, s29) = (s!(), s!(), s!(), s!(), s!()); let s_crawler = shared.clone(); - let s_pb = shared.clone(); - let s_reload = shared.clone(); - let api = Router::new() - .route( - "/api/features", - get(move || routes::get_features(s1.load_state())), - ) - .route( - "/api/hexagons", - get(move |ext, query| routes::get_hexagons(s2.load_state(), ext, query)), - ) - .route( - "/api/postcodes", - get(move |ext, query| routes::get_postcodes(s3.load_state(), ext, query)), - ) - .route( - "/api/postcode/{postcode}", - get(move |path| routes::get_postcode_lookup(s4.load_state(), path)), - ) - .route( - "/api/pois", - get(move |query| routes::get_pois(s5.load_state(), query)), - ) - .route( - "/api/poi-categories", - get(move || routes::get_poi_categories(s6.load_state())), - ) - .route( - "/api/places", - get(move |query| routes::get_places(s7.load_state(), query)), - ) - .route( - "/api/travel-modes", - get(move || routes::get_travel_modes(s8.load_state())), - ) - .route( - "/api/travel-destinations", - get(move |query| routes::get_travel_destinations(s9.load_state(), query)), - ) - .route( - "/api/journey", - get(move |query| routes::get_journey(s10.load_state(), query)), - ) - .route( - "/api/hexagon-properties", - get(move |ext, query| routes::get_hexagon_properties(s11.load_state(), ext, query)), - ) - .route( - "/api/hexagon-stats", - get(move |ext, query| routes::get_hexagon_stats(s12.load_state(), ext, query)), - ) - .route( - "/api/postcode-stats", - get(move |ext, query| routes::get_postcode_stats(s13.load_state(), ext, query)), - ) - .route( - "/api/postcode-properties", - get(move |ext, query| routes::get_postcode_properties(s14.load_state(), ext, query)), - ) - .route( - "/api/screenshot", - get(move |headers, query| routes::get_screenshot(s15.load_state(), headers, query)), - ) - .route( - "/api/export", - get(move |headers, ext, query| { - routes::get_export(s16.load_state(), headers, ext, query) - }) - .layer(ConcurrencyLimitLayer::new(3)), - ) - .route("/api/me", get(routes::get_me)) - .route( - "/api/shorten", - post(move |body| routes::post_shorten(s17.load_state(), body)), - ) - .route( - "/api/ai-filters", - post(move |ext, body| routes::post_ai_filters(s18.load_state(), ext, body)) - .layer(ConcurrencyLimitLayer::new(5)), - ) - .route( - "/api/streetview", - get(move |query| routes::get_streetview(s19.load_state(), query)), - ) - .route( - "/api/newsletter", - patch(move |ext, body| routes::patch_newsletter(s20.load_state(), ext, body)), - ) - .route( - "/api/pricing", - get(move || routes::get_pricing(s21.load_state())), - ) - .route( - "/api/checkout", - post(move |ext, body| routes::post_checkout(s22.load_state(), ext, body)) - .layer(ConcurrencyLimitLayer::new(10)), - ) - .route( - "/api/stripe-webhook", - post(move |headers, body| routes::post_stripe_webhook(s23.load_state(), headers, body)), - ) - .route( - "/api/invites", - get(move |ext| routes::get_invites(s24.load_state(), ext)) - .post(move |ext, body| routes::post_invites(s25.load_state(), ext, body)), - ) - .route( - "/api/invite/{code}", - get(move |ext, path| routes::get_invite(s26.load_state(), ext, path)), - ) - .route( - "/api/redeem-invite", - post(move |ext, body| routes::post_redeem_invite(s27.load_state(), ext, body)), - ) - .route( - "/s/{code}", - get(move |path| routes::get_short_url(s28.load_state(), path)), - ) - .route( - "/api/telemetry", - post(move |ext, headers, body| { - let _ = s29.load_state(); - routes::post_telemetry(ext, headers, body) - }), - ) - .route( - "/api/reload", - post(move || routes::post_reload(s_reload.clone())), - ); - - // Add tile routes let reader_tile = tile_reader.clone(); let reader_style = tile_reader.clone(); let public_url_tiles = initial_state.public_url.clone(); - let api = api + + let api = Router::new() + .route("/api/features", get(routes::get_features)) + .route("/api/hexagons", get(routes::get_hexagons)) + .route("/api/postcodes", get(routes::get_postcodes)) + .route( + "/api/postcode/{postcode}", + get(routes::get_postcode_lookup), + ) + .route("/api/pois", get(routes::get_pois)) + .route("/api/poi-categories", get(routes::get_poi_categories)) + .route("/api/places", get(routes::get_places)) + .route("/api/travel-modes", get(routes::get_travel_modes)) + .route( + "/api/travel-destinations", + get(routes::get_travel_destinations), + ) + .route("/api/journey", get(routes::get_journey)) + .route( + "/api/hexagon-properties", + get(routes::get_hexagon_properties), + ) + .route("/api/hexagon-stats", get(routes::get_hexagon_stats)) + .route("/api/postcode-stats", get(routes::get_postcode_stats)) + .route( + "/api/postcode-properties", + get(routes::get_postcode_properties), + ) + .route("/api/screenshot", get(routes::get_screenshot)) + .route( + "/api/export", + get(routes::get_export).layer(ConcurrencyLimitLayer::new(3)), + ) + .route("/api/me", get(routes::get_me)) + .route("/api/shorten", post(routes::post_shorten)) + .route( + "/api/ai-filters", + post(routes::post_ai_filters).layer(ConcurrencyLimitLayer::new(5)), + ) + .route("/api/streetview", get(routes::get_streetview)) + .route("/api/newsletter", patch(routes::patch_newsletter)) + .route("/api/pricing", get(routes::get_pricing)) + .route( + "/api/checkout", + post(routes::post_checkout).layer(ConcurrencyLimitLayer::new(10)), + ) + .route( + "/api/stripe-webhook", + post(routes::post_stripe_webhook), + ) + .route( + "/api/invites", + get(routes::get_invites).post(routes::post_invites), + ) + .route("/api/invite/{code}", get(routes::get_invite)) + .route("/api/redeem-invite", post(routes::post_redeem_invite)) + .route("/s/{code}", get(routes::get_short_url)) + .route("/api/telemetry", post(routes::post_telemetry)) + .route("/api/reload", post(routes::post_reload)) + .route( + "/pb/{*rest}", + any(routes::proxy_to_pocketbase), + ) + // Tile routes use a different state type — kept as closures .route( "/api/tiles/{z}/{x}/{y}", get(move |path| routes::get_tile(axum::extract::State(reader_tile.clone()), path)), @@ -589,10 +510,7 @@ async fn main() -> anyhow::Result<()> { "/metrics", get(move || metrics::metrics_handler(metrics_handle.clone())), ) - .route( - "/pb/{*rest}", - any(move |req| routes::proxy_to_pocketbase(s_pb.load_state(), req)), - ); + .with_state(shared.clone()); let app = if let Some(ref dist) = cli.dist { api.fallback_service(ServeDir::new(dist).fallback(ServeFile::new(dist.join("index.html")))) diff --git a/server-rs/src/routes/ai_filters.rs b/server-rs/src/routes/ai_filters.rs index 2ec1f6c..d451676 100644 --- a/server-rs/src/routes/ai_filters.rs +++ b/server-rs/src/routes/ai_filters.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use axum::extract::State; use axum::http::StatusCode; use axum::response::Json; use axum::Extension; @@ -13,7 +14,7 @@ use crate::consts::{AI_FILTERS_MAX_TOKENS, AI_FILTERS_TEMPERATURE, AI_FILTERS_WE use crate::data::slugify; use crate::pocketbase::auth_superuser; use crate::routes::{FeatureInfo, FeaturesResponse}; -use crate::state::AppState; +use crate::state::{AppState, SharedState}; use crate::utils::gemini_chat; #[derive(Deserialize)] @@ -501,10 +502,11 @@ async fn update_ai_usage(state: &AppState, user_id: &str, tokens_used: u64, week const MAX_TOOL_ROUNDS: usize = 5; pub async fn post_ai_filters( - state: Arc, + State(shared): State>, Extension(user): Extension, Json(req): Json, ) -> Result, (StatusCode, String)> { + let state = shared.load_state(); // Auth check let user = user .0 diff --git a/server-rs/src/routes/checkout.rs b/server-rs/src/routes/checkout.rs index 463ed74..2fda1b2 100644 --- a/server-rs/src/routes/checkout.rs +++ b/server-rs/src/routes/checkout.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use axum::extract::State; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use axum::{Extension, Json}; @@ -8,7 +9,7 @@ use tracing::{info, warn}; use crate::auth::OptionalUser; use crate::pocketbase::auth_superuser; -use crate::state::AppState; +use crate::state::{AppState, SharedState}; use super::pricing::{count_licensed_users, price_for_count}; @@ -25,10 +26,11 @@ struct CheckoutResponse { /// Create a Stripe Checkout session for the lifetime license (or grant for free if in free tier). /// Requires authentication. Optionally accepts a referral code to apply a coupon. pub async fn post_checkout( - state: Arc, + State(shared): State>, Extension(user): Extension, Json(req): Json, ) -> Response { + let state = shared.load_state(); let user = match user.0 { Some(u) => u, None => return StatusCode::UNAUTHORIZED.into_response(), diff --git a/server-rs/src/routes/export.rs b/server-rs/src/routes/export.rs index 7a45ba4..bbd735a 100644 --- a/server-rs/src/routes/export.rs +++ b/server-rs/src/routes/export.rs @@ -2,7 +2,7 @@ use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use axum::extract::Query; +use axum::extract::{Query, State}; use axum::http::{header, HeaderMap, StatusCode}; use axum::response::IntoResponse; use axum::Extension; @@ -17,7 +17,7 @@ use crate::data::QuantRef; use crate::licensing::check_license_bounds; use crate::parsing::{parse_field_indices, parse_filters, require_bounds, row_passes_filters}; use crate::routes::{fetch_screenshot_bytes, FeatureInfo}; -use crate::state::AppState; +use crate::state::SharedState; const MAX_EXPORT_POSTCODES: usize = 250; /// Height (in pixels) reserved for the screenshot row @@ -125,11 +125,12 @@ fn build_frontend_params( } pub async fn get_export( - state: Arc, + State(shared): State>, headers: HeaderMap, Extension(user): Extension, Query(params): Query, ) -> Result { + let state = shared.load_state(); let (south, west, north, east) = require_bounds(params.bounds).map_err(IntoResponse::into_response)?; diff --git a/server-rs/src/routes/features.rs b/server-rs/src/routes/features.rs index bde721c..20c082f 100644 --- a/server-rs/src/routes/features.rs +++ b/server-rs/src/routes/features.rs @@ -1,12 +1,13 @@ use std::sync::Arc; +use axum::extract::State; use axum::response::Json; use serde::Serialize; use tracing::info; use crate::data::{Histogram, PropertyData}; use crate::features::{ENUM_FEATURE_GROUPS, FEATURE_GROUPS}; -use crate::state::AppState; +use crate::state::SharedState; fn is_empty(val: &str) -> bool { val.is_empty() @@ -154,7 +155,8 @@ pub fn build_features_response(data: &PropertyData) -> FeaturesResponse { FeaturesResponse { groups } } -pub async fn get_features(state: Arc) -> Json { +pub async fn get_features(State(shared): State>) -> Json { + let state = shared.load_state(); info!("GET /api/features"); Json(state.features_response.clone()) } diff --git a/server-rs/src/routes/hexagon_stats.rs b/server-rs/src/routes/hexagon_stats.rs index 362cdbb..38d1c4b 100644 --- a/server-rs/src/routes/hexagon_stats.rs +++ b/server-rs/src/routes/hexagon_stats.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::str::FromStr; use std::sync::Arc; -use axum::extract::Query; +use axum::extract::{Query, State}; use axum::http::StatusCode; use axum::response::{IntoResponse, Json}; use axum::Extension; @@ -16,7 +16,7 @@ use crate::parsing::{ cell_for_row_cached, h3_cell_bounds, needs_parent, parse_field_set, parse_filters, row_passes_filters, validate_h3_resolution, }; -use crate::state::AppState; +use crate::state::SharedState; use super::stats; @@ -79,10 +79,11 @@ pub struct HexagonStatsParams { } pub async fn get_hexagon_stats( - state: Arc, + State(shared): State>, Extension(user): Extension, Query(params): Query, ) -> Result, axum::response::Response> { + let state = shared.load_state(); let cell = h3o::CellIndex::from_str(¶ms.h3).map_err(|error| { warn!(h3 = %params.h3, error = %error, "Invalid H3 cell index"); ( diff --git a/server-rs/src/routes/hexagons.rs b/server-rs/src/routes/hexagons.rs index 8dc0508..a91fd27 100644 --- a/server-rs/src/routes/hexagons.rs +++ b/server-rs/src/routes/hexagons.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use axum::extract::Query; +use axum::extract::{Query, State}; use axum::http::StatusCode; use axum::response::{IntoResponse, Json}; use axum::Extension; @@ -21,7 +21,7 @@ use crate::parsing::{ row_passes_filters, validate_h3_resolution, }; use crate::routes::travel_time::{parse_optional_travel, TravelTimeAgg}; -use crate::state::AppState; +use crate::state::SharedState; /// Row count threshold above which we use rayon parallel aggregation. const PARALLEL_THRESHOLD: usize = 50_000; @@ -182,10 +182,11 @@ fn build_feature_maps( } pub async fn get_hexagons( - state: Arc, + State(shared): State>, Extension(user): Extension, Query(params): Query, ) -> Result, axum::response::Response> { + let state = shared.load_state(); let resolution = params.resolution; validate_h3_resolution(resolution).map_err(IntoResponse::into_response)?; diff --git a/server-rs/src/routes/invites.rs b/server-rs/src/routes/invites.rs index a52a207..10d8324 100644 --- a/server-rs/src/routes/invites.rs +++ b/server-rs/src/routes/invites.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use axum::extract::Path; +use axum::extract::{Path, State}; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use axum::{Extension, Json}; @@ -9,7 +9,7 @@ use tracing::{info, warn}; use crate::auth::OptionalUser; use crate::pocketbase::auth_superuser; -use crate::state::AppState; +use crate::state::SharedState; #[derive(Serialize)] struct InviteResponse { @@ -90,10 +90,11 @@ fn generate_invite_code() -> String { /// Create an invite. Admins create "admin" invites (free license) by default, /// but can explicitly request "referral" type. Licensed non-admin users always create "referral" invites (30% off). pub async fn post_invites( - state: Arc, + State(shared): State>, Extension(user): Extension, Json(body): Json, ) -> Response { + let state = shared.load_state(); let user = match user.0 { Some(u) => u, None => return StatusCode::UNAUTHORIZED.into_response(), @@ -179,10 +180,11 @@ const DEV_INVITE_CODE: &str = "devdevdevdev"; /// Validate an invite code. Public endpoint — codes are 12-char random alphanumeric /// so enumeration is impractical, and the response only reveals valid/invalid + type. pub async fn get_invite( - state: Arc, + State(shared): State>, Extension(_user): Extension, Path(code): Path, ) -> Response { + let state = shared.load_state(); if let Err(msg) = validate_invite_code(&code) { return (StatusCode::BAD_REQUEST, msg).into_response(); } @@ -297,10 +299,11 @@ pub async fn get_invite( /// Admin invite: sets subscription to "licensed" directly. /// Referral invite: returns a discounted Stripe checkout URL. pub async fn post_redeem_invite( - state: Arc, + State(shared): State>, Extension(user): Extension, Json(req): Json, ) -> Response { + let state = shared.load_state(); let user = match user.0 { Some(u) => u, None => return StatusCode::UNAUTHORIZED.into_response(), @@ -486,9 +489,10 @@ pub async fn post_redeem_invite( /// List invites. Admins see all invites; licensed users see only their own. pub async fn get_invites( - state: Arc, + State(shared): State>, Extension(user): Extension, ) -> Response { + let state = shared.load_state(); let user = match user.0 { Some(u) => u, None => return StatusCode::UNAUTHORIZED.into_response(), diff --git a/server-rs/src/routes/journey.rs b/server-rs/src/routes/journey.rs index 835e9ff..31d3db2 100644 --- a/server-rs/src/routes/journey.rs +++ b/server-rs/src/routes/journey.rs @@ -1,10 +1,11 @@ use std::sync::Arc; +use axum::extract::State; use axum::http::StatusCode; use axum::response::Json; use serde::{Deserialize, Serialize}; -use crate::state::AppState; +use crate::state::SharedState; #[derive(Deserialize)] pub struct JourneyQuery { @@ -24,9 +25,10 @@ pub struct JourneyResponse { } pub async fn get_journey( - state: Arc, + State(shared): State>, query: axum::extract::Query, ) -> Result, (StatusCode, String)> { + let state = shared.load_state(); let store = &state.travel_time_store; if !store.has_destination(&query.mode, &query.slug) { diff --git a/server-rs/src/routes/newsletter.rs b/server-rs/src/routes/newsletter.rs index 04ff8b1..28a71bc 100644 --- a/server-rs/src/routes/newsletter.rs +++ b/server-rs/src/routes/newsletter.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use axum::extract::State; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use axum::{Extension, Json}; @@ -8,7 +9,7 @@ use tracing::warn; use crate::auth::OptionalUser; use crate::pocketbase::auth_superuser; -use crate::state::AppState; +use crate::state::SharedState; #[derive(Deserialize)] pub struct UpdateNewsletterRequest { @@ -16,10 +17,11 @@ pub struct UpdateNewsletterRequest { } pub async fn patch_newsletter( - state: Arc, + State(shared): State>, Extension(user): Extension, Json(req): Json, ) -> Response { + let state = shared.load_state(); let user = match user.0 { Some(u) => u, None => return StatusCode::UNAUTHORIZED.into_response(), diff --git a/server-rs/src/routes/pb_proxy.rs b/server-rs/src/routes/pb_proxy.rs index 4655d8d..125c084 100644 --- a/server-rs/src/routes/pb_proxy.rs +++ b/server-rs/src/routes/pb_proxy.rs @@ -2,12 +2,12 @@ use std::sync::{Arc, LazyLock}; use std::time::Duration; use axum::body::Body; -use axum::extract::Request; +use axum::extract::{Request, State}; use axum::http::{HeaderName, StatusCode}; use axum::response::{IntoResponse, Response}; use tracing::warn; -use crate::state::AppState; +use crate::state::SharedState; /// Dedicated HTTP client for proxying — does not follow redirects so 3xx /// responses are passed through to the browser (needed for OAuth flows). @@ -22,7 +22,8 @@ static PROXY_CLIENT: LazyLock = LazyLock::new(|| { .expect("Failed to build proxy HTTP client") }); -pub async fn proxy_to_pocketbase(state: Arc, req: Request) -> impl IntoResponse { +pub async fn proxy_to_pocketbase(State(shared): State>, req: Request) -> impl IntoResponse { + let state = shared.load_state(); let pb_url = state.pocketbase_url.trim_end_matches('/'); let path = req.uri().path(); diff --git a/server-rs/src/routes/places.rs b/server-rs/src/routes/places.rs index d8b75bc..9de9a7a 100644 --- a/server-rs/src/routes/places.rs +++ b/server-rs/src/routes/places.rs @@ -1,13 +1,13 @@ use std::sync::Arc; -use axum::extract::Query; +use axum::extract::{Query, State}; use axum::http::StatusCode; use axum::response::Json; use serde::{Deserialize, Serialize}; use tracing::info; use crate::data::slugify; -use crate::state::AppState; +use crate::state::SharedState; #[derive(Serialize)] pub struct PlaceResult { @@ -35,9 +35,10 @@ pub struct PlacesParams { } pub async fn get_places( - state: Arc, + State(shared): State>, Query(params): Query, ) -> Result, (StatusCode, String)> { + let state = shared.load_state(); let query = if params.q.is_empty() { return Err((StatusCode::BAD_REQUEST, "'q' must not be empty".into())); } else { diff --git a/server-rs/src/routes/pois.rs b/server-rs/src/routes/pois.rs index e079c6b..7b38822 100644 --- a/server-rs/src/routes/pois.rs +++ b/server-rs/src/routes/pois.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use axum::extract::Query; +use axum::extract::{Query, State}; use axum::http::StatusCode; use axum::response::Json; use serde::{Deserialize, Serialize}; @@ -9,7 +9,7 @@ use tracing::info; use crate::consts::MAX_POIS_PER_REQUEST; use crate::data::POICategoryGroup; use crate::parsing::require_bounds; -use crate::state::AppState; +use crate::state::SharedState; #[derive(Serialize)] #[allow(clippy::upper_case_acronyms)] @@ -36,9 +36,10 @@ pub struct POIParams { } pub async fn get_pois( - state: Arc, + State(shared): State>, Query(params): Query, ) -> Result, (StatusCode, String)> { + let state = shared.load_state(); let (south, west, north, east) = require_bounds(params.bounds)?; let category_filter: Option> = params @@ -127,7 +128,8 @@ pub struct POICategoriesResponse { groups: Vec, } -pub async fn get_poi_categories(state: Arc) -> Json { +pub async fn get_poi_categories(State(shared): State>) -> Json { + let state = shared.load_state(); let groups: Vec = state.poi_category_groups.to_vec(); let total: usize = groups.iter().map(|group| group.categories.len()).sum(); diff --git a/server-rs/src/routes/postcode_properties.rs b/server-rs/src/routes/postcode_properties.rs index 4043923..c6f6214 100644 --- a/server-rs/src/routes/postcode_properties.rs +++ b/server-rs/src/routes/postcode_properties.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use axum::extract::Query; +use axum::extract::{Query, State}; use axum::http::StatusCode; use axum::response::{IntoResponse, Json}; use axum::Extension; @@ -11,7 +11,7 @@ use crate::auth::OptionalUser; use crate::consts::{DEFAULT_PROPERTIES_LIMIT, MAX_PROPERTIES_LIMIT, POSTCODE_SEARCH_OFFSET}; use crate::licensing::check_license_point; use crate::parsing::{parse_filters, row_passes_filters}; -use crate::state::AppState; +use crate::state::SharedState; use crate::utils::normalize_postcode; use super::properties::{HexagonPropertiesResponse, Property}; @@ -25,10 +25,11 @@ pub struct PostcodePropertiesParams { } pub async fn get_postcode_properties( - state: Arc, + State(shared): State>, Extension(user): Extension, Query(params): Query, ) -> Result, axum::response::Response> { + let state = shared.load_state(); let normalized = normalize_postcode(¶ms.postcode); let pc_idx = match state.postcode_data.postcode_to_idx.get(&normalized) { diff --git a/server-rs/src/routes/postcode_stats.rs b/server-rs/src/routes/postcode_stats.rs index fdab6f5..1351aa0 100644 --- a/server-rs/src/routes/postcode_stats.rs +++ b/server-rs/src/routes/postcode_stats.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use axum::extract::Query; +use axum::extract::{Query, State}; use axum::http::StatusCode; use axum::response::{IntoResponse, Json}; use axum::Extension; @@ -11,7 +11,7 @@ use crate::auth::OptionalUser; use crate::consts::POSTCODE_SEARCH_OFFSET; use crate::licensing::check_license_point; use crate::parsing::{parse_field_set, parse_filters, row_passes_filters}; -use crate::state::AppState; +use crate::state::SharedState; use crate::utils::normalize_postcode; use super::hexagon_stats::HexagonStatsResponse; @@ -27,10 +27,11 @@ pub struct PostcodeStatsParams { } pub async fn get_postcode_stats( - state: Arc, + State(shared): State>, Extension(user): Extension, Query(params): Query, ) -> Result, axum::response::Response> { + let state = shared.load_state(); let normalized = normalize_postcode(¶ms.postcode); // Look up postcode centroid for spatial search diff --git a/server-rs/src/routes/postcodes.rs b/server-rs/src/routes/postcodes.rs index 54d78df..118948b 100644 --- a/server-rs/src/routes/postcodes.rs +++ b/server-rs/src/routes/postcodes.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use axum::extract::{Path, Query}; +use axum::extract::{Path, Query, State}; use axum::http::StatusCode; use axum::response::{IntoResponse, Json}; use axum::Extension; @@ -19,7 +19,7 @@ use crate::parsing::{ bounds_intersect, parse_field_indices, parse_filters, require_bounds, row_passes_filters, }; use crate::routes::travel_time::{parse_optional_travel, TravelTimeAgg}; -use crate::state::AppState; +use crate::state::SharedState; use crate::utils::normalize_postcode; #[derive(Serialize)] @@ -40,10 +40,11 @@ pub struct PostcodeParams { } pub async fn get_postcodes( - state: Arc, + State(shared): State>, Extension(user): Extension, Query(params): Query, ) -> Result, axum::response::Response> { + let state = shared.load_state(); let (south, west, north, east) = require_bounds(params.bounds).map_err(IntoResponse::into_response)?; @@ -312,9 +313,10 @@ pub async fn get_postcodes( /// Look up a single postcode and return its centroid coordinates and geometry. pub async fn get_postcode_lookup( - state: Arc, + State(shared): State>, Path(postcode): Path, ) -> Result, StatusCode> { + let state = shared.load_state(); let normalized = normalize_postcode(&postcode); let postcode_data = &state.postcode_data; diff --git a/server-rs/src/routes/pricing.rs b/server-rs/src/routes/pricing.rs index 0008b07..e7428c6 100644 --- a/server-rs/src/routes/pricing.rs +++ b/server-rs/src/routes/pricing.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use axum::extract::State; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use axum::Json; @@ -7,7 +8,7 @@ use serde::Serialize; use tracing::warn; use crate::pocketbase::auth_superuser; -use crate::state::AppState; +use crate::state::{AppState, SharedState}; /// Pricing tiers: (cumulative user cap, price in pence). const TIERS: &[(u64, u64)] = &[ @@ -75,7 +76,8 @@ pub async fn count_licensed_users(state: &AppState) -> anyhow::Result { Ok(total) } -pub async fn get_pricing(state: Arc) -> Response { +pub async fn get_pricing(State(shared): State>) -> Response { + let state = shared.load_state(); let count = match count_licensed_users(&state).await { Ok(c) => c, Err(err) => { diff --git a/server-rs/src/routes/properties.rs b/server-rs/src/routes/properties.rs index 1796b3a..c499318 100644 --- a/server-rs/src/routes/properties.rs +++ b/server-rs/src/routes/properties.rs @@ -1,7 +1,7 @@ use std::str::FromStr; use std::sync::Arc; -use axum::extract::Query; +use axum::extract::{Query, State}; use axum::http::StatusCode; use axum::response::{IntoResponse, Json}; use axum::Extension; @@ -17,7 +17,7 @@ use crate::parsing::{ cell_for_row_cached, h3_cell_bounds, needs_parent, parse_filters, row_passes_filters, validate_h3_resolution, }; -use crate::state::AppState; +use crate::state::{AppState, SharedState}; #[derive(Deserialize)] pub struct HexagonPropertiesParams { @@ -173,10 +173,11 @@ pub fn build_property( } pub async fn get_hexagon_properties( - state: Arc, + State(shared): State>, Extension(user): Extension, Query(params): Query, ) -> Result, axum::response::Response> { + let state = shared.load_state(); let cell = h3o::CellIndex::from_str(¶ms.h3).map_err(|error| { warn!(h3 = %params.h3, error = %error, "Invalid H3 cell index"); ( diff --git a/server-rs/src/routes/reload.rs b/server-rs/src/routes/reload.rs index e9374fb..f8cd79d 100644 --- a/server-rs/src/routes/reload.rs +++ b/server-rs/src/routes/reload.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use std::time::Instant; +use axum::extract::State; use axum::http::StatusCode; use axum::response::{IntoResponse, Json, Response}; use serde_json::json; @@ -13,7 +14,7 @@ use crate::routes::{build_features_response, build_system_prompt}; use crate::state::{AppState, SharedState}; use crate::utils::GridIndex; -pub async fn post_reload(shared: Arc) -> Response { +pub async fn post_reload(State(shared): State>) -> Response { if !shared.try_start_reload() { return (StatusCode::CONFLICT, "Reload already in progress").into_response(); } diff --git a/server-rs/src/routes/screenshot.rs b/server-rs/src/routes/screenshot.rs index e4f06d4..d6aa8f4 100644 --- a/server-rs/src/routes/screenshot.rs +++ b/server-rs/src/routes/screenshot.rs @@ -1,12 +1,13 @@ use std::sync::Arc; +use axum::extract::State; use axum::http::header::HeaderValue; use axum::http::{header, HeaderMap, StatusCode, Uri}; use axum::response::IntoResponse; use metrics::histogram; use tracing::{info, warn}; -use crate::state::AppState; +use crate::state::{AppState, SharedState}; /// Fetch a JPEG screenshot from the screenshot service. /// Used by both the `/api/screenshot` proxy and the xlsx export. @@ -39,10 +40,11 @@ pub async fn fetch_screenshot_bytes( } pub async fn get_screenshot( - state: Arc, + State(shared): State>, headers: HeaderMap, uri: Uri, ) -> impl IntoResponse { + let state = shared.load_state(); let qs = uri.query().unwrap_or_default(); let auth = headers.get(header::AUTHORIZATION); let is_og = qs.contains("og=1"); diff --git a/server-rs/src/routes/shorten.rs b/server-rs/src/routes/shorten.rs index e63b498..17905c0 100644 --- a/server-rs/src/routes/shorten.rs +++ b/server-rs/src/routes/shorten.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use axum::extract::Path; +use axum::extract::{Path, State}; use axum::http::StatusCode; use axum::response::{IntoResponse, Redirect, Response}; use axum::Json; @@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize}; use tracing::warn; use crate::pocketbase::auth_superuser; -use crate::state::AppState; +use crate::state::SharedState; const CODE_LEN: usize = 8; const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyz0123456789"; @@ -38,7 +38,8 @@ struct PbRecord { params: String, } -pub async fn post_shorten(state: Arc, Json(req): Json) -> Response { +pub async fn post_shorten(State(shared): State>, Json(req): Json) -> Response { + let state = shared.load_state(); let pb_url = state.pocketbase_url.trim_end_matches('/'); let token = match auth_superuser( @@ -92,7 +93,9 @@ pub async fn post_shorten(state: Arc, Json(req): Json) } } -pub async fn get_short_url(state: Arc, Path(code): Path) -> Response { +pub async fn get_short_url(State(shared): State>, Path(code): Path) -> Response { + let state = shared.load_state(); + if code.is_empty() || code.len() > 20 || !code.bytes().all(|b| b.is_ascii_alphanumeric()) { return StatusCode::BAD_REQUEST.into_response(); } diff --git a/server-rs/src/routes/streetview.rs b/server-rs/src/routes/streetview.rs index d611ae3..0b806b3 100644 --- a/server-rs/src/routes/streetview.rs +++ b/server-rs/src/routes/streetview.rs @@ -1,11 +1,12 @@ use std::sync::Arc; +use axum::extract::State; use axum::http::StatusCode; use axum::response::{IntoResponse, Json}; use serde::{Deserialize, Serialize}; use tracing::warn; -use crate::state::AppState; +use crate::state::SharedState; #[derive(Deserialize)] pub struct StreetViewQuery { @@ -28,9 +29,10 @@ struct StreetViewResponse { } pub async fn get_streetview( - state: Arc, + State(shared): State>, query: axum::extract::Query, ) -> impl IntoResponse { + let state = shared.load_state(); let url = format!( "https://maps.googleapis.com/maps/api/streetview/metadata?location={},{}&radius=1000&source=outdoor&key={}", query.lat, query.lon, state.google_maps_api_key diff --git a/server-rs/src/routes/stripe_webhook.rs b/server-rs/src/routes/stripe_webhook.rs index 092e373..7ffe5fe 100644 --- a/server-rs/src/routes/stripe_webhook.rs +++ b/server-rs/src/routes/stripe_webhook.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use axum::body::Bytes; +use axum::extract::State; use axum::http::{HeaderMap, StatusCode}; use axum::response::{IntoResponse, Response}; use hmac::{Hmac, Mac}; @@ -8,7 +9,7 @@ use sha2::Sha256; use tracing::{info, warn}; use crate::pocketbase::auth_superuser; -use crate::state::AppState; +use crate::state::SharedState; type HmacSha256 = Hmac; @@ -49,10 +50,11 @@ fn verify_signature(payload: &[u8], sig_header: &str, secret: &str) -> bool { /// Handle Stripe webhook events. /// On `checkout.session.completed`, updates the user's subscription to "licensed". pub async fn post_stripe_webhook( - state: Arc, + State(shared): State>, headers: HeaderMap, body: Bytes, ) -> Response { + let state = shared.load_state(); let webhook_secret = &state.stripe_webhook_secret; let sig_header = match headers diff --git a/server-rs/src/routes/travel_destinations.rs b/server-rs/src/routes/travel_destinations.rs index eeb4456..518fb96 100644 --- a/server-rs/src/routes/travel_destinations.rs +++ b/server-rs/src/routes/travel_destinations.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use axum::extract::Query; +use axum::extract::{Query, State}; use axum::http::StatusCode; use axum::response::Json; use rustc_hash::FxHashSet; @@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize}; use tracing::info; use crate::data::slugify; -use crate::state::AppState; +use crate::state::SharedState; #[derive(Serialize)] pub struct DestinationResult { @@ -30,9 +30,10 @@ pub struct DestinationsParams { } pub async fn get_travel_destinations( - state: Arc, + State(shared): State>, Query(params): Query, ) -> Result, (StatusCode, String)> { + let state = shared.load_state(); let mode = params.mode; let destinations = tokio::task::spawn_blocking(move || { diff --git a/server-rs/src/routes/travel_modes.rs b/server-rs/src/routes/travel_modes.rs index dc90f51..f4f5421 100644 --- a/server-rs/src/routes/travel_modes.rs +++ b/server-rs/src/routes/travel_modes.rs @@ -1,10 +1,11 @@ use std::sync::Arc; +use axum::extract::State; use axum::http::StatusCode; use axum::response::Json; use serde::Serialize; -use crate::state::AppState; +use crate::state::SharedState; #[derive(Serialize)] pub struct TravelModeInfo { @@ -18,8 +19,9 @@ pub struct TravelModesResponse { } pub async fn get_travel_modes( - state: Arc, + State(shared): State>, ) -> Result, (StatusCode, String)> { + let state = shared.load_state(); let store = &state.travel_time_store; let modes = store .available_modes