perfect-postcode/server-rs/src/main.rs

617 lines
20 KiB
Rust

mod aggregation;
mod auth;
mod consts;
mod data;
mod features;
mod licensing;
mod metrics;
mod og_middleware;
pub mod parsing;
mod pocketbase;
mod routes;
mod state;
pub mod utils;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{bail, Context};
use consts::SERVICE_CALL_TIMEOUT;
use axum::middleware;
use axum::routing::{any, get, patch, post};
use axum::Router;
use clap::Parser;
use tower::limit::ConcurrencyLimitLayer;
use tower_http::compression::CompressionLayer;
use tower_http::cors::{AllowHeaders, AllowMethods, CorsLayer};
use tower_http::services::{ServeDir, ServeFile};
use tower_http::trace::TraceLayer;
use tracing::info;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::EnvFilter;
use state::AppState;
#[derive(Parser)]
#[command(name = "perfect-postcode", about = "Perfect Postcode property map server")]
struct Cli {
/// Path to properties.parquet (one row per historical property)
#[arg(long)]
properties: PathBuf,
/// Path to postcode.parquet (one row per postcode with area-level data)
#[arg(long)]
postcode_features: PathBuf,
/// Path to online_listings_buy.parquet
#[arg(long)]
listings_buy: PathBuf,
/// Path to online_listings_rent.parquet
#[arg(long)]
listings_rent: PathBuf,
/// Path to the POI parquet file
#[arg(long)]
pois: PathBuf,
/// Path to the places parquet file
#[arg(long)]
places: PathBuf,
/// Path to the postcode boundaries directory
#[arg(long)]
postcodes: PathBuf,
/// Path to the PMTiles file for map tiles
#[arg(long)]
tiles: PathBuf,
/// Path to the frontend dist directory (optional; disables static serving and OG injection when omitted)
#[arg(long)]
dist: Option<PathBuf>,
/// URL of the screenshot service (e.g. http://screenshot:8002)
#[arg(long, env = "SCREENSHOT_URL")]
screenshot_url: String,
/// Public-facing URL for absolute og:image URLs
#[arg(long, env = "PUBLIC_URL")]
public_url: String,
/// PocketBase server URL for authentication (e.g. http://localhost:8090)
#[arg(long, env = "POCKETBASE_URL")]
pocketbase_url: String,
/// PocketBase superuser email (for auto-creating collections at startup)
#[arg(long, env = "POCKETBASE_ADMIN_EMAIL")]
pocketbase_admin_email: String,
/// PocketBase superuser password (for auto-creating collections at startup)
#[arg(long, env = "POCKETBASE_ADMIN_PASSWORD")]
pocketbase_admin_password: String,
/// Ollama server URL (e.g. http://ollama:11434)
#[arg(long, env = "OLLAMA_URL")]
ollama_url: String,
/// Ollama model name
#[arg(long, env = "OLLAMA_MODEL")]
ollama_model: String,
/// Path to precomputed travel times directory (contains mode subdirs with parquet files)
#[arg(long, env = "TRAVEL_TIMES")]
travel_times: PathBuf,
/// Google Maps API key for Street View metadata lookups
#[arg(long, env = "GOOGLE_MAPS_API_KEY")]
google_maps_api_key: String,
/// Stripe secret key for checkout sessions
#[arg(long, env = "STRIPE_SECRET_KEY")]
stripe_secret_key: String,
/// Stripe webhook signing secret for verifying webhook signatures
#[arg(long, env = "STRIPE_WEBHOOK_SECRET")]
stripe_webhook_secret: String,
/// Stripe Coupon ID applied when a referral code is used
#[arg(long, env = "STRIPE_REFERRAL_COUPON_ID")]
stripe_referral_coupon_id: String,
/// Google OAuth client ID for PocketBase SSO
#[arg(long, env = "GOOGLE_OAUTH_CLIENT_ID")]
google_oauth_client_id: String,
/// Google OAuth client secret for PocketBase SSO
#[arg(long, env = "GOOGLE_OAUTH_CLIENT_SECRET")]
google_oauth_client_secret: String,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let file_appender = tracing_appender::rolling::daily("logs", "server.log");
let (non_blocking, _guard) = tracing_appender::non_blocking(file_appender);
let env_filter =
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
tracing_subscriber::registry()
.with(env_filter)
.with(tracing_subscriber::fmt::layer().with_ansi(true))
.with(
tracing_subscriber::fmt::layer()
.with_ansi(false)
.with_writer(non_blocking),
)
.init();
// Initialize Prometheus metrics
let metrics_handle = metrics::init_metrics();
info!("Prometheus metrics initialized");
let cli = Cli::parse();
for (label, path) in [
("Properties", &cli.properties),
("Postcode features", &cli.postcode_features),
("Listings buy", &cli.listings_buy),
("Listings rent", &cli.listings_rent),
] {
if !path.exists() {
bail!("{} parquet file not found: {}", label, path.display());
}
}
info!(
"Loading property data from {}, {}, {}, {}",
cli.properties.display(),
cli.postcode_features.display(),
cli.listings_buy.display(),
cli.listings_rent.display(),
);
let property_data = data::PropertyData::load(
&cli.properties,
&cli.postcode_features,
&cli.listings_buy,
&cli.listings_rent,
)?;
info!(
rows = property_data.lat.len(),
features = property_data.num_features,
enums = property_data.enum_values.len(),
"Property data loaded"
);
info!("Building spatial grid index (0.01° cells)");
let grid = utils::GridIndex::build(
&property_data.lat,
&property_data.lon,
consts::GRID_CELL_SIZE,
);
info!(
"Precomputing H3 cells at resolution {}",
consts::H3_PRECOMPUTE_MAX
);
let h3_cells = data::precompute_h3(&property_data.lat, &property_data.lon)?;
let poi_path = cli.pois;
if !poi_path.exists() {
bail!("POI parquet file not found: {}", poi_path.display());
}
info!("Loading POI data from {}", poi_path.display());
let poi_data = data::POIData::load(&poi_path)?;
info!(pois = poi_data.lat.len(), "POI data loaded");
info!("Building POI spatial grid index");
let poi_grid = utils::GridIndex::build(&poi_data.lat, &poi_data.lng, consts::GRID_CELL_SIZE);
// Load place data
let places_path = &cli.places;
if !places_path.exists() {
bail!("Places parquet file not found: {}", places_path.display());
}
info!("Loading place data from {}", places_path.display());
let place_data = data::PlaceData::load(places_path)?;
info!(places = place_data.name.len(), "Place data loaded");
// Load postcode boundaries
let postcodes_path = &cli.postcodes;
if !postcodes_path.exists() {
bail!(
"Postcode boundaries not found: {}",
postcodes_path.display()
);
}
info!(
"Loading postcode boundaries from {}",
postcodes_path.display()
);
let postcode_data = data::PostcodeData::load(postcodes_path)?;
info!(
postcodes = postcode_data.postcodes.len(),
"Postcode boundaries loaded"
);
// Initialize tile reader
let tiles_path = &cli.tiles;
if !tiles_path.exists() {
bail!("PMTiles file not found: {}", tiles_path.display());
}
info!("Loading PMTiles from {}", tiles_path.display());
let tile_reader = Arc::new(routes::init_tile_reader(tiles_path).await?);
info!("PMTiles loaded successfully");
let feature_name_to_index: rustc_hash::FxHashMap<String, usize> = property_data
.feature_names
.iter()
.enumerate()
.map(|(idx, name)| (name.clone(), idx))
.collect();
let min_keys: Vec<String> = property_data
.feature_names
.iter()
.map(|name| format!("min_{}", name))
.collect();
let max_keys: Vec<String> = property_data
.feature_names
.iter()
.map(|name| format!("max_{}", name))
.collect();
let avg_keys: Vec<String> = property_data
.feature_names
.iter()
.map(|name| format!("avg_{}", name))
.collect();
let poi_category_groups = poi_data.category_groups()?;
// Read index.html at startup for crawler OG injection (only when --dist is provided)
let index_html = if let Some(ref dist) = cli.dist {
let index_path = dist.join("index.html");
let html = std::fs::read_to_string(&index_path)
.with_context(|| format!("Failed to read {}", index_path.display()))?;
info!("Loaded index.html for OG injection");
Some(html)
} else {
info!("No --dist provided; static serving and OG injection disabled");
None
};
let http_client = reqwest::Client::builder()
.timeout(Duration::from_secs(SERVICE_CALL_TIMEOUT))
.connect_timeout(Duration::from_secs(5))
.build()
.context("Failed to build HTTP client")?;
info!("Screenshot service configured: {}", cli.screenshot_url);
let features_response = routes::build_features_response(&property_data);
info!(
groups = features_response.groups.len(),
"Precomputed features response"
);
let ai_filters_schema = routes::build_ollama_schema(&features_response);
let ai_filters_system_prompt = routes::build_system_prompt(&features_response);
info!("Precomputed AI filters schema and system prompt");
// Record data loading metrics
metrics::record_data_stats(
property_data.lat.len(),
poi_data.lat.len(),
postcode_data.postcodes.len(),
);
info!("PocketBase configured: {}", cli.pocketbase_url);
pocketbase::ensure_collections(
&http_client,
&cli.pocketbase_url,
&cli.pocketbase_admin_email,
&cli.pocketbase_admin_password,
)
.await?;
pocketbase::ensure_oauth_providers(
&http_client,
&cli.pocketbase_url,
&cli.pocketbase_admin_email,
&cli.pocketbase_admin_password,
&cli.public_url,
&cli.google_oauth_client_id,
&cli.google_oauth_client_secret,
)
.await?;
info!(
"Ollama configured: {} (model: {})",
cli.ollama_url, cli.ollama_model
);
let tt_path = &cli.travel_times;
if !tt_path.exists() {
bail!(
"Travel times directory not found: {}",
tt_path.display()
);
}
info!("Loading travel time data from {}", tt_path.display());
let travel_time_store = {
let store = data::TravelTimeStore::load(tt_path, 50)?;
info!(
modes = store.available_modes.len(),
"Travel time store loaded"
);
Arc::new(store)
};
let token_cache = Arc::new(auth::TokenCache::new());
let state = Arc::new(AppState {
data: property_data,
grid,
h3_cells,
poi_data,
poi_grid,
place_data,
postcode_data,
feature_name_to_index,
min_keys,
max_keys,
avg_keys,
poi_category_groups,
features_response,
screenshot_url: cli.screenshot_url,
public_url: cli.public_url,
index_html,
http_client,
pocketbase_url: cli.pocketbase_url,
pocketbase_admin_email: cli.pocketbase_admin_email,
pocketbase_admin_password: cli.pocketbase_admin_password,
ollama_url: cli.ollama_url,
ollama_model: cli.ollama_model,
travel_time_store,
token_cache,
ai_filters_schema,
ai_filters_system_prompt,
google_maps_api_key: cli.google_maps_api_key,
stripe_secret_key: cli.stripe_secret_key,
stripe_webhook_secret: cli.stripe_webhook_secret,
stripe_referral_coupon_id: cli.stripe_referral_coupon_id,
});
let cors = CorsLayer::new()
.allow_origin(
state
.public_url
.parse::<axum::http::HeaderValue>()
.expect("public_url must be a valid header value"),
)
.allow_methods(AllowMethods::mirror_request())
.allow_headers(AllowHeaders::mirror_request())
.allow_credentials(true);
let state_features = state.clone();
let state_hexagons = state.clone();
let state_postcodes = state.clone();
let state_postcode_lookup = state.clone();
let state_pois = state.clone();
let state_poi_categories = state.clone();
let state_hexagon_properties = state.clone();
let state_hexagon_stats = state.clone();
let state_screenshot = state.clone();
let state_export = state.clone();
let state_crawler = state.clone();
let state_pb = state.clone();
let state_postcode_stats = state.clone();
let state_places = state.clone();
let state_shorten = state.clone();
let state_short_url = state.clone();
let state_ai_filters = state.clone();
let state_streetview = state.clone();
let state_subscription = state.clone();
let state_newsletter = state.clone();
let state_travel_modes = state.clone();
let state_checkout = state.clone();
let state_stripe_webhook = state.clone();
let state_pricing = state.clone();
let state_invites_create = state.clone();
let state_invite_get = state.clone();
let state_redeem_invite = state.clone();
let state_rightmove = state.clone();
let api = Router::new()
.route(
"/api/features",
get(move || routes::get_features(state_features.clone())),
)
.route(
"/api/hexagons",
get(move |ext, query| routes::get_hexagons(state_hexagons.clone(), ext, query)),
)
.route(
"/api/postcodes",
get(move |ext, query| routes::get_postcodes(state_postcodes.clone(), ext, query)),
)
.route(
"/api/postcode/{postcode}",
get(move |path| routes::get_postcode_lookup(state_postcode_lookup.clone(), path)),
)
.route(
"/api/pois",
get(move |query| routes::get_pois(state_pois.clone(), query)),
)
.route(
"/api/poi-categories",
get(move || routes::get_poi_categories(state_poi_categories.clone())),
)
.route(
"/api/places",
get(move |query| routes::get_places(state_places.clone(), query)),
)
.route(
"/api/travel-modes",
get(move || routes::get_travel_modes(state_travel_modes.clone())),
)
.route(
"/api/hexagon-properties",
get(move |ext, query| {
routes::get_hexagon_properties(state_hexagon_properties.clone(), ext, query)
}),
)
.route(
"/api/hexagon-stats",
get(move |ext, query| routes::get_hexagon_stats(state_hexagon_stats.clone(), ext, query)),
)
.route(
"/api/postcode-stats",
get(move |ext, query| routes::get_postcode_stats(state_postcode_stats.clone(), ext, query)),
)
.route(
"/api/screenshot",
get(move |headers, query| routes::get_screenshot(state_screenshot.clone(), headers, query)),
)
.route(
"/api/export",
get(move |headers, ext, query| routes::get_export(state_export.clone(), headers, ext, query))
.layer(ConcurrencyLimitLayer::new(3)),
)
.route("/api/me", get(routes::get_me))
.route(
"/api/shorten",
post(move |body| routes::post_shorten(state_shorten.clone(), body)),
)
.route(
"/api/ai-filters",
post(move |body| routes::post_ai_filters(state_ai_filters.clone(), body))
.layer(ConcurrencyLimitLayer::new(5)),
)
.route(
"/api/streetview",
get(move |query| routes::get_streetview(state_streetview.clone(), query)),
)
.route(
"/api/rightmove-location",
get(move |query| routes::get_rightmove_typeahead(state_rightmove.clone(), query)),
)
.route(
"/api/subscription",
patch(move |ext, body| {
routes::patch_subscription(state_subscription.clone(), ext, body)
}),
)
.route(
"/api/newsletter",
patch(move |ext, body| {
routes::patch_newsletter(state_newsletter.clone(), ext, body)
}),
)
.route(
"/api/pricing",
get(move || routes::get_pricing(state_pricing.clone())),
)
.route(
"/api/checkout",
post(move |ext, body| routes::post_checkout(state_checkout.clone(), ext, body))
.layer(ConcurrencyLimitLayer::new(10)),
)
.route(
"/api/stripe-webhook",
post(move |headers, body| {
routes::post_stripe_webhook(state_stripe_webhook.clone(), headers, body)
}),
)
.route(
"/api/invites",
post(move |ext, body| routes::post_invites(state_invites_create.clone(), ext, body)),
)
.route(
"/api/invite/{code}",
get(move |ext, path| routes::get_invite(state_invite_get.clone(), ext, path)),
)
.route(
"/api/redeem-invite",
post(move |ext, body| {
routes::post_redeem_invite(state_redeem_invite.clone(), ext, body)
}),
)
.route(
"/s/{code}",
get(move |path| routes::get_short_url(state_short_url.clone(), path)),
);
// Add tile routes
let reader_tile = tile_reader.clone();
let reader_style = tile_reader.clone();
let public_url_tiles = state.public_url.clone();
let api = api
.route(
"/api/tiles/{z}/{x}/{y}",
get(move |path| routes::get_tile(axum::extract::State(reader_tile.clone()), path)),
)
.route(
"/api/tiles/style.json",
get(move |query| {
let pu = public_url_tiles.clone();
routes::get_style(axum::extract::State(reader_style.clone()), pu, query)
}),
)
.route("/health", get(|| async { "ok" }))
.route(
"/metrics",
get(move || metrics::metrics_handler(metrics_handle.clone())),
)
.route(
"/pb/{*rest}",
any(move |req| routes::proxy_to_pocketbase(state_pb.clone(), req)),
);
let app = if let Some(ref dist) = cli.dist {
api.fallback_service(
ServeDir::new(dist).fallback(ServeFile::new(dist.join("index.html"))),
)
} else {
api
}
.layer(middleware::from_fn(metrics::track_metrics))
.layer(middleware::from_fn(auth::auth_middleware))
.layer(middleware::from_fn(
move |req: axum::extract::Request, next: middleware::Next| {
let st = state_crawler.clone();
async move {
// Inject state into request extensions for auth + OG middleware
let (mut parts, body) = req.into_parts();
parts.extensions.insert(st);
let req = axum::extract::Request::from_parts(parts, body);
og_middleware::og_middleware(req, next).await
}
},
))
.layer(cors)
.layer(CompressionLayer::new().zstd(true).gzip(true))
.layer(TraceLayer::new_for_http());
// Lock all current and future memory pages to prevent swapping
unsafe {
if libc::mlockall(libc::MCL_CURRENT | libc::MCL_FUTURE) != 0 {
let err = std::io::Error::last_os_error();
tracing::warn!("mlockall failed (need CAP_IPC_LOCK or sufficient RLIMIT_MEMLOCK): {err}");
} else {
info!("All memory pages locked (mlockall)");
}
}
let addr = consts::SERVER_ADDRESS;
let listener = tokio::net::TcpListener::bind(addr)
.await
.with_context(|| format!("Failed to bind to {addr}"))?;
info!("Server listening on {}", addr);
axum::serve(listener, app).await.context("Server error")?;
Ok(())
}