Improve LLM
This commit is contained in:
parent
02712f41e8
commit
80c093b7ba
16 changed files with 898 additions and 278 deletions
|
|
@ -15,7 +15,9 @@ pub const MAX_PRICE_HISTORY_POINTS: usize = 5000;
|
|||
pub const POSTCODE_SEARCH_OFFSET: f64 = 0.02;
|
||||
|
||||
pub const AI_FILTERS_MAX_TOKENS: usize = 2000;
|
||||
pub const AI_FILTERS_TEMPERATURE: f32 = 0.0;
|
||||
/// Gemini 3 recommends 1.0; lower values can cause looping or degraded performance.
|
||||
pub const AI_FILTERS_TEMPERATURE: f32 = 1.0;
|
||||
pub const AI_FILTERS_WEEKLY_TOKEN_LIMIT: u64 = 10_000_000;
|
||||
|
||||
/// Timeout for outbound HTTP service calls (seconds).
|
||||
pub const SERVICE_CALL_TIMEOUT: u64 = 120;
|
||||
|
|
|
|||
|
|
@ -94,13 +94,13 @@ struct Cli {
|
|||
#[arg(long, env = "POCKETBASE_ADMIN_PASSWORD")]
|
||||
pocketbase_admin_password: String,
|
||||
|
||||
/// Ollama server URL (e.g. http://ollama:11434)
|
||||
#[arg(long, env = "OLLAMA_URL")]
|
||||
ollama_url: String,
|
||||
/// Gemini API key
|
||||
#[arg(long, env = "GEMINI_API_KEY")]
|
||||
gemini_api_key: String,
|
||||
|
||||
/// Ollama model name
|
||||
#[arg(long, env = "OLLAMA_MODEL")]
|
||||
ollama_model: String,
|
||||
/// Gemini model name (e.g. gemini-2.0-flash)
|
||||
#[arg(long, env = "GEMINI_MODEL")]
|
||||
gemini_model: String,
|
||||
|
||||
/// Path to precomputed travel times directory (contains mode subdirs with parquet files)
|
||||
#[arg(long, env = "TRAVEL_TIMES")]
|
||||
|
|
@ -301,9 +301,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
"Precomputed features response"
|
||||
);
|
||||
|
||||
let ai_filters_schema = routes::build_ollama_schema(&features_response);
|
||||
let ai_filters_system_prompt = routes::build_system_prompt(&features_response);
|
||||
info!("Precomputed AI filters schema and system prompt");
|
||||
// AI filters system prompt built after travel_time_store is loaded (needs mode counts)
|
||||
|
||||
// Record data loading metrics
|
||||
metrics::record_data_stats(
|
||||
|
|
@ -331,10 +329,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
&cli.google_oauth_client_secret,
|
||||
)
|
||||
.await?;
|
||||
info!(
|
||||
"Ollama configured: {} (model: {})",
|
||||
cli.ollama_url, cli.ollama_model
|
||||
);
|
||||
info!("Gemini configured (model: {})", cli.gemini_model);
|
||||
let tt_path = &cli.travel_times;
|
||||
if !tt_path.exists() {
|
||||
bail!(
|
||||
|
|
@ -352,6 +347,23 @@ async fn main() -> anyhow::Result<()> {
|
|||
Arc::new(store)
|
||||
};
|
||||
|
||||
let mode_destinations: Vec<(String, usize)> = travel_time_store
|
||||
.available_modes
|
||||
.iter()
|
||||
.map(|mode| {
|
||||
let count = travel_time_store
|
||||
.destinations
|
||||
.get(mode.as_str())
|
||||
.map(|slugs| slugs.len())
|
||||
.unwrap_or(0);
|
||||
(mode.clone(), count)
|
||||
})
|
||||
.filter(|(_, count)| *count > 0)
|
||||
.collect();
|
||||
let ai_filters_system_prompt =
|
||||
routes::build_system_prompt(&features_response, &mode_destinations);
|
||||
info!("Precomputed AI filters system prompt");
|
||||
|
||||
let token_cache = Arc::new(auth::TokenCache::new());
|
||||
|
||||
let state = Arc::new(AppState {
|
||||
|
|
@ -370,16 +382,16 @@ async fn main() -> anyhow::Result<()> {
|
|||
features_response,
|
||||
screenshot_url: cli.screenshot_url,
|
||||
public_url: cli.public_url,
|
||||
is_dev: index_html.is_none(),
|
||||
index_html,
|
||||
http_client,
|
||||
pocketbase_url: cli.pocketbase_url,
|
||||
pocketbase_admin_email: cli.pocketbase_admin_email,
|
||||
pocketbase_admin_password: cli.pocketbase_admin_password,
|
||||
ollama_url: cli.ollama_url,
|
||||
ollama_model: cli.ollama_model,
|
||||
gemini_api_key: cli.gemini_api_key,
|
||||
gemini_model: cli.gemini_model,
|
||||
travel_time_store,
|
||||
token_cache,
|
||||
ai_filters_schema,
|
||||
ai_filters_system_prompt,
|
||||
google_maps_api_key: cli.google_maps_api_key,
|
||||
stripe_secret_key: cli.stripe_secret_key,
|
||||
|
|
@ -504,7 +516,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
)
|
||||
.route(
|
||||
"/api/ai-filters",
|
||||
post(move |body| routes::post_ai_filters(state_ai_filters.clone(), body))
|
||||
post(move |ext, body| routes::post_ai_filters(state_ai_filters.clone(), ext, body))
|
||||
.layer(ConcurrencyLimitLayer::new(5)),
|
||||
)
|
||||
.route(
|
||||
|
|
|
|||
|
|
@ -240,9 +240,11 @@ async fn ensure_user_fields(
|
|||
let has_is_admin = fields.iter().any(|f| f["name"] == "is_admin");
|
||||
let has_subscription = fields.iter().any(|f| f["name"] == "subscription");
|
||||
let has_newsletter = fields.iter().any(|f| f["name"] == "newsletter");
|
||||
let has_ai_tokens_used = fields.iter().any(|f| f["name"] == "ai_tokens_used");
|
||||
let has_ai_tokens_week = fields.iter().any(|f| f["name"] == "ai_tokens_week");
|
||||
|
||||
if has_is_admin && has_subscription && has_newsletter {
|
||||
info!("PocketBase users collection already has is_admin, subscription, and newsletter fields");
|
||||
if has_is_admin && has_subscription && has_newsletter && has_ai_tokens_used && has_ai_tokens_week {
|
||||
info!("PocketBase users collection already has all required fields");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
|
@ -269,6 +271,20 @@ async fn ensure_user_fields(
|
|||
}));
|
||||
}
|
||||
|
||||
if !has_ai_tokens_used {
|
||||
new_fields.push(serde_json::json!({
|
||||
"name": "ai_tokens_used",
|
||||
"type": "number",
|
||||
}));
|
||||
}
|
||||
|
||||
if !has_ai_tokens_week {
|
||||
new_fields.push(serde_json::json!({
|
||||
"name": "ai_tokens_week",
|
||||
"type": "number",
|
||||
}));
|
||||
}
|
||||
|
||||
let patch_resp = client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ pub(crate) mod travel_time;
|
|||
mod travel_destinations;
|
||||
mod travel_modes;
|
||||
|
||||
pub use ai_filters::{build_ollama_schema, build_system_prompt, post_ai_filters};
|
||||
pub use ai_filters::{build_system_prompt, post_ai_filters};
|
||||
pub use checkout::post_checkout;
|
||||
pub use export::get_export;
|
||||
pub use features::{build_features_response, get_features, FeatureInfo, FeaturesResponse};
|
||||
|
|
|
|||
|
|
@ -2,76 +2,190 @@ use std::sync::Arc;
|
|||
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::Json;
|
||||
use axum::Extension;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::consts::{AI_FILTERS_MAX_TOKENS, AI_FILTERS_TEMPERATURE};
|
||||
use crate::auth::OptionalUser;
|
||||
use crate::consts::{AI_FILTERS_MAX_TOKENS, AI_FILTERS_TEMPERATURE, AI_FILTERS_WEEKLY_TOKEN_LIMIT};
|
||||
use crate::data::slugify;
|
||||
use crate::pocketbase::auth_superuser;
|
||||
use crate::routes::{FeatureInfo, FeaturesResponse};
|
||||
use crate::state::AppState;
|
||||
use crate::utils::{extract_ollama_content, ollama_chat, strip_think_blocks};
|
||||
use crate::utils::gemini_chat;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct AiFiltersContext {
|
||||
filters: Value,
|
||||
#[serde(default)]
|
||||
travel_time: Vec<AiTravelTimeContext>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct AiTravelTimeContext {
|
||||
mode: String,
|
||||
label: String,
|
||||
min: Option<f32>,
|
||||
max: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct AiFiltersRequest {
|
||||
query: String,
|
||||
/// Current filters for conversational refinement (e.g. "make it cheaper")
|
||||
context: Option<AiFiltersContext>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct TravelTimeFilter {
|
||||
mode: String,
|
||||
slug: String,
|
||||
label: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
min: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
max: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct AiFiltersResponse {
|
||||
filters: Value,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
travel_time_filters: Vec<TravelTimeFilter>,
|
||||
/// What the LLM couldn't map to existing filters (empty if everything matched)
|
||||
#[serde(skip_serializing_if = "String::is_empty")]
|
||||
notes: String,
|
||||
}
|
||||
|
||||
/// Build a JSON schema for Ollama structured output.
|
||||
///
|
||||
/// Uses two arrays (`numeric_filters` and `enum_filters`) instead of one property
|
||||
/// per feature, because Ollama converts JSON schema to GBNF grammar and a schema
|
||||
/// with 50+ optional keys causes a combinatorial explosion that crashes the parser.
|
||||
/// Array-based schema keeps the grammar small and constant-size.
|
||||
pub fn build_ollama_schema(_features: &FeaturesResponse) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"numeric_filters": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"bound": { "type": "string", "enum": ["min", "max"] },
|
||||
"value": { "type": "number" }
|
||||
},
|
||||
"required": ["name", "bound", "value"]
|
||||
}
|
||||
},
|
||||
"enum_filters": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"values": { "type": "array", "items": { "type": "string" } }
|
||||
},
|
||||
"required": ["name", "values"]
|
||||
}
|
||||
},
|
||||
"notes": {
|
||||
"type": "string"
|
||||
}
|
||||
/// Strip markdown code fences (```json ... ``` or ``` ... ```) from LLM output.
|
||||
/// Models occasionally wrap JSON in markdown fencing even when told not to.
|
||||
fn strip_markdown_fences(text: &str) -> &str {
|
||||
let trimmed = text.trim();
|
||||
|
||||
// Try ```json\n...\n``` or ```\n...\n```
|
||||
if let Some(rest) = trimmed.strip_prefix("```") {
|
||||
// Skip optional language tag (e.g. "json")
|
||||
let rest = if let Some(newline_pos) = rest.find('\n') {
|
||||
&rest[newline_pos + 1..]
|
||||
} else {
|
||||
return trimmed;
|
||||
};
|
||||
if let Some(content) = rest.strip_suffix("```") {
|
||||
return content.trim();
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
trimmed
|
||||
}
|
||||
|
||||
/// Build the Gemini tool declaration for destination search.
|
||||
fn build_tool_declarations(state: &AppState) -> Value {
|
||||
let modes: Vec<&str> = state
|
||||
.travel_time_store
|
||||
.available_modes
|
||||
.iter()
|
||||
.map(|mode| mode.as_str())
|
||||
.collect();
|
||||
|
||||
json!([{
|
||||
"functionDeclarations": [{
|
||||
"name": "search_destinations",
|
||||
"description": "Search for available travel time destinations (cities, stations, towns) that have precomputed travel time data. Call this when the user mentions wanting to be near, close to, or within a certain travel time of a specific place.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Place name to search for (e.g. 'Manchester', 'Kings Cross', 'Heathrow')"
|
||||
},
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": modes,
|
||||
"description": "Transport mode to search destinations for"
|
||||
}
|
||||
},
|
||||
"required": ["query", "mode"]
|
||||
}
|
||||
}]
|
||||
}])
|
||||
}
|
||||
|
||||
/// Execute a destination search against PlaceData + TravelTimeStore.
|
||||
/// Returns matching destinations as a JSON value with `results` and optional `message`.
|
||||
///
|
||||
/// Uses word-based matching: all words in the query must appear somewhere in the
|
||||
/// place name (order-independent). Also matches against slugs for short queries.
|
||||
fn execute_destination_search(state: &AppState, query: &str, mode: &str) -> Value {
|
||||
let query_lower = query.to_lowercase();
|
||||
let query_words: Vec<&str> = query_lower.split_whitespace().collect();
|
||||
let query_slug = slugify(query);
|
||||
let tt_store = &state.travel_time_store;
|
||||
let pd = &state.place_data;
|
||||
|
||||
let slug_set = match tt_store.destinations.get(mode) {
|
||||
Some(slugs) => slugs,
|
||||
None => return json!({ "results": [], "message": format!("No travel data available for mode '{}'", mode) }),
|
||||
};
|
||||
|
||||
// Find places matching the query that have travel time data.
|
||||
// A place matches if ALL query words appear in its name, OR its slug matches the query slug.
|
||||
let mut matches: Vec<(usize, String, u8, u32)> = pd
|
||||
.name_lower
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, name_lower)| {
|
||||
let words_match = query_words.iter().all(|word| name_lower.contains(word));
|
||||
let slug = slugify(&pd.name[idx]);
|
||||
let slug_match = slug.contains(&query_slug) || query_slug.contains(&slug);
|
||||
if !words_match && !slug_match {
|
||||
return None;
|
||||
}
|
||||
if slug_set.contains(&slug) {
|
||||
Some((idx, slug, pd.type_rank[idx], pd.population[idx]))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Sort: type rank asc, population desc
|
||||
matches.sort_unstable_by(|a, b| a.2.cmp(&b.2).then(b.3.cmp(&a.3)));
|
||||
matches.truncate(10);
|
||||
|
||||
if matches.is_empty() {
|
||||
info!(query = query, mode = mode, "Destination search returned no results");
|
||||
return json!({
|
||||
"results": [],
|
||||
"message": format!("No travel time data available for '{}' by {}. This destination cannot be used as a travel time filter.", query, mode)
|
||||
});
|
||||
}
|
||||
|
||||
let results: Vec<Value> = matches
|
||||
.into_iter()
|
||||
.map(|(idx, slug, ..)| {
|
||||
json!({
|
||||
"name": pd.name[idx],
|
||||
"slug": slug,
|
||||
"place_type": pd.place_type.get(idx).to_string(),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
json!({ "results": results })
|
||||
}
|
||||
|
||||
/// Build the complete system prompt for AI filters.
|
||||
///
|
||||
/// Contains: role instructions, feature catalogue, few-shot examples, output rules.
|
||||
/// Contains: role instructions, feature catalogue, travel time info,
|
||||
/// few-shot examples, output rules.
|
||||
/// Precomputed at startup and cached in AppState.
|
||||
pub fn build_system_prompt(features: &FeaturesResponse) -> String {
|
||||
pub fn build_system_prompt(
|
||||
features: &FeaturesResponse,
|
||||
mode_destinations: &[(String, usize)],
|
||||
) -> String {
|
||||
let mut parts = Vec::new();
|
||||
|
||||
// Role and task description
|
||||
parts.push(
|
||||
"You are a UK property search assistant. \
|
||||
The user describes their ideal property or area in natural language. \
|
||||
|
|
@ -91,10 +205,61 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String {
|
|||
(note: this counts bedrooms + living rooms combined, so 3 bed ~ min 4).\n\
|
||||
- If the user mentions something that has no matching filter, put it in \"notes\" \
|
||||
as a short phrase (e.g. \"No filter for: garden, sea view\"). \
|
||||
If everything was matched, set \"notes\" to an empty string."
|
||||
If everything was matched, set \"notes\" to an empty string.\n\
|
||||
\n\
|
||||
CONVERSATIONAL REFINEMENT:\n\
|
||||
The user's message may include their currently active filters as context. \
|
||||
When context is provided:\n\
|
||||
- \"make it cheaper\" / \"lower the price\" = adjust the existing price filter down\n\
|
||||
- \"also add ...\" / \"and good schools\" = keep existing filters and add new ones\n\
|
||||
- \"remove the ...\" / \"drop the ...\" = return filters WITHOUT the mentioned one\n\
|
||||
- If the request is a completely new search (not a refinement), ignore the context \
|
||||
and build filters from scratch.\n\
|
||||
- Always output the COMPLETE set of filters (existing + modified), not just the changes."
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
// Travel time section with available modes
|
||||
let modes_list = mode_destinations
|
||||
.iter()
|
||||
.map(|(mode, count)| format!("- {} ({} destinations available)", mode, count))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
parts.push(format!(
|
||||
"\n--- TRAVEL TIME FILTERS ---\n\
|
||||
You can add travel time filters when the user mentions commute times, \
|
||||
proximity to places, or wanting to be near/within X minutes of somewhere.\n\
|
||||
\n\
|
||||
Available transport modes (only use modes that have destinations):\n\
|
||||
{}\n\
|
||||
- \"car\" / \"drive\" / \"driving\" = car mode\n\
|
||||
- \"cycle\" / \"bike\" / \"cycling\" = bicycle mode\n\
|
||||
- \"walk\" / \"walking\" / \"on foot\" = walking mode\n\
|
||||
- \"train\" / \"tube\" / \"bus\" / \"public transport\" / \"commute\" = transit mode\n\
|
||||
\n\
|
||||
When the user mentions a specific place, you MUST call the search_destinations \
|
||||
tool to find the exact slug. Use the name and slug from the search results.\n\
|
||||
If search_destinations returns an empty array, the destination is not available — \
|
||||
mention it in \"notes\" (e.g. \"No travel data for: Gatwick Airport\") and do NOT \
|
||||
include a travel_time_filter for it.\n\
|
||||
\n\
|
||||
Travel time values are in MINUTES (0-120 range).\n\
|
||||
- \"within 30 minutes\" = max 30\n\
|
||||
- \"at least 10 minutes\" = min 10\n\
|
||||
- \"30-45 minute commute\" = min 30, max 45\n\
|
||||
- If only a max is given, omit min (and vice versa).\n\
|
||||
\n\
|
||||
INFERRING TRANSPORT MODE (when the user does not specify one explicitly):\n\
|
||||
- \"commute\" to a major city centre or station = transit\n\
|
||||
- \"near\" / \"close to\" a city centre or station = transit\n\
|
||||
- \"near\" / \"close to\" a smaller town, village, or rural area = car\n\
|
||||
- \"drive\" / \"driving distance\" / \"driving time\" = always car\n\
|
||||
- If multiple modes are plausible, prefer transit for urban destinations \
|
||||
(London, Manchester, Birmingham, Leeds, etc.) and car for everything else.",
|
||||
modes_list,
|
||||
));
|
||||
|
||||
// Feature catalogue
|
||||
parts.push("\n--- AVAILABLE FEATURES ---\n".to_string());
|
||||
for group in &features.groups {
|
||||
|
|
@ -148,6 +313,7 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String {
|
|||
Output: {\"numeric_filters\": [{\"name\": \"Last known price\", \"bound\": \"max\", \"value\": 400000}], \
|
||||
\"enum_filters\": [{\"name\": \"Leasehold/Freehold\", \"values\": [\"Freehold\"]}, \
|
||||
{\"name\": \"Property type\", \"values\": [\"Detached\", \"Semi-Detached\", \"Terraced\"]}], \
|
||||
\"travel_time_filters\": [], \
|
||||
\"notes\": \"\"}"
|
||||
.to_string(),
|
||||
);
|
||||
|
|
@ -161,7 +327,7 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String {
|
|||
{\"name\": \"Good+ primary schools within 5km\", \"bound\": \"min\", \"value\": 5}, \
|
||||
{\"name\": \"Good+ secondary schools within 5km\", \"bound\": \"min\", \"value\": 2}, \
|
||||
{\"name\": \"Number of parks within 2km\", \"bound\": \"min\", \"value\": 3}], \
|
||||
\"enum_filters\": [], \"notes\": \"\"}"
|
||||
\"enum_filters\": [], \"travel_time_filters\": [], \"notes\": \"\"}"
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
|
|
@ -172,18 +338,37 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String {
|
|||
{\"name\": \"Number of bedrooms & living rooms\", \"bound\": \"min\", \"value\": 4}], \
|
||||
\"enum_filters\": [{\"name\": \"Property type\", \"values\": [\"Flats/Maisonettes\"]}, \
|
||||
{\"name\": \"Max available download speed (Mbps)\", \"values\": [\"100\", \"300\", \"1000\"]}], \
|
||||
\"travel_time_filters\": [], \
|
||||
\"notes\": \"No filter for: beach proximity\"}"
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
parts.push(
|
||||
"\nUser: \"large family home with a garden near restaurants\"\n\
|
||||
"\nUser: \"within 30 minutes commute of Kings Cross, under 500k\"\n\
|
||||
(After calling search_destinations for \"Kings Cross\" with mode \"transit\" \
|
||||
and getting [{\"name\": \"Kings Cross\", \"slug\": \"kings-cross\", \"place_type\": \"station\"}])\n\
|
||||
Output: {\"numeric_filters\": [\
|
||||
{\"name\": \"Last known price\", \"bound\": \"max\", \"value\": 500000}], \
|
||||
\"enum_filters\": [], \
|
||||
\"travel_time_filters\": [{\"mode\": \"transit\", \"slug\": \"kings-cross\", \
|
||||
\"label\": \"Kings Cross\", \"bound\": \"max\", \"value\": 30}], \
|
||||
\"notes\": \"\"}"
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
parts.push(
|
||||
"\nUser: \"family home with garden, 45 min drive from Manchester, good schools\"\n\
|
||||
(After calling search_destinations for \"Manchester\" with mode \"car\" \
|
||||
and getting [{\"name\": \"Manchester\", \"slug\": \"manchester\", \"place_type\": \"city\"}])\n\
|
||||
Output: {\"numeric_filters\": [\
|
||||
{\"name\": \"Total floor area (sqm)\", \"bound\": \"min\", \"value\": 100}, \
|
||||
{\"name\": \"Number of bedrooms & living rooms\", \"bound\": \"min\", \"value\": 5}, \
|
||||
{\"name\": \"Number of restaurants within 2km\", \"bound\": \"min\", \"value\": 10}], \
|
||||
{\"name\": \"Good+ primary schools within 5km\", \"bound\": \"min\", \"value\": 5}, \
|
||||
{\"name\": \"Good+ secondary schools within 5km\", \"bound\": \"min\", \"value\": 2}], \
|
||||
\"enum_filters\": [{\"name\": \"Property type\", \
|
||||
\"values\": [\"Detached\", \"Semi-Detached\"]}], \
|
||||
\"travel_time_filters\": [{\"mode\": \"car\", \"slug\": \"manchester\", \
|
||||
\"label\": \"Manchester\", \"bound\": \"max\", \"value\": 45}], \
|
||||
\"notes\": \"No filter for: garden\"}"
|
||||
.to_string(),
|
||||
);
|
||||
|
|
@ -191,7 +376,8 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String {
|
|||
// Output format reminder
|
||||
parts.push(
|
||||
"\n--- OUTPUT FORMAT ---\n\
|
||||
{\"numeric_filters\": [{\"name\": \"...\", \"bound\": \"min\"|\"max\", \"value\": N}, ...], \"enum_filters\": [...], \"notes\": \"...\"}\n\
|
||||
{\"numeric_filters\": [...], \"enum_filters\": [...], \"travel_time_filters\": [{\"mode\": \"...\", \"slug\": \"...\", \"label\": \"...\", \"bound\": \"min\"|\"max\", \"value\": N}, ...], \"notes\": \"...\"}\n\
|
||||
- travel_time_filters: use ONLY slugs returned by search_destinations. If a place isn't found, mention it in notes.\n\
|
||||
Respond with ONLY the JSON object. No explanation."
|
||||
.to_string(),
|
||||
);
|
||||
|
|
@ -199,86 +385,393 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String {
|
|||
parts.join("\n")
|
||||
}
|
||||
|
||||
pub async fn post_ai_filters(
|
||||
state: Arc<AppState>,
|
||||
Json(req): Json<AiFiltersRequest>,
|
||||
) -> Result<Json<AiFiltersResponse>, (StatusCode, String)> {
|
||||
info!(query = %req.query, "POST /api/ai-filters");
|
||||
|
||||
let url = format!("{}/api/chat", state.ollama_url);
|
||||
let body = json!({
|
||||
"model": state.ollama_model,
|
||||
"messages": [
|
||||
{ "role": "system", "content": state.ai_filters_system_prompt },
|
||||
{ "role": "user", "content": req.query }
|
||||
],
|
||||
"stream": false,
|
||||
"format": state.ai_filters_schema,
|
||||
"options": {
|
||||
"temperature": AI_FILTERS_TEMPERATURE,
|
||||
"num_predict": AI_FILTERS_MAX_TOKENS,
|
||||
}
|
||||
});
|
||||
|
||||
// Try up to 2 attempts — LLMs occasionally return empty content (e.g. only
|
||||
// <think> blocks with no JSON output), which is transient and usually
|
||||
// succeeds on retry.
|
||||
let mut last_err = None;
|
||||
for attempt in 0..2 {
|
||||
let raw = call_ollama_and_parse(&state.http_client, &url, &body).await;
|
||||
match raw {
|
||||
Ok(raw) => {
|
||||
let filters = validate_and_convert(&raw, &state.features_response);
|
||||
let notes = raw
|
||||
.get("notes")
|
||||
.and_then(|val| val.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
return Ok(Json(AiFiltersResponse { filters, notes }));
|
||||
}
|
||||
Err(err) => {
|
||||
if attempt == 0 {
|
||||
warn!("LLM attempt 1 failed, retrying: {}", err.1);
|
||||
}
|
||||
last_err = Some(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_err.unwrap())
|
||||
/// Monotonically increasing week number derived from Unix epoch.
|
||||
/// Resets every 7 days (604800 seconds). Used for weekly rate limiting.
|
||||
fn current_week_number() -> u64 {
|
||||
let secs = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.expect("system time before epoch")
|
||||
.as_secs();
|
||||
secs / 604_800
|
||||
}
|
||||
|
||||
/// Call Ollama and parse the response content as JSON.
|
||||
///
|
||||
/// Returns an error if: the HTTP call fails, the response is malformed,
|
||||
/// the content is empty after stripping think blocks, or the content is
|
||||
/// not valid JSON.
|
||||
async fn call_ollama_and_parse(
|
||||
client: &reqwest::Client,
|
||||
url: &str,
|
||||
body: &Value,
|
||||
) -> Result<Value, (StatusCode, String)> {
|
||||
let json_resp = ollama_chat(client, url, body).await?;
|
||||
let content = extract_ollama_content(&json_resp)?;
|
||||
/// Fetch the user's current AI token usage from PocketBase.
|
||||
/// Returns `(tokens_used, week_number)`.
|
||||
async fn fetch_ai_usage(
|
||||
state: &AppState,
|
||||
user_id: &str,
|
||||
) -> Result<(u64, u64), (StatusCode, String)> {
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let token = auth_superuser(
|
||||
&state.http_client,
|
||||
pb_url,
|
||||
&state.pocketbase_admin_email,
|
||||
&state.pocketbase_admin_password,
|
||||
)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
warn!("Failed to auth superuser for AI usage check: {err}");
|
||||
(StatusCode::BAD_GATEWAY, "Internal error".into())
|
||||
})?;
|
||||
|
||||
let content = strip_think_blocks(content);
|
||||
let content = content.trim();
|
||||
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
|
||||
let resp = state
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| {
|
||||
warn!("Failed to fetch user record for AI usage: {err}");
|
||||
(StatusCode::BAD_GATEWAY, "Internal error".into())
|
||||
})?;
|
||||
|
||||
if content.is_empty() {
|
||||
warn!("LLM returned empty content after stripping think blocks");
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
warn!("PocketBase user fetch failed ({status})");
|
||||
return Err((StatusCode::BAD_GATEWAY, "Internal error".into()));
|
||||
}
|
||||
|
||||
let body: Value = resp.json().await.map_err(|err| {
|
||||
warn!("Failed to parse user record: {err}");
|
||||
(StatusCode::BAD_GATEWAY, "Internal error".into())
|
||||
})?;
|
||||
|
||||
let tokens_used = body
|
||||
.get("ai_tokens_used")
|
||||
.and_then(|val| val.as_u64())
|
||||
.unwrap_or(0);
|
||||
let week = body
|
||||
.get("ai_tokens_week")
|
||||
.and_then(|val| val.as_u64())
|
||||
.unwrap_or(0);
|
||||
|
||||
Ok((tokens_used, week))
|
||||
}
|
||||
|
||||
/// Update the user's AI token usage in PocketBase.
|
||||
/// Best-effort — logs warnings on failure but does not propagate errors.
|
||||
async fn update_ai_usage(state: &AppState, user_id: &str, tokens_used: u64, week: u64) {
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
let token = match auth_superuser(
|
||||
&state.http_client,
|
||||
pb_url,
|
||||
&state.pocketbase_admin_email,
|
||||
&state.pocketbase_admin_password,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(tk) => tk,
|
||||
Err(err) => {
|
||||
warn!("Failed to auth superuser for AI usage update: {err}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
|
||||
let res = state
|
||||
.http_client
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&json!({
|
||||
"ai_tokens_used": tokens_used,
|
||||
"ai_tokens_week": week,
|
||||
}))
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {}
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
warn!("Failed to update AI usage ({status})");
|
||||
}
|
||||
Err(err) => warn!("Failed to update AI usage: {err}"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Maximum number of round trips (function calls + retries) before giving up.
|
||||
const MAX_TOOL_ROUNDS: usize = 5;
|
||||
|
||||
pub async fn post_ai_filters(
|
||||
state: Arc<AppState>,
|
||||
Extension(user): Extension<OptionalUser>,
|
||||
Json(req): Json<AiFiltersRequest>,
|
||||
) -> Result<Json<AiFiltersResponse>, (StatusCode, String)> {
|
||||
// Auth check
|
||||
let user = user
|
||||
.0
|
||||
.ok_or((StatusCode::UNAUTHORIZED, "Login required".into()))?;
|
||||
|
||||
// Email verification check (skipped in dev mode)
|
||||
if !user.verified && !state.is_dev {
|
||||
return Err((
|
||||
StatusCode::BAD_GATEWAY,
|
||||
"LLM returned empty content (no JSON output)".into(),
|
||||
StatusCode::FORBIDDEN,
|
||||
"Please verify your email to use AI filters".into(),
|
||||
));
|
||||
}
|
||||
|
||||
serde_json::from_str(content).map_err(|err| {
|
||||
warn!(error = %err, content = %content, "Failed to parse LLM JSON output");
|
||||
(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Failed to parse LLM output as JSON: {}", err),
|
||||
// Check weekly token usage
|
||||
let current_week = current_week_number();
|
||||
let (stored_tokens, stored_week) = fetch_ai_usage(&state, &user.id).await?;
|
||||
let tokens_used = if stored_week == current_week {
|
||||
stored_tokens
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
if tokens_used >= AI_FILTERS_WEEKLY_TOKEN_LIMIT {
|
||||
return Err((
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
"Weekly AI usage limit reached. Resets next week.".into(),
|
||||
));
|
||||
}
|
||||
|
||||
info!(query = %req.query, user_id = %user.id, "POST /api/ai-filters");
|
||||
|
||||
let tools = build_tool_declarations(&state);
|
||||
|
||||
// Build user message with optional context for conversational refinement
|
||||
let user_text = if let Some(ref ctx) = req.context {
|
||||
let mut msg = String::new();
|
||||
msg.push_str("Currently active filters:\n");
|
||||
msg.push_str(&serde_json::to_string(&ctx.filters).unwrap_or_default());
|
||||
if !ctx.travel_time.is_empty() {
|
||||
msg.push_str("\nCurrently active travel time filters:\n");
|
||||
for tt in &ctx.travel_time {
|
||||
let bounds = match (tt.min, tt.max) {
|
||||
(Some(min), Some(max)) => format!("{}-{} min", min, max),
|
||||
(Some(min), None) => format!("min {} min", min),
|
||||
(None, Some(max)) => format!("max {} min", max),
|
||||
(None, None) => "no range".to_string(),
|
||||
};
|
||||
msg.push_str(&format!("- {} to {} ({})\n", tt.mode, tt.label, bounds));
|
||||
}
|
||||
}
|
||||
msg.push_str(&format!("\nUser request: {}", req.query));
|
||||
msg
|
||||
} else {
|
||||
req.query.clone()
|
||||
};
|
||||
|
||||
let mut contents = vec![json!({
|
||||
"role": "user",
|
||||
"parts": [{ "text": user_text }]
|
||||
})];
|
||||
|
||||
let mut total_tokens_accumulated: u64 = 0;
|
||||
|
||||
// Function calling loop: model may call search_destinations, we execute and feed back
|
||||
for round in 0..MAX_TOOL_ROUNDS {
|
||||
let body = json!({
|
||||
"systemInstruction": {
|
||||
"parts": [{ "text": state.ai_filters_system_prompt }]
|
||||
},
|
||||
"contents": contents,
|
||||
"tools": tools,
|
||||
"generationConfig": {
|
||||
"temperature": AI_FILTERS_TEMPERATURE,
|
||||
"maxOutputTokens": AI_FILTERS_MAX_TOKENS,
|
||||
"thinkingConfig": { "thinkingLevel": "LOW" },
|
||||
}
|
||||
});
|
||||
|
||||
let json_resp = gemini_chat(
|
||||
&state.http_client,
|
||||
&state.gemini_api_key,
|
||||
&state.gemini_model,
|
||||
&body,
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
|
||||
// Accumulate token usage
|
||||
total_tokens_accumulated += json_resp
|
||||
.get("usageMetadata")
|
||||
.and_then(|md| md.get("totalTokenCount"))
|
||||
.and_then(|tc| tc.as_u64())
|
||||
.unwrap_or(0);
|
||||
|
||||
let candidate = json_resp
|
||||
.get("candidates")
|
||||
.and_then(|cs| cs.get(0))
|
||||
.and_then(|c| c.get("content"))
|
||||
.ok_or_else(|| {
|
||||
warn!("Malformed Gemini response: missing candidates[0].content");
|
||||
(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
"Malformed Gemini response".into(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let parts = candidate
|
||||
.get("parts")
|
||||
.and_then(|p| p.as_array())
|
||||
.ok_or_else(|| {
|
||||
warn!("Malformed Gemini response: missing parts array");
|
||||
(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
"Malformed Gemini response".into(),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Check if the model made a function call.
|
||||
// Find the full part (includes thoughtSignature required by Gemini 3 models).
|
||||
if let Some(fc_part) = parts.iter().find(|part| part.get("functionCall").is_some()) {
|
||||
let fc = fc_part.get("functionCall").unwrap();
|
||||
let fn_name = fc.get("name").and_then(|n| n.as_str()).unwrap_or("");
|
||||
let fn_args = fc.get("args").cloned().unwrap_or(json!({}));
|
||||
|
||||
info!(
|
||||
function = fn_name,
|
||||
round = round,
|
||||
"AI called tool"
|
||||
);
|
||||
|
||||
let fn_result = if fn_name == "search_destinations" {
|
||||
let query = fn_args
|
||||
.get("query")
|
||||
.and_then(|q| q.as_str())
|
||||
.unwrap_or("");
|
||||
let mode = fn_args
|
||||
.get("mode")
|
||||
.and_then(|m| m.as_str())
|
||||
.unwrap_or("transit");
|
||||
execute_destination_search(&state, query, mode)
|
||||
} else {
|
||||
json!({"error": "unknown function"})
|
||||
};
|
||||
|
||||
// Append the model's full response (preserves thoughtSignature) + our function result
|
||||
contents.push(candidate.clone());
|
||||
contents.push(json!({
|
||||
"role": "user",
|
||||
"parts": [{
|
||||
"functionResponse": {
|
||||
"name": fn_name,
|
||||
"response": { "results": fn_result }
|
||||
}
|
||||
}]
|
||||
}));
|
||||
|
||||
// Continue the loop — model will process the results
|
||||
continue;
|
||||
}
|
||||
|
||||
// Model returned text — extract and parse as JSON
|
||||
let text = parts
|
||||
.iter()
|
||||
.find_map(|part| part.get("text").and_then(|t| t.as_str()))
|
||||
.unwrap_or("");
|
||||
let text = strip_markdown_fences(text);
|
||||
let text = text.trim();
|
||||
|
||||
if text.is_empty() {
|
||||
warn!("Gemini returned empty text content (round {})", round);
|
||||
// Retry by continuing the loop
|
||||
contents.push(candidate.clone());
|
||||
contents.push(json!({
|
||||
"role": "user",
|
||||
"parts": [{ "text": "Your response was empty. Please output the JSON object." }]
|
||||
}));
|
||||
continue;
|
||||
}
|
||||
|
||||
let raw: Value = match serde_json::from_str(text) {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
warn!(error = %err, round = round, "Failed to parse Gemini JSON output, retrying");
|
||||
// Ask the model to fix its output
|
||||
contents.push(candidate.clone());
|
||||
contents.push(json!({
|
||||
"role": "user",
|
||||
"parts": [{ "text": "That was not valid JSON. Please output ONLY the JSON object with numeric_filters, enum_filters, travel_time_filters, and notes." }]
|
||||
}));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let filters = validate_and_convert(&raw, &state.features_response);
|
||||
let travel_time_filters = validate_travel_time_filters(&raw, &state);
|
||||
let notes = raw
|
||||
.get("notes")
|
||||
.and_then(|val| val.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
// Update usage with total accumulated tokens
|
||||
let new_total = tokens_used + total_tokens_accumulated;
|
||||
update_ai_usage(&state, &user.id, new_total, current_week).await;
|
||||
|
||||
return Ok(Json(AiFiltersResponse {
|
||||
filters,
|
||||
travel_time_filters,
|
||||
notes,
|
||||
}));
|
||||
}
|
||||
|
||||
// Exhausted tool rounds without getting a final text response
|
||||
warn!("AI exhausted {} tool-calling rounds without final response", MAX_TOOL_ROUNDS);
|
||||
Err((
|
||||
StatusCode::BAD_GATEWAY,
|
||||
"AI could not complete the request".into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Validate travel time filters from LLM output against available destinations.
|
||||
fn validate_travel_time_filters(raw: &Value, state: &AppState) -> Vec<TravelTimeFilter> {
|
||||
let arr = match raw
|
||||
.get("travel_time_filters")
|
||||
.and_then(|val| val.as_array())
|
||||
{
|
||||
Some(arr) => arr,
|
||||
None => return Vec::new(),
|
||||
};
|
||||
|
||||
let tt_store = &state.travel_time_store;
|
||||
let mut results = Vec::new();
|
||||
|
||||
for item in arr {
|
||||
let mode = match item.get("mode").and_then(|val| val.as_str()) {
|
||||
Some(mode) => mode,
|
||||
None => continue,
|
||||
};
|
||||
let slug = match item.get("slug").and_then(|val| val.as_str()) {
|
||||
Some(slug) => slug,
|
||||
None => continue,
|
||||
};
|
||||
let label = item
|
||||
.get("label")
|
||||
.and_then(|val| val.as_str())
|
||||
.unwrap_or(slug);
|
||||
|
||||
// Verify this destination actually exists
|
||||
if !tt_store.has_destination(mode, slug) {
|
||||
warn!(mode = mode, slug = slug, "AI suggested non-existent destination");
|
||||
continue;
|
||||
}
|
||||
|
||||
let bound = item.get("bound").and_then(|val| val.as_str());
|
||||
let value = item.get("value").and_then(|val| val.as_f64());
|
||||
|
||||
let (min, max) = match (bound, value) {
|
||||
(Some("min"), Some(val)) => (Some(val.max(0.0).min(120.0) as f32), None),
|
||||
(Some("max"), Some(val)) => (None, Some(val.max(0.0).min(120.0) as f32)),
|
||||
_ => (None, None),
|
||||
};
|
||||
|
||||
// Only include if at least one bound is set
|
||||
if min.is_some() || max.is_some() {
|
||||
results.push(TravelTimeFilter {
|
||||
mode: mode.to_string(),
|
||||
slug: slug.to_string(),
|
||||
label: label.to_string(),
|
||||
min,
|
||||
max,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// Validate LLM output against feature metadata and convert to FeatureFilters format.
|
||||
|
|
@ -374,3 +867,32 @@ fn validate_and_convert(raw: &Value, features: &FeaturesResponse) -> Value {
|
|||
|
||||
Value::Object(result)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn strip_fences_json_tag() {
|
||||
let input = "```json\n{\"a\": 1}\n```";
|
||||
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strip_fences_no_tag() {
|
||||
let input = "```\n{\"a\": 1}\n```";
|
||||
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strip_fences_passthrough() {
|
||||
let input = "{\"a\": 1}";
|
||||
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strip_fences_whitespace() {
|
||||
let input = " ```json\n {\"a\": 1} \n``` ";
|
||||
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -178,7 +178,7 @@ pub async fn get_invite(
|
|||
}
|
||||
|
||||
// Dev-only: return a fake valid admin invite without hitting PocketBase
|
||||
if state.index_html.is_none() && code == DEV_INVITE_CODE {
|
||||
if state.is_dev && code == DEV_INVITE_CODE {
|
||||
return Json(InviteValidation {
|
||||
valid: true,
|
||||
invite_type: "admin".to_string(),
|
||||
|
|
@ -294,7 +294,7 @@ pub async fn post_redeem_invite(
|
|||
}
|
||||
|
||||
// Dev-only: fake redeem — just return "licensed" without touching PocketBase
|
||||
if state.index_html.is_none() && req.code == DEV_INVITE_CODE {
|
||||
if state.is_dev && req.code == DEV_INVITE_CODE {
|
||||
info!(user_id = %user.id, "Dev invite redeemed (no-op)");
|
||||
return Json(RedeemResponse {
|
||||
result: "licensed".to_string(),
|
||||
|
|
|
|||
|
|
@ -34,6 +34,8 @@ pub struct AppState {
|
|||
pub screenshot_url: String,
|
||||
/// Public-facing URL for absolute og:image URLs (e.g. https://perfectpostcodes.schmelczer.dev)
|
||||
pub public_url: String,
|
||||
/// True when --dist is not provided (no static serving, relaxed auth checks)
|
||||
pub is_dev: bool,
|
||||
/// Contents of index.html read at startup, used for crawler OG injection (None when --dist omitted)
|
||||
pub index_html: Option<String>,
|
||||
/// Shared HTTP client for proxying to the screenshot service and PocketBase
|
||||
|
|
@ -44,16 +46,14 @@ pub struct AppState {
|
|||
pub pocketbase_admin_email: String,
|
||||
/// PocketBase superuser password
|
||||
pub pocketbase_admin_password: String,
|
||||
/// Ollama server URL for AI area summaries (e.g. http://ollama:11434)
|
||||
pub ollama_url: String,
|
||||
/// Ollama model name for area summaries (e.g. gemma3:12b)
|
||||
pub ollama_model: String,
|
||||
/// Gemini API key for AI filters
|
||||
pub gemini_api_key: String,
|
||||
/// Gemini model name (e.g. gemini-2.0-flash)
|
||||
pub gemini_model: String,
|
||||
/// Precomputed travel time data store
|
||||
pub travel_time_store: Arc<TravelTimeStore>,
|
||||
/// Token validation cache (60s TTL)
|
||||
pub token_cache: Arc<TokenCache>,
|
||||
/// JSON schema for Ollama structured output in AI filters
|
||||
pub ai_filters_schema: serde_json::Value,
|
||||
/// Complete system prompt for AI filters (features + examples + instructions)
|
||||
pub ai_filters_system_prompt: String,
|
||||
/// Google Maps API key for Street View metadata lookups
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ mod llm;
|
|||
pub use grid_index::GridIndex;
|
||||
pub use hash::{generate_priorities, splitmix64_hash};
|
||||
pub use interned_column::InternedColumn;
|
||||
pub use llm::{extract_ollama_content, ollama_chat, strip_think_blocks};
|
||||
pub use llm::{extract_gemini_content, gemini_chat};
|
||||
|
||||
/// Normalize a UK postcode: uppercase, strip spaces, insert canonical space before inward code.
|
||||
/// e.g. "e142dg" → "E14 2DG", "E14 2DG" → "E14 2DG", "EC1A1BB" → "EC1A 1BB"
|
||||
|
|
|
|||
|
|
@ -4,68 +4,62 @@ use tracing::warn;
|
|||
|
||||
pub type LlmError = (StatusCode, String);
|
||||
|
||||
/// Send a chat request to Ollama and return the parsed JSON response.
|
||||
/// Send a generateContent request to the Gemini API and return the parsed JSON response.
|
||||
///
|
||||
/// Handles connection errors, non-success status codes, and JSON parse failures
|
||||
/// uniformly as `BAD_GATEWAY` errors.
|
||||
pub async fn ollama_chat(
|
||||
pub async fn gemini_chat(
|
||||
client: &reqwest::Client,
|
||||
url: &str,
|
||||
api_key: &str,
|
||||
model: &str,
|
||||
body: &Value,
|
||||
) -> Result<Value, LlmError> {
|
||||
let response = client.post(url).json(body).send().await.map_err(|err| {
|
||||
warn!(error = %err, "Failed to connect to Ollama");
|
||||
let url = format!(
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
|
||||
model, api_key
|
||||
);
|
||||
|
||||
let response = client.post(&url).json(body).send().await.map_err(|err| {
|
||||
warn!(error = %err, "Failed to connect to Gemini API");
|
||||
(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Failed to connect to Ollama: {}", err),
|
||||
format!("Failed to connect to Gemini API: {}", err),
|
||||
)
|
||||
})?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let body_text = response.text().await.unwrap_or_default();
|
||||
warn!(status = %status, body = %body_text, "Ollama returned error");
|
||||
warn!(status = %status, body = %body_text, "Gemini API returned error");
|
||||
return Err((
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Ollama error {}: {}", status, body_text),
|
||||
format!("Gemini API error {}: {}", status, body_text),
|
||||
));
|
||||
}
|
||||
|
||||
response.json().await.map_err(|err| {
|
||||
warn!(error = %err, "Failed to parse Ollama response");
|
||||
warn!(error = %err, "Failed to parse Gemini API response");
|
||||
(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Failed to parse Ollama response: {}", err),
|
||||
format!("Failed to parse Gemini API response: {}", err),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract content from Ollama native response (`message.content`)
|
||||
pub fn extract_ollama_content(json: &Value) -> Result<&str, LlmError> {
|
||||
json.get("message")
|
||||
.and_then(|msg| msg.get("content"))
|
||||
.and_then(|ct| ct.as_str())
|
||||
/// Extract text content from Gemini response (`candidates[0].content.parts[0].text`)
|
||||
pub fn extract_gemini_content(json: &Value) -> Result<&str, LlmError> {
|
||||
json.get("candidates")
|
||||
.and_then(|c| c.get(0))
|
||||
.and_then(|c| c.get("content"))
|
||||
.and_then(|c| c.get("parts"))
|
||||
.and_then(|p| p.get(0))
|
||||
.and_then(|p| p.get("text"))
|
||||
.and_then(|t| t.as_str())
|
||||
.ok_or_else(|| {
|
||||
warn!("Malformed Ollama response: missing message.content");
|
||||
warn!("Malformed Gemini response: missing candidates[0].content.parts[0].text");
|
||||
(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
"Malformed LLM response: missing message.content".into(),
|
||||
"Malformed Gemini response: missing content".into(),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Strip `<think>...</think>` blocks from model output
|
||||
pub fn strip_think_blocks(text: &str) -> String {
|
||||
let mut result = String::new();
|
||||
let mut remaining = text;
|
||||
while let Some(start) = remaining.find("<think>") {
|
||||
result.push_str(&remaining[..start]);
|
||||
if let Some(end) = remaining[start..].find("</think>") {
|
||||
remaining = &remaining[start + end + 8..];
|
||||
} else {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
result.push_str(remaining);
|
||||
result
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue