perfect-postcode/server-rs/src/main.rs
2026-05-12 22:13:07 +01:00

653 lines
21 KiB
Rust

#![allow(clippy::min_ident_chars)]
mod aggregation;
mod auth;
mod checkout_sessions;
mod consts;
mod data;
mod features;
mod licensing;
mod metrics;
mod og_middleware;
pub mod parsing;
mod pocketbase;
mod pocketbase_locks;
mod routes;
mod state;
pub mod utils;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{bail, Context};
use axum::http::{header, HeaderValue};
use axum::middleware;
use axum::routing::{any, get, patch, post};
use axum::Router;
use clap::Parser;
use consts::SERVICE_CALL_TIMEOUT;
use tower::limit::ConcurrencyLimitLayer;
use tower_http::compression::CompressionLayer;
use tower_http::cors::{AllowHeaders, AllowMethods, CorsLayer};
use tower_http::services::{ServeDir, ServeFile};
use tower_http::trace::TraceLayer;
use tracing::info;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::EnvFilter;
use state::{AppState, SharedState};
fn is_api_path(path: &str) -> bool {
path.starts_with("/api/")
|| path.starts_with("/pb/")
|| path.starts_with("/s/")
|| matches!(path, "/health" | "/metrics")
}
fn is_fingerprinted_asset(path: &str) -> bool {
let Some(filename) = path.rsplit('/').next() else {
return false;
};
let Some((stem, extension)) = filename.rsplit_once('.') else {
return false;
};
if !matches!(extension, "css" | "js") {
return false;
}
let Some((_, hash)) = stem.rsplit_once('.') else {
return false;
};
hash.len() >= 8 && hash.bytes().all(|byte| byte.is_ascii_hexdigit())
}
fn is_static_asset_path(path: &str) -> bool {
path.rsplit('/')
.next()
.is_some_and(|segment| segment.contains('.'))
}
async fn static_cache_headers(
request: axum::extract::Request,
next: middleware::Next,
) -> axum::response::Response {
let path = request.uri().path().to_string();
let mut response = next.run(request).await;
if is_api_path(&path) || response.headers().contains_key(header::CACHE_CONTROL) {
return response;
}
let cache_control = response
.headers()
.get(header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.filter(|content_type| content_type.contains("text/html"))
.map(|_| HeaderValue::from_static("no-cache, must-revalidate"))
.or_else(|| {
is_fingerprinted_asset(&path)
.then(|| HeaderValue::from_static("public, max-age=31536000, immutable"))
})
.or_else(|| {
is_static_asset_path(&path).then(|| HeaderValue::from_static("public, max-age=3600"))
});
if let Some(value) = cache_control {
response.headers_mut().insert(header::CACHE_CONTROL, value);
}
response
}
#[cfg(target_os = "linux")]
fn resident_memory_kib() -> Option<u64> {
let status = std::fs::read_to_string("/proc/self/status").ok()?;
status.lines().find_map(|line| {
line.strip_prefix("VmRSS:")?
.split_whitespace()
.next()?
.parse()
.ok()
})
}
#[cfg(target_os = "linux")]
fn trim_allocator(label: &'static str) {
let before = resident_memory_kib();
let trimmed = unsafe { libc::malloc_trim(0) };
let after = resident_memory_kib();
if let (Some(before), Some(after)) = (before, after) {
info!(
label,
trimmed = trimmed != 0,
rss_before_mib = format_args!("{:.1}", before as f64 / 1024.0),
rss_after_mib = format_args!("{:.1}", after as f64 / 1024.0),
released_mib = format_args!("{:.1}", before.saturating_sub(after) as f64 / 1024.0),
"Allocator trim"
);
}
}
#[cfg(not(target_os = "linux"))]
fn trim_allocator(_label: &'static str) {}
#[derive(Parser)]
#[command(
name = "perfect-postcode",
about = "Perfect Postcode property map server"
)]
struct Cli {
/// Path to properties.parquet (one row per historical property)
#[arg(long)]
properties: PathBuf,
/// Path to postcode.parquet (one row per postcode with area-level data)
#[arg(long)]
postcode_features: PathBuf,
/// Path to the POI parquet file
#[arg(long)]
pois: PathBuf,
/// Path to the places parquet file
#[arg(long)]
places: PathBuf,
/// Path to the postcode boundaries directory
#[arg(long)]
postcodes: PathBuf,
/// Path to the PMTiles file for map tiles
#[arg(long)]
tiles: PathBuf,
/// Path to the frontend dist directory (optional; disables static serving and OG injection when omitted)
#[arg(long)]
dist: Option<PathBuf>,
/// URL of the screenshot service (e.g. http://screenshot:8002)
#[arg(long, env = "SCREENSHOT_URL")]
screenshot_url: String,
/// Public-facing URL for absolute og:image URLs
#[arg(long, env = "PUBLIC_URL")]
public_url: String,
/// PocketBase server URL for authentication (e.g. http://localhost:8090)
#[arg(long, env = "POCKETBASE_URL")]
pocketbase_url: String,
/// PocketBase superuser email (for auto-creating collections at startup)
#[arg(long, env = "POCKETBASE_ADMIN_EMAIL")]
pocketbase_admin_email: String,
/// PocketBase superuser password (for auto-creating collections at startup)
#[arg(long, env = "POCKETBASE_ADMIN_PASSWORD")]
pocketbase_admin_password: String,
/// Gemini API key
#[arg(long, env = "GEMINI_API_KEY")]
gemini_api_key: String,
/// Gemini model name (e.g. gemini-2.0-flash)
#[arg(long, env = "GEMINI_MODEL")]
gemini_model: String,
/// Path to precomputed travel times directory (contains mode subdirs with parquet files)
#[arg(long, env = "TRAVEL_TIMES")]
travel_times: PathBuf,
/// Google Maps API key for Street View metadata lookups
#[arg(long, env = "GOOGLE_MAPS_API_KEY")]
google_maps_api_key: String,
/// Stripe secret key for checkout sessions
#[arg(long, env = "STRIPE_SECRET_KEY")]
stripe_secret_key: String,
/// Stripe webhook signing secret for verifying webhook signatures
#[arg(long, env = "STRIPE_WEBHOOK_SECRET")]
stripe_webhook_secret: String,
/// Stripe Coupon ID applied when a referral code is used
#[arg(long, env = "STRIPE_REFERRAL_COUPON_ID")]
stripe_referral_coupon_id: String,
/// Google OAuth client ID for PocketBase SSO
#[arg(long, env = "GOOGLE_OAUTH_CLIENT_ID")]
google_oauth_client_id: String,
/// Google OAuth client secret for PocketBase SSO
#[arg(long, env = "GOOGLE_OAUTH_CLIENT_SECRET")]
google_oauth_client_secret: String,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let file_appender = tracing_appender::rolling::daily("logs", "server.log");
let (non_blocking, _guard) = tracing_appender::non_blocking(file_appender);
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
tracing_subscriber::registry()
.with(env_filter)
.with(tracing_subscriber::fmt::layer().with_ansi(true))
.with(
tracing_subscriber::fmt::layer()
.with_ansi(false)
.with_writer(non_blocking),
)
.init();
// Initialize Prometheus metrics
let metrics_handle = metrics::init_metrics();
info!("Prometheus metrics initialized");
let cli = Cli::parse();
for (label, path) in [
("Properties", &cli.properties),
("Postcode features", &cli.postcode_features),
] {
if !path.exists() {
bail!("{} parquet file not found: {}", label, path.display());
}
}
info!(
"Loading property data from {}, {}",
cli.properties.display(),
cli.postcode_features.display(),
);
let property_data = data::PropertyData::load(&cli.properties, &cli.postcode_features)?;
trim_allocator("property data load");
info!(
rows = property_data.lat.len(),
features = property_data.num_features,
enums = property_data.enum_values.len(),
"Property data loaded"
);
info!("Building spatial grid index (0.01° cells)");
let grid = utils::GridIndex::build(
&property_data.lat,
&property_data.lon,
consts::GRID_CELL_SIZE,
);
info!(
"Precomputing H3 cells at resolution {}",
consts::H3_PRECOMPUTE_MAX
);
let h3_cells = data::precompute_h3(&property_data.lat, &property_data.lon)?;
let poi_path = cli.pois;
if !poi_path.exists() {
bail!("POI parquet file not found: {}", poi_path.display());
}
info!("Loading POI data from {}", poi_path.display());
let poi_data = data::POIData::load(&poi_path)?;
trim_allocator("poi data load");
info!(pois = poi_data.lat.len(), "POI data loaded");
info!("Building POI spatial grid index");
let poi_grid = utils::GridIndex::build(&poi_data.lat, &poi_data.lng, consts::GRID_CELL_SIZE);
// Load place data
let places_path = &cli.places;
if !places_path.exists() {
bail!("Places parquet file not found: {}", places_path.display());
}
info!("Loading place data from {}", places_path.display());
let place_data = data::PlaceData::load(places_path)?;
trim_allocator("place data load");
info!(places = place_data.name.len(), "Place data loaded");
// Load postcode boundaries
let postcodes_path = &cli.postcodes;
if !postcodes_path.exists() {
bail!(
"Postcode boundaries not found: {}",
postcodes_path.display()
);
}
info!(
"Loading postcode boundaries from {}",
postcodes_path.display()
);
let postcode_data = data::PostcodeData::load(postcodes_path)?;
trim_allocator("postcode boundary load");
info!(
postcodes = postcode_data.postcodes.len(),
"Postcode boundaries loaded"
);
let outcode_data = data::OutcodeData::from_postcode_and_place_data(&postcode_data, &place_data);
// Initialize tile reader
let tiles_path = &cli.tiles;
if !tiles_path.exists() {
bail!("PMTiles file not found: {}", tiles_path.display());
}
info!("Loading PMTiles from {}", tiles_path.display());
let tile_reader = Arc::new(routes::init_tile_reader(tiles_path).await?);
info!("PMTiles loaded successfully");
let feature_name_to_index: rustc_hash::FxHashMap<String, usize> = property_data
.feature_names
.iter()
.enumerate()
.map(|(idx, name)| (name.clone(), idx))
.collect();
let min_keys: Vec<String> = property_data
.feature_names
.iter()
.map(|name| format!("min_{}", name))
.collect();
let max_keys: Vec<String> = property_data
.feature_names
.iter()
.map(|name| format!("max_{}", name))
.collect();
let avg_keys: Vec<String> = property_data
.feature_names
.iter()
.map(|name| format!("avg_{}", name))
.collect();
let poi_category_groups = poi_data.category_groups()?;
let is_dev = if cli.dist.is_some() {
info!("Static frontend serving enabled");
false
} else {
info!("No --dist provided; static serving disabled");
true
};
let http_client = reqwest::Client::builder()
.timeout(Duration::from_secs(SERVICE_CALL_TIMEOUT))
.connect_timeout(Duration::from_secs(5))
.referer(false)
.build()
.context("Failed to build HTTP client")?;
info!("Screenshot service configured: {}", cli.screenshot_url);
let features_response = routes::build_features_response(&property_data);
info!(
groups = features_response.groups.len(),
"Precomputed features response"
);
// AI filters system prompt built after travel_time_store is loaded (needs mode counts)
// Record data loading metrics
metrics::record_data_stats(
property_data.lat.len(),
poi_data.lat.len(),
postcode_data.postcodes.len(),
);
info!("PocketBase configured: {}", cli.pocketbase_url);
pocketbase::ensure_collections(
&http_client,
&cli.pocketbase_url,
&cli.pocketbase_admin_email,
&cli.pocketbase_admin_password,
)
.await?;
pocketbase::ensure_oauth_providers(
&http_client,
&cli.pocketbase_url,
&cli.pocketbase_admin_email,
&cli.pocketbase_admin_password,
&cli.public_url,
&cli.google_oauth_client_id,
&cli.google_oauth_client_secret,
)
.await?;
info!("Gemini configured (model: {})", cli.gemini_model);
let tt_path = &cli.travel_times;
if !tt_path.exists() {
bail!("Travel times directory not found: {}", tt_path.display());
}
info!("Loading travel time data from {}", tt_path.display());
let travel_time_store = {
let store = data::TravelTimeStore::load(tt_path, 200)?;
info!(
modes = store.available_modes.len(),
"Travel time store loaded"
);
Arc::new(store)
};
let mode_destinations: Vec<(String, usize)> = travel_time_store
.available_modes
.iter()
.map(|mode| {
let count = travel_time_store
.destinations
.get(mode.as_str())
.map(|slugs| slugs.len())
.unwrap_or(0);
(mode.clone(), count)
})
.filter(|(_, count)| *count > 0)
.collect();
let ai_filters_system_prompt =
routes::build_system_prompt(&features_response, &mode_destinations);
info!("Precomputed AI filters system prompt");
let token_cache = Arc::new(auth::TokenCache::new());
let superuser_token_cache = Arc::new(pocketbase::SuperuserTokenCache::new());
let share_cache = Arc::new(licensing::ShareBoundsCache::new());
let app_state = AppState {
data: property_data,
grid,
h3_cells,
poi_data: Arc::new(poi_data),
poi_grid: Arc::new(poi_grid),
place_data: Arc::new(place_data),
postcode_data: Arc::new(postcode_data),
outcode_data: Arc::new(outcode_data),
feature_name_to_index,
min_keys,
max_keys,
avg_keys,
poi_category_groups: Arc::new(poi_category_groups),
features_response,
screenshot_url: cli.screenshot_url,
public_url: cli.public_url,
is_dev,
http_client,
pocketbase_url: cli.pocketbase_url,
pocketbase_admin_email: cli.pocketbase_admin_email,
pocketbase_admin_password: cli.pocketbase_admin_password,
gemini_api_key: cli.gemini_api_key,
gemini_model: cli.gemini_model,
travel_time_store,
token_cache,
superuser_token_cache,
share_cache,
ai_filters_system_prompt,
google_maps_api_key: cli.google_maps_api_key,
stripe_secret_key: cli.stripe_secret_key,
stripe_webhook_secret: cli.stripe_webhook_secret,
stripe_referral_coupon_id: cli.stripe_referral_coupon_id,
};
let shared = Arc::new(SharedState::new(app_state));
// Start background PocketBase metrics poller (users, saved searches/properties counts)
pocketbase::start_metrics_poller(shared.clone());
let initial_state = shared.load_state();
let cors = CorsLayer::new()
.allow_origin(
initial_state
.public_url
.parse::<axum::http::HeaderValue>()
.expect("public_url must be a valid header value"),
)
.allow_methods(AllowMethods::mirror_request())
.allow_headers(AllowHeaders::mirror_request())
.allow_credentials(true);
// Handlers use Axum's State extractor to get Arc<SharedState>, then call
// load_state() to get the current Arc<AppState>.
let s_crawler = shared.clone();
let reader_tile = tile_reader.clone();
let reader_style = tile_reader.clone();
let public_url_tiles = initial_state.public_url.clone();
let api = Router::new()
.route("/api/features", get(routes::get_features))
.route(
"/api/hexagons",
get(routes::get_hexagons).layer(ConcurrencyLimitLayer::new(20)),
)
.route(
"/api/postcodes",
get(routes::get_postcodes).layer(ConcurrencyLimitLayer::new(20)),
)
.route("/api/postcode/{postcode}", get(routes::get_postcode_lookup))
.route("/api/nearest-postcode", get(routes::get_nearest_postcode))
.route(
"/api/pois",
get(routes::get_pois).layer(ConcurrencyLimitLayer::new(20)),
)
.route("/api/poi-categories", get(routes::get_poi_categories))
.route("/api/places", get(routes::get_places))
.route("/api/travel-modes", get(routes::get_travel_modes))
.route(
"/api/travel-destinations",
get(routes::get_travel_destinations),
)
.route("/api/journey", get(routes::get_journey))
.route(
"/api/hexagon-properties",
get(routes::get_hexagon_properties),
)
.route("/api/filter-counts", get(routes::get_filter_counts))
.route("/api/hexagon-stats", get(routes::get_hexagon_stats))
.route("/api/postcode-stats", get(routes::get_postcode_stats))
.route(
"/api/postcode-properties",
get(routes::get_postcode_properties),
)
.route(
"/api/screenshot",
get(routes::get_screenshot).layer(ConcurrencyLimitLayer::new(3)),
)
.route(
"/api/export",
get(routes::get_export).layer(ConcurrencyLimitLayer::new(3)),
)
.route("/api/me", get(routes::get_me))
.route("/api/shorten", post(routes::post_shorten))
.route(
"/api/ai-filters",
post(routes::post_ai_filters).layer(ConcurrencyLimitLayer::new(5)),
)
.route("/api/streetview", get(routes::get_streetview))
.route(
"/api/rightmove-search",
get(routes::get_rightmove_redirect).layer(ConcurrencyLimitLayer::new(10)),
)
.route(
"/api/newsletter",
patch(routes::patch_newsletter).layer(ConcurrencyLimitLayer::new(10)),
)
.route("/api/pricing", get(routes::get_pricing))
.route(
"/api/checkout",
post(routes::post_checkout).layer(ConcurrencyLimitLayer::new(10)),
)
.route("/api/stripe-webhook", post(routes::post_stripe_webhook))
.route(
"/api/invites",
get(routes::get_invites).post(routes::post_invites),
)
.route("/api/invite/{code}", get(routes::get_invite))
.route("/api/redeem-invite", post(routes::post_redeem_invite))
.route("/s/{code}", get(routes::get_short_url))
.route(
"/api/telemetry",
post(routes::post_telemetry).layer(ConcurrencyLimitLayer::new(20)),
)
.route(
"/pb/{*rest}",
any(routes::proxy_to_pocketbase).layer(ConcurrencyLimitLayer::new(10)),
)
// Tile routes use a different state type — kept as closures
.route(
"/api/tiles/{z}/{x}/{y}",
get(move |path| routes::get_tile(axum::extract::State(reader_tile.clone()), path)),
)
.route(
"/api/tiles/style.json",
get(move |query| {
let pu = public_url_tiles.clone();
routes::get_style(axum::extract::State(reader_style.clone()), pu, query)
}),
)
.route("/health", get(|| async { "ok" }))
.route(
"/metrics",
get(move || metrics::metrics_handler(metrics_handle.clone())),
)
.with_state(shared.clone());
let app = if let Some(ref dist) = cli.dist {
api.fallback_service(ServeDir::new(dist).fallback(ServeFile::new(dist.join("index.html"))))
} else {
api
}
.layer(middleware::from_fn(metrics::track_metrics))
.layer(middleware::from_fn(auth::auth_middleware))
.layer(middleware::from_fn(
move |req: axum::extract::Request, next: middleware::Next| {
let st = s_crawler.load_state();
async move {
// Inject state into request extensions for auth + OG middleware
let (mut parts, body) = req.into_parts();
parts.extensions.insert(st);
let req = axum::extract::Request::from_parts(parts, body);
og_middleware::og_middleware(req, next).await
}
},
))
.layer(middleware::from_fn(static_cache_headers))
.layer(cors)
.layer(CompressionLayer::new().zstd(true).gzip(true))
.layer(TraceLayer::new_for_http());
// Lock all current and future memory pages to prevent swapping
unsafe {
if libc::mlockall(libc::MCL_CURRENT | libc::MCL_FUTURE) != 0 {
let err = std::io::Error::last_os_error();
tracing::warn!(
"mlockall failed (need CAP_IPC_LOCK or sufficient RLIMIT_MEMLOCK): {err}"
);
} else {
info!("All memory pages locked (mlockall)");
}
}
let addr = consts::SERVER_ADDRESS;
let listener = tokio::net::TcpListener::bind(addr)
.await
.with_context(|| format!("Failed to bind to {addr}"))?;
info!("Server listening on {}", addr);
axum::serve(listener, app).await.context("Server error")?;
Ok(())
}