From 80c093b7ba8f149ee6611686da3a9f284af435ce Mon Sep 17 00:00:00 2001 From: Andras Schmelczer Date: Sun, 15 Mar 2026 14:05:34 +0000 Subject: [PATCH] Improve LLM --- .github/workflows/docker-publish.yml | 15 +- docker-compose.yml | 4 +- frontend/src/components/map/AiFilterInput.tsx | 9 +- frontend/src/components/map/Filters.tsx | 4 +- frontend/src/components/map/MapPage.tsx | 38 +- frontend/src/hooks/useAiFilters.ts | 184 +++-- pipeline/download/pois.py | 14 +- server-rs/src/consts.rs | 4 +- server-rs/src/main.rs | 46 +- server-rs/src/pocketbase.rs | 20 +- server-rs/src/routes.rs | 2 +- server-rs/src/routes/ai_filters.rs | 758 +++++++++++++++--- server-rs/src/routes/invites.rs | 4 +- server-rs/src/state.rs | 12 +- server-rs/src/utils.rs | 2 +- server-rs/src/utils/llm.rs | 60 +- 16 files changed, 898 insertions(+), 278 deletions(-) diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml index ea7da6e..9dc8577 100644 --- a/.github/workflows/docker-publish.yml +++ b/.github/workflows/docker-publish.yml @@ -21,6 +21,15 @@ jobs: - name: Checkout uses: actions/checkout@v4 + - name: Set up uv + uses: astral-sh/setup-uv@v4 + + - name: Download map assets (fonts, sprites, twemoji) + run: uv run python -m pipeline.download.map_assets --output frontend/public/assets + + - name: Download arcgis data for finder + run: uv run python -m pipeline.download.arcgis --output property-data/arcgis_data.parquet + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 @@ -64,12 +73,6 @@ jobs: cache-from: type=gha,scope=screenshot cache-to: type=gha,mode=max,scope=screenshot - - name: Set up uv - uses: astral-sh/setup-uv@v4 - - - name: Download arcgis data for finder - run: uv run python -m pipeline.download.arcgis --output property-data/arcgis_data.parquet - - name: Build and push finder service uses: docker/build-push-action@v6 with: diff --git a/docker-compose.yml b/docker-compose.yml index 22fb204..2b2ff44 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -36,8 +36,8 @@ services: POCKETBASE_ADMIN_EMAIL: *pb-email POCKETBASE_ADMIN_PASSWORD: *pb-password SCREENSHOT_URL: http://screenshot:8002 - OLLAMA_URL: http://host.docker.internal:11434 - OLLAMA_MODEL: gpt-oss:20b + GEMINI_API_KEY: AIzaSyC2mQDcEwILHM3uOE2C-lxUQbQrKTX9Xi4 + GEMINI_MODEL: gemini-3-flash-preview PUBLIC_URL: https://perfectpostcodes.schmelczer.dev GOOGLE_MAPS_API_KEY: "AIzaSyBgBn9LjrxHCjb9j1LZbLYpEdCJj-NkHPY" STRIPE_SECRET_KEY: sk_test_51SyVcePRjj2bdyn1HLkatQ5onwp8kamm41tjMcRdxXnJYWVPsVd9usMTOSNtNdGhrjbsrtNbgTdKXICg2qBiocEn00PvNDC0d3 diff --git a/frontend/src/components/map/AiFilterInput.tsx b/frontend/src/components/map/AiFilterInput.tsx index 371e5c0..af7698e 100644 --- a/frontend/src/components/map/AiFilterInput.tsx +++ b/frontend/src/components/map/AiFilterInput.tsx @@ -42,6 +42,7 @@ interface AiFilterInputProps { error: string | null; errorType: AiFilterErrorType | null; notes: string | null; + summary: string | null; onSubmit: (query: string) => void; isLoggedIn: boolean; onLoginRequired: () => void; @@ -52,6 +53,7 @@ export default memo(function AiFilterInput({ error, errorType, notes, + summary, onSubmit, isLoggedIn, onLoginRequired, @@ -88,7 +90,7 @@ export default memo(function AiFilterInput({ ); const hasContent = query.trim().length > 0; - const showExamples = expanded && !hasContent && !loading && !error && !notes; + const showExamples = expanded && !hasContent && !loading && !error && !notes && !summary; if (!expanded) { return ( @@ -173,6 +175,11 @@ export default memo(function AiFilterInput({ {error}

)} + {summary && !error && !loading && ( +

+ {summary} +

+ )} {notes && !error && !loading && (

{notes} diff --git a/frontend/src/components/map/Filters.tsx b/frontend/src/components/map/Filters.tsx index 280148b..308f7ec 100644 --- a/frontend/src/components/map/Filters.tsx +++ b/frontend/src/components/map/Filters.tsx @@ -92,6 +92,7 @@ interface FiltersProps { aiFilterError: string | null; aiFilterErrorType: AiFilterErrorType | null; aiFilterNotes: string | null; + aiFilterSummary: string | null; onAiFilterSubmit: (query: string) => void; isLoggedIn: boolean; onLoginRequired: () => void; @@ -127,6 +128,7 @@ export default memo(function Filters({ aiFilterError, aiFilterErrorType, aiFilterNotes, + aiFilterSummary, onAiFilterSubmit, isLoggedIn, onLoginRequired, @@ -285,7 +287,7 @@ export default memo(function Filters({

- +
{(['historical', 'buy', 'rent'] as const).map((type) => { diff --git a/frontend/src/components/map/MapPage.tsx b/frontend/src/components/map/MapPage.tsx index c232472..0e3b93a 100644 --- a/frontend/src/components/map/MapPage.tsx +++ b/frontend/src/components/map/MapPage.tsx @@ -148,22 +148,33 @@ export default function MapPage({ const handleAiFilterSubmit = useCallback( async (query: string) => { - const result = await aiFilters.fetchAiFilters(query); + // Build context from current filters for conversational refinement + const context = { + filters, + travelTime: travelTime.activeEntries.map((entry) => ({ + mode: entry.mode, + label: entry.label, + min: entry.timeRange?.[0], + max: entry.timeRange?.[1], + })), + }; + const hasContext = + Object.keys(context.filters).length > 0 || context.travelTime.length > 0; + + const result = await aiFilters.fetchAiFilters(query, hasContext ? context : undefined); if (!result) return; handleSetFilters(result.filters); - // Apply travel time filters from AI - if (result.travelTimeFilters.length > 0) { - const newEntries = result.travelTimeFilters.map((tt) => ({ - mode: tt.mode, - slug: tt.slug, - label: tt.label, - timeRange: [tt.min ?? 0, tt.max ?? 120] as [number, number], - useBest: false, - })); - travelTime.handleSetEntries(newEntries); - } + // Always sync travel time entries — clear stale ones when AI returns none + const newEntries = result.travelTimeFilters.map((tt) => ({ + mode: tt.mode, + slug: tt.slug, + label: tt.label, + timeRange: [tt.min ?? 0, tt.max ?? 120] as [number, number], + useBest: false, + })); + travelTime.handleSetEntries(newEntries); }, - [aiFilters.fetchAiFilters, handleSetFilters, travelTime.handleSetEntries] + [aiFilters.fetchAiFilters, handleSetFilters, travelTime.handleSetEntries, travelTime.activeEntries, filters] ); const handleTravelTimeSetDestination = useCallback( @@ -514,6 +525,7 @@ export default function MapPage({ aiFilterError={aiFilters.error} aiFilterErrorType={aiFilters.errorType} aiFilterNotes={aiFilters.notes} + aiFilterSummary={aiFilters.summary} onAiFilterSubmit={handleAiFilterSubmit} isLoggedIn={!!user} onLoginRequired={onRegisterClick ?? (() => {})} diff --git a/frontend/src/hooks/useAiFilters.ts b/frontend/src/hooks/useAiFilters.ts index d4c775e..b436b30 100644 --- a/frontend/src/hooks/useAiFilters.ts +++ b/frontend/src/hooks/useAiFilters.ts @@ -1,6 +1,6 @@ import { useState, useCallback, useRef } from 'react'; import type { FeatureFilters } from '../types'; -import type { TransportMode } from './useTravelTime'; +import type { TransportMode, TravelTimeEntry } from './useTravelTime'; import { apiUrl, authHeaders, logNonAbortError } from '../lib/api'; export interface AiTravelTimeFilter { @@ -11,20 +11,54 @@ export interface AiTravelTimeFilter { max?: number; } -interface AiFiltersResult { +export interface AiFiltersResult { filters: FeatureFilters; travelTimeFilters: AiTravelTimeFilter[]; notes: string; + /** Human-readable summary of what was set */ + summary: string; } export type AiFilterErrorType = 'auth' | 'verification' | 'limit' | 'error'; +/** Context of currently active filters, sent for conversational refinement. */ +export interface AiFiltersContext { + filters: FeatureFilters; + travelTime: { mode: string; label: string; min?: number; max?: number }[]; +} + interface UseAiFiltersResult { - fetchAiFilters: (query: string) => Promise; + fetchAiFilters: (query: string, context?: AiFiltersContext) => Promise; loading: boolean; error: string | null; errorType: AiFilterErrorType | null; notes: string | null; + summary: string | null; +} + +/** Build a human-readable summary of the AI result. */ +function buildSummary( + filters: FeatureFilters, + travelTimeFilters: AiTravelTimeFilter[] +): string { + const parts: string[] = []; + + for (const [name, value] of Object.entries(filters)) { + if (Array.isArray(value) && value.length === 2 && typeof value[0] === 'number') { + parts.push(name); + } else if (Array.isArray(value)) { + parts.push(`${name}: ${(value as string[]).join(', ')}`); + } + } + + for (const tt of travelTimeFilters) { + const bounds = + tt.max !== undefined ? `< ${tt.max} min` : tt.min !== undefined ? `> ${tt.min} min` : ''; + parts.push(`${tt.mode} to ${tt.label} ${bounds}`.trim()); + } + + if (parts.length === 0) return 'No filters set'; + return `Set ${parts.length} filter${parts.length > 1 ? 's' : ''}: ${parts.join(', ')}`; } export function useAiFilters(): UseAiFiltersResult { @@ -32,77 +66,93 @@ export function useAiFilters(): UseAiFiltersResult { const [error, setError] = useState(null); const [errorType, setErrorType] = useState(null); const [notes, setNotes] = useState(null); + const [summary, setSummary] = useState(null); const abortRef = useRef(null); - const fetchAiFilters = useCallback(async (query: string): Promise => { - abortRef.current?.abort(); - const controller = new AbortController(); - abortRef.current = controller; + const fetchAiFilters = useCallback( + async (query: string, context?: AiFiltersContext): Promise => { + abortRef.current?.abort(); + const controller = new AbortController(); + abortRef.current = controller; - setLoading(true); - setError(null); - setErrorType(null); - setNotes(null); + setLoading(true); + setError(null); + setErrorType(null); + setNotes(null); + setSummary(null); - try { - const url = apiUrl('ai-filters'); - const response = await fetch( - url, - authHeaders({ - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ query }), - signal: controller.signal, - }) - ); - - if (!response.ok) { - const text = await response.text(); - if (response.status === 401) { - setErrorType('auth'); - setError(text || 'Login required'); - } else if (response.status === 403) { - setErrorType('verification'); - setError(text || 'Email verification required'); - } else if (response.status === 429) { - setErrorType('limit'); - setError(text || 'Weekly usage limit reached'); - } else { - setErrorType('error'); - setError(text || `HTTP ${response.status}`); + try { + const url = apiUrl('ai-filters'); + const bodyObj: Record = { query }; + if (context) { + bodyObj.context = { + filters: context.filters, + travel_time: context.travelTime, + }; } + const response = await fetch( + url, + authHeaders({ + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(bodyObj), + signal: controller.signal, + }) + ); + + if (!response.ok) { + const text = await response.text(); + if (response.status === 401) { + setErrorType('auth'); + setError(text || 'Login required'); + } else if (response.status === 403) { + setErrorType('verification'); + setError(text || 'Email verification required'); + } else if (response.status === 429) { + setErrorType('limit'); + setError(text || 'Weekly usage limit reached'); + } else { + setErrorType('error'); + setError(text || `HTTP ${response.status}`); + } + setLoading(false); + return null; + } + + const json = await response.json(); + const travelTimeFilters: AiTravelTimeFilter[] = (json.travel_time_filters || []).map( + (tt: { mode: string; slug: string; label: string; min?: number; max?: number }) => ({ + mode: tt.mode as TransportMode, + slug: tt.slug, + label: tt.label, + min: tt.min, + max: tt.max, + }) + ); + const filters = json.filters as FeatureFilters; + const summaryText = buildSummary(filters, travelTimeFilters); + const result: AiFiltersResult = { + filters, + travelTimeFilters, + notes: json.notes || '', + summary: summaryText, + }; + setNotes(result.notes || null); + setSummary(summaryText); + setLoading(false); + return result; + } catch (err) { + if (controller.signal.aborted) return null; + logNonAbortError('ai-filters', err); + const message = err instanceof Error ? err.message : 'Failed to generate filters'; + setErrorType('error'); + setError(message); setLoading(false); return null; } + }, + [] + ); - const json = await response.json(); - const travelTimeFilters: AiTravelTimeFilter[] = (json.travel_time_filters || []).map( - (tt: { mode: string; slug: string; label: string; min?: number; max?: number }) => ({ - mode: tt.mode as TransportMode, - slug: tt.slug, - label: tt.label, - min: tt.min, - max: tt.max, - }) - ); - const result: AiFiltersResult = { - filters: json.filters as FeatureFilters, - travelTimeFilters, - notes: json.notes || '', - }; - setNotes(result.notes || null); - setLoading(false); - return result; - } catch (err) { - if (controller.signal.aborted) return null; - logNonAbortError('ai-filters', err); - const message = err instanceof Error ? err.message : 'Failed to generate filters'; - setErrorType('error'); - setError(message); - setLoading(false); - return null; - } - }, []); - - return { fetchAiFilters, loading, error, errorType, notes }; + return { fetchAiFilters, loading, error, errorType, notes, summary }; } diff --git a/pipeline/download/pois.py b/pipeline/download/pois.py index 0b39f8b..ed08b8e 100644 --- a/pipeline/download/pois.py +++ b/pipeline/download/pois.py @@ -7,18 +7,18 @@ import polars as pl from shapely.geometry import Point from tqdm import tqdm -from pipeline.utils.england_geometry import load_england_polygon +from pipeline.utils.england_geometry import ( + ENGLAND_BBOX_EAST, + ENGLAND_BBOX_NORTH, + ENGLAND_BBOX_SOUTH, + ENGLAND_BBOX_WEST, + load_england_polygon, +) BATCH_SIZE = 50_000 MIN_OCCURENCE_COUNT = 20 -# Bounding box for fast pre-filtering before the precise polygon check -ENGLAND_BBOX_WEST = -6.45 -ENGLAND_BBOX_SOUTH = 49.85 -ENGLAND_BBOX_EAST = 1.77 -ENGLAND_BBOX_NORTH = 55.82 - POI_TAG_KEYS: list[str] = [ "amenity", "building", diff --git a/server-rs/src/consts.rs b/server-rs/src/consts.rs index ba61d81..923e514 100644 --- a/server-rs/src/consts.rs +++ b/server-rs/src/consts.rs @@ -15,7 +15,9 @@ pub const MAX_PRICE_HISTORY_POINTS: usize = 5000; pub const POSTCODE_SEARCH_OFFSET: f64 = 0.02; pub const AI_FILTERS_MAX_TOKENS: usize = 2000; -pub const AI_FILTERS_TEMPERATURE: f32 = 0.0; +/// Gemini 3 recommends 1.0; lower values can cause looping or degraded performance. +pub const AI_FILTERS_TEMPERATURE: f32 = 1.0; +pub const AI_FILTERS_WEEKLY_TOKEN_LIMIT: u64 = 10_000_000; /// Timeout for outbound HTTP service calls (seconds). pub const SERVICE_CALL_TIMEOUT: u64 = 120; diff --git a/server-rs/src/main.rs b/server-rs/src/main.rs index 01281a3..930c411 100644 --- a/server-rs/src/main.rs +++ b/server-rs/src/main.rs @@ -94,13 +94,13 @@ struct Cli { #[arg(long, env = "POCKETBASE_ADMIN_PASSWORD")] pocketbase_admin_password: String, - /// Ollama server URL (e.g. http://ollama:11434) - #[arg(long, env = "OLLAMA_URL")] - ollama_url: String, + /// Gemini API key + #[arg(long, env = "GEMINI_API_KEY")] + gemini_api_key: String, - /// Ollama model name - #[arg(long, env = "OLLAMA_MODEL")] - ollama_model: String, + /// Gemini model name (e.g. gemini-2.0-flash) + #[arg(long, env = "GEMINI_MODEL")] + gemini_model: String, /// Path to precomputed travel times directory (contains mode subdirs with parquet files) #[arg(long, env = "TRAVEL_TIMES")] @@ -301,9 +301,7 @@ async fn main() -> anyhow::Result<()> { "Precomputed features response" ); - let ai_filters_schema = routes::build_ollama_schema(&features_response); - let ai_filters_system_prompt = routes::build_system_prompt(&features_response); - info!("Precomputed AI filters schema and system prompt"); + // AI filters system prompt built after travel_time_store is loaded (needs mode counts) // Record data loading metrics metrics::record_data_stats( @@ -331,10 +329,7 @@ async fn main() -> anyhow::Result<()> { &cli.google_oauth_client_secret, ) .await?; - info!( - "Ollama configured: {} (model: {})", - cli.ollama_url, cli.ollama_model - ); + info!("Gemini configured (model: {})", cli.gemini_model); let tt_path = &cli.travel_times; if !tt_path.exists() { bail!( @@ -352,6 +347,23 @@ async fn main() -> anyhow::Result<()> { Arc::new(store) }; + let mode_destinations: Vec<(String, usize)> = travel_time_store + .available_modes + .iter() + .map(|mode| { + let count = travel_time_store + .destinations + .get(mode.as_str()) + .map(|slugs| slugs.len()) + .unwrap_or(0); + (mode.clone(), count) + }) + .filter(|(_, count)| *count > 0) + .collect(); + let ai_filters_system_prompt = + routes::build_system_prompt(&features_response, &mode_destinations); + info!("Precomputed AI filters system prompt"); + let token_cache = Arc::new(auth::TokenCache::new()); let state = Arc::new(AppState { @@ -370,16 +382,16 @@ async fn main() -> anyhow::Result<()> { features_response, screenshot_url: cli.screenshot_url, public_url: cli.public_url, + is_dev: index_html.is_none(), index_html, http_client, pocketbase_url: cli.pocketbase_url, pocketbase_admin_email: cli.pocketbase_admin_email, pocketbase_admin_password: cli.pocketbase_admin_password, - ollama_url: cli.ollama_url, - ollama_model: cli.ollama_model, + gemini_api_key: cli.gemini_api_key, + gemini_model: cli.gemini_model, travel_time_store, token_cache, - ai_filters_schema, ai_filters_system_prompt, google_maps_api_key: cli.google_maps_api_key, stripe_secret_key: cli.stripe_secret_key, @@ -504,7 +516,7 @@ async fn main() -> anyhow::Result<()> { ) .route( "/api/ai-filters", - post(move |body| routes::post_ai_filters(state_ai_filters.clone(), body)) + post(move |ext, body| routes::post_ai_filters(state_ai_filters.clone(), ext, body)) .layer(ConcurrencyLimitLayer::new(5)), ) .route( diff --git a/server-rs/src/pocketbase.rs b/server-rs/src/pocketbase.rs index ac4aa64..b1a4aa3 100644 --- a/server-rs/src/pocketbase.rs +++ b/server-rs/src/pocketbase.rs @@ -240,9 +240,11 @@ async fn ensure_user_fields( let has_is_admin = fields.iter().any(|f| f["name"] == "is_admin"); let has_subscription = fields.iter().any(|f| f["name"] == "subscription"); let has_newsletter = fields.iter().any(|f| f["name"] == "newsletter"); + let has_ai_tokens_used = fields.iter().any(|f| f["name"] == "ai_tokens_used"); + let has_ai_tokens_week = fields.iter().any(|f| f["name"] == "ai_tokens_week"); - if has_is_admin && has_subscription && has_newsletter { - info!("PocketBase users collection already has is_admin, subscription, and newsletter fields"); + if has_is_admin && has_subscription && has_newsletter && has_ai_tokens_used && has_ai_tokens_week { + info!("PocketBase users collection already has all required fields"); return Ok(()); } @@ -269,6 +271,20 @@ async fn ensure_user_fields( })); } + if !has_ai_tokens_used { + new_fields.push(serde_json::json!({ + "name": "ai_tokens_used", + "type": "number", + })); + } + + if !has_ai_tokens_week { + new_fields.push(serde_json::json!({ + "name": "ai_tokens_week", + "type": "number", + })); + } + let patch_resp = client .patch(&url) .header("Authorization", format!("Bearer {token}")) diff --git a/server-rs/src/routes.rs b/server-rs/src/routes.rs index ae1a416..946e763 100644 --- a/server-rs/src/routes.rs +++ b/server-rs/src/routes.rs @@ -26,7 +26,7 @@ pub(crate) mod travel_time; mod travel_destinations; mod travel_modes; -pub use ai_filters::{build_ollama_schema, build_system_prompt, post_ai_filters}; +pub use ai_filters::{build_system_prompt, post_ai_filters}; pub use checkout::post_checkout; pub use export::get_export; pub use features::{build_features_response, get_features, FeatureInfo, FeaturesResponse}; diff --git a/server-rs/src/routes/ai_filters.rs b/server-rs/src/routes/ai_filters.rs index 8a9d522..b8911b5 100644 --- a/server-rs/src/routes/ai_filters.rs +++ b/server-rs/src/routes/ai_filters.rs @@ -2,76 +2,190 @@ use std::sync::Arc; use axum::http::StatusCode; use axum::response::Json; +use axum::Extension; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use tracing::{info, warn}; -use crate::consts::{AI_FILTERS_MAX_TOKENS, AI_FILTERS_TEMPERATURE}; +use crate::auth::OptionalUser; +use crate::consts::{AI_FILTERS_MAX_TOKENS, AI_FILTERS_TEMPERATURE, AI_FILTERS_WEEKLY_TOKEN_LIMIT}; +use crate::data::slugify; +use crate::pocketbase::auth_superuser; use crate::routes::{FeatureInfo, FeaturesResponse}; use crate::state::AppState; -use crate::utils::{extract_ollama_content, ollama_chat, strip_think_blocks}; +use crate::utils::gemini_chat; + +#[derive(Deserialize)] +pub struct AiFiltersContext { + filters: Value, + #[serde(default)] + travel_time: Vec, +} + +#[derive(Deserialize)] +pub struct AiTravelTimeContext { + mode: String, + label: String, + min: Option, + max: Option, +} #[derive(Deserialize)] pub struct AiFiltersRequest { query: String, + /// Current filters for conversational refinement (e.g. "make it cheaper") + context: Option, +} + +#[derive(Serialize)] +pub struct TravelTimeFilter { + mode: String, + slug: String, + label: String, + #[serde(skip_serializing_if = "Option::is_none")] + min: Option, + #[serde(skip_serializing_if = "Option::is_none")] + max: Option, } #[derive(Serialize)] pub struct AiFiltersResponse { filters: Value, + #[serde(skip_serializing_if = "Vec::is_empty")] + travel_time_filters: Vec, /// What the LLM couldn't map to existing filters (empty if everything matched) #[serde(skip_serializing_if = "String::is_empty")] notes: String, } -/// Build a JSON schema for Ollama structured output. -/// -/// Uses two arrays (`numeric_filters` and `enum_filters`) instead of one property -/// per feature, because Ollama converts JSON schema to GBNF grammar and a schema -/// with 50+ optional keys causes a combinatorial explosion that crashes the parser. -/// Array-based schema keeps the grammar small and constant-size. -pub fn build_ollama_schema(_features: &FeaturesResponse) -> Value { - json!({ - "type": "object", - "properties": { - "numeric_filters": { - "type": "array", - "items": { - "type": "object", - "properties": { - "name": { "type": "string" }, - "bound": { "type": "string", "enum": ["min", "max"] }, - "value": { "type": "number" } - }, - "required": ["name", "bound", "value"] - } - }, - "enum_filters": { - "type": "array", - "items": { - "type": "object", - "properties": { - "name": { "type": "string" }, - "values": { "type": "array", "items": { "type": "string" } } - }, - "required": ["name", "values"] - } - }, - "notes": { - "type": "string" - } +/// Strip markdown code fences (```json ... ``` or ``` ... ```) from LLM output. +/// Models occasionally wrap JSON in markdown fencing even when told not to. +fn strip_markdown_fences(text: &str) -> &str { + let trimmed = text.trim(); + + // Try ```json\n...\n``` or ```\n...\n``` + if let Some(rest) = trimmed.strip_prefix("```") { + // Skip optional language tag (e.g. "json") + let rest = if let Some(newline_pos) = rest.find('\n') { + &rest[newline_pos + 1..] + } else { + return trimmed; + }; + if let Some(content) = rest.strip_suffix("```") { + return content.trim(); } - }) + } + + trimmed +} + +/// Build the Gemini tool declaration for destination search. +fn build_tool_declarations(state: &AppState) -> Value { + let modes: Vec<&str> = state + .travel_time_store + .available_modes + .iter() + .map(|mode| mode.as_str()) + .collect(); + + json!([{ + "functionDeclarations": [{ + "name": "search_destinations", + "description": "Search for available travel time destinations (cities, stations, towns) that have precomputed travel time data. Call this when the user mentions wanting to be near, close to, or within a certain travel time of a specific place.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Place name to search for (e.g. 'Manchester', 'Kings Cross', 'Heathrow')" + }, + "mode": { + "type": "string", + "enum": modes, + "description": "Transport mode to search destinations for" + } + }, + "required": ["query", "mode"] + } + }] + }]) +} + +/// Execute a destination search against PlaceData + TravelTimeStore. +/// Returns matching destinations as a JSON value with `results` and optional `message`. +/// +/// Uses word-based matching: all words in the query must appear somewhere in the +/// place name (order-independent). Also matches against slugs for short queries. +fn execute_destination_search(state: &AppState, query: &str, mode: &str) -> Value { + let query_lower = query.to_lowercase(); + let query_words: Vec<&str> = query_lower.split_whitespace().collect(); + let query_slug = slugify(query); + let tt_store = &state.travel_time_store; + let pd = &state.place_data; + + let slug_set = match tt_store.destinations.get(mode) { + Some(slugs) => slugs, + None => return json!({ "results": [], "message": format!("No travel data available for mode '{}'", mode) }), + }; + + // Find places matching the query that have travel time data. + // A place matches if ALL query words appear in its name, OR its slug matches the query slug. + let mut matches: Vec<(usize, String, u8, u32)> = pd + .name_lower + .iter() + .enumerate() + .filter_map(|(idx, name_lower)| { + let words_match = query_words.iter().all(|word| name_lower.contains(word)); + let slug = slugify(&pd.name[idx]); + let slug_match = slug.contains(&query_slug) || query_slug.contains(&slug); + if !words_match && !slug_match { + return None; + } + if slug_set.contains(&slug) { + Some((idx, slug, pd.type_rank[idx], pd.population[idx])) + } else { + None + } + }) + .collect(); + + // Sort: type rank asc, population desc + matches.sort_unstable_by(|a, b| a.2.cmp(&b.2).then(b.3.cmp(&a.3))); + matches.truncate(10); + + if matches.is_empty() { + info!(query = query, mode = mode, "Destination search returned no results"); + return json!({ + "results": [], + "message": format!("No travel time data available for '{}' by {}. This destination cannot be used as a travel time filter.", query, mode) + }); + } + + let results: Vec = matches + .into_iter() + .map(|(idx, slug, ..)| { + json!({ + "name": pd.name[idx], + "slug": slug, + "place_type": pd.place_type.get(idx).to_string(), + }) + }) + .collect(); + + json!({ "results": results }) } /// Build the complete system prompt for AI filters. /// -/// Contains: role instructions, feature catalogue, few-shot examples, output rules. +/// Contains: role instructions, feature catalogue, travel time info, +/// few-shot examples, output rules. /// Precomputed at startup and cached in AppState. -pub fn build_system_prompt(features: &FeaturesResponse) -> String { +pub fn build_system_prompt( + features: &FeaturesResponse, + mode_destinations: &[(String, usize)], +) -> String { let mut parts = Vec::new(); - // Role and task description parts.push( "You are a UK property search assistant. \ The user describes their ideal property or area in natural language. \ @@ -91,10 +205,61 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String { (note: this counts bedrooms + living rooms combined, so 3 bed ~ min 4).\n\ - If the user mentions something that has no matching filter, put it in \"notes\" \ as a short phrase (e.g. \"No filter for: garden, sea view\"). \ - If everything was matched, set \"notes\" to an empty string." + If everything was matched, set \"notes\" to an empty string.\n\ + \n\ + CONVERSATIONAL REFINEMENT:\n\ + The user's message may include their currently active filters as context. \ + When context is provided:\n\ + - \"make it cheaper\" / \"lower the price\" = adjust the existing price filter down\n\ + - \"also add ...\" / \"and good schools\" = keep existing filters and add new ones\n\ + - \"remove the ...\" / \"drop the ...\" = return filters WITHOUT the mentioned one\n\ + - If the request is a completely new search (not a refinement), ignore the context \ + and build filters from scratch.\n\ + - Always output the COMPLETE set of filters (existing + modified), not just the changes." .to_string(), ); + // Travel time section with available modes + let modes_list = mode_destinations + .iter() + .map(|(mode, count)| format!("- {} ({} destinations available)", mode, count)) + .collect::>() + .join("\n"); + + parts.push(format!( + "\n--- TRAVEL TIME FILTERS ---\n\ + You can add travel time filters when the user mentions commute times, \ + proximity to places, or wanting to be near/within X minutes of somewhere.\n\ + \n\ + Available transport modes (only use modes that have destinations):\n\ + {}\n\ + - \"car\" / \"drive\" / \"driving\" = car mode\n\ + - \"cycle\" / \"bike\" / \"cycling\" = bicycle mode\n\ + - \"walk\" / \"walking\" / \"on foot\" = walking mode\n\ + - \"train\" / \"tube\" / \"bus\" / \"public transport\" / \"commute\" = transit mode\n\ + \n\ + When the user mentions a specific place, you MUST call the search_destinations \ + tool to find the exact slug. Use the name and slug from the search results.\n\ + If search_destinations returns an empty array, the destination is not available — \ + mention it in \"notes\" (e.g. \"No travel data for: Gatwick Airport\") and do NOT \ + include a travel_time_filter for it.\n\ + \n\ + Travel time values are in MINUTES (0-120 range).\n\ + - \"within 30 minutes\" = max 30\n\ + - \"at least 10 minutes\" = min 10\n\ + - \"30-45 minute commute\" = min 30, max 45\n\ + - If only a max is given, omit min (and vice versa).\n\ + \n\ + INFERRING TRANSPORT MODE (when the user does not specify one explicitly):\n\ + - \"commute\" to a major city centre or station = transit\n\ + - \"near\" / \"close to\" a city centre or station = transit\n\ + - \"near\" / \"close to\" a smaller town, village, or rural area = car\n\ + - \"drive\" / \"driving distance\" / \"driving time\" = always car\n\ + - If multiple modes are plausible, prefer transit for urban destinations \ + (London, Manchester, Birmingham, Leeds, etc.) and car for everything else.", + modes_list, + )); + // Feature catalogue parts.push("\n--- AVAILABLE FEATURES ---\n".to_string()); for group in &features.groups { @@ -148,6 +313,7 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String { Output: {\"numeric_filters\": [{\"name\": \"Last known price\", \"bound\": \"max\", \"value\": 400000}], \ \"enum_filters\": [{\"name\": \"Leasehold/Freehold\", \"values\": [\"Freehold\"]}, \ {\"name\": \"Property type\", \"values\": [\"Detached\", \"Semi-Detached\", \"Terraced\"]}], \ + \"travel_time_filters\": [], \ \"notes\": \"\"}" .to_string(), ); @@ -161,7 +327,7 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String { {\"name\": \"Good+ primary schools within 5km\", \"bound\": \"min\", \"value\": 5}, \ {\"name\": \"Good+ secondary schools within 5km\", \"bound\": \"min\", \"value\": 2}, \ {\"name\": \"Number of parks within 2km\", \"bound\": \"min\", \"value\": 3}], \ - \"enum_filters\": [], \"notes\": \"\"}" + \"enum_filters\": [], \"travel_time_filters\": [], \"notes\": \"\"}" .to_string(), ); @@ -172,18 +338,37 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String { {\"name\": \"Number of bedrooms & living rooms\", \"bound\": \"min\", \"value\": 4}], \ \"enum_filters\": [{\"name\": \"Property type\", \"values\": [\"Flats/Maisonettes\"]}, \ {\"name\": \"Max available download speed (Mbps)\", \"values\": [\"100\", \"300\", \"1000\"]}], \ + \"travel_time_filters\": [], \ \"notes\": \"No filter for: beach proximity\"}" .to_string(), ); parts.push( - "\nUser: \"large family home with a garden near restaurants\"\n\ + "\nUser: \"within 30 minutes commute of Kings Cross, under 500k\"\n\ + (After calling search_destinations for \"Kings Cross\" with mode \"transit\" \ + and getting [{\"name\": \"Kings Cross\", \"slug\": \"kings-cross\", \"place_type\": \"station\"}])\n\ + Output: {\"numeric_filters\": [\ + {\"name\": \"Last known price\", \"bound\": \"max\", \"value\": 500000}], \ + \"enum_filters\": [], \ + \"travel_time_filters\": [{\"mode\": \"transit\", \"slug\": \"kings-cross\", \ + \"label\": \"Kings Cross\", \"bound\": \"max\", \"value\": 30}], \ + \"notes\": \"\"}" + .to_string(), + ); + + parts.push( + "\nUser: \"family home with garden, 45 min drive from Manchester, good schools\"\n\ + (After calling search_destinations for \"Manchester\" with mode \"car\" \ + and getting [{\"name\": \"Manchester\", \"slug\": \"manchester\", \"place_type\": \"city\"}])\n\ Output: {\"numeric_filters\": [\ {\"name\": \"Total floor area (sqm)\", \"bound\": \"min\", \"value\": 100}, \ {\"name\": \"Number of bedrooms & living rooms\", \"bound\": \"min\", \"value\": 5}, \ - {\"name\": \"Number of restaurants within 2km\", \"bound\": \"min\", \"value\": 10}], \ + {\"name\": \"Good+ primary schools within 5km\", \"bound\": \"min\", \"value\": 5}, \ + {\"name\": \"Good+ secondary schools within 5km\", \"bound\": \"min\", \"value\": 2}], \ \"enum_filters\": [{\"name\": \"Property type\", \ \"values\": [\"Detached\", \"Semi-Detached\"]}], \ + \"travel_time_filters\": [{\"mode\": \"car\", \"slug\": \"manchester\", \ + \"label\": \"Manchester\", \"bound\": \"max\", \"value\": 45}], \ \"notes\": \"No filter for: garden\"}" .to_string(), ); @@ -191,7 +376,8 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String { // Output format reminder parts.push( "\n--- OUTPUT FORMAT ---\n\ - {\"numeric_filters\": [{\"name\": \"...\", \"bound\": \"min\"|\"max\", \"value\": N}, ...], \"enum_filters\": [...], \"notes\": \"...\"}\n\ + {\"numeric_filters\": [...], \"enum_filters\": [...], \"travel_time_filters\": [{\"mode\": \"...\", \"slug\": \"...\", \"label\": \"...\", \"bound\": \"min\"|\"max\", \"value\": N}, ...], \"notes\": \"...\"}\n\ + - travel_time_filters: use ONLY slugs returned by search_destinations. If a place isn't found, mention it in notes.\n\ Respond with ONLY the JSON object. No explanation." .to_string(), ); @@ -199,86 +385,393 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String { parts.join("\n") } -pub async fn post_ai_filters( - state: Arc, - Json(req): Json, -) -> Result, (StatusCode, String)> { - info!(query = %req.query, "POST /api/ai-filters"); - - let url = format!("{}/api/chat", state.ollama_url); - let body = json!({ - "model": state.ollama_model, - "messages": [ - { "role": "system", "content": state.ai_filters_system_prompt }, - { "role": "user", "content": req.query } - ], - "stream": false, - "format": state.ai_filters_schema, - "options": { - "temperature": AI_FILTERS_TEMPERATURE, - "num_predict": AI_FILTERS_MAX_TOKENS, - } - }); - - // Try up to 2 attempts — LLMs occasionally return empty content (e.g. only - // blocks with no JSON output), which is transient and usually - // succeeds on retry. - let mut last_err = None; - for attempt in 0..2 { - let raw = call_ollama_and_parse(&state.http_client, &url, &body).await; - match raw { - Ok(raw) => { - let filters = validate_and_convert(&raw, &state.features_response); - let notes = raw - .get("notes") - .and_then(|val| val.as_str()) - .unwrap_or("") - .to_string(); - return Ok(Json(AiFiltersResponse { filters, notes })); - } - Err(err) => { - if attempt == 0 { - warn!("LLM attempt 1 failed, retrying: {}", err.1); - } - last_err = Some(err); - } - } - } - - Err(last_err.unwrap()) +/// Monotonically increasing week number derived from Unix epoch. +/// Resets every 7 days (604800 seconds). Used for weekly rate limiting. +fn current_week_number() -> u64 { + let secs = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("system time before epoch") + .as_secs(); + secs / 604_800 } -/// Call Ollama and parse the response content as JSON. -/// -/// Returns an error if: the HTTP call fails, the response is malformed, -/// the content is empty after stripping think blocks, or the content is -/// not valid JSON. -async fn call_ollama_and_parse( - client: &reqwest::Client, - url: &str, - body: &Value, -) -> Result { - let json_resp = ollama_chat(client, url, body).await?; - let content = extract_ollama_content(&json_resp)?; +/// Fetch the user's current AI token usage from PocketBase. +/// Returns `(tokens_used, week_number)`. +async fn fetch_ai_usage( + state: &AppState, + user_id: &str, +) -> Result<(u64, u64), (StatusCode, String)> { + let pb_url = state.pocketbase_url.trim_end_matches('/'); + let token = auth_superuser( + &state.http_client, + pb_url, + &state.pocketbase_admin_email, + &state.pocketbase_admin_password, + ) + .await + .map_err(|err| { + warn!("Failed to auth superuser for AI usage check: {err}"); + (StatusCode::BAD_GATEWAY, "Internal error".into()) + })?; - let content = strip_think_blocks(content); - let content = content.trim(); + let url = format!("{pb_url}/api/collections/users/records/{user_id}"); + let resp = state + .http_client + .get(&url) + .header("Authorization", format!("Bearer {token}")) + .send() + .await + .map_err(|err| { + warn!("Failed to fetch user record for AI usage: {err}"); + (StatusCode::BAD_GATEWAY, "Internal error".into()) + })?; - if content.is_empty() { - warn!("LLM returned empty content after stripping think blocks"); + if !resp.status().is_success() { + let status = resp.status(); + warn!("PocketBase user fetch failed ({status})"); + return Err((StatusCode::BAD_GATEWAY, "Internal error".into())); + } + + let body: Value = resp.json().await.map_err(|err| { + warn!("Failed to parse user record: {err}"); + (StatusCode::BAD_GATEWAY, "Internal error".into()) + })?; + + let tokens_used = body + .get("ai_tokens_used") + .and_then(|val| val.as_u64()) + .unwrap_or(0); + let week = body + .get("ai_tokens_week") + .and_then(|val| val.as_u64()) + .unwrap_or(0); + + Ok((tokens_used, week)) +} + +/// Update the user's AI token usage in PocketBase. +/// Best-effort — logs warnings on failure but does not propagate errors. +async fn update_ai_usage(state: &AppState, user_id: &str, tokens_used: u64, week: u64) { + let pb_url = state.pocketbase_url.trim_end_matches('/'); + let token = match auth_superuser( + &state.http_client, + pb_url, + &state.pocketbase_admin_email, + &state.pocketbase_admin_password, + ) + .await + { + Ok(tk) => tk, + Err(err) => { + warn!("Failed to auth superuser for AI usage update: {err}"); + return; + } + }; + + let url = format!("{pb_url}/api/collections/users/records/{user_id}"); + let res = state + .http_client + .patch(&url) + .header("Authorization", format!("Bearer {token}")) + .json(&json!({ + "ai_tokens_used": tokens_used, + "ai_tokens_week": week, + })) + .send() + .await; + + match res { + Ok(resp) if resp.status().is_success() => {} + Ok(resp) => { + let status = resp.status(); + warn!("Failed to update AI usage ({status})"); + } + Err(err) => warn!("Failed to update AI usage: {err}"), + } +} + +/// Maximum number of round trips (function calls + retries) before giving up. +const MAX_TOOL_ROUNDS: usize = 5; + +pub async fn post_ai_filters( + state: Arc, + Extension(user): Extension, + Json(req): Json, +) -> Result, (StatusCode, String)> { + // Auth check + let user = user + .0 + .ok_or((StatusCode::UNAUTHORIZED, "Login required".into()))?; + + // Email verification check (skipped in dev mode) + if !user.verified && !state.is_dev { return Err(( - StatusCode::BAD_GATEWAY, - "LLM returned empty content (no JSON output)".into(), + StatusCode::FORBIDDEN, + "Please verify your email to use AI filters".into(), )); } - serde_json::from_str(content).map_err(|err| { - warn!(error = %err, content = %content, "Failed to parse LLM JSON output"); - ( - StatusCode::BAD_GATEWAY, - format!("Failed to parse LLM output as JSON: {}", err), + // Check weekly token usage + let current_week = current_week_number(); + let (stored_tokens, stored_week) = fetch_ai_usage(&state, &user.id).await?; + let tokens_used = if stored_week == current_week { + stored_tokens + } else { + 0 + }; + + if tokens_used >= AI_FILTERS_WEEKLY_TOKEN_LIMIT { + return Err(( + StatusCode::TOO_MANY_REQUESTS, + "Weekly AI usage limit reached. Resets next week.".into(), + )); + } + + info!(query = %req.query, user_id = %user.id, "POST /api/ai-filters"); + + let tools = build_tool_declarations(&state); + + // Build user message with optional context for conversational refinement + let user_text = if let Some(ref ctx) = req.context { + let mut msg = String::new(); + msg.push_str("Currently active filters:\n"); + msg.push_str(&serde_json::to_string(&ctx.filters).unwrap_or_default()); + if !ctx.travel_time.is_empty() { + msg.push_str("\nCurrently active travel time filters:\n"); + for tt in &ctx.travel_time { + let bounds = match (tt.min, tt.max) { + (Some(min), Some(max)) => format!("{}-{} min", min, max), + (Some(min), None) => format!("min {} min", min), + (None, Some(max)) => format!("max {} min", max), + (None, None) => "no range".to_string(), + }; + msg.push_str(&format!("- {} to {} ({})\n", tt.mode, tt.label, bounds)); + } + } + msg.push_str(&format!("\nUser request: {}", req.query)); + msg + } else { + req.query.clone() + }; + + let mut contents = vec![json!({ + "role": "user", + "parts": [{ "text": user_text }] + })]; + + let mut total_tokens_accumulated: u64 = 0; + + // Function calling loop: model may call search_destinations, we execute and feed back + for round in 0..MAX_TOOL_ROUNDS { + let body = json!({ + "systemInstruction": { + "parts": [{ "text": state.ai_filters_system_prompt }] + }, + "contents": contents, + "tools": tools, + "generationConfig": { + "temperature": AI_FILTERS_TEMPERATURE, + "maxOutputTokens": AI_FILTERS_MAX_TOKENS, + "thinkingConfig": { "thinkingLevel": "LOW" }, + } + }); + + let json_resp = gemini_chat( + &state.http_client, + &state.gemini_api_key, + &state.gemini_model, + &body, ) - }) + .await?; + + // Accumulate token usage + total_tokens_accumulated += json_resp + .get("usageMetadata") + .and_then(|md| md.get("totalTokenCount")) + .and_then(|tc| tc.as_u64()) + .unwrap_or(0); + + let candidate = json_resp + .get("candidates") + .and_then(|cs| cs.get(0)) + .and_then(|c| c.get("content")) + .ok_or_else(|| { + warn!("Malformed Gemini response: missing candidates[0].content"); + ( + StatusCode::BAD_GATEWAY, + "Malformed Gemini response".into(), + ) + })?; + + let parts = candidate + .get("parts") + .and_then(|p| p.as_array()) + .ok_or_else(|| { + warn!("Malformed Gemini response: missing parts array"); + ( + StatusCode::BAD_GATEWAY, + "Malformed Gemini response".into(), + ) + })?; + + // Check if the model made a function call. + // Find the full part (includes thoughtSignature required by Gemini 3 models). + if let Some(fc_part) = parts.iter().find(|part| part.get("functionCall").is_some()) { + let fc = fc_part.get("functionCall").unwrap(); + let fn_name = fc.get("name").and_then(|n| n.as_str()).unwrap_or(""); + let fn_args = fc.get("args").cloned().unwrap_or(json!({})); + + info!( + function = fn_name, + round = round, + "AI called tool" + ); + + let fn_result = if fn_name == "search_destinations" { + let query = fn_args + .get("query") + .and_then(|q| q.as_str()) + .unwrap_or(""); + let mode = fn_args + .get("mode") + .and_then(|m| m.as_str()) + .unwrap_or("transit"); + execute_destination_search(&state, query, mode) + } else { + json!({"error": "unknown function"}) + }; + + // Append the model's full response (preserves thoughtSignature) + our function result + contents.push(candidate.clone()); + contents.push(json!({ + "role": "user", + "parts": [{ + "functionResponse": { + "name": fn_name, + "response": { "results": fn_result } + } + }] + })); + + // Continue the loop — model will process the results + continue; + } + + // Model returned text — extract and parse as JSON + let text = parts + .iter() + .find_map(|part| part.get("text").and_then(|t| t.as_str())) + .unwrap_or(""); + let text = strip_markdown_fences(text); + let text = text.trim(); + + if text.is_empty() { + warn!("Gemini returned empty text content (round {})", round); + // Retry by continuing the loop + contents.push(candidate.clone()); + contents.push(json!({ + "role": "user", + "parts": [{ "text": "Your response was empty. Please output the JSON object." }] + })); + continue; + } + + let raw: Value = match serde_json::from_str(text) { + Ok(val) => val, + Err(err) => { + warn!(error = %err, round = round, "Failed to parse Gemini JSON output, retrying"); + // Ask the model to fix its output + contents.push(candidate.clone()); + contents.push(json!({ + "role": "user", + "parts": [{ "text": "That was not valid JSON. Please output ONLY the JSON object with numeric_filters, enum_filters, travel_time_filters, and notes." }] + })); + continue; + } + }; + + let filters = validate_and_convert(&raw, &state.features_response); + let travel_time_filters = validate_travel_time_filters(&raw, &state); + let notes = raw + .get("notes") + .and_then(|val| val.as_str()) + .unwrap_or("") + .to_string(); + + // Update usage with total accumulated tokens + let new_total = tokens_used + total_tokens_accumulated; + update_ai_usage(&state, &user.id, new_total, current_week).await; + + return Ok(Json(AiFiltersResponse { + filters, + travel_time_filters, + notes, + })); + } + + // Exhausted tool rounds without getting a final text response + warn!("AI exhausted {} tool-calling rounds without final response", MAX_TOOL_ROUNDS); + Err(( + StatusCode::BAD_GATEWAY, + "AI could not complete the request".into(), + )) +} + +/// Validate travel time filters from LLM output against available destinations. +fn validate_travel_time_filters(raw: &Value, state: &AppState) -> Vec { + let arr = match raw + .get("travel_time_filters") + .and_then(|val| val.as_array()) + { + Some(arr) => arr, + None => return Vec::new(), + }; + + let tt_store = &state.travel_time_store; + let mut results = Vec::new(); + + for item in arr { + let mode = match item.get("mode").and_then(|val| val.as_str()) { + Some(mode) => mode, + None => continue, + }; + let slug = match item.get("slug").and_then(|val| val.as_str()) { + Some(slug) => slug, + None => continue, + }; + let label = item + .get("label") + .and_then(|val| val.as_str()) + .unwrap_or(slug); + + // Verify this destination actually exists + if !tt_store.has_destination(mode, slug) { + warn!(mode = mode, slug = slug, "AI suggested non-existent destination"); + continue; + } + + let bound = item.get("bound").and_then(|val| val.as_str()); + let value = item.get("value").and_then(|val| val.as_f64()); + + let (min, max) = match (bound, value) { + (Some("min"), Some(val)) => (Some(val.max(0.0).min(120.0) as f32), None), + (Some("max"), Some(val)) => (None, Some(val.max(0.0).min(120.0) as f32)), + _ => (None, None), + }; + + // Only include if at least one bound is set + if min.is_some() || max.is_some() { + results.push(TravelTimeFilter { + mode: mode.to_string(), + slug: slug.to_string(), + label: label.to_string(), + min, + max, + }); + } + } + + results } /// Validate LLM output against feature metadata and convert to FeatureFilters format. @@ -374,3 +867,32 @@ fn validate_and_convert(raw: &Value, features: &FeaturesResponse) -> Value { Value::Object(result) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn strip_fences_json_tag() { + let input = "```json\n{\"a\": 1}\n```"; + assert_eq!(strip_markdown_fences(input), "{\"a\": 1}"); + } + + #[test] + fn strip_fences_no_tag() { + let input = "```\n{\"a\": 1}\n```"; + assert_eq!(strip_markdown_fences(input), "{\"a\": 1}"); + } + + #[test] + fn strip_fences_passthrough() { + let input = "{\"a\": 1}"; + assert_eq!(strip_markdown_fences(input), "{\"a\": 1}"); + } + + #[test] + fn strip_fences_whitespace() { + let input = " ```json\n {\"a\": 1} \n``` "; + assert_eq!(strip_markdown_fences(input), "{\"a\": 1}"); + } +} diff --git a/server-rs/src/routes/invites.rs b/server-rs/src/routes/invites.rs index d04ed73..1d773ae 100644 --- a/server-rs/src/routes/invites.rs +++ b/server-rs/src/routes/invites.rs @@ -178,7 +178,7 @@ pub async fn get_invite( } // Dev-only: return a fake valid admin invite without hitting PocketBase - if state.index_html.is_none() && code == DEV_INVITE_CODE { + if state.is_dev && code == DEV_INVITE_CODE { return Json(InviteValidation { valid: true, invite_type: "admin".to_string(), @@ -294,7 +294,7 @@ pub async fn post_redeem_invite( } // Dev-only: fake redeem — just return "licensed" without touching PocketBase - if state.index_html.is_none() && req.code == DEV_INVITE_CODE { + if state.is_dev && req.code == DEV_INVITE_CODE { info!(user_id = %user.id, "Dev invite redeemed (no-op)"); return Json(RedeemResponse { result: "licensed".to_string(), diff --git a/server-rs/src/state.rs b/server-rs/src/state.rs index 4130a14..99ebd58 100644 --- a/server-rs/src/state.rs +++ b/server-rs/src/state.rs @@ -34,6 +34,8 @@ pub struct AppState { pub screenshot_url: String, /// Public-facing URL for absolute og:image URLs (e.g. https://perfectpostcodes.schmelczer.dev) pub public_url: String, + /// True when --dist is not provided (no static serving, relaxed auth checks) + pub is_dev: bool, /// Contents of index.html read at startup, used for crawler OG injection (None when --dist omitted) pub index_html: Option, /// Shared HTTP client for proxying to the screenshot service and PocketBase @@ -44,16 +46,14 @@ pub struct AppState { pub pocketbase_admin_email: String, /// PocketBase superuser password pub pocketbase_admin_password: String, - /// Ollama server URL for AI area summaries (e.g. http://ollama:11434) - pub ollama_url: String, - /// Ollama model name for area summaries (e.g. gemma3:12b) - pub ollama_model: String, + /// Gemini API key for AI filters + pub gemini_api_key: String, + /// Gemini model name (e.g. gemini-2.0-flash) + pub gemini_model: String, /// Precomputed travel time data store pub travel_time_store: Arc, /// Token validation cache (60s TTL) pub token_cache: Arc, - /// JSON schema for Ollama structured output in AI filters - pub ai_filters_schema: serde_json::Value, /// Complete system prompt for AI filters (features + examples + instructions) pub ai_filters_system_prompt: String, /// Google Maps API key for Street View metadata lookups diff --git a/server-rs/src/utils.rs b/server-rs/src/utils.rs index 6961b84..2dcc1ec 100644 --- a/server-rs/src/utils.rs +++ b/server-rs/src/utils.rs @@ -6,7 +6,7 @@ mod llm; pub use grid_index::GridIndex; pub use hash::{generate_priorities, splitmix64_hash}; pub use interned_column::InternedColumn; -pub use llm::{extract_ollama_content, ollama_chat, strip_think_blocks}; +pub use llm::{extract_gemini_content, gemini_chat}; /// Normalize a UK postcode: uppercase, strip spaces, insert canonical space before inward code. /// e.g. "e142dg" → "E14 2DG", "E14 2DG" → "E14 2DG", "EC1A1BB" → "EC1A 1BB" diff --git a/server-rs/src/utils/llm.rs b/server-rs/src/utils/llm.rs index 8625ded..58a74ae 100644 --- a/server-rs/src/utils/llm.rs +++ b/server-rs/src/utils/llm.rs @@ -4,68 +4,62 @@ use tracing::warn; pub type LlmError = (StatusCode, String); -/// Send a chat request to Ollama and return the parsed JSON response. +/// Send a generateContent request to the Gemini API and return the parsed JSON response. /// /// Handles connection errors, non-success status codes, and JSON parse failures /// uniformly as `BAD_GATEWAY` errors. -pub async fn ollama_chat( +pub async fn gemini_chat( client: &reqwest::Client, - url: &str, + api_key: &str, + model: &str, body: &Value, ) -> Result { - let response = client.post(url).json(body).send().await.map_err(|err| { - warn!(error = %err, "Failed to connect to Ollama"); + let url = format!( + "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}", + model, api_key + ); + + let response = client.post(&url).json(body).send().await.map_err(|err| { + warn!(error = %err, "Failed to connect to Gemini API"); ( StatusCode::BAD_GATEWAY, - format!("Failed to connect to Ollama: {}", err), + format!("Failed to connect to Gemini API: {}", err), ) })?; if !response.status().is_success() { let status = response.status(); let body_text = response.text().await.unwrap_or_default(); - warn!(status = %status, body = %body_text, "Ollama returned error"); + warn!(status = %status, body = %body_text, "Gemini API returned error"); return Err(( StatusCode::BAD_GATEWAY, - format!("Ollama error {}: {}", status, body_text), + format!("Gemini API error {}: {}", status, body_text), )); } response.json().await.map_err(|err| { - warn!(error = %err, "Failed to parse Ollama response"); + warn!(error = %err, "Failed to parse Gemini API response"); ( StatusCode::BAD_GATEWAY, - format!("Failed to parse Ollama response: {}", err), + format!("Failed to parse Gemini API response: {}", err), ) }) } -/// Extract content from Ollama native response (`message.content`) -pub fn extract_ollama_content(json: &Value) -> Result<&str, LlmError> { - json.get("message") - .and_then(|msg| msg.get("content")) - .and_then(|ct| ct.as_str()) +/// Extract text content from Gemini response (`candidates[0].content.parts[0].text`) +pub fn extract_gemini_content(json: &Value) -> Result<&str, LlmError> { + json.get("candidates") + .and_then(|c| c.get(0)) + .and_then(|c| c.get("content")) + .and_then(|c| c.get("parts")) + .and_then(|p| p.get(0)) + .and_then(|p| p.get("text")) + .and_then(|t| t.as_str()) .ok_or_else(|| { - warn!("Malformed Ollama response: missing message.content"); + warn!("Malformed Gemini response: missing candidates[0].content.parts[0].text"); ( StatusCode::BAD_GATEWAY, - "Malformed LLM response: missing message.content".into(), + "Malformed Gemini response: missing content".into(), ) }) } - -/// Strip `...` blocks from model output -pub fn strip_think_blocks(text: &str) -> String { - let mut result = String::new(); - let mut remaining = text; - while let Some(start) = remaining.find("") { - result.push_str(&remaining[..start]); - if let Some(end) = remaining[start..].find("") { - remaining = &remaining[start + end + 8..]; - } else { - return result; - } - } - result.push_str(remaining); - result -}