Improve LLM

This commit is contained in:
Andras Schmelczer 2026-03-15 14:05:34 +00:00
parent 02712f41e8
commit 80c093b7ba
16 changed files with 898 additions and 278 deletions

View file

@ -21,6 +21,15 @@ jobs:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Set up uv
uses: astral-sh/setup-uv@v4
- name: Download map assets (fonts, sprites, twemoji)
run: uv run python -m pipeline.download.map_assets --output frontend/public/assets
- name: Download arcgis data for finder
run: uv run python -m pipeline.download.arcgis --output property-data/arcgis_data.parquet
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
@ -64,12 +73,6 @@ jobs:
cache-from: type=gha,scope=screenshot cache-from: type=gha,scope=screenshot
cache-to: type=gha,mode=max,scope=screenshot cache-to: type=gha,mode=max,scope=screenshot
- name: Set up uv
uses: astral-sh/setup-uv@v4
- name: Download arcgis data for finder
run: uv run python -m pipeline.download.arcgis --output property-data/arcgis_data.parquet
- name: Build and push finder service - name: Build and push finder service
uses: docker/build-push-action@v6 uses: docker/build-push-action@v6
with: with:

View file

@ -36,8 +36,8 @@ services:
POCKETBASE_ADMIN_EMAIL: *pb-email POCKETBASE_ADMIN_EMAIL: *pb-email
POCKETBASE_ADMIN_PASSWORD: *pb-password POCKETBASE_ADMIN_PASSWORD: *pb-password
SCREENSHOT_URL: http://screenshot:8002 SCREENSHOT_URL: http://screenshot:8002
OLLAMA_URL: http://host.docker.internal:11434 GEMINI_API_KEY: AIzaSyC2mQDcEwILHM3uOE2C-lxUQbQrKTX9Xi4
OLLAMA_MODEL: gpt-oss:20b GEMINI_MODEL: gemini-3-flash-preview
PUBLIC_URL: https://perfectpostcodes.schmelczer.dev PUBLIC_URL: https://perfectpostcodes.schmelczer.dev
GOOGLE_MAPS_API_KEY: "AIzaSyBgBn9LjrxHCjb9j1LZbLYpEdCJj-NkHPY" GOOGLE_MAPS_API_KEY: "AIzaSyBgBn9LjrxHCjb9j1LZbLYpEdCJj-NkHPY"
STRIPE_SECRET_KEY: sk_test_51SyVcePRjj2bdyn1HLkatQ5onwp8kamm41tjMcRdxXnJYWVPsVd9usMTOSNtNdGhrjbsrtNbgTdKXICg2qBiocEn00PvNDC0d3 STRIPE_SECRET_KEY: sk_test_51SyVcePRjj2bdyn1HLkatQ5onwp8kamm41tjMcRdxXnJYWVPsVd9usMTOSNtNdGhrjbsrtNbgTdKXICg2qBiocEn00PvNDC0d3

View file

