Improve LLM

This commit is contained in:
Andras Schmelczer 2026-03-15 14:05:34 +00:00
parent 02712f41e8
commit 80c093b7ba
16 changed files with 898 additions and 278 deletions

View file

@ -15,7 +15,9 @@ pub const MAX_PRICE_HISTORY_POINTS: usize = 5000;
pub const POSTCODE_SEARCH_OFFSET: f64 = 0.02;
pub const AI_FILTERS_MAX_TOKENS: usize = 2000;
pub const AI_FILTERS_TEMPERATURE: f32 = 0.0;
/// Gemini 3 recommends 1.0; lower values can cause looping or degraded performance.
pub const AI_FILTERS_TEMPERATURE: f32 = 1.0;
pub const AI_FILTERS_WEEKLY_TOKEN_LIMIT: u64 = 10_000_000;
/// Timeout for outbound HTTP service calls (seconds).
pub const SERVICE_CALL_TIMEOUT: u64 = 120;

View file

@ -94,13 +94,13 @@ struct Cli {
#[arg(long, env = "POCKETBASE_ADMIN_PASSWORD")]
pocketbase_admin_password: String,
/// Ollama server URL (e.g. http://ollama:11434)
#[arg(long, env = "OLLAMA_URL")]
ollama_url: String,
/// Gemini API key
#[arg(long, env = "GEMINI_API_KEY")]
gemini_api_key: String,
/// Ollama model name
#[arg(long, env = "OLLAMA_MODEL")]
ollama_model: String,
/// Gemini model name (e.g. gemini-2.0-flash)
#[arg(long, env = "GEMINI_MODEL")]
gemini_model: String,
/// Path to precomputed travel times directory (contains mode subdirs with parquet files)
#[arg(long, env = "TRAVEL_TIMES")]
@ -301,9 +301,7 @@ async fn main() -> anyhow::Result<()> {
"Precomputed features response"
);
let ai_filters_schema = routes::build_ollama_schema(&features_response);
let ai_filters_system_prompt = routes::build_system_prompt(&features_response);
info!("Precomputed AI filters schema and system prompt");
// AI filters system prompt built after travel_time_store is loaded (needs mode counts)
// Record data loading metrics
metrics::record_data_stats(
@ -331,10 +329,7 @@ async fn main() -> anyhow::Result<()> {
&cli.google_oauth_client_secret,
)
.await?;
info!(
"Ollama configured: {} (model: {})",
cli.ollama_url, cli.ollama_model
);
info!("Gemini configured (model: {})", cli.gemini_model);
let tt_path = &cli.travel_times;
if !tt_path.exists() {
bail!(
@ -352,6 +347,23 @@ async fn main() -> anyhow::Result<()> {
Arc::new(store)
};
let mode_destinations: Vec<(String, usize)> = travel_time_store
.available_modes
.iter()
.map(|mode| {
let count = travel_time_store
.destinations
.get(mode.as_str())
.map(|slugs| slugs.len())
.unwrap_or(0);
(mode.clone(), count)
})
.filter(|(_, count)| *count > 0)
.collect();
let ai_filters_system_prompt =
routes::build_system_prompt(&features_response, &mode_destinations);
info!("Precomputed AI filters system prompt");
let token_cache = Arc::new(auth::TokenCache::new());
let state = Arc::new(AppState {
@ -370,16 +382,16 @@ async fn main() -> anyhow::Result<()> {
features_response,
screenshot_url: cli.screenshot_url,
public_url: cli.public_url,
is_dev: index_html.is_none(),
index_html,
http_client,
pocketbase_url: cli.pocketbase_url,
pocketbase_admin_email: cli.pocketbase_admin_email,
pocketbase_admin_password: cli.pocketbase_admin_password,
ollama_url: cli.ollama_url,
ollama_model: cli.ollama_model,
gemini_api_key: cli.gemini_api_key,
gemini_model: cli.gemini_model,
travel_time_store,
token_cache,
ai_filters_schema,
ai_filters_system_prompt,
google_maps_api_key: cli.google_maps_api_key,
stripe_secret_key: cli.stripe_secret_key,
@ -504,7 +516,7 @@ async fn main() -> anyhow::Result<()> {
)
.route(
"/api/ai-filters",
post(move |body| routes::post_ai_filters(state_ai_filters.clone(), body))
post(move |ext, body| routes::post_ai_filters(state_ai_filters.clone(), ext, body))
.layer(ConcurrencyLimitLayer::new(5)),
)
.route(

View file

@ -240,9 +240,11 @@ async fn ensure_user_fields(
let has_is_admin = fields.iter().any(|f| f["name"] == "is_admin");
let has_subscription = fields.iter().any(|f| f["name"] == "subscription");
let has_newsletter = fields.iter().any(|f| f["name"] == "newsletter");
let has_ai_tokens_used = fields.iter().any(|f| f["name"] == "ai_tokens_used");
let has_ai_tokens_week = fields.iter().any(|f| f["name"] == "ai_tokens_week");
if has_is_admin && has_subscription && has_newsletter {
info!("PocketBase users collection already has is_admin, subscription, and newsletter fields");
if has_is_admin && has_subscription && has_newsletter && has_ai_tokens_used && has_ai_tokens_week {
info!("PocketBase users collection already has all required fields");
return Ok(());
}
@ -269,6 +271,20 @@ async fn ensure_user_fields(
}));
}
if !has_ai_tokens_used {
new_fields.push(serde_json::json!({
"name": "ai_tokens_used",
"type": "number",
}));
}
if !has_ai_tokens_week {
new_fields.push(serde_json::json!({
"name": "ai_tokens_week",
"type": "number",
}));
}
let patch_resp = client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))

View file

@ -26,7 +26,7 @@ pub(crate) mod travel_time;
mod travel_destinations;
mod travel_modes;
pub use ai_filters::{build_ollama_schema, build_system_prompt, post_ai_filters};
pub use ai_filters::{build_system_prompt, post_ai_filters};
pub use checkout::post_checkout;
pub use export::get_export;
pub use features::{build_features_response, get_features, FeatureInfo, FeaturesResponse};

View file

@ -2,76 +2,190 @@ use std::sync::Arc;
use axum::http::StatusCode;
use axum::response::Json;
use axum::Extension;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use tracing::{info, warn};
use crate::consts::{AI_FILTERS_MAX_TOKENS, AI_FILTERS_TEMPERATURE};
use crate::auth::OptionalUser;
use crate::consts::{AI_FILTERS_MAX_TOKENS, AI_FILTERS_TEMPERATURE, AI_FILTERS_WEEKLY_TOKEN_LIMIT};
use crate::data::slugify;
use crate::pocketbase::auth_superuser;
use crate::routes::{FeatureInfo, FeaturesResponse};
use crate::state::AppState;
use crate::utils::{extract_ollama_content, ollama_chat, strip_think_blocks};
use crate::utils::gemini_chat;
#[derive(Deserialize)]
pub struct AiFiltersContext {
filters: Value,
#[serde(default)]
travel_time: Vec<AiTravelTimeContext>,
}
#[derive(Deserialize)]
pub struct AiTravelTimeContext {
mode: String,
label: String,
min: Option<f32>,
max: Option<f32>,
}
#[derive(Deserialize)]
pub struct AiFiltersRequest {
query: String,
/// Current filters for conversational refinement (e.g. "make it cheaper")
context: Option<AiFiltersContext>,
}
#[derive(Serialize)]
pub struct TravelTimeFilter {
mode: String,
slug: String,
label: String,
#[serde(skip_serializing_if = "Option::is_none")]
min: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
max: Option<f32>,
}
#[derive(Serialize)]
pub struct AiFiltersResponse {
filters: Value,
#[serde(skip_serializing_if = "Vec::is_empty")]
travel_time_filters: Vec<TravelTimeFilter>,
/// What the LLM couldn't map to existing filters (empty if everything matched)
#[serde(skip_serializing_if = "String::is_empty")]
notes: String,
}
/// Build a JSON schema for Ollama structured output.
///
/// Uses two arrays (`numeric_filters` and `enum_filters`) instead of one property
/// per feature, because Ollama converts JSON schema to GBNF grammar and a schema
/// with 50+ optional keys causes a combinatorial explosion that crashes the parser.
/// Array-based schema keeps the grammar small and constant-size.
pub fn build_ollama_schema(_features: &FeaturesResponse) -> Value {
json!({
"type": "object",
"properties": {
"numeric_filters": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": { "type": "string" },
"bound": { "type": "string", "enum": ["min", "max"] },
"value": { "type": "number" }
},
"required": ["name", "bound", "value"]
}
},
"enum_filters": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": { "type": "string" },
"values": { "type": "array", "items": { "type": "string" } }
},
"required": ["name", "values"]
}
},
"notes": {
"type": "string"
}
/// Strip markdown code fences (```json ... ``` or ``` ... ```) from LLM output.
/// Models occasionally wrap JSON in markdown fencing even when told not to.
fn strip_markdown_fences(text: &str) -> &str {
let trimmed = text.trim();
// Try ```json\n...\n``` or ```\n...\n```
if let Some(rest) = trimmed.strip_prefix("```") {
// Skip optional language tag (e.g. "json")
let rest = if let Some(newline_pos) = rest.find('\n') {
&rest[newline_pos + 1..]
} else {
return trimmed;
};
if let Some(content) = rest.strip_suffix("```") {
return content.trim();
}
})
}
trimmed
}
/// Build the Gemini tool declaration for destination search.
fn build_tool_declarations(state: &AppState) -> Value {
let modes: Vec<&str> = state
.travel_time_store
.available_modes
.iter()
.map(|mode| mode.as_str())
.collect();
json!([{
"functionDeclarations": [{
"name": "search_destinations",
"description": "Search for available travel time destinations (cities, stations, towns) that have precomputed travel time data. Call this when the user mentions wanting to be near, close to, or within a certain travel time of a specific place.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Place name to search for (e.g. 'Manchester', 'Kings Cross', 'Heathrow')"
},
"mode": {
"type": "string",
"enum": modes,
"description": "Transport mode to search destinations for"
}
},
"required": ["query", "mode"]
}
}]
}])
}
/// Execute a destination search against PlaceData + TravelTimeStore.
/// Returns matching destinations as a JSON value with `results` and optional `message`.
///
/// Uses word-based matching: all words in the query must appear somewhere in the
/// place name (order-independent). Also matches against slugs for short queries.
fn execute_destination_search(state: &AppState, query: &str, mode: &str) -> Value {
let query_lower = query.to_lowercase();
let query_words: Vec<&str> = query_lower.split_whitespace().collect();
let query_slug = slugify(query);
let tt_store = &state.travel_time_store;
let pd = &state.place_data;
let slug_set = match tt_store.destinations.get(mode) {
Some(slugs) => slugs,
None => return json!({ "results": [], "message": format!("No travel data available for mode '{}'", mode) }),
};
// Find places matching the query that have travel time data.
// A place matches if ALL query words appear in its name, OR its slug matches the query slug.
let mut matches: Vec<(usize, String, u8, u32)> = pd
.name_lower
.iter()
.enumerate()
.filter_map(|(idx, name_lower)| {
let words_match = query_words.iter().all(|word| name_lower.contains(word));
let slug = slugify(&pd.name[idx]);
let slug_match = slug.contains(&query_slug) || query_slug.contains(&slug);
if !words_match && !slug_match {
return None;
}
if slug_set.contains(&slug) {
Some((idx, slug, pd.type_rank[idx], pd.population[idx]))
} else {
None
}
})
.collect();
// Sort: type rank asc, population desc
matches.sort_unstable_by(|a, b| a.2.cmp(&b.2).then(b.3.cmp(&a.3)));
matches.truncate(10);
if matches.is_empty() {
info!(query = query, mode = mode, "Destination search returned no results");
return json!({
"results": [],
"message": format!("No travel time data available for '{}' by {}. This destination cannot be used as a travel time filter.", query, mode)
});
}
let results: Vec<Value> = matches
.into_iter()
.map(|(idx, slug, ..)| {
json!({
"name": pd.name[idx],
"slug": slug,
"place_type": pd.place_type.get(idx).to_string(),
})
})
.collect();
json!({ "results": results })
}
/// Build the complete system prompt for AI filters.
///
/// Contains: role instructions, feature catalogue, few-shot examples, output rules.
/// Contains: role instructions, feature catalogue, travel time info,
/// few-shot examples, output rules.
/// Precomputed at startup and cached in AppState.
pub fn build_system_prompt(features: &FeaturesResponse) -> String {
pub fn build_system_prompt(
features: &FeaturesResponse,
mode_destinations: &[(String, usize)],
) -> String {
let mut parts = Vec::new();
// Role and task description
parts.push(
"You are a UK property search assistant. \
The user describes their ideal property or area in natural language. \
@ -91,10 +205,61 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String {
(note: this counts bedrooms + living rooms combined, so 3 bed ~ min 4).\n\
- If the user mentions something that has no matching filter, put it in \"notes\" \
as a short phrase (e.g. \"No filter for: garden, sea view\"). \
If everything was matched, set \"notes\" to an empty string."
If everything was matched, set \"notes\" to an empty string.\n\
\n\
CONVERSATIONAL REFINEMENT:\n\
The user's message may include their currently active filters as context. \
When context is provided:\n\
- \"make it cheaper\" / \"lower the price\" = adjust the existing price filter down\n\
- \"also add ...\" / \"and good schools\" = keep existing filters and add new ones\n\
- \"remove the ...\" / \"drop the ...\" = return filters WITHOUT the mentioned one\n\
- If the request is a completely new search (not a refinement), ignore the context \
and build filters from scratch.\n\
- Always output the COMPLETE set of filters (existing + modified), not just the changes."
.to_string(),
);
// Travel time section with available modes
let modes_list = mode_destinations
.iter()
.map(|(mode, count)| format!("- {} ({} destinations available)", mode, count))
.collect::<Vec<_>>()
.join("\n");
parts.push(format!(
"\n--- TRAVEL TIME FILTERS ---\n\
You can add travel time filters when the user mentions commute times, \
proximity to places, or wanting to be near/within X minutes of somewhere.\n\
\n\
Available transport modes (only use modes that have destinations):\n\
{}\n\
- \"car\" / \"drive\" / \"driving\" = car mode\n\
- \"cycle\" / \"bike\" / \"cycling\" = bicycle mode\n\
- \"walk\" / \"walking\" / \"on foot\" = walking mode\n\
- \"train\" / \"tube\" / \"bus\" / \"public transport\" / \"commute\" = transit mode\n\
\n\
When the user mentions a specific place, you MUST call the search_destinations \
tool to find the exact slug. Use the name and slug from the search results.\n\
If search_destinations returns an empty array, the destination is not available \
mention it in \"notes\" (e.g. \"No travel data for: Gatwick Airport\") and do NOT \
include a travel_time_filter for it.\n\
\n\
Travel time values are in MINUTES (0-120 range).\n\
- \"within 30 minutes\" = max 30\n\
- \"at least 10 minutes\" = min 10\n\
- \"30-45 minute commute\" = min 30, max 45\n\
- If only a max is given, omit min (and vice versa).\n\
\n\
INFERRING TRANSPORT MODE (when the user does not specify one explicitly):\n\
- \"commute\" to a major city centre or station = transit\n\
- \"near\" / \"close to\" a city centre or station = transit\n\
- \"near\" / \"close to\" a smaller town, village, or rural area = car\n\
- \"drive\" / \"driving distance\" / \"driving time\" = always car\n\
- If multiple modes are plausible, prefer transit for urban destinations \
(London, Manchester, Birmingham, Leeds, etc.) and car for everything else.",
modes_list,
));
// Feature catalogue
parts.push("\n--- AVAILABLE FEATURES ---\n".to_string());
for group in &features.groups {
@ -148,6 +313,7 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String {
Output: {\"numeric_filters\": [{\"name\": \"Last known price\", \"bound\": \"max\", \"value\": 400000}], \
\"enum_filters\": [{\"name\": \"Leasehold/Freehold\", \"values\": [\"Freehold\"]}, \
{\"name\": \"Property type\", \"values\": [\"Detached\", \"Semi-Detached\", \"Terraced\"]}], \
\"travel_time_filters\": [], \
\"notes\": \"\"}"
.to_string(),
);
@ -161,7 +327,7 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String {
{\"name\": \"Good+ primary schools within 5km\", \"bound\": \"min\", \"value\": 5}, \
{\"name\": \"Good+ secondary schools within 5km\", \"bound\": \"min\", \"value\": 2}, \
{\"name\": \"Number of parks within 2km\", \"bound\": \"min\", \"value\": 3}], \
\"enum_filters\": [], \"notes\": \"\"}"
\"enum_filters\": [], \"travel_time_filters\": [], \"notes\": \"\"}"
.to_string(),
);
@ -172,18 +338,37 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String {
{\"name\": \"Number of bedrooms & living rooms\", \"bound\": \"min\", \"value\": 4}], \
\"enum_filters\": [{\"name\": \"Property type\", \"values\": [\"Flats/Maisonettes\"]}, \
{\"name\": \"Max available download speed (Mbps)\", \"values\": [\"100\", \"300\", \"1000\"]}], \
\"travel_time_filters\": [], \
\"notes\": \"No filter for: beach proximity\"}"
.to_string(),
);
parts.push(
"\nUser: \"large family home with a garden near restaurants\"\n\
"\nUser: \"within 30 minutes commute of Kings Cross, under 500k\"\n\
(After calling search_destinations for \"Kings Cross\" with mode \"transit\" \
and getting [{\"name\": \"Kings Cross\", \"slug\": \"kings-cross\", \"place_type\": \"station\"}])\n\
Output: {\"numeric_filters\": [\
{\"name\": \"Last known price\", \"bound\": \"max\", \"value\": 500000}], \
\"enum_filters\": [], \
\"travel_time_filters\": [{\"mode\": \"transit\", \"slug\": \"kings-cross\", \
\"label\": \"Kings Cross\", \"bound\": \"max\", \"value\": 30}], \
\"notes\": \"\"}"
.to_string(),
);
parts.push(
"\nUser: \"family home with garden, 45 min drive from Manchester, good schools\"\n\
(After calling search_destinations for \"Manchester\" with mode \"car\" \
and getting [{\"name\": \"Manchester\", \"slug\": \"manchester\", \"place_type\": \"city\"}])\n\
Output: {\"numeric_filters\": [\
{\"name\": \"Total floor area (sqm)\", \"bound\": \"min\", \"value\": 100}, \
{\"name\": \"Number of bedrooms & living rooms\", \"bound\": \"min\", \"value\": 5}, \
{\"name\": \"Number of restaurants within 2km\", \"bound\": \"min\", \"value\": 10}], \
{\"name\": \"Good+ primary schools within 5km\", \"bound\": \"min\", \"value\": 5}, \
{\"name\": \"Good+ secondary schools within 5km\", \"bound\": \"min\", \"value\": 2}], \
\"enum_filters\": [{\"name\": \"Property type\", \
\"values\": [\"Detached\", \"Semi-Detached\"]}], \
\"travel_time_filters\": [{\"mode\": \"car\", \"slug\": \"manchester\", \
\"label\": \"Manchester\", \"bound\": \"max\", \"value\": 45}], \
\"notes\": \"No filter for: garden\"}"
.to_string(),
);
@ -191,7 +376,8 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String {
// Output format reminder
parts.push(
"\n--- OUTPUT FORMAT ---\n\
{\"numeric_filters\": [{\"name\": \"...\", \"bound\": \"min\"|\"max\", \"value\": N}, ...], \"enum_filters\": [...], \"notes\": \"...\"}\n\
{\"numeric_filters\": [...], \"enum_filters\": [...], \"travel_time_filters\": [{\"mode\": \"...\", \"slug\": \"...\", \"label\": \"...\", \"bound\": \"min\"|\"max\", \"value\": N}, ...], \"notes\": \"...\"}\n\
- travel_time_filters: use ONLY slugs returned by search_destinations. If a place isn't found, mention it in notes.\n\
Respond with ONLY the JSON object. No explanation."
.to_string(),
);
@ -199,86 +385,393 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String {
parts.join("\n")
}
pub async fn post_ai_filters(
state: Arc<AppState>,
Json(req): Json<AiFiltersRequest>,
) -> Result<Json<AiFiltersResponse>, (StatusCode, String)> {
info!(query = %req.query, "POST /api/ai-filters");
let url = format!("{}/api/chat", state.ollama_url);
let body = json!({
"model": state.ollama_model,
"messages": [
{ "role": "system", "content": state.ai_filters_system_prompt },
{ "role": "user", "content": req.query }
],
"stream": false,
"format": state.ai_filters_schema,
"options": {
"temperature": AI_FILTERS_TEMPERATURE,
"num_predict": AI_FILTERS_MAX_TOKENS,
}
});
// Try up to 2 attempts — LLMs occasionally return empty content (e.g. only
// <think> blocks with no JSON output), which is transient and usually
// succeeds on retry.
let mut last_err = None;
for attempt in 0..2 {
let raw = call_ollama_and_parse(&state.http_client, &url, &body).await;
match raw {
Ok(raw) => {
let filters = validate_and_convert(&raw, &state.features_response);
let notes = raw
.get("notes")
.and_then(|val| val.as_str())
.unwrap_or("")
.to_string();
return Ok(Json(AiFiltersResponse { filters, notes }));
}
Err(err) => {
if attempt == 0 {
warn!("LLM attempt 1 failed, retrying: {}", err.1);
}
last_err = Some(err);
}
}
}
Err(last_err.unwrap())
/// Monotonically increasing week number derived from Unix epoch.
/// Resets every 7 days (604800 seconds). Used for weekly rate limiting.
fn current_week_number() -> u64 {
let secs = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("system time before epoch")
.as_secs();
secs / 604_800
}
/// Call Ollama and parse the response content as JSON.
///
/// Returns an error if: the HTTP call fails, the response is malformed,
/// the content is empty after stripping think blocks, or the content is
/// not valid JSON.
async fn call_ollama_and_parse(
client: &reqwest::Client,
url: &str,
body: &Value,
) -> Result<Value, (StatusCode, String)> {
let json_resp = ollama_chat(client, url, body).await?;
let content = extract_ollama_content(&json_resp)?;
/// Fetch the user's current AI token usage from PocketBase.
/// Returns `(tokens_used, week_number)`.
async fn fetch_ai_usage(
state: &AppState,
user_id: &str,
) -> Result<(u64, u64), (StatusCode, String)> {
let pb_url = state.pocketbase_url.trim_end_matches('/');
let token = auth_superuser(
&state.http_client,
pb_url,
&state.pocketbase_admin_email,
&state.pocketbase_admin_password,
)
.await
.map_err(|err| {
warn!("Failed to auth superuser for AI usage check: {err}");
(StatusCode::BAD_GATEWAY, "Internal error".into())
})?;
let content = strip_think_blocks(content);
let content = content.trim();
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await
.map_err(|err| {
warn!("Failed to fetch user record for AI usage: {err}");
(StatusCode::BAD_GATEWAY, "Internal error".into())
})?;
if content.is_empty() {
warn!("LLM returned empty content after stripping think blocks");
if !resp.status().is_success() {
let status = resp.status();
warn!("PocketBase user fetch failed ({status})");
return Err((StatusCode::BAD_GATEWAY, "Internal error".into()));
}
let body: Value = resp.json().await.map_err(|err| {
warn!("Failed to parse user record: {err}");
(StatusCode::BAD_GATEWAY, "Internal error".into())
})?;
let tokens_used = body
.get("ai_tokens_used")
.and_then(|val| val.as_u64())
.unwrap_or(0);
let week = body
.get("ai_tokens_week")
.and_then(|val| val.as_u64())
.unwrap_or(0);
Ok((tokens_used, week))
}
/// Update the user's AI token usage in PocketBase.
/// Best-effort — logs warnings on failure but does not propagate errors.
async fn update_ai_usage(state: &AppState, user_id: &str, tokens_used: u64, week: u64) {
let pb_url = state.pocketbase_url.trim_end_matches('/');
let token = match auth_superuser(
&state.http_client,
pb_url,
&state.pocketbase_admin_email,
&state.pocketbase_admin_password,
)
.await
{
Ok(tk) => tk,
Err(err) => {
warn!("Failed to auth superuser for AI usage update: {err}");
return;
}
};
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
let res = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&json!({
"ai_tokens_used": tokens_used,
"ai_tokens_week": week,
}))
.send()
.await;
match res {
Ok(resp) if resp.status().is_success() => {}
Ok(resp) => {
let status = resp.status();
warn!("Failed to update AI usage ({status})");
}
Err(err) => warn!("Failed to update AI usage: {err}"),
}
}
/// Maximum number of round trips (function calls + retries) before giving up.
const MAX_TOOL_ROUNDS: usize = 5;
pub async fn post_ai_filters(
state: Arc<AppState>,
Extension(user): Extension<OptionalUser>,
Json(req): Json<AiFiltersRequest>,
) -> Result<Json<AiFiltersResponse>, (StatusCode, String)> {
// Auth check
let user = user
.0
.ok_or((StatusCode::UNAUTHORIZED, "Login required".into()))?;
// Email verification check (skipped in dev mode)
if !user.verified && !state.is_dev {
return Err((
StatusCode::BAD_GATEWAY,
"LLM returned empty content (no JSON output)".into(),
StatusCode::FORBIDDEN,
"Please verify your email to use AI filters".into(),
));
}
serde_json::from_str(content).map_err(|err| {
warn!(error = %err, content = %content, "Failed to parse LLM JSON output");
(
StatusCode::BAD_GATEWAY,
format!("Failed to parse LLM output as JSON: {}", err),
// Check weekly token usage
let current_week = current_week_number();
let (stored_tokens, stored_week) = fetch_ai_usage(&state, &user.id).await?;
let tokens_used = if stored_week == current_week {
stored_tokens
} else {
0
};
if tokens_used >= AI_FILTERS_WEEKLY_TOKEN_LIMIT {
return Err((
StatusCode::TOO_MANY_REQUESTS,
"Weekly AI usage limit reached. Resets next week.".into(),
));
}
info!(query = %req.query, user_id = %user.id, "POST /api/ai-filters");
let tools = build_tool_declarations(&state);
// Build user message with optional context for conversational refinement
let user_text = if let Some(ref ctx) = req.context {
let mut msg = String::new();
msg.push_str("Currently active filters:\n");
msg.push_str(&serde_json::to_string(&ctx.filters).unwrap_or_default());
if !ctx.travel_time.is_empty() {
msg.push_str("\nCurrently active travel time filters:\n");
for tt in &ctx.travel_time {
let bounds = match (tt.min, tt.max) {
(Some(min), Some(max)) => format!("{}-{} min", min, max),
(Some(min), None) => format!("min {} min", min),
(None, Some(max)) => format!("max {} min", max),
(None, None) => "no range".to_string(),
};
msg.push_str(&format!("- {} to {} ({})\n", tt.mode, tt.label, bounds));
}
}
msg.push_str(&format!("\nUser request: {}", req.query));
msg
} else {
req.query.clone()
};
let mut contents = vec![json!({
"role": "user",
"parts": [{ "text": user_text }]
})];
let mut total_tokens_accumulated: u64 = 0;
// Function calling loop: model may call search_destinations, we execute and feed back
for round in 0..MAX_TOOL_ROUNDS {
let body = json!({
"systemInstruction": {
"parts": [{ "text": state.ai_filters_system_prompt }]
},
"contents": contents,
"tools": tools,
"generationConfig": {
"temperature": AI_FILTERS_TEMPERATURE,
"maxOutputTokens": AI_FILTERS_MAX_TOKENS,
"thinkingConfig": { "thinkingLevel": "LOW" },
}
});
let json_resp = gemini_chat(
&state.http_client,
&state.gemini_api_key,
&state.gemini_model,
&body,
)
})
.await?;
// Accumulate token usage
total_tokens_accumulated += json_resp
.get("usageMetadata")
.and_then(|md| md.get("totalTokenCount"))
.and_then(|tc| tc.as_u64())
.unwrap_or(0);
let candidate = json_resp
.get("candidates")
.and_then(|cs| cs.get(0))
.and_then(|c| c.get("content"))
.ok_or_else(|| {
warn!("Malformed Gemini response: missing candidates[0].content");
(
StatusCode::BAD_GATEWAY,
"Malformed Gemini response".into(),
)
})?;
let parts = candidate
.get("parts")
.and_then(|p| p.as_array())
.ok_or_else(|| {
warn!("Malformed Gemini response: missing parts array");
(
StatusCode::BAD_GATEWAY,
"Malformed Gemini response".into(),
)
})?;
// Check if the model made a function call.
// Find the full part (includes thoughtSignature required by Gemini 3 models).
if let Some(fc_part) = parts.iter().find(|part| part.get("functionCall").is_some()) {
let fc = fc_part.get("functionCall").unwrap();
let fn_name = fc.get("name").and_then(|n| n.as_str()).unwrap_or("");
let fn_args = fc.get("args").cloned().unwrap_or(json!({}));
info!(
function = fn_name,
round = round,
"AI called tool"
);
let fn_result = if fn_name == "search_destinations" {
let query = fn_args
.get("query")
.and_then(|q| q.as_str())
.unwrap_or("");
let mode = fn_args
.get("mode")
.and_then(|m| m.as_str())
.unwrap_or("transit");
execute_destination_search(&state, query, mode)
} else {
json!({"error": "unknown function"})
};
// Append the model's full response (preserves thoughtSignature) + our function result
contents.push(candidate.clone());
contents.push(json!({
"role": "user",
"parts": [{
"functionResponse": {
"name": fn_name,
"response": { "results": fn_result }
}
}]
}));
// Continue the loop — model will process the results
continue;
}
// Model returned text — extract and parse as JSON
let text = parts
.iter()
.find_map(|part| part.get("text").and_then(|t| t.as_str()))
.unwrap_or("");
let text = strip_markdown_fences(text);
let text = text.trim();
if text.is_empty() {
warn!("Gemini returned empty text content (round {})", round);
// Retry by continuing the loop
contents.push(candidate.clone());
contents.push(json!({
"role": "user",
"parts": [{ "text": "Your response was empty. Please output the JSON object." }]
}));
continue;
}
let raw: Value = match serde_json::from_str(text) {
Ok(val) => val,
Err(err) => {
warn!(error = %err, round = round, "Failed to parse Gemini JSON output, retrying");
// Ask the model to fix its output
contents.push(candidate.clone());
contents.push(json!({
"role": "user",
"parts": [{ "text": "That was not valid JSON. Please output ONLY the JSON object with numeric_filters, enum_filters, travel_time_filters, and notes." }]
}));
continue;
}
};
let filters = validate_and_convert(&raw, &state.features_response);
let travel_time_filters = validate_travel_time_filters(&raw, &state);
let notes = raw
.get("notes")
.and_then(|val| val.as_str())
.unwrap_or("")
.to_string();
// Update usage with total accumulated tokens
let new_total = tokens_used + total_tokens_accumulated;
update_ai_usage(&state, &user.id, new_total, current_week).await;
return Ok(Json(AiFiltersResponse {
filters,
travel_time_filters,
notes,
}));
}
// Exhausted tool rounds without getting a final text response
warn!("AI exhausted {} tool-calling rounds without final response", MAX_TOOL_ROUNDS);
Err((
StatusCode::BAD_GATEWAY,
"AI could not complete the request".into(),
))
}
/// Validate travel time filters from LLM output against available destinations.
fn validate_travel_time_filters(raw: &Value, state: &AppState) -> Vec<TravelTimeFilter> {
let arr = match raw
.get("travel_time_filters")
.and_then(|val| val.as_array())
{
Some(arr) => arr,
None => return Vec::new(),
};
let tt_store = &state.travel_time_store;
let mut results = Vec::new();
for item in arr {
let mode = match item.get("mode").and_then(|val| val.as_str()) {
Some(mode) => mode,
None => continue,
};
let slug = match item.get("slug").and_then(|val| val.as_str()) {
Some(slug) => slug,
None => continue,
};
let label = item
.get("label")
.and_then(|val| val.as_str())
.unwrap_or(slug);
// Verify this destination actually exists
if !tt_store.has_destination(mode, slug) {
warn!(mode = mode, slug = slug, "AI suggested non-existent destination");
continue;
}
let bound = item.get("bound").and_then(|val| val.as_str());
let value = item.get("value").and_then(|val| val.as_f64());
let (min, max) = match (bound, value) {
(Some("min"), Some(val)) => (Some(val.max(0.0).min(120.0) as f32), None),
(Some("max"), Some(val)) => (None, Some(val.max(0.0).min(120.0) as f32)),
_ => (None, None),
};
// Only include if at least one bound is set
if min.is_some() || max.is_some() {
results.push(TravelTimeFilter {
mode: mode.to_string(),
slug: slug.to_string(),
label: label.to_string(),
min,
max,
});
}
}
results
}
/// Validate LLM output against feature metadata and convert to FeatureFilters format.
@ -374,3 +867,32 @@ fn validate_and_convert(raw: &Value, features: &FeaturesResponse) -> Value {
Value::Object(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn strip_fences_json_tag() {
let input = "```json\n{\"a\": 1}\n```";
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
}
#[test]
fn strip_fences_no_tag() {
let input = "```\n{\"a\": 1}\n```";
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
}
#[test]
fn strip_fences_passthrough() {
let input = "{\"a\": 1}";
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
}
#[test]
fn strip_fences_whitespace() {
let input = " ```json\n {\"a\": 1} \n``` ";
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
}
}

View file

@ -178,7 +178,7 @@ pub async fn get_invite(
}
// Dev-only: return a fake valid admin invite without hitting PocketBase
if state.index_html.is_none() && code == DEV_INVITE_CODE {
if state.is_dev && code == DEV_INVITE_CODE {
return Json(InviteValidation {
valid: true,
invite_type: "admin".to_string(),
@ -294,7 +294,7 @@ pub async fn post_redeem_invite(
}
// Dev-only: fake redeem — just return "licensed" without touching PocketBase
if state.index_html.is_none() && req.code == DEV_INVITE_CODE {
if state.is_dev && req.code == DEV_INVITE_CODE {
info!(user_id = %user.id, "Dev invite redeemed (no-op)");
return Json(RedeemResponse {
result: "licensed".to_string(),

View file

@ -34,6 +34,8 @@ pub struct AppState {
pub screenshot_url: String,
/// Public-facing URL for absolute og:image URLs (e.g. https://perfectpostcodes.schmelczer.dev)
pub public_url: String,
/// True when --dist is not provided (no static serving, relaxed auth checks)
pub is_dev: bool,
/// Contents of index.html read at startup, used for crawler OG injection (None when --dist omitted)
pub index_html: Option<String>,
/// Shared HTTP client for proxying to the screenshot service and PocketBase
@ -44,16 +46,14 @@ pub struct AppState {
pub pocketbase_admin_email: String,
/// PocketBase superuser password
pub pocketbase_admin_password: String,
/// Ollama server URL for AI area summaries (e.g. http://ollama:11434)
pub ollama_url: String,
/// Ollama model name for area summaries (e.g. gemma3:12b)
pub ollama_model: String,
/// Gemini API key for AI filters
pub gemini_api_key: String,
/// Gemini model name (e.g. gemini-2.0-flash)
pub gemini_model: String,
/// Precomputed travel time data store
pub travel_time_store: Arc<TravelTimeStore>,
/// Token validation cache (60s TTL)
pub token_cache: Arc<TokenCache>,
/// JSON schema for Ollama structured output in AI filters
pub ai_filters_schema: serde_json::Value,
/// Complete system prompt for AI filters (features + examples + instructions)
pub ai_filters_system_prompt: String,
/// Google Maps API key for Street View metadata lookups

View file

@ -6,7 +6,7 @@ mod llm;
pub use grid_index::GridIndex;
pub use hash::{generate_priorities, splitmix64_hash};
pub use interned_column::InternedColumn;
pub use llm::{extract_ollama_content, ollama_chat, strip_think_blocks};
pub use llm::{extract_gemini_content, gemini_chat};
/// Normalize a UK postcode: uppercase, strip spaces, insert canonical space before inward code.
/// e.g. "e142dg" → "E14 2DG", "E14 2DG" → "E14 2DG", "EC1A1BB" → "EC1A 1BB"

View file

@ -4,68 +4,62 @@ use tracing::warn;
pub type LlmError = (StatusCode, String);
/// Send a chat request to Ollama and return the parsed JSON response.
/// Send a generateContent request to the Gemini API and return the parsed JSON response.
///
/// Handles connection errors, non-success status codes, and JSON parse failures
/// uniformly as `BAD_GATEWAY` errors.
pub async fn ollama_chat(
pub async fn gemini_chat(
client: &reqwest::Client,
url: &str,
api_key: &str,
model: &str,
body: &Value,
) -> Result<Value, LlmError> {
let response = client.post(url).json(body).send().await.map_err(|err| {
warn!(error = %err, "Failed to connect to Ollama");
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
model, api_key
);
let response = client.post(&url).json(body).send().await.map_err(|err| {
warn!(error = %err, "Failed to connect to Gemini API");
(
StatusCode::BAD_GATEWAY,
format!("Failed to connect to Ollama: {}", err),
format!("Failed to connect to Gemini API: {}", err),
)
})?;
if !response.status().is_success() {
let status = response.status();
let body_text = response.text().await.unwrap_or_default();
warn!(status = %status, body = %body_text, "Ollama returned error");
warn!(status = %status, body = %body_text, "Gemini API returned error");
return Err((
StatusCode::BAD_GATEWAY,
format!("Ollama error {}: {}", status, body_text),
format!("Gemini API error {}: {}", status, body_text),
));
}
response.json().await.map_err(|err| {
warn!(error = %err, "Failed to parse Ollama response");
warn!(error = %err, "Failed to parse Gemini API response");
(
StatusCode::BAD_GATEWAY,
format!("Failed to parse Ollama response: {}", err),
format!("Failed to parse Gemini API response: {}", err),
)
})
}
/// Extract content from Ollama native response (`message.content`)
pub fn extract_ollama_content(json: &Value) -> Result<&str, LlmError> {
json.get("message")
.and_then(|msg| msg.get("content"))
.and_then(|ct| ct.as_str())
/// Extract text content from Gemini response (`candidates[0].content.parts[0].text`)
pub fn extract_gemini_content(json: &Value) -> Result<&str, LlmError> {
json.get("candidates")
.and_then(|c| c.get(0))
.and_then(|c| c.get("content"))
.and_then(|c| c.get("parts"))
.and_then(|p| p.get(0))
.and_then(|p| p.get("text"))
.and_then(|t| t.as_str())
.ok_or_else(|| {
warn!("Malformed Ollama response: missing message.content");
warn!("Malformed Gemini response: missing candidates[0].content.parts[0].text");
(
StatusCode::BAD_GATEWAY,
"Malformed LLM response: missing message.content".into(),
"Malformed Gemini response: missing content".into(),
)
})
}
/// Strip `<think>...</think>` blocks from model output
pub fn strip_think_blocks(text: &str) -> String {
let mut result = String::new();
let mut remaining = text;
while let Some(start) = remaining.find("<think>") {
result.push_str(&remaining[..start]);
if let Some(end) = remaining[start..].find("</think>") {
remaining = &remaining[start + end + 8..];
} else {
return result;
}
}
result.push_str(remaining);
result
}