@ -42,6 +42,7 @@ interface AiFilterInputProps {
error: string | null; error: string | null;
errorType: AiFilterErrorType | null; errorType: AiFilterErrorType | null;
notes: string | null; notes: string | null;
summary: string | null;
onSubmit: (query: string) => void; onSubmit: (query: string) => void;
isLoggedIn: boolean; isLoggedIn: boolean;
onLoginRequired: () => void; onLoginRequired: () => void;
@ -52,6 +53,7 @@ export default memo(function AiFilterInput({
error, error,
errorType, errorType,
notes, notes,
summary,
onSubmit, onSubmit,
isLoggedIn, isLoggedIn,
onLoginRequired, onLoginRequired,
@ -88,7 +90,7 @@ export default memo(function AiFilterInput({
); );
const hasContent = query.trim().length > 0; const hasContent = query.trim().length > 0;
const showExamples = expanded && !hasContent && !loading && !error && !notes; const showExamples = expanded && !hasContent && !loading && !error && !notes && !summary;
if (!expanded) { if (!expanded) {
return ( return (
@ -173,6 +175,11 @@ export default memo(function AiFilterInput({
{error} {error}
</p> </p>
)} )}
{summary && !error && !loading && (
<p className="mt-1 text-xs text-teal-600 dark:text-teal-400">
{summary}
</p>
)}
{notes && !error && !loading && ( {notes && !error && !loading && (
<p className="mt-1 text-xs text-warm-500 dark:text-warm-400 italic"> <p className="mt-1 text-xs text-warm-500 dark:text-warm-400 italic">
{notes} {notes}

View file

@ -92,6 +92,7 @@ interface FiltersProps {
aiFilterError: string | null; aiFilterError: string | null;
aiFilterErrorType: AiFilterErrorType | null; aiFilterErrorType: AiFilterErrorType | null;
aiFilterNotes: string | null; aiFilterNotes: string | null;
aiFilterSummary: string | null;
onAiFilterSubmit: (query: string) => void; onAiFilterSubmit: (query: string) => void;
isLoggedIn: boolean; isLoggedIn: boolean;
onLoginRequired: () => void; onLoginRequired: () => void;
@ -127,6 +128,7 @@ export default memo(function Filters({
aiFilterError, aiFilterError,
aiFilterErrorType, aiFilterErrorType,
aiFilterNotes, aiFilterNotes,
aiFilterSummary,
onAiFilterSubmit, onAiFilterSubmit,
isLoggedIn, isLoggedIn,
onLoginRequired, onLoginRequired,
@ -285,7 +287,7 @@ export default memo(function Filters({
</div> </div>
<div ref={scrollRef} className="md:flex-1 md:overflow-y-auto"> <div ref={scrollRef} className="md:flex-1 md:overflow-y-auto">
<AiFilterInput loading={aiFilterLoading} error={aiFilterError} errorType={aiFilterErrorType} notes={aiFilterNotes} onSubmit={onAiFilterSubmit} isLoggedIn={isLoggedIn} onLoginRequired={onLoginRequired} /> <AiFilterInput loading={aiFilterLoading} error={aiFilterError} errorType={aiFilterErrorType} notes={aiFilterNotes} summary={aiFilterSummary} onSubmit={onAiFilterSubmit} isLoggedIn={isLoggedIn} onLoginRequired={onLoginRequired} />
<div className="px-3 pb-2 space-y-2"> <div className="px-3 pb-2 space-y-2">
<div className="flex rounded-lg bg-warm-100 dark:bg-warm-800 p-0.5"> <div className="flex rounded-lg bg-warm-100 dark:bg-warm-800 p-0.5">
{(['historical', 'buy', 'rent'] as const).map((type) => { {(['historical', 'buy', 'rent'] as const).map((type) => {

View file

@ -148,22 +148,33 @@ export default function MapPage({
const handleAiFilterSubmit = useCallback( const handleAiFilterSubmit = useCallback(
async (query: string) => { async (query: string) => {
const result = await aiFilters.fetchAiFilters(query); // Build context from current filters for conversational refinement
const context = {
filters,
travelTime: travelTime.activeEntries.map((entry) => ({
mode: entry.mode,
label: entry.label,
min: entry.timeRange?.[0],
max: entry.timeRange?.[1],
})),
};
const hasContext =
Object.keys(context.filters).length > 0 || context.travelTime.length > 0;
const result = await aiFilters.fetchAiFilters(query, hasContext ? context : undefined);
if (!result) return; if (!result) return;
handleSetFilters(result.filters); handleSetFilters(result.filters);
// Apply travel time filters from AI // Always sync travel time entries — clear stale ones when AI returns none
if (result.travelTimeFilters.length > 0) { const newEntries = result.travelTimeFilters.map((tt) => ({
const newEntries = result.travelTimeFilters.map((tt) => ({ mode: tt.mode,
mode: tt.mode, slug: tt.slug,
slug: tt.slug, label: tt.label,
label: tt.label, timeRange: [tt.min ?? 0, tt.max ?? 120] as [number, number],
timeRange: [tt.min ?? 0, tt.max ?? 120] as [number, number], useBest: false,
useBest: false, }));
})); travelTime.handleSetEntries(newEntries);
travelTime.handleSetEntries(newEntries);
}
}, },
[aiFilters.fetchAiFilters, handleSetFilters, travelTime.handleSetEntries] [aiFilters.fetchAiFilters, handleSetFilters, travelTime.handleSetEntries, travelTime.activeEntries, filters]
); );
const handleTravelTimeSetDestination = useCallback( const handleTravelTimeSetDestination = useCallback(
@ -514,6 +525,7 @@ export default function MapPage({
aiFilterError={aiFilters.error} aiFilterError={aiFilters.error}
aiFilterErrorType={aiFilters.errorType} aiFilterErrorType={aiFilters.errorType}
aiFilterNotes={aiFilters.notes} aiFilterNotes={aiFilters.notes}
aiFilterSummary={aiFilters.summary}
onAiFilterSubmit={handleAiFilterSubmit} onAiFilterSubmit={handleAiFilterSubmit}
isLoggedIn={!!user} isLoggedIn={!!user}
onLoginRequired={onRegisterClick ?? (() => {})} onLoginRequired={onRegisterClick ?? (() => {})}

View file

@ -1,6 +1,6 @@
import { useState, useCallback, useRef } from 'react'; import { useState, useCallback, useRef } from 'react';
import type { FeatureFilters } from '../types'; import type { FeatureFilters } from '../types';
import type { TransportMode } from './useTravelTime'; import type { TransportMode, TravelTimeEntry } from './useTravelTime';
import { apiUrl, authHeaders, logNonAbortError } from '../lib/api'; import { apiUrl, authHeaders, logNonAbortError } from '../lib/api';
export interface AiTravelTimeFilter { export interface AiTravelTimeFilter {
@ -11,20 +11,54 @@ export interface AiTravelTimeFilter {
max?: number; max?: number;
} }
interface AiFiltersResult { export interface AiFiltersResult {
filters: FeatureFilters; filters: FeatureFilters;
travelTimeFilters: AiTravelTimeFilter[]; travelTimeFilters: AiTravelTimeFilter[];
notes: string; notes: string;
/** Human-readable summary of what was set */
summary: string;
} }
export type AiFilterErrorType = 'auth' | 'verification' | 'limit' | 'error'; export type AiFilterErrorType = 'auth' | 'verification' | 'limit' | 'error';
/** Context of currently active filters, sent for conversational refinement. */
export interface AiFiltersContext {
filters: FeatureFilters;
travelTime: { mode: string; label: string; min?: number; max?: number }[];
}
interface UseAiFiltersResult { interface UseAiFiltersResult {
fetchAiFilters: (query: string) => Promise<AiFiltersResult | null>; fetchAiFilters: (query: string, context?: AiFiltersContext) => Promise<AiFiltersResult | null>;
loading: boolean; loading: boolean;
error: string | null; error: string | null;
errorType: AiFilterErrorType | null; errorType: AiFilterErrorType | null;
notes: string | null; notes: string | null;
summary: string | null;
}
/** Build a human-readable summary of the AI result. */
function buildSummary(
filters: FeatureFilters,
travelTimeFilters: AiTravelTimeFilter[]
): string {
const parts: string[] = [];
for (const [name, value] of Object.entries(filters)) {
if (Array.isArray(value) && value.length === 2 && typeof value[0] === 'number') {
parts.push(name);
} else if (Array.isArray(value)) {
parts.push(`${name}: ${(value as string[]).join(', ')}`);
}
}
for (const tt of travelTimeFilters) {
const bounds =
tt.max !== undefined ? `< ${tt.max} min` : tt.min !== undefined ? `> ${tt.min} min` : '';
parts.push(`${tt.mode} to ${tt.label} ${bounds}`.trim());
}
if (parts.length === 0) return 'No filters set';
return `Set ${parts.length} filter${parts.length > 1 ? 's' : ''}: ${parts.join(', ')}`;
} }
export function useAiFilters(): UseAiFiltersResult { export function useAiFilters(): UseAiFiltersResult {
@ -32,77 +66,93 @@ export function useAiFilters(): UseAiFiltersResult {
const [error, setError] = useState<string | null>(null); const [error, setError] = useState<string | null>(null);
const [errorType, setErrorType] = useState<AiFilterErrorType | null>(null); const [errorType, setErrorType] = useState<AiFilterErrorType | null>(null);
const [notes, setNotes] = useState<string | null>(null); const [notes, setNotes] = useState<string | null>(null);
const [summary, setSummary] = useState<string | null>(null);
const abortRef = useRef<AbortController | null>(null); const abortRef = useRef<AbortController | null>(null);
const fetchAiFilters = useCallback(async (query: string): Promise<AiFiltersResult | null> => { const fetchAiFilters = useCallback(
abortRef.current?.abort(); async (query: string, context?: AiFiltersContext): Promise<AiFiltersResult | null> => {
const controller = new AbortController(); abortRef.current?.abort();
abortRef.current = controller; const controller = new AbortController();
abortRef.current = controller;
setLoading(true); setLoading(true);
setError(null); setError(null);
setErrorType(null); setErrorType(null);
setNotes(null); setNotes(null);
setSummary(null);
try { try {
const url = apiUrl('ai-filters'); const url = apiUrl('ai-filters');
const response = await fetch( const bodyObj: Record<string, unknown> = { query };
url, if (context) {
authHeaders({ bodyObj.context = {
method: 'POST', filters: context.filters,
headers: { 'Content-Type': 'application/json' }, travel_time: context.travelTime,
body: JSON.stringify({ query }), };
signal: controller.signal,
})
);
if (!response.ok) {
const text = await response.text();
if (response.status === 401) {
setErrorType('auth');
setError(text || 'Login required');
} else if (response.status === 403) {
setErrorType('verification');
setError(text || 'Email verification required');
} else if (response.status === 429) {
setErrorType('limit');
setError(text || 'Weekly usage limit reached');
} else {
setErrorType('error');
setError(text || `HTTP ${response.status}`);
} }
const response = await fetch(
url,
authHeaders({
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(bodyObj),
signal: controller.signal,
})
);
if (!response.ok) {
const text = await response.text();
if (response.status === 401) {
setErrorType('auth');
setError(text || 'Login required');
} else if (response.status === 403) {
setErrorType('verification');
setError(text || 'Email verification required');
} else if (response.status === 429) {
setErrorType('limit');
setError(text || 'Weekly usage limit reached');
} else {
setErrorType('error');
setError(text || `HTTP ${response.status}`);
}
setLoading(false);
return null;
}
const json = await response.json();
const travelTimeFilters: AiTravelTimeFilter[] = (json.travel_time_filters || []).map(
(tt: { mode: string; slug: string; label: string; min?: number; max?: number }) => ({
mode: tt.mode as TransportMode,
slug: tt.slug,
label: tt.label,
min: tt.min,
max: tt.max,
})
);
const filters = json.filters as FeatureFilters;
const summaryText = buildSummary(filters, travelTimeFilters);
const result: AiFiltersResult = {
filters,
travelTimeFilters,
notes: json.notes || '',
summary: summaryText,
};
setNotes(result.notes || null);
setSummary(summaryText);
setLoading(false);
return result;
} catch (err) {
if (controller.signal.aborted) return null;
logNonAbortError('ai-filters', err);
const message = err instanceof Error ? err.message : 'Failed to generate filters';
setErrorType('error');
setError(message);
setLoading(false); setLoading(false);
return null; return null;
} }
},
[]
);
const json = await response.json(); return { fetchAiFilters, loading, error, errorType, notes, summary };
const travelTimeFilters: AiTravelTimeFilter[] = (json.travel_time_filters || []).map(
(tt: { mode: string; slug: string; label: string; min?: number; max?: number }) => ({
mode: tt.mode as TransportMode,
slug: tt.slug,
label: tt.label,
min: tt.min,
max: tt.max,
})
);
const result: AiFiltersResult = {
filters: json.filters as FeatureFilters,
travelTimeFilters,
notes: json.notes || '',
};
setNotes(result.notes || null);
setLoading(false);
return result;
} catch (err) {
if (controller.signal.aborted) return null;
logNonAbortError('ai-filters', err);
const message = err instanceof Error ? err.message : 'Failed to generate filters';
setErrorType('error');
setError(message);
setLoading(false);
return null;
}
}, []);
return { fetchAiFilters, loading, error, errorType, notes };
} }

View file

@ -7,18 +7,18 @@ import polars as pl
from shapely.geometry import Point from shapely.geometry import Point
from tqdm import tqdm from tqdm import tqdm
from pipeline.utils.england_geometry import load_england_polygon from pipeline.utils.england_geometry import (
ENGLAND_BBOX_EAST,
ENGLAND_BBOX_NORTH,
ENGLAND_BBOX_SOUTH,
ENGLAND_BBOX_WEST,
load_england_polygon,
)
BATCH_SIZE = 50_000 BATCH_SIZE = 50_000
MIN_OCCURENCE_COUNT = 20 MIN_OCCURENCE_COUNT = 20
# Bounding box for fast pre-filtering before the precise polygon check
ENGLAND_BBOX_WEST = -6.45
ENGLAND_BBOX_SOUTH = 49.85
ENGLAND_BBOX_EAST = 1.77
ENGLAND_BBOX_NORTH = 55.82
POI_TAG_KEYS: list[str] = [ POI_TAG_KEYS: list[str] = [
"amenity", "amenity",
"building", "building",

View file

@ -15,7 +15,9 @@ pub const MAX_PRICE_HISTORY_POINTS: usize = 5000;
pub const POSTCODE_SEARCH_OFFSET: f64 = 0.02; pub const POSTCODE_SEARCH_OFFSET: f64 = 0.02;
pub const AI_FILTERS_MAX_TOKENS: usize = 2000; pub const AI_FILTERS_MAX_TOKENS: usize = 2000;
pub const AI_FILTERS_TEMPERATURE: f32 = 0.0; /// Gemini 3 recommends 1.0; lower values can cause looping or degraded performance.
pub const AI_FILTERS_TEMPERATURE: f32 = 1.0;
pub const AI_FILTERS_WEEKLY_TOKEN_LIMIT: u64 = 10_000_000;
/// Timeout for outbound HTTP service calls (seconds). /// Timeout for outbound HTTP service calls (seconds).
pub const SERVICE_CALL_TIMEOUT: u64 = 120; pub const SERVICE_CALL_TIMEOUT: u64 = 120;

View file

@ -94,13 +94,13 @@ struct Cli {
#[arg(long, env = "POCKETBASE_ADMIN_PASSWORD")] #[arg(long, env = "POCKETBASE_ADMIN_PASSWORD")]
pocketbase_admin_password: String, pocketbase_admin_password: String,
/// Ollama server URL (e.g. http://ollama:11434) /// Gemini API key
#[arg(long, env = "OLLAMA_URL")] #[arg(long, env = "GEMINI_API_KEY")]
ollama_url: String, gemini_api_key: String,
/// Ollama model name /// Gemini model name (e.g. gemini-2.0-flash)
#[arg(long, env = "OLLAMA_MODEL")] #[arg(long, env = "GEMINI_MODEL")]
ollama_model: String, gemini_model: String,
/// Path to precomputed travel times directory (contains mode subdirs with parquet files) /// Path to precomputed travel times directory (contains mode subdirs with parquet files)
#[arg(long, env = "TRAVEL_TIMES")] #[arg(long, env = "TRAVEL_TIMES")]
@ -301,9 +301,7 @@ async fn main() -> anyhow::Result<()> {
"Precomputed features response" "Precomputed features response"
); );
let ai_filters_schema = routes::build_ollama_schema(&features_response); // AI filters system prompt built after travel_time_store is loaded (needs mode counts)
let ai_filters_system_prompt = routes::build_system_prompt(&features_response);
info!("Precomputed AI filters schema and system prompt");
// Record data loading metrics // Record data loading metrics
metrics::record_data_stats( metrics::record_data_stats(
@ -331,10 +329,7 @@ async fn main() -> anyhow::Result<()> {
&cli.google_oauth_client_secret, &cli.google_oauth_client_secret,
) )
.await?; .await?;
info!( info!("Gemini configured (model: {})", cli.gemini_model);
"Ollama configured: {} (model: {})",
cli.ollama_url, cli.ollama_model
);
let tt_path = &cli.travel_times; let tt_path = &cli.travel_times;
if !tt_path.exists() { if !tt_path.exists() {
bail!( bail!(
@ -352,6 +347,23 @@ async fn main() -> anyhow::Result<()> {
Arc::new(store) Arc::new(store)
}; };
let mode_destinations: Vec<(String, usize)> = travel_time_store
.available_modes
.iter()
.map(|mode| {
let count = travel_time_store
.destinations
.get(mode.as_str())
.map(|slugs| slugs.len())
.unwrap_or(0);
(mode.clone(), count)
})
.filter(|(_, count)| *count > 0)
.collect();
let ai_filters_system_prompt =
routes::build_system_prompt(&features_response, &mode_destinations);
info!("Precomputed AI filters system prompt");
let token_cache = Arc::new(auth::TokenCache::new()); let token_cache = Arc::new(auth::TokenCache::new());
let state = Arc::new(AppState { let state = Arc::new(AppState {
@ -370,16 +382,16 @@ async fn main() -> anyhow::Result<()> {
features_response, features_response,
screenshot_url: cli.screenshot_url, screenshot_url: cli.screenshot_url,
public_url: cli.public_url, public_url: cli.public_url,
is_dev: index_html.is_none(),
index_html, index_html,
http_client, http_client,
pocketbase_url: cli.pocketbase_url, pocketbase_url: cli.pocketbase_url,
pocketbase_admin_email: cli.pocketbase_admin_email, pocketbase_admin_email: cli.pocketbase_admin_email,
pocketbase_admin_password: cli.pocketbase_admin_password, pocketbase_admin_password: cli.pocketbase_admin_password,
ollama_url: cli.ollama_url, gemini_api_key: cli.gemini_api_key,
ollama_model: cli.ollama_model, gemini_model: cli.gemini_model,
travel_time_store, travel_time_store,
token_cache, token_cache,
ai_filters_schema,
ai_filters_system_prompt, ai_filters_system_prompt,
google_maps_api_key: cli.google_maps_api_key, google_maps_api_key: cli.google_maps_api_key,
stripe_secret_key: cli.stripe_secret_key, stripe_secret_key: cli.stripe_secret_key,
@ -504,7 +516,7 @@ async fn main() -> anyhow::Result<()> {
) )
.route( .route(
"/api/ai-filters", "/api/ai-filters",
post(move |body| routes::post_ai_filters(state_ai_filters.clone(), body)) post(move |ext, body| routes::post_ai_filters(state_ai_filters.clone(), ext, body))
.layer(ConcurrencyLimitLayer::new(5)), .layer(ConcurrencyLimitLayer::new(5)),
) )
.route( .route(

View file

@ -240,9 +240,11 @@ async fn ensure_user_fields(
let has_is_admin = fields.iter().any(|f| f["name"] == "is_admin"); let has_is_admin = fields.iter().any(|f| f["name"] == "is_admin");
let has_subscription = fields.iter().any(|f| f["name"] == "subscription"); let has_subscription = fields.iter().any(|f| f["name"] == "subscription");
let has_newsletter = fields.iter().any(|f| f["name"] == "newsletter"); let has_newsletter = fields.iter().any(|f| f["name"] == "newsletter");
let has_ai_tokens_used = fields.iter().any(|f| f["name"] == "ai_tokens_used");
let has_ai_tokens_week = fields.iter().any(|f| f["name"] == "ai_tokens_week");
if has_is_admin && has_subscription && has_newsletter { if has_is_admin && has_subscription && has_newsletter && has_ai_tokens_used && has_ai_tokens_week {
info!("PocketBase users collection already has is_admin, subscription, and newsletter fields"); info!("PocketBase users collection already has all required fields");
return Ok(()); return Ok(());
} }
@ -269,6 +271,20 @@ async fn ensure_user_fields(
})); }));
} }
if !has_ai_tokens_used {
new_fields.push(serde_json::json!({
"name": "ai_tokens_used",
"type": "number",
}));
}
if !has_ai_tokens_week {
new_fields.push(serde_json::json!({
"name": "ai_tokens_week",
"type": "number",
}));
}
let patch_resp = client let patch_resp = client
.patch(&url) .patch(&url)
.header("Authorization", format!("Bearer {token}")) .header("Authorization", format!("Bearer {token}"))

View file

@ -26,7 +26,7 @@ pub(crate) mod travel_time;
mod travel_destinations; mod travel_destinations;
mod travel_modes; mod travel_modes;
pub use ai_filters::{build_ollama_schema, build_system_prompt, post_ai_filters}; pub use ai_filters::{build_system_prompt, post_ai_filters};
pub use checkout::post_checkout; pub use checkout::post_checkout;
pub use export::get_export; pub use export::get_export;
pub use features::{build_features_response, get_features, FeatureInfo, FeaturesResponse}; pub use features::{build_features_response, get_features, FeatureInfo, FeaturesResponse};

View file

@ -2,76 +2,190 @@ use std::sync::Arc;
use axum::http::StatusCode; use axum::http::StatusCode;
use axum::response::Json; use axum::response::Json;
use axum::Extension;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{json, Value}; use serde_json::{json, Value};
use tracing::{info, warn}; use tracing::{info, warn};
use crate::consts::{AI_FILTERS_MAX_TOKENS, AI_FILTERS_TEMPERATURE}; use crate::auth::OptionalUser;
use crate::consts::{AI_FILTERS_MAX_TOKENS, AI_FILTERS_TEMPERATURE, AI_FILTERS_WEEKLY_TOKEN_LIMIT};
use crate::data::slugify;
use crate::pocketbase::auth_superuser;
use crate::routes::{FeatureInfo, FeaturesResponse}; use crate::routes::{FeatureInfo, FeaturesResponse};
use crate::state::AppState; use crate::state::AppState;
use crate::utils::{extract_ollama_content, ollama_chat, strip_think_blocks}; use crate::utils::gemini_chat;
#[derive(Deserialize)]
pub struct AiFiltersContext {
filters: Value,
#[serde(default)]
travel_time: Vec<AiTravelTimeContext>,
}
#[derive(Deserialize)]
pub struct AiTravelTimeContext {
mode: String,
label: String,
min: Option<f32>,
max: Option<f32>,
}
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct AiFiltersRequest { pub struct AiFiltersRequest {
query: String, query: String,
/// Current filters for conversational refinement (e.g. "make it cheaper")
context: Option<AiFiltersContext>,
}
#[derive(Serialize)]
pub struct TravelTimeFilter {
mode: String,
slug: String,
label: String,
#[serde(skip_serializing_if = "Option::is_none")]
min: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
max: Option<f32>,
} }
#[derive(Serialize)] #[derive(Serialize)]
pub struct AiFiltersResponse { pub struct AiFiltersResponse {
filters: Value, filters: Value,
#[serde(skip_serializing_if = "Vec::is_empty")]
travel_time_filters: Vec<TravelTimeFilter>,
/// What the LLM couldn't map to existing filters (empty if everything matched) /// What the LLM couldn't map to existing filters (empty if everything matched)
#[serde(skip_serializing_if = "String::is_empty")] #[serde(skip_serializing_if = "String::is_empty")]
notes: String, notes: String,
} }
/// Build a JSON schema for Ollama structured output. /// Strip markdown code fences (```json ... ``` or ``` ... ```) from LLM output.
/// /// Models occasionally wrap JSON in markdown fencing even when told not to.
/// Uses two arrays (`numeric_filters` and `enum_filters`) instead of one property fn strip_markdown_fences(text: &str) -> &str {
/// per feature, because Ollama converts JSON schema to GBNF grammar and a schema let trimmed = text.trim();
/// with 50+ optional keys causes a combinatorial explosion that crashes the parser.
/// Array-based schema keeps the grammar small and constant-size. // Try ```json\n...\n``` or ```\n...\n```
pub fn build_ollama_schema(_features: &FeaturesResponse) -> Value { if let Some(rest) = trimmed.strip_prefix("```") {
json!({ // Skip optional language tag (e.g. "json")
"type": "object", let rest = if let Some(newline_pos) = rest.find('\n') {
"properties": { &rest[newline_pos + 1..]
"numeric_filters": { } else {
"type": "array", return trimmed;
"items": { };
"type": "object", if let Some(content) = rest.strip_suffix("```") {
"properties": { return content.trim();
"name": { "type": "string" },
"bound": { "type": "string", "enum": ["min", "max"] },
"value": { "type": "number" }
},
"required": ["name", "bound", "value"]
}
},
"enum_filters": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": { "type": "string" },
"values": { "type": "array", "items": { "type": "string" } }
},
"required": ["name", "values"]
}
},
"notes": {
"type": "string"
}
} }
}) }
trimmed
}
/// Build the Gemini tool declaration for destination search.
fn build_tool_declarations(state: &AppState) -> Value {
let modes: Vec<&str> = state
.travel_time_store
.available_modes
.iter()
.map(|mode| mode.as_str())
.collect();
json!([{
"functionDeclarations": [{
"name": "search_destinations",
"description": "Search for available travel time destinations (cities, stations, towns) that have precomputed travel time data. Call this when the user mentions wanting to be near, close to, or within a certain travel time of a specific place.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Place name to search for (e.g. 'Manchester', 'Kings Cross', 'Heathrow')"
},
"mode": {
"type": "string",
"enum": modes,
"description": "Transport mode to search destinations for"
}
},
"required": ["query", "mode"]
}
}]
}])
}
/// Execute a destination search against PlaceData + TravelTimeStore.
/// Returns matching destinations as a JSON value with `results` and optional `message`.
///
/// Uses word-based matching: all words in the query must appear somewhere in the
/// place name (order-independent). Also matches against slugs for short queries.
fn execute_destination_search(state: &AppState, query: &str, mode: &str) -> Value {
let query_lower = query.to_lowercase();
let query_words: Vec<&str> = query_lower.split_whitespace().collect();
let query_slug = slugify(query);
let tt_store = &state.travel_time_store;
let pd = &state.place_data;
let slug_set = match tt_store.destinations.get(mode) {
Some(slugs) => slugs,
None => return json!({ "results": [], "message": format!("No travel data available for mode '{}'", mode) }),
};
// Find places matching the query that have travel time data.
// A place matches if ALL query words appear in its name, OR its slug matches the query slug.
let mut matches: Vec<(usize, String, u8, u32)> = pd
.name_lower
.iter()
.enumerate()
.filter_map(|(idx, name_lower)| {
let words_match = query_words.iter().all(|word| name_lower.contains(word));
let slug = slugify(&pd.name[idx]);
let slug_match = slug.contains(&query_slug) || query_slug.contains(&slug);
if !words_match && !slug_match {
return None;
}
if slug_set.contains(&slug) {
Some((idx, slug, pd.type_rank[idx], pd.population[idx]))
} else {
None
}
})
.collect();
// Sort: type rank asc, population desc
matches.sort_unstable_by(|a, b| a.2.cmp(&b.2).then(b.3.cmp(&a.3)));
matches.truncate(10);
if matches.is_empty() {
info!(query = query, mode = mode, "Destination search returned no results");
return json!({
"results": [],
"message": format!("No travel time data available for '{}' by {}. This destination cannot be used as a travel time filter.", query, mode)
});
}
let results: Vec<Value> = matches
.into_iter()
.map(|(idx, slug, ..)| {
json!({
"name": pd.name[idx],
"slug": slug,
"place_type": pd.place_type.get(idx).to_string(),
})
})
.collect();
json!({ "results": results })
} }
/// Build the complete system prompt for AI filters. /// Build the complete system prompt for AI filters.
/// ///
/// Contains: role instructions, feature catalogue, few-shot examples, output rules. /// Contains: role instructions, feature catalogue, travel time info,
/// few-shot examples, output rules.
/// Precomputed at startup and cached in AppState. /// Precomputed at startup and cached in AppState.
pub fn build_system_prompt(features: &FeaturesResponse) -> String { pub fn build_system_prompt(
features: &FeaturesResponse,
mode_destinations: &[(String, usize)],
) -> String {
let mut parts = Vec::new(); let mut parts = Vec::new();
// Role and task description
parts.push( parts.push(
"You are a UK property search assistant. \ "You are a UK property search assistant. \
The user describes their ideal property or area in natural language. \ The user describes their ideal property or area in natural language. \
@ -91,10 +205,61 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String {
(note: this counts bedrooms + living rooms combined, so 3 bed ~ min 4).\n\ (note: this counts bedrooms + living rooms combined, so 3 bed ~ min 4).\n\
- If the user mentions something that has no matching filter, put it in \"notes\" \ - If the user mentions something that has no matching filter, put it in \"notes\" \
as a short phrase (e.g. \"No filter for: garden, sea view\"). \ as a short phrase (e.g. \"No filter for: garden, sea view\"). \
If everything was matched, set \"notes\" to an empty string." If everything was matched, set \"notes\" to an empty string.\n\
\n\
CONVERSATIONAL REFINEMENT:\n\
The user's message may include their currently active filters as context. \
When context is provided:\n\
- \"make it cheaper\" / \"lower the price\" = adjust the existing price filter down\n\
- \"also add ...\" / \"and good schools\" = keep existing filters and add new ones\n\
- \"remove the ...\" / \"drop the ...\" = return filters WITHOUT the mentioned one\n\
- If the request is a completely new search (not a refinement), ignore the context \
and build filters from scratch.\n\
- Always output the COMPLETE set of filters (existing + modified), not just the changes."
.to_string(), .to_string(),
); );
// Travel time section with available modes
let modes_list = mode_destinations
.iter()
.map(|(mode, count)| format!("- {} ({} destinations available)", mode, count))
.collect::<Vec<_>>()
.join("\n");
parts.push(format!(
"\n--- TRAVEL TIME FILTERS ---\n\
You can add travel time filters when the user mentions commute times, \
proximity to places, or wanting to be near/within X minutes of somewhere.\n\
\n\
Available transport modes (only use modes that have destinations):\n\
{}\n\
- \"car\" / \"drive\" / \"driving\" = car mode\n\
- \"cycle\" / \"bike\" / \"cycling\" = bicycle mode\n\
- \"walk\" / \"walking\" / \"on foot\" = walking mode\n\
- \"train\" / \"tube\" / \"bus\" / \"public transport\" / \"commute\" = transit mode\n\
\n\
When the user mentions a specific place, you MUST call the search_destinations \
tool to find the exact slug. Use the name and slug from the search results.\n\
If search_destinations returns an empty array, the destination is not available \
mention it in \"notes\" (e.g. \"No travel data for: Gatwick Airport\") and do NOT \
include a travel_time_filter for it.\n\
\n\
Travel time values are in MINUTES (0-120 range).\n\
- \"within 30 minutes\" = max 30\n\
- \"at least 10 minutes\" = min 10\n\
- \"30-45 minute commute\" = min 30, max 45\n\
- If only a max is given, omit min (and vice versa).\n\
\n\
INFERRING TRANSPORT MODE (when the user does not specify one explicitly):\n\
- \"commute\" to a major city centre or station = transit\n\
- \"near\" / \"close to\" a city centre or station = transit\n\
- \"near\" / \"close to\" a smaller town, village, or rural area = car\n\
- \"drive\" / \"driving distance\" / \"driving time\" = always car\n\
- If multiple modes are plausible, prefer transit for urban destinations \
(London, Manchester, Birmingham, Leeds, etc.) and car for everything else.",
modes_list,
));
// Feature catalogue // Feature catalogue
parts.push("\n--- AVAILABLE FEATURES ---\n".to_string()); parts.push("\n--- AVAILABLE FEATURES ---\n".to_string());
for group in &features.groups { for group in &features.groups {
@ -148,6 +313,7 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String {
Output: {\"numeric_filters\": [{\"name\": \"Last known price\", \"bound\": \"max\", \"value\": 400000}], \ Output: {\"numeric_filters\": [{\"name\": \"Last known price\", \"bound\": \"max\", \"value\": 400000}], \
\"enum_filters\": [{\"name\": \"Leasehold/Freehold\", \"values\": [\"Freehold\"]}, \ \"enum_filters\": [{\"name\": \"Leasehold/Freehold\", \"values\": [\"Freehold\"]}, \
{\"name\": \"Property type\", \"values\": [\"Detached\", \"Semi-Detached\", \"Terraced\"]}], \ {\"name\": \"Property type\", \"values\": [\"Detached\", \"Semi-Detached\", \"Terraced\"]}], \
\"travel_time_filters\": [], \
\"notes\": \"\"}" \"notes\": \"\"}"
.to_string(), .to_string(),
); );
@ -161,7 +327,7 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String {
{\"name\": \"Good+ primary schools within 5km\", \"bound\": \"min\", \"value\": 5}, \ {\"name\": \"Good+ primary schools within 5km\", \"bound\": \"min\", \"value\": 5}, \
{\"name\": \"Good+ secondary schools within 5km\", \"bound\": \"min\", \"value\": 2}, \ {\"name\": \"Good+ secondary schools within 5km\", \"bound\": \"min\", \"value\": 2}, \
{\"name\": \"Number of parks within 2km\", \"bound\": \"min\", \"value\": 3}], \ {\"name\": \"Number of parks within 2km\", \"bound\": \"min\", \"value\": 3}], \
\"enum_filters\": [], \"notes\": \"\"}" \"enum_filters\": [], \"travel_time_filters\": [], \"notes\": \"\"}"
.to_string(), .to_string(),
); );
@ -172,18 +338,37 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String {
{\"name\": \"Number of bedrooms & living rooms\", \"bound\": \"min\", \"value\": 4}], \ {\"name\": \"Number of bedrooms & living rooms\", \"bound\": \"min\", \"value\": 4}], \
\"enum_filters\": [{\"name\": \"Property type\", \"values\": [\"Flats/Maisonettes\"]}, \ \"enum_filters\": [{\"name\": \"Property type\", \"values\": [\"Flats/Maisonettes\"]}, \
{\"name\": \"Max available download speed (Mbps)\", \"values\": [\"100\", \"300\", \"1000\"]}], \ {\"name\": \"Max available download speed (Mbps)\", \"values\": [\"100\", \"300\", \"1000\"]}], \
\"travel_time_filters\": [], \
\"notes\": \"No filter for: beach proximity\"}" \"notes\": \"No filter for: beach proximity\"}"
.to_string(), .to_string(),
); );
parts.push( parts.push(
"\nUser: \"large family home with a garden near restaurants\"\n\ "\nUser: \"within 30 minutes commute of Kings Cross, under 500k\"\n\
(After calling search_destinations for \"Kings Cross\" with mode \"transit\" \
and getting [{\"name\": \"Kings Cross\", \"slug\": \"kings-cross\", \"place_type\": \"station\"}])\n\
Output: {\"numeric_filters\": [\
{\"name\": \"Last known price\", \"bound\": \"max\", \"value\": 500000}], \
\"enum_filters\": [], \
\"travel_time_filters\": [{\"mode\": \"transit\", \"slug\": \"kings-cross\", \
\"label\": \"Kings Cross\", \"bound\": \"max\", \"value\": 30}], \
\"notes\": \"\"}"
.to_string(),
);
parts.push(
"\nUser: \"family home with garden, 45 min drive from Manchester, good schools\"\n\
(After calling search_destinations for \"Manchester\" with mode \"car\" \
and getting [{\"name\": \"Manchester\", \"slug\": \"manchester\", \"place_type\": \"city\"}])\n\
Output: {\"numeric_filters\": [\ Output: {\"numeric_filters\": [\
{\"name\": \"Total floor area (sqm)\", \"bound\": \"min\", \"value\": 100}, \ {\"name\": \"Total floor area (sqm)\", \"bound\": \"min\", \"value\": 100}, \
{\"name\": \"Number of bedrooms & living rooms\", \"bound\": \"min\", \"value\": 5}, \ {\"name\": \"Number of bedrooms & living rooms\", \"bound\": \"min\", \"value\": 5}, \
{\"name\": \"Number of restaurants within 2km\", \"bound\": \"min\", \"value\": 10}], \ {\"name\": \"Good+ primary schools within 5km\", \"bound\": \"min\", \"value\": 5}, \
{\"name\": \"Good+ secondary schools within 5km\", \"bound\": \"min\", \"value\": 2}], \
\"enum_filters\": [{\"name\": \"Property type\", \ \"enum_filters\": [{\"name\": \"Property type\", \
\"values\": [\"Detached\", \"Semi-Detached\"]}], \ \"values\": [\"Detached\", \"Semi-Detached\"]}], \
\"travel_time_filters\": [{\"mode\": \"car\", \"slug\": \"manchester\", \
\"label\": \"Manchester\", \"bound\": \"max\", \"value\": 45}], \
\"notes\": \"No filter for: garden\"}" \"notes\": \"No filter for: garden\"}"
.to_string(), .to_string(),
); );
@ -191,7 +376,8 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String {
// Output format reminder // Output format reminder
parts.push( parts.push(
"\n--- OUTPUT FORMAT ---\n\ "\n--- OUTPUT FORMAT ---\n\
{\"numeric_filters\": [{\"name\": \"...\", \"bound\": \"min\"|\"max\", \"value\": N}, ...], \"enum_filters\": [...], \"notes\": \"...\"}\n\ {\"numeric_filters\": [...], \"enum_filters\": [...], \"travel_time_filters\": [{\"mode\": \"...\", \"slug\": \"...\", \"label\": \"...\", \"bound\": \"min\"|\"max\", \"value\": N}, ...], \"notes\": \"...\"}\n\
- travel_time_filters: use ONLY slugs returned by search_destinations. If a place isn't found, mention it in notes.\n\
Respond with ONLY the JSON object. No explanation." Respond with ONLY the JSON object. No explanation."
.to_string(), .to_string(),
); );
@ -199,86 +385,393 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String {
parts.join("\n") parts.join("\n")
} }
pub async fn post_ai_filters( /// Monotonically increasing week number derived from Unix epoch.
state: Arc<AppState>, /// Resets every 7 days (604800 seconds). Used for weekly rate limiting.
Json(req): Json<AiFiltersRequest>, fn current_week_number() -> u64 {
) -> Result<Json<AiFiltersResponse>, (StatusCode, String)> { let secs = std::time::SystemTime::now()
info!(query = %req.query, "POST /api/ai-filters"); .duration_since(std::time::UNIX_EPOCH)
.expect("system time before epoch")
let url = format!("{}/api/chat", state.ollama_url); .as_secs();
let body = json!({ secs / 604_800
"model": state.ollama_model,
"messages": [
{ "role": "system", "content": state.ai_filters_system_prompt },
{ "role": "user", "content": req.query }
],
"stream": false,
"format": state.ai_filters_schema,
"options": {
"temperature": AI_FILTERS_TEMPERATURE,
"num_predict": AI_FILTERS_MAX_TOKENS,
}
});
// Try up to 2 attempts — LLMs occasionally return empty content (e.g. only
// <think> blocks with no JSON output), which is transient and usually
// succeeds on retry.
let mut last_err = None;
for attempt in 0..2 {
let raw = call_ollama_and_parse(&state.http_client, &url, &body).await;
match raw {
Ok(raw) => {
let filters = validate_and_convert(&raw, &state.features_response);
let notes = raw
.get("notes")
.and_then(|val| val.as_str())
.unwrap_or("")
.to_string();
return Ok(Json(AiFiltersResponse { filters, notes }));
}
Err(err) => {
if attempt == 0 {
warn!("LLM attempt 1 failed, retrying: {}", err.1);
}
last_err = Some(err);
}
}
}
Err(last_err.unwrap())
} }
/// Call Ollama and parse the response content as JSON. /// Fetch the user's current AI token usage from PocketBase.
/// /// Returns `(tokens_used, week_number)`.
/// Returns an error if: the HTTP call fails, the response is malformed, async fn fetch_ai_usage(
/// the content is empty after stripping think blocks, or the content is state: &AppState,
/// not valid JSON. user_id: &str,
async fn call_ollama_and_parse( ) -> Result<(u64, u64), (StatusCode, String)> {
client: &reqwest::Client, let pb_url = state.pocketbase_url.trim_end_matches('/');
url: &str, let token = auth_superuser(
body: &Value, &state.http_client,
) -> Result<Value, (StatusCode, String)> { pb_url,
let json_resp = ollama_chat(client, url, body).await?; &state.pocketbase_admin_email,
let content = extract_ollama_content(&json_resp)?; &state.pocketbase_admin_password,
)
.await
.map_err(|err| {
warn!("Failed to auth superuser for AI usage check: {err}");
(StatusCode::BAD_GATEWAY, "Internal error".into())
})?;
let content = strip_think_blocks(content); let url = format!("{pb_url}/api/collections/users/records/{user_id}");
let content = content.trim(); let resp = state
.http_client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await
.map_err(|err| {
warn!("Failed to fetch user record for AI usage: {err}");
(StatusCode::BAD_GATEWAY, "Internal error".into())
})?;
if content.is_empty() { if !resp.status().is_success() {
warn!("LLM returned empty content after stripping think blocks"); let status = resp.status();
warn!("PocketBase user fetch failed ({status})");
return Err((StatusCode::BAD_GATEWAY, "Internal error".into()));
}
let body: Value = resp.json().await.map_err(|err| {
warn!("Failed to parse user record: {err}");
(StatusCode::BAD_GATEWAY, "Internal error".into())
})?;
let tokens_used = body
.get("ai_tokens_used")
.and_then(|val| val.as_u64())
.unwrap_or(0);
let week = body
.get("ai_tokens_week")
.and_then(|val| val.as_u64())
.unwrap_or(0);
Ok((tokens_used, week))
}
/// Update the user's AI token usage in PocketBase.
/// Best-effort — logs warnings on failure but does not propagate errors.
async fn update_ai_usage(state: &AppState, user_id: &str, tokens_used: u64, week: u64) {
let pb_url = state.pocketbase_url.trim_end_matches('/');
let token = match auth_superuser(
&state.http_client,
pb_url,
&state.pocketbase_admin_email,
&state.pocketbase_admin_password,
)
.await
{
Ok(tk) => tk,
Err(err) => {
warn!("Failed to auth superuser for AI usage update: {err}");
return;
}
};
let url = format!("{pb_url}/api/collections/users/records/{user_id}");
let res = state
.http_client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&json!({
"ai_tokens_used": tokens_used,
"ai_tokens_week": week,
}))
.send()
.await;
match res {
Ok(resp) if resp.status().is_success() => {}
Ok(resp) => {
let status = resp.status();
warn!("Failed to update AI usage ({status})");
}
Err(err) => warn!("Failed to update AI usage: {err}"),
}
}
/// Maximum number of round trips (function calls + retries) before giving up.
const MAX_TOOL_ROUNDS: usize = 5;
pub async fn post_ai_filters(
state: Arc<AppState>,
Extension(user): Extension<OptionalUser>,
Json(req): Json<AiFiltersRequest>,
) -> Result<Json<AiFiltersResponse>, (StatusCode, String)> {
// Auth check
let user = user
.0
.ok_or((StatusCode::UNAUTHORIZED, "Login required".into()))?;
// Email verification check (skipped in dev mode)
if !user.verified && !state.is_dev {
return Err(( return Err((
StatusCode::BAD_GATEWAY, StatusCode::FORBIDDEN,
"LLM returned empty content (no JSON output)".into(), "Please verify your email to use AI filters".into(),
)); ));
} }
serde_json::from_str(content).map_err(|err| { // Check weekly token usage
warn!(error = %err, content = %content, "Failed to parse LLM JSON output"); let current_week = current_week_number();
( let (stored_tokens, stored_week) = fetch_ai_usage(&state, &user.id).await?;
StatusCode::BAD_GATEWAY, let tokens_used = if stored_week == current_week {
format!("Failed to parse LLM output as JSON: {}", err), stored_tokens
} else {
0
};
if tokens_used >= AI_FILTERS_WEEKLY_TOKEN_LIMIT {
return Err((
StatusCode::TOO_MANY_REQUESTS,
"Weekly AI usage limit reached. Resets next week.".into(),
));
}
info!(query = %req.query, user_id = %user.id, "POST /api/ai-filters");
let tools = build_tool_declarations(&state);
// Build user message with optional context for conversational refinement
let user_text = if let Some(ref ctx) = req.context {
let mut msg = String::new();
msg.push_str("Currently active filters:\n");
msg.push_str(&serde_json::to_string(&ctx.filters).unwrap_or_default());
if !ctx.travel_time.is_empty() {
msg.push_str("\nCurrently active travel time filters:\n");
for tt in &ctx.travel_time {
let bounds = match (tt.min, tt.max) {
(Some(min), Some(max)) => format!("{}-{} min", min, max),
(Some(min), None) => format!("min {} min", min),
(None, Some(max)) => format!("max {} min", max),
(None, None) => "no range".to_string(),
};
msg.push_str(&format!("- {} to {} ({})\n", tt.mode, tt.label, bounds));
}
}
msg.push_str(&format!("\nUser request: {}", req.query));
msg
} else {
req.query.clone()
};
let mut contents = vec![json!({
"role": "user",
"parts": [{ "text": user_text }]
})];
let mut total_tokens_accumulated: u64 = 0;
// Function calling loop: model may call search_destinations, we execute and feed back
for round in 0..MAX_TOOL_ROUNDS {
let body = json!({
"systemInstruction": {
"parts": [{ "text": state.ai_filters_system_prompt }]
},
"contents": contents,
"tools": tools,
"generationConfig": {
"temperature": AI_FILTERS_TEMPERATURE,
"maxOutputTokens": AI_FILTERS_MAX_TOKENS,
"thinkingConfig": { "thinkingLevel": "LOW" },
}
});
let json_resp = gemini_chat(
&state.http_client,
&state.gemini_api_key,
&state.gemini_model,
&body,
) )
}) .await?;
// Accumulate token usage
total_tokens_accumulated += json_resp
.get("usageMetadata")
.and_then(|md| md.get("totalTokenCount"))
.and_then(|tc| tc.as_u64())
.unwrap_or(0);
let candidate = json_resp
.get("candidates")
.and_then(|cs| cs.get(0))
.and_then(|c| c.get("content"))
.ok_or_else(|| {
warn!("Malformed Gemini response: missing candidates[0].content");
(
StatusCode::BAD_GATEWAY,
"Malformed Gemini response".into(),
)
})?;
let parts = candidate
.get("parts")
.and_then(|p| p.as_array())
.ok_or_else(|| {
warn!("Malformed Gemini response: missing parts array");
(
StatusCode::BAD_GATEWAY,
"Malformed Gemini response".into(),
)
})?;
// Check if the model made a function call.
// Find the full part (includes thoughtSignature required by Gemini 3 models).
if let Some(fc_part) = parts.iter().find(|part| part.get("functionCall").is_some()) {
let fc = fc_part.get("functionCall").unwrap();
let fn_name = fc.get("name").and_then(|n| n.as_str()).unwrap_or("");
let fn_args = fc.get("args").cloned().unwrap_or(json!({}));
info!(
function = fn_name,
round = round,
"AI called tool"
);
let fn_result = if fn_name == "search_destinations" {
let query = fn_args
.get("query")
.and_then(|q| q.as_str())
.unwrap_or("");
let mode = fn_args
.get("mode")
.and_then(|m| m.as_str())
.unwrap_or("transit");
execute_destination_search(&state, query, mode)
} else {
json!({"error": "unknown function"})
};
// Append the model's full response (preserves thoughtSignature) + our function result
contents.push(candidate.clone());
contents.push(json!({
"role": "user",
"parts": [{
"functionResponse": {
"name": fn_name,
"response": { "results": fn_result }
}
}]
}));
// Continue the loop — model will process the results
continue;
}
// Model returned text — extract and parse as JSON
let text = parts
.iter()
.find_map(|part| part.get("text").and_then(|t| t.as_str()))
.unwrap_or("");
let text = strip_markdown_fences(text);
let text = text.trim();
if text.is_empty() {
warn!("Gemini returned empty text content (round {})", round);
// Retry by continuing the loop
contents.push(candidate.clone());
contents.push(json!({
"role": "user",
"parts": [{ "text": "Your response was empty. Please output the JSON object." }]
}));
continue;
}
let raw: Value = match serde_json::from_str(text) {
Ok(val) => val,
Err(err) => {
warn!(error = %err, round = round, "Failed to parse Gemini JSON output, retrying");
// Ask the model to fix its output
contents.push(candidate.clone());
contents.push(json!({
"role": "user",
"parts": [{ "text": "That was not valid JSON. Please output ONLY the JSON object with numeric_filters, enum_filters, travel_time_filters, and notes." }]
}));
continue;
}
};
let filters = validate_and_convert(&raw, &state.features_response);
let travel_time_filters = validate_travel_time_filters(&raw, &state);
let notes = raw
.get("notes")
.and_then(|val| val.as_str())
.unwrap_or("")
.to_string();
// Update usage with total accumulated tokens
let new_total = tokens_used + total_tokens_accumulated;
update_ai_usage(&state, &user.id, new_total, current_week).await;
return Ok(Json(AiFiltersResponse {
filters,
travel_time_filters,
notes,
}));
}
// Exhausted tool rounds without getting a final text response
warn!("AI exhausted {} tool-calling rounds without final response", MAX_TOOL_ROUNDS);
Err((
StatusCode::BAD_GATEWAY,
"AI could not complete the request".into(),
))
}
/// Validate travel time filters from LLM output against available destinations.
fn validate_travel_time_filters(raw: &Value, state: &AppState) -> Vec<TravelTimeFilter> {
let arr = match raw
.get("travel_time_filters")
.and_then(|val| val.as_array())
{
Some(arr) => arr,
None => return Vec::new(),
};
let tt_store = &state.travel_time_store;
let mut results = Vec::new();
for item in arr {
let mode = match item.get("mode").and_then(|val| val.as_str()) {
Some(mode) => mode,
None => continue,
};
let slug = match item.get("slug").and_then(|val| val.as_str()) {
Some(slug) => slug,
None => continue,
};
let label = item
.get("label")
.and_then(|val| val.as_str())
.unwrap_or(slug);
// Verify this destination actually exists
if !tt_store.has_destination(mode, slug) {
warn!(mode = mode, slug = slug, "AI suggested non-existent destination");
continue;
}
let bound = item.get("bound").and_then(|val| val.as_str());
let value = item.get("value").and_then(|val| val.as_f64());
let (min, max) = match (bound, value) {
(Some("min"), Some(val)) => (Some(val.max(0.0).min(120.0) as f32), None),
(Some("max"), Some(val)) => (None, Some(val.max(0.0).min(120.0) as f32)),
_ => (None, None),
};
// Only include if at least one bound is set
if min.is_some() || max.is_some() {
results.push(TravelTimeFilter {
mode: mode.to_string(),
slug: slug.to_string(),
label: label.to_string(),
min,
max,
});
}
}
results
} }
/// Validate LLM output against feature metadata and convert to FeatureFilters format. /// Validate LLM output against feature metadata and convert to FeatureFilters format.
@ -374,3 +867,32 @@ fn validate_and_convert(raw: &Value, features: &FeaturesResponse) -> Value {
Value::Object(result) Value::Object(result)
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn strip_fences_json_tag() {
let input = "```json\n{\"a\": 1}\n```";
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
}
#[test]
fn strip_fences_no_tag() {
let input = "```\n{\"a\": 1}\n```";
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
}
#[test]
fn strip_fences_passthrough() {
let input = "{\"a\": 1}";
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
}
#[test]
fn strip_fences_whitespace() {
let input = " ```json\n {\"a\": 1} \n``` ";
assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
}
}

View file

@ -178,7 +178,7 @@ pub async fn get_invite(
} }
// Dev-only: return a fake valid admin invite without hitting PocketBase // Dev-only: return a fake valid admin invite without hitting PocketBase
if state.index_html.is_none() && code == DEV_INVITE_CODE { if state.is_dev && code == DEV_INVITE_CODE {
return Json(InviteValidation { return Json(InviteValidation {
valid: true, valid: true,
invite_type: "admin".to_string(), invite_type: "admin".to_string(),
@ -294,7 +294,7 @@ pub async fn post_redeem_invite(
} }
// Dev-only: fake redeem — just return "licensed" without touching PocketBase // Dev-only: fake redeem — just return "licensed" without touching PocketBase
if state.index_html.is_none() && req.code == DEV_INVITE_CODE { if state.is_dev && req.code == DEV_INVITE_CODE {
info!(user_id = %user.id, "Dev invite redeemed (no-op)"); info!(user_id = %user.id, "Dev invite redeemed (no-op)");
return Json(RedeemResponse { return Json(RedeemResponse {
result: "licensed".to_string(), result: "licensed".to_string(),

View file

@ -34,6 +34,8 @@ pub struct AppState {
pub screenshot_url: String, pub screenshot_url: String,
/// Public-facing URL for absolute og:image URLs (e.g. https://perfectpostcodes.schmelczer.dev) /// Public-facing URL for absolute og:image URLs (e.g. https://perfectpostcodes.schmelczer.dev)
pub public_url: String, pub public_url: String,
/// True when --dist is not provided (no static serving, relaxed auth checks)
pub is_dev: bool,
/// Contents of index.html read at startup, used for crawler OG injection (None when --dist omitted) /// Contents of index.html read at startup, used for crawler OG injection (None when --dist omitted)
pub index_html: Option<String>, pub index_html: Option<String>,
/// Shared HTTP client for proxying to the screenshot service and PocketBase /// Shared HTTP client for proxying to the screenshot service and PocketBase
@ -44,16 +46,14 @@ pub struct AppState {
pub pocketbase_admin_email: String, pub pocketbase_admin_email: String,
/// PocketBase superuser password /// PocketBase superuser password
pub pocketbase_admin_password: String, pub pocketbase_admin_password: String,
/// Ollama server URL for AI area summaries (e.g. http://ollama:11434) /// Gemini API key for AI filters
pub ollama_url: String, pub gemini_api_key: String,
/// Ollama model name for area summaries (e.g. gemma3:12b) /// Gemini model name (e.g. gemini-2.0-flash)
pub ollama_model: String, pub gemini_model: String,
/// Precomputed travel time data store /// Precomputed travel time data store
pub travel_time_store: Arc<TravelTimeStore>, pub travel_time_store: Arc<TravelTimeStore>,
/// Token validation cache (60s TTL) /// Token validation cache (60s TTL)
pub token_cache: Arc<TokenCache>, pub token_cache: Arc<TokenCache>,
/// JSON schema for Ollama structured output in AI filters
pub ai_filters_schema: serde_json::Value,
/// Complete system prompt for AI filters (features + examples + instructions) /// Complete system prompt for AI filters (features + examples + instructions)
pub ai_filters_system_prompt: String, pub ai_filters_system_prompt: String,
/// Google Maps API key for Street View metadata lookups /// Google Maps API key for Street View metadata lookups

View file

@ -6,7 +6,7 @@ mod llm;
pub use grid_index::GridIndex; pub use grid_index::GridIndex;
pub use hash::{generate_priorities, splitmix64_hash}; pub use hash::{generate_priorities, splitmix64_hash};
pub use interned_column::InternedColumn; pub use interned_column::InternedColumn;
pub use llm::{extract_ollama_content, ollama_chat, strip_think_blocks}; pub use llm::{extract_gemini_content, gemini_chat};
/// Normalize a UK postcode: uppercase, strip spaces, insert canonical space before inward code. /// Normalize a UK postcode: uppercase, strip spaces, insert canonical space before inward code.
/// e.g. "e142dg" → "E14 2DG", "E14 2DG" → "E14 2DG", "EC1A1BB" → "EC1A 1BB" /// e.g. "e142dg" → "E14 2DG", "E14 2DG" → "E14 2DG", "EC1A1BB" → "EC1A 1BB"

View file

@ -4,68 +4,62 @@ use tracing::warn;
pub type LlmError = (StatusCode, String); pub type LlmError = (StatusCode, String);
/// Send a chat request to Ollama and return the parsed JSON response. /// Send a generateContent request to the Gemini API and return the parsed JSON response.
/// ///
/// Handles connection errors, non-success status codes, and JSON parse failures /// Handles connection errors, non-success status codes, and JSON parse failures
/// uniformly as `BAD_GATEWAY` errors. /// uniformly as `BAD_GATEWAY` errors.
pub async fn ollama_chat( pub async fn gemini_chat(
client: &reqwest::Client, client: &reqwest::Client,
url: &str, api_key: &str,
model: &str,
body: &Value, body: &Value,
) -> Result<Value, LlmError> { ) -> Result<Value, LlmError> {
let response = client.post(url).json(body).send().await.map_err(|err| { let url = format!(
warn!(error = %err, "Failed to connect to Ollama"); "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
model, api_key
);
let response = client.post(&url).json(body).send().await.map_err(|err| {
warn!(error = %err, "Failed to connect to Gemini API");
( (
StatusCode::BAD_GATEWAY, StatusCode::BAD_GATEWAY,
format!("Failed to connect to Ollama: {}", err), format!("Failed to connect to Gemini API: {}", err),
) )
})?; })?;
if !response.status().is_success() { if !response.status().is_success() {
let status = response.status(); let status = response.status();
let body_text = response.text().await.unwrap_or_default(); let body_text = response.text().await.unwrap_or_default();
warn!(status = %status, body = %body_text, "Ollama returned error"); warn!(status = %status, body = %body_text, "Gemini API returned error");
return Err(( return Err((
StatusCode::BAD_GATEWAY, StatusCode::BAD_GATEWAY,
format!("Ollama error {}: {}", status, body_text), format!("Gemini API error {}: {}", status, body_text),
)); ));
} }
response.json().await.map_err(|err| { response.json().await.map_err(|err| {
warn!(error = %err, "Failed to parse Ollama response"); warn!(error = %err, "Failed to parse Gemini API response");
( (
StatusCode::BAD_GATEWAY, StatusCode::BAD_GATEWAY,
format!("Failed to parse Ollama response: {}", err), format!("Failed to parse Gemini API response: {}", err),
) )
}) })
} }
/// Extract content from Ollama native response (`message.content`) /// Extract text content from Gemini response (`candidates[0].content.parts[0].text`)
pub fn extract_ollama_content(json: &Value) -> Result<&str, LlmError> { pub fn extract_gemini_content(json: &Value) -> Result<&str, LlmError> {
json.get("message") json.get("candidates")
.and_then(|msg| msg.get("content")) .and_then(|c| c.get(0))
.and_then(|ct| ct.as_str()) .and_then(|c| c.get("content"))
.and_then(|c| c.get("parts"))
.and_then(|p| p.get(0))
.and_then(|p| p.get("text"))
.and_then(|t| t.as_str())
.ok_or_else(|| { .ok_or_else(|| {
warn!("Malformed Ollama response: missing message.content"); warn!("Malformed Gemini response: missing candidates[0].content.parts[0].text");
( (
StatusCode::BAD_GATEWAY, StatusCode::BAD_GATEWAY,
"Malformed LLM response: missing message.content".into(), "Malformed Gemini response: missing content".into(),
) )
}) })
} }
/// Strip `<think>...</think>` blocks from model output
pub fn strip_think_blocks(text: &str) -> String {
let mut result = String::new();
let mut remaining = text;
while let Some(start) = remaining.find("<think>") {
result.push_str(&remaining[..start]);
if let Some(end) = remaining[start..].find("</think>") {
remaining = &remaining[start + end + 8..];
} else {
return result;
}
}
result.push_str(remaining);
result
}