From 80c093b7ba8f149ee6611686da3a9f284af435ce Mon Sep 17 00:00:00 2001
From: Andras Schmelczer
Date: Sun, 15 Mar 2026 14:05:34 +0000
Subject: [PATCH] Improve LLM
---
.github/workflows/docker-publish.yml | 15 +-
docker-compose.yml | 4 +-
frontend/src/components/map/AiFilterInput.tsx | 9 +-
frontend/src/components/map/Filters.tsx | 4 +-
frontend/src/components/map/MapPage.tsx | 38 +-
frontend/src/hooks/useAiFilters.ts | 184 +++--
pipeline/download/pois.py | 14 +-
server-rs/src/consts.rs | 4 +-
server-rs/src/main.rs | 46 +-
server-rs/src/pocketbase.rs | 20 +-
server-rs/src/routes.rs | 2 +-
server-rs/src/routes/ai_filters.rs | 758 +++++++++++++++---
server-rs/src/routes/invites.rs | 4 +-
server-rs/src/state.rs | 12 +-
server-rs/src/utils.rs | 2 +-
server-rs/src/utils/llm.rs | 60 +-
16 files changed, 898 insertions(+), 278 deletions(-)
diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml
index ea7da6e..9dc8577 100644
--- a/.github/workflows/docker-publish.yml
+++ b/.github/workflows/docker-publish.yml
@@ -21,6 +21,15 @@ jobs:
- name: Checkout
uses: actions/checkout@v4
+ - name: Set up uv
+ uses: astral-sh/setup-uv@v4
+
+ - name: Download map assets (fonts, sprites, twemoji)
+ run: uv run python -m pipeline.download.map_assets --output frontend/public/assets
+
+ - name: Download arcgis data for finder
+ run: uv run python -m pipeline.download.arcgis --output property-data/arcgis_data.parquet
+
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -64,12 +73,6 @@ jobs:
cache-from: type=gha,scope=screenshot
cache-to: type=gha,mode=max,scope=screenshot
- - name: Set up uv
- uses: astral-sh/setup-uv@v4
-
- - name: Download arcgis data for finder
- run: uv run python -m pipeline.download.arcgis --output property-data/arcgis_data.parquet
-
- name: Build and push finder service
uses: docker/build-push-action@v6
with:
diff --git a/docker-compose.yml b/docker-compose.yml
index 22fb204..2b2ff44 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -36,8 +36,8 @@ services:
POCKETBASE_ADMIN_EMAIL: *pb-email
POCKETBASE_ADMIN_PASSWORD: *pb-password
SCREENSHOT_URL: http://screenshot:8002
- OLLAMA_URL: http://host.docker.internal:11434
- OLLAMA_MODEL: gpt-oss:20b
+ GEMINI_API_KEY: AIzaSyC2mQDcEwILHM3uOE2C-lxUQbQrKTX9Xi4
+ GEMINI_MODEL: gemini-3-flash-preview
PUBLIC_URL: https://perfectpostcodes.schmelczer.dev
GOOGLE_MAPS_API_KEY: "AIzaSyBgBn9LjrxHCjb9j1LZbLYpEdCJj-NkHPY"
STRIPE_SECRET_KEY: sk_test_51SyVcePRjj2bdyn1HLkatQ5onwp8kamm41tjMcRdxXnJYWVPsVd9usMTOSNtNdGhrjbsrtNbgTdKXICg2qBiocEn00PvNDC0d3
diff --git a/frontend/src/components/map/AiFilterInput.tsx b/frontend/src/components/map/AiFilterInput.tsx
index 371e5c0..af7698e 100644
--- a/frontend/src/components/map/AiFilterInput.tsx
+++ b/frontend/src/components/map/AiFilterInput.tsx
@@ -42,6 +42,7 @@ interface AiFilterInputProps {
error: string | null;
errorType: AiFilterErrorType | null;
notes: string | null;
+ summary: string | null;
onSubmit: (query: string) => void;
isLoggedIn: boolean;
onLoginRequired: () => void;
@@ -52,6 +53,7 @@ export default memo(function AiFilterInput({
error,
errorType,
notes,
+ summary,
onSubmit,
isLoggedIn,
onLoginRequired,
@@ -88,7 +90,7 @@ export default memo(function AiFilterInput({
);
const hasContent = query.trim().length > 0;
- const showExamples = expanded && !hasContent && !loading && !error && !notes;
+ const showExamples = expanded && !hasContent && !loading && !error && !notes && !summary;
if (!expanded) {
return (
@@ -173,6 +175,11 @@ export default memo(function AiFilterInput({
{error}
)}
+ {summary && !error && !loading && (
+
+ {summary}
+
+ )}
{notes && !error && !loading && (
{notes}
diff --git a/frontend/src/components/map/Filters.tsx b/frontend/src/components/map/Filters.tsx
index 280148b..308f7ec 100644
--- a/frontend/src/components/map/Filters.tsx
+++ b/frontend/src/components/map/Filters.tsx
@@ -92,6 +92,7 @@ interface FiltersProps {
aiFilterError: string | null;
aiFilterErrorType: AiFilterErrorType | null;
aiFilterNotes: string | null;
+ aiFilterSummary: string | null;
onAiFilterSubmit: (query: string) => void;
isLoggedIn: boolean;
onLoginRequired: () => void;
@@ -127,6 +128,7 @@ export default memo(function Filters({
aiFilterError,
aiFilterErrorType,
aiFilterNotes,
+ aiFilterSummary,
onAiFilterSubmit,
isLoggedIn,
onLoginRequired,
@@ -285,7 +287,7 @@ export default memo(function Filters({
-
+
{(['historical', 'buy', 'rent'] as const).map((type) => {
diff --git a/frontend/src/components/map/MapPage.tsx b/frontend/src/components/map/MapPage.tsx
index c232472..0e3b93a 100644
--- a/frontend/src/components/map/MapPage.tsx
+++ b/frontend/src/components/map/MapPage.tsx
@@ -148,22 +148,33 @@ export default function MapPage({
const handleAiFilterSubmit = useCallback(
async (query: string) => {
- const result = await aiFilters.fetchAiFilters(query);
+ // Build context from current filters for conversational refinement
+ const context = {
+ filters,
+ travelTime: travelTime.activeEntries.map((entry) => ({
+ mode: entry.mode,
+ label: entry.label,
+ min: entry.timeRange?.[0],
+ max: entry.timeRange?.[1],
+ })),
+ };
+ const hasContext =
+ Object.keys(context.filters).length > 0 || context.travelTime.length > 0;
+
+ const result = await aiFilters.fetchAiFilters(query, hasContext ? context : undefined);
if (!result) return;
handleSetFilters(result.filters);
- // Apply travel time filters from AI
- if (result.travelTimeFilters.length > 0) {
- const newEntries = result.travelTimeFilters.map((tt) => ({
- mode: tt.mode,
- slug: tt.slug,
- label: tt.label,
- timeRange: [tt.min ?? 0, tt.max ?? 120] as [number, number],
- useBest: false,
- }));
- travelTime.handleSetEntries(newEntries);
- }
+ // Always sync travel time entries — clear stale ones when AI returns none
+ const newEntries = result.travelTimeFilters.map((tt) => ({
+ mode: tt.mode,
+ slug: tt.slug,
+ label: tt.label,
+ timeRange: [tt.min ?? 0, tt.max ?? 120] as [number, number],
+ useBest: false,
+ }));
+ travelTime.handleSetEntries(newEntries);
},
- [aiFilters.fetchAiFilters, handleSetFilters, travelTime.handleSetEntries]
+ [aiFilters.fetchAiFilters, handleSetFilters, travelTime.handleSetEntries, travelTime.activeEntries, filters]
);
const handleTravelTimeSetDestination = useCallback(
@@ -514,6 +525,7 @@ export default function MapPage({
aiFilterError={aiFilters.error}
aiFilterErrorType={aiFilters.errorType}
aiFilterNotes={aiFilters.notes}
+ aiFilterSummary={aiFilters.summary}
onAiFilterSubmit={handleAiFilterSubmit}
isLoggedIn={!!user}
onLoginRequired={onRegisterClick ?? (() => {})}
diff --git a/frontend/src/hooks/useAiFilters.ts b/frontend/src/hooks/useAiFilters.ts
index d4c775e..b436b30 100644
--- a/frontend/src/hooks/useAiFilters.ts
+++ b/frontend/src/hooks/useAiFilters.ts
@@ -1,6 +1,6 @@
import { useState, useCallback, useRef } from 'react';
import type { FeatureFilters } from '../types';
-import type { TransportMode } from './useTravelTime';
+import type { TransportMode, TravelTimeEntry } from './useTravelTime';
import { apiUrl, authHeaders, logNonAbortError } from '../lib/api';
export interface AiTravelTimeFilter {
@@ -11,20 +11,54 @@ export interface AiTravelTimeFilter {
max?: number;
}
-interface AiFiltersResult {
+export interface AiFiltersResult {
filters: FeatureFilters;
travelTimeFilters: AiTravelTimeFilter[];
notes: string;
+ /** Human-readable summary of what was set */
+ summary: string;
}
export type AiFilterErrorType = 'auth' | 'verification' | 'limit' | 'error';
+/** Context of currently active filters, sent for conversational refinement. */
+export interface AiFiltersContext {
+ filters: FeatureFilters;
+ travelTime: { mode: string; label: string; min?: number; max?: number }[];
+}
+
interface UseAiFiltersResult {
- fetchAiFilters: (query: string) => Promise
;
+ fetchAiFilters: (query: string, context?: AiFiltersContext) => Promise;
loading: boolean;
error: string | null;
errorType: AiFilterErrorType | null;
notes: string | null;
+ summary: string | null;
+}
+
+/** Build a human-readable summary of the AI result. */
+function buildSummary(
+ filters: FeatureFilters,
+ travelTimeFilters: AiTravelTimeFilter[]
+): string {
+ const parts: string[] = [];
+
+ for (const [name, value] of Object.entries(filters)) {
+ if (Array.isArray(value) && value.length === 2 && typeof value[0] === 'number') {
+ parts.push(name);
+ } else if (Array.isArray(value)) {
+ parts.push(`${name}: ${(value as string[]).join(', ')}`);
+ }
+ }
+
+ for (const tt of travelTimeFilters) {
+ const bounds =
+ tt.max !== undefined ? `< ${tt.max} min` : tt.min !== undefined ? `> ${tt.min} min` : '';
+ parts.push(`${tt.mode} to ${tt.label} ${bounds}`.trim());
+ }
+
+ if (parts.length === 0) return 'No filters set';
+ return `Set ${parts.length} filter${parts.length > 1 ? 's' : ''}: ${parts.join(', ')}`;
}
export function useAiFilters(): UseAiFiltersResult {
@@ -32,77 +66,93 @@ export function useAiFilters(): UseAiFiltersResult {
const [error, setError] = useState(null);
const [errorType, setErrorType] = useState(null);
const [notes, setNotes] = useState(null);
+ const [summary, setSummary] = useState(null);
const abortRef = useRef(null);
- const fetchAiFilters = useCallback(async (query: string): Promise => {
- abortRef.current?.abort();
- const controller = new AbortController();
- abortRef.current = controller;
+ const fetchAiFilters = useCallback(
+ async (query: string, context?: AiFiltersContext): Promise => {
+ abortRef.current?.abort();
+ const controller = new AbortController();
+ abortRef.current = controller;
- setLoading(true);
- setError(null);
- setErrorType(null);
- setNotes(null);
+ setLoading(true);
+ setError(null);
+ setErrorType(null);
+ setNotes(null);
+ setSummary(null);
- try {
- const url = apiUrl('ai-filters');
- const response = await fetch(
- url,
- authHeaders({
- method: 'POST',
- headers: { 'Content-Type': 'application/json' },
- body: JSON.stringify({ query }),
- signal: controller.signal,
- })
- );
-
- if (!response.ok) {
- const text = await response.text();
- if (response.status === 401) {
- setErrorType('auth');
- setError(text || 'Login required');
- } else if (response.status === 403) {
- setErrorType('verification');
- setError(text || 'Email verification required');
- } else if (response.status === 429) {
- setErrorType('limit');
- setError(text || 'Weekly usage limit reached');
- } else {
- setErrorType('error');
- setError(text || `HTTP ${response.status}`);
+ try {
+ const url = apiUrl('ai-filters');
+ const bodyObj: Record = { query };
+ if (context) {
+ bodyObj.context = {
+ filters: context.filters,
+ travel_time: context.travelTime,
+ };
}
+ const response = await fetch(
+ url,
+ authHeaders({
+ method: 'POST',
+ headers: { 'Content-Type': 'application/json' },
+ body: JSON.stringify(bodyObj),
+ signal: controller.signal,
+ })
+ );
+
+ if (!response.ok) {
+ const text = await response.text();
+ if (response.status === 401) {
+ setErrorType('auth');
+ setError(text || 'Login required');
+ } else if (response.status === 403) {
+ setErrorType('verification');
+ setError(text || 'Email verification required');
+ } else if (response.status === 429) {
+ setErrorType('limit');
+ setError(text || 'Weekly usage limit reached');
+ } else {
+ setErrorType('error');
+ setError(text || `HTTP ${response.status}`);
+ }
+ setLoading(false);
+ return null;
+ }
+
+ const json = await response.json();
+ const travelTimeFilters: AiTravelTimeFilter[] = (json.travel_time_filters || []).map(
+ (tt: { mode: string; slug: string; label: string; min?: number; max?: number }) => ({
+ mode: tt.mode as TransportMode,
+ slug: tt.slug,
+ label: tt.label,
+ min: tt.min,
+ max: tt.max,
+ })
+ );
+ const filters = json.filters as FeatureFilters;
+ const summaryText = buildSummary(filters, travelTimeFilters);
+ const result: AiFiltersResult = {
+ filters,
+ travelTimeFilters,
+ notes: json.notes || '',
+ summary: summaryText,
+ };
+ setNotes(result.notes || null);
+ setSummary(summaryText);
+ setLoading(false);
+ return result;
+ } catch (err) {
+ if (controller.signal.aborted) return null;
+ logNonAbortError('ai-filters', err);
+ const message = err instanceof Error ? err.message : 'Failed to generate filters';
+ setErrorType('error');
+ setError(message);
setLoading(false);
return null;
}
+ },
+ []
+ );
- const json = await response.json();
- const travelTimeFilters: AiTravelTimeFilter[] = (json.travel_time_filters || []).map(
- (tt: { mode: string; slug: string; label: string; min?: number; max?: number }) => ({
- mode: tt.mode as TransportMode,
- slug: tt.slug,
- label: tt.label,
- min: tt.min,
- max: tt.max,
- })
- );
- const result: AiFiltersResult = {
- filters: json.filters as FeatureFilters,
- travelTimeFilters,
- notes: json.notes || '',
- };
- setNotes(result.notes || null);
- setLoading(false);
- return result;
- } catch (err) {
- if (controller.signal.aborted) return null;
- logNonAbortError('ai-filters', err);
- const message = err instanceof Error ? err.message : 'Failed to generate filters';
- setErrorType('error');
- setError(message);
- setLoading(false);
- return null;
- }
- }, []);
-
- return { fetchAiFilters, loading, error, errorType, notes };
+ return { fetchAiFilters, loading, error, errorType, notes, summary };
}
diff --git a/pipeline/download/pois.py b/pipeline/download/pois.py
index 0b39f8b..ed08b8e 100644
--- a/pipeline/download/pois.py
+++ b/pipeline/download/pois.py
@@ -7,18 +7,18 @@ import polars as pl
from shapely.geometry import Point
from tqdm import tqdm
-from pipeline.utils.england_geometry import load_england_polygon
+from pipeline.utils.england_geometry import (
+ ENGLAND_BBOX_EAST,
+ ENGLAND_BBOX_NORTH,
+ ENGLAND_BBOX_SOUTH,
+ ENGLAND_BBOX_WEST,
+ load_england_polygon,
+)
BATCH_SIZE = 50_000
MIN_OCCURENCE_COUNT = 20
-# Bounding box for fast pre-filtering before the precise polygon check
-ENGLAND_BBOX_WEST = -6.45
-ENGLAND_BBOX_SOUTH = 49.85
-ENGLAND_BBOX_EAST = 1.77
-ENGLAND_BBOX_NORTH = 55.82
-
POI_TAG_KEYS: list[str] = [
"amenity",
"building",
diff --git a/server-rs/src/consts.rs b/server-rs/src/consts.rs
index ba61d81..923e514 100644
--- a/server-rs/src/consts.rs
+++ b/server-rs/src/consts.rs
@@ -15,7 +15,9 @@ pub const MAX_PRICE_HISTORY_POINTS: usize = 5000;
pub const POSTCODE_SEARCH_OFFSET: f64 = 0.02;
pub const AI_FILTERS_MAX_TOKENS: usize = 2000;
-pub const AI_FILTERS_TEMPERATURE: f32 = 0.0;
+/// Gemini 3 recommends 1.0; lower values can cause looping or degraded performance.
+pub const AI_FILTERS_TEMPERATURE: f32 = 1.0;
+pub const AI_FILTERS_WEEKLY_TOKEN_LIMIT: u64 = 10_000_000;
/// Timeout for outbound HTTP service calls (seconds).
pub const SERVICE_CALL_TIMEOUT: u64 = 120;
diff --git a/server-rs/src/main.rs b/server-rs/src/main.rs
index 01281a3..930c411 100644
--- a/server-rs/src/main.rs
+++ b/server-rs/src/main.rs
@@ -94,13 +94,13 @@ struct Cli {
#[arg(long, env = "POCKETBASE_ADMIN_PASSWORD")]
pocketbase_admin_password: String,
- /// Ollama server URL (e.g. http://ollama:11434)
- #[arg(long, env = "OLLAMA_URL")]
- ollama_url: String,
+ /// Gemini API key
+ #[arg(long, env = "GEMINI_API_KEY")]
+ gemini_api_key: String,
- /// Ollama model name
- #[arg(long, env = "OLLAMA_MODEL")]
- ollama_model: String,
+ /// Gemini model name (e.g. gemini-2.0-flash)
+ #[arg(long, env = "GEMINI_MODEL")]
+ gemini_model: String,
/// Path to precomputed travel times directory (contains mode subdirs with parquet files)
#[arg(long, env = "TRAVEL_TIMES")]
@@ -301,9 +301,7 @@ async fn main() -> anyhow::Result<()> {
"Precomputed features response"
);
- let ai_filters_schema = routes::build_ollama_schema(&features_response);
- let ai_filters_system_prompt = routes::build_system_prompt(&features_response);
- info!("Precomputed AI filters schema and system prompt");
+ // AI filters system prompt built after travel_time_store is loaded (needs mode counts)
// Record data loading metrics
metrics::record_data_stats(
@@ -331,10 +329,7 @@ async fn main() -> anyhow::Result<()> {
&cli.google_oauth_client_secret,
)
.await?;
- info!(
- "Ollama configured: {} (model: {})",
- cli.ollama_url, cli.ollama_model
- );
+ info!("Gemini configured (model: {})", cli.gemini_model);
let tt_path = &cli.travel_times;
if !tt_path.exists() {
bail!(
@@ -352,6 +347,23 @@ async fn main() -> anyhow::Result<()> {
Arc::new(store)
};
+ let mode_destinations: Vec<(String, usize)> = travel_time_store
+ .available_modes
+ .iter()
+ .map(|mode| {
+ let count = travel_time_store
+ .destinations
+ .get(mode.as_str())
+ .map(|slugs| slugs.len())
+ .unwrap_or(0);
+ (mode.clone(), count)
+ })
+ .filter(|(_, count)| *count > 0)
+ .collect();
+ let ai_filters_system_prompt =
+ routes::build_system_prompt(&features_response, &mode_destinations);
+ info!("Precomputed AI filters system prompt");
+
let token_cache = Arc::new(auth::TokenCache::new());
let state = Arc::new(AppState {
@@ -370,16 +382,16 @@ async fn main() -> anyhow::Result<()> {
features_response,
screenshot_url: cli.screenshot_url,
public_url: cli.public_url,
+ is_dev: index_html.is_none(),
index_html,
http_client,
pocketbase_url: cli.pocketbase_url,
pocketbase_admin_email: cli.pocketbase_admin_email,
pocketbase_admin_password: cli.pocketbase_admin_password,
- ollama_url: cli.ollama_url,
- ollama_model: cli.ollama_model,
+ gemini_api_key: cli.gemini_api_key,
+ gemini_model: cli.gemini_model,
travel_time_store,
token_cache,
- ai_filters_schema,
ai_filters_system_prompt,
google_maps_api_key: cli.google_maps_api_key,
stripe_secret_key: cli.stripe_secret_key,
@@ -504,7 +516,7 @@ async fn main() -> anyhow::Result<()> {
)
.route(
"/api/ai-filters",
- post(move |body| routes::post_ai_filters(state_ai_filters.clone(), body))
+ post(move |ext, body| routes::post_ai_filters(state_ai_filters.clone(), ext, body))
.layer(ConcurrencyLimitLayer::new(5)),
)
.route(
diff --git a/server-rs/src/pocketbase.rs b/server-rs/src/pocketbase.rs
index ac4aa64..b1a4aa3 100644
--- a/server-rs/src/pocketbase.rs
+++ b/server-rs/src/pocketbase.rs
@@ -240,9 +240,11 @@ async fn ensure_user_fields(
let has_is_admin = fields.iter().any(|f| f["name"] == "is_admin");
let has_subscription = fields.iter().any(|f| f["name"] == "subscription");
let has_newsletter = fields.iter().any(|f| f["name"] == "newsletter");
+ let has_ai_tokens_used = fields.iter().any(|f| f["name"] == "ai_tokens_used");
+ let has_ai_tokens_week = fields.iter().any(|f| f["name"] == "ai_tokens_week");
- if has_is_admin && has_subscription && has_newsletter {
- info!("PocketBase users collection already has is_admin, subscription, and newsletter fields");
+ if has_is_admin && has_subscription && has_newsletter && has_ai_tokens_used && has_ai_tokens_week {
+ info!("PocketBase users collection already has all required fields");
return Ok(());
}
@@ -269,6 +271,20 @@ async fn ensure_user_fields(
}));
}
+ if !has_ai_tokens_used {
+ new_fields.push(serde_json::json!({
+ "name": "ai_tokens_used",
+ "type": "number",
+ }));
+ }
+
+ if !has_ai_tokens_week {
+ new_fields.push(serde_json::json!({
+ "name": "ai_tokens_week",
+ "type": "number",
+ }));
+ }
+
let patch_resp = client
.patch(&url)
.header("Authorization", format!("Bearer {token}"))
diff --git a/server-rs/src/routes.rs b/server-rs/src/routes.rs
index ae1a416..946e763 100644
--- a/server-rs/src/routes.rs
+++ b/server-rs/src/routes.rs
@@ -26,7 +26,7 @@ pub(crate) mod travel_time;
mod travel_destinations;
mod travel_modes;
-pub use ai_filters::{build_ollama_schema, build_system_prompt, post_ai_filters};
+pub use ai_filters::{build_system_prompt, post_ai_filters};
pub use checkout::post_checkout;
pub use export::get_export;
pub use features::{build_features_response, get_features, FeatureInfo, FeaturesResponse};
diff --git a/server-rs/src/routes/ai_filters.rs b/server-rs/src/routes/ai_filters.rs
index 8a9d522..b8911b5 100644
--- a/server-rs/src/routes/ai_filters.rs
+++ b/server-rs/src/routes/ai_filters.rs
@@ -2,76 +2,190 @@ use std::sync::Arc;
use axum::http::StatusCode;
use axum::response::Json;
+use axum::Extension;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use tracing::{info, warn};
-use crate::consts::{AI_FILTERS_MAX_TOKENS, AI_FILTERS_TEMPERATURE};
+use crate::auth::OptionalUser;
+use crate::consts::{AI_FILTERS_MAX_TOKENS, AI_FILTERS_TEMPERATURE, AI_FILTERS_WEEKLY_TOKEN_LIMIT};
+use crate::data::slugify;
+use crate::pocketbase::auth_superuser;
use crate::routes::{FeatureInfo, FeaturesResponse};
use crate::state::AppState;
-use crate::utils::{extract_ollama_content, ollama_chat, strip_think_blocks};
+use crate::utils::gemini_chat;
+
+#[derive(Deserialize)]
+pub struct AiFiltersContext {
+ filters: Value,
+ #[serde(default)]
+ travel_time: Vec,
+}
+
+#[derive(Deserialize)]
+pub struct AiTravelTimeContext {
+ mode: String,
+ label: String,
+ min: Option,
+ max: Option,
+}
#[derive(Deserialize)]
pub struct AiFiltersRequest {
query: String,
+ /// Current filters for conversational refinement (e.g. "make it cheaper")
+ context: Option,
+}
+
+#[derive(Serialize)]
+pub struct TravelTimeFilter {
+ mode: String,
+ slug: String,
+ label: String,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ min: Option,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ max: Option,
}
#[derive(Serialize)]
pub struct AiFiltersResponse {
filters: Value,
+ #[serde(skip_serializing_if = "Vec::is_empty")]
+ travel_time_filters: Vec,
/// What the LLM couldn't map to existing filters (empty if everything matched)
#[serde(skip_serializing_if = "String::is_empty")]
notes: String,
}
-/// Build a JSON schema for Ollama structured output.
-///
-/// Uses two arrays (`numeric_filters` and `enum_filters`) instead of one property
-/// per feature, because Ollama converts JSON schema to GBNF grammar and a schema
-/// with 50+ optional keys causes a combinatorial explosion that crashes the parser.
-/// Array-based schema keeps the grammar small and constant-size.
-pub fn build_ollama_schema(_features: &FeaturesResponse) -> Value {
- json!({
- "type": "object",
- "properties": {
- "numeric_filters": {
- "type": "array",
- "items": {
- "type": "object",
- "properties": {
- "name": { "type": "string" },
- "bound": { "type": "string", "enum": ["min", "max"] },
- "value": { "type": "number" }
- },
- "required": ["name", "bound", "value"]
- }
- },
- "enum_filters": {
- "type": "array",
- "items": {
- "type": "object",
- "properties": {
- "name": { "type": "string" },
- "values": { "type": "array", "items": { "type": "string" } }
- },
- "required": ["name", "values"]
- }
- },
- "notes": {
- "type": "string"
- }
+/// Strip markdown code fences (```json ... ``` or ``` ... ```) from LLM output.
+/// Models occasionally wrap JSON in markdown fencing even when told not to.
+fn strip_markdown_fences(text: &str) -> &str {
+ let trimmed = text.trim();
+
+ // Try ```json\n...\n``` or ```\n...\n```
+ if let Some(rest) = trimmed.strip_prefix("```") {
+ // Skip optional language tag (e.g. "json")
+ let rest = if let Some(newline_pos) = rest.find('\n') {
+ &rest[newline_pos + 1..]
+ } else {
+ return trimmed;
+ };
+ if let Some(content) = rest.strip_suffix("```") {
+ return content.trim();
}
- })
+ }
+
+ trimmed
+}
+
+/// Build the Gemini tool declaration for destination search.
+fn build_tool_declarations(state: &AppState) -> Value {
+ let modes: Vec<&str> = state
+ .travel_time_store
+ .available_modes
+ .iter()
+ .map(|mode| mode.as_str())
+ .collect();
+
+ json!([{
+ "functionDeclarations": [{
+ "name": "search_destinations",
+ "description": "Search for available travel time destinations (cities, stations, towns) that have precomputed travel time data. Call this when the user mentions wanting to be near, close to, or within a certain travel time of a specific place.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "Place name to search for (e.g. 'Manchester', 'Kings Cross', 'Heathrow')"
+ },
+ "mode": {
+ "type": "string",
+ "enum": modes,
+ "description": "Transport mode to search destinations for"
+ }
+ },
+ "required": ["query", "mode"]
+ }
+ }]
+ }])
+}
+
+/// Execute a destination search against PlaceData + TravelTimeStore.
+/// Returns matching destinations as a JSON value with `results` and optional `message`.
+///
+/// Uses word-based matching: all words in the query must appear somewhere in the
+/// place name (order-independent). Also matches against slugs for short queries.
+fn execute_destination_search(state: &AppState, query: &str, mode: &str) -> Value {
+ let query_lower = query.to_lowercase();
+ let query_words: Vec<&str> = query_lower.split_whitespace().collect();
+ let query_slug = slugify(query);
+ let tt_store = &state.travel_time_store;
+ let pd = &state.place_data;
+
+ let slug_set = match tt_store.destinations.get(mode) {
+ Some(slugs) => slugs,
+ None => return json!({ "results": [], "message": format!("No travel data available for mode '{}'", mode) }),
+ };
+
+ // Find places matching the query that have travel time data.
+ // A place matches if ALL query words appear in its name, OR its slug matches the query slug.
+ let mut matches: Vec<(usize, String, u8, u32)> = pd
+ .name_lower
+ .iter()
+ .enumerate()
+ .filter_map(|(idx, name_lower)| {
+ let words_match = query_words.iter().all(|word| name_lower.contains(word));
+ let slug = slugify(&pd.name[idx]);
+ let slug_match = slug.contains(&query_slug) || query_slug.contains(&slug);
+ if !words_match && !slug_match {
+ return None;
+ }
+ if slug_set.contains(&slug) {
+ Some((idx, slug, pd.type_rank[idx], pd.population[idx]))
+ } else {
+ None
+ }
+ })
+ .collect();
+
+ // Sort: type rank asc, population desc
+ matches.sort_unstable_by(|a, b| a.2.cmp(&b.2).then(b.3.cmp(&a.3)));
+ matches.truncate(10);
+
+ if matches.is_empty() {
+ info!(query = query, mode = mode, "Destination search returned no results");
+ return json!({
+ "results": [],
+ "message": format!("No travel time data available for '{}' by {}. This destination cannot be used as a travel time filter.", query, mode)
+ });
+ }
+
+ let results: Vec = matches
+ .into_iter()
+ .map(|(idx, slug, ..)| {
+ json!({
+ "name": pd.name[idx],
+ "slug": slug,
+ "place_type": pd.place_type.get(idx).to_string(),
+ })
+ })
+ .collect();
+
+ json!({ "results": results })
}
/// Build the complete system prompt for AI filters.
///
-/// Contains: role instructions, feature catalogue, few-shot examples, output rules.
+/// Contains: role instructions, feature catalogue, travel time info,
+/// few-shot examples, output rules.
/// Precomputed at startup and cached in AppState.
-pub fn build_system_prompt(features: &FeaturesResponse) -> String {
+pub fn build_system_prompt(
+ features: &FeaturesResponse,
+ mode_destinations: &[(String, usize)],
+) -> String {
let mut parts = Vec::new();
- // Role and task description
parts.push(
"You are a UK property search assistant. \
The user describes their ideal property or area in natural language. \
@@ -91,10 +205,61 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String {
(note: this counts bedrooms + living rooms combined, so 3 bed ~ min 4).\n\
- If the user mentions something that has no matching filter, put it in \"notes\" \
as a short phrase (e.g. \"No filter for: garden, sea view\"). \
- If everything was matched, set \"notes\" to an empty string."
+ If everything was matched, set \"notes\" to an empty string.\n\
+ \n\
+ CONVERSATIONAL REFINEMENT:\n\
+ The user's message may include their currently active filters as context. \
+ When context is provided:\n\
+ - \"make it cheaper\" / \"lower the price\" = adjust the existing price filter down\n\
+ - \"also add ...\" / \"and good schools\" = keep existing filters and add new ones\n\
+ - \"remove the ...\" / \"drop the ...\" = return filters WITHOUT the mentioned one\n\
+ - If the request is a completely new search (not a refinement), ignore the context \
+ and build filters from scratch.\n\
+ - Always output the COMPLETE set of filters (existing + modified), not just the changes."
.to_string(),
);
+ // Travel time section with available modes
+ let modes_list = mode_destinations
+ .iter()
+ .map(|(mode, count)| format!("- {} ({} destinations available)", mode, count))
+ .collect::>()
+ .join("\n");
+
+ parts.push(format!(
+ "\n--- TRAVEL TIME FILTERS ---\n\
+ You can add travel time filters when the user mentions commute times, \
+ proximity to places, or wanting to be near/within X minutes of somewhere.\n\
+ \n\
+ Available transport modes (only use modes that have destinations):\n\
+ {}\n\
+ - \"car\" / \"drive\" / \"driving\" = car mode\n\
+ - \"cycle\" / \"bike\" / \"cycling\" = bicycle mode\n\
+ - \"walk\" / \"walking\" / \"on foot\" = walking mode\n\
+ - \"train\" / \"tube\" / \"bus\" / \"public transport\" / \"commute\" = transit mode\n\
+ \n\
+ When the user mentions a specific place, you MUST call the search_destinations \
+ tool to find the exact slug. Use the name and slug from the search results.\n\
+ If search_destinations returns an empty array, the destination is not available — \
+ mention it in \"notes\" (e.g. \"No travel data for: Gatwick Airport\") and do NOT \
+ include a travel_time_filter for it.\n\
+ \n\
+ Travel time values are in MINUTES (0-120 range).\n\
+ - \"within 30 minutes\" = max 30\n\
+ - \"at least 10 minutes\" = min 10\n\
+ - \"30-45 minute commute\" = min 30, max 45\n\
+ - If only a max is given, omit min (and vice versa).\n\
+ \n\
+ INFERRING TRANSPORT MODE (when the user does not specify one explicitly):\n\
+ - \"commute\" to a major city centre or station = transit\n\
+ - \"near\" / \"close to\" a city centre or station = transit\n\
+ - \"near\" / \"close to\" a smaller town, village, or rural area = car\n\
+ - \"drive\" / \"driving distance\" / \"driving time\" = always car\n\
+ - If multiple modes are plausible, prefer transit for urban destinations \
+ (London, Manchester, Birmingham, Leeds, etc.) and car for everything else.",
+ modes_list,
+ ));
+
// Feature catalogue
parts.push("\n--- AVAILABLE FEATURES ---\n".to_string());
for group in &features.groups {
@@ -148,6 +313,7 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String {
Output: {\"numeric_filters\": [{\"name\": \"Last known price\", \"bound\": \"max\", \"value\": 400000}], \
\"enum_filters\": [{\"name\": \"Leasehold/Freehold\", \"values\": [\"Freehold\"]}, \
{\"name\": \"Property type\", \"values\": [\"Detached\", \"Semi-Detached\", \"Terraced\"]}], \
+ \"travel_time_filters\": [], \
\"notes\": \"\"}"
.to_string(),
);
@@ -161,7 +327,7 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String {
{\"name\": \"Good+ primary schools within 5km\", \"bound\": \"min\", \"value\": 5}, \
{\"name\": \"Good+ secondary schools within 5km\", \"bound\": \"min\", \"value\": 2}, \
{\"name\": \"Number of parks within 2km\", \"bound\": \"min\", \"value\": 3}], \
- \"enum_filters\": [], \"notes\": \"\"}"
+ \"enum_filters\": [], \"travel_time_filters\": [], \"notes\": \"\"}"
.to_string(),
);
@@ -172,18 +338,37 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String {
{\"name\": \"Number of bedrooms & living rooms\", \"bound\": \"min\", \"value\": 4}], \
\"enum_filters\": [{\"name\": \"Property type\", \"values\": [\"Flats/Maisonettes\"]}, \
{\"name\": \"Max available download speed (Mbps)\", \"values\": [\"100\", \"300\", \"1000\"]}], \
+ \"travel_time_filters\": [], \
\"notes\": \"No filter for: beach proximity\"}"
.to_string(),
);
parts.push(
- "\nUser: \"large family home with a garden near restaurants\"\n\
+ "\nUser: \"within 30 minutes commute of Kings Cross, under 500k\"\n\
+ (After calling search_destinations for \"Kings Cross\" with mode \"transit\" \
+ and getting [{\"name\": \"Kings Cross\", \"slug\": \"kings-cross\", \"place_type\": \"station\"}])\n\
+ Output: {\"numeric_filters\": [\
+ {\"name\": \"Last known price\", \"bound\": \"max\", \"value\": 500000}], \
+ \"enum_filters\": [], \
+ \"travel_time_filters\": [{\"mode\": \"transit\", \"slug\": \"kings-cross\", \
+ \"label\": \"Kings Cross\", \"bound\": \"max\", \"value\": 30}], \
+ \"notes\": \"\"}"
+ .to_string(),
+ );
+
+ parts.push(
+ "\nUser: \"family home with garden, 45 min drive from Manchester, good schools\"\n\
+ (After calling search_destinations for \"Manchester\" with mode \"car\" \
+ and getting [{\"name\": \"Manchester\", \"slug\": \"manchester\", \"place_type\": \"city\"}])\n\
Output: {\"numeric_filters\": [\
{\"name\": \"Total floor area (sqm)\", \"bound\": \"min\", \"value\": 100}, \
{\"name\": \"Number of bedrooms & living rooms\", \"bound\": \"min\", \"value\": 5}, \
- {\"name\": \"Number of restaurants within 2km\", \"bound\": \"min\", \"value\": 10}], \
+ {\"name\": \"Good+ primary schools within 5km\", \"bound\": \"min\", \"value\": 5}, \
+ {\"name\": \"Good+ secondary schools within 5km\", \"bound\": \"min\", \"value\": 2}], \
\"enum_filters\": [{\"name\": \"Property type\", \
\"values\": [\"Detached\", \"Semi-Detached\"]}], \
+ \"travel_time_filters\": [{\"mode\": \"car\", \"slug\": \"manchester\", \
+ \"label\": \"Manchester\", \"bound\": \"max\", \"value\": 45}], \
\"notes\": \"No filter for: garden\"}"
.to_string(),
);
@@ -191,7 +376,8 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String {
// Output format reminder
parts.push(
"\n--- OUTPUT FORMAT ---\n\
- {\"numeric_filters\": [{\"name\": \"...\", \"bound\": \"min\"|\"max\", \"value\": N}, ...], \"enum_filters\": [...], \"notes\": \"...\"}\n\
+ {\"numeric_filters\": [...], \"enum_filters\": [...], \"travel_time_filters\": [{\"mode\": \"...\", \"slug\": \"...\", \"label\": \"...\", \"bound\": \"min\"|\"max\", \"value\": N}, ...], \"notes\": \"...\"}\n\
+ - travel_time_filters: use ONLY slugs returned by search_destinations. If a place isn't found, mention it in notes.\n\
Respond with ONLY the JSON object. No explanation."
.to_string(),
);
@@ -199,86 +385,393 @@ pub fn build_system_prompt(features: &FeaturesResponse) -> String {
parts.join("\n")
}
-pub async fn post_ai_filters(
- state: Arc,
- Json(req): Json,
-) -> Result, (StatusCode, String)> {
- info!(query = %req.query, "POST /api/ai-filters");
-
- let url = format!("{}/api/chat", state.ollama_url);
- let body = json!({
- "model": state.ollama_model,
- "messages": [
- { "role": "system", "content": state.ai_filters_system_prompt },
- { "role": "user", "content": req.query }
- ],
- "stream": false,
- "format": state.ai_filters_schema,
- "options": {
- "temperature": AI_FILTERS_TEMPERATURE,
- "num_predict": AI_FILTERS_MAX_TOKENS,
- }
- });
-
- // Try up to 2 attempts — LLMs occasionally return empty content (e.g. only
- // blocks with no JSON output), which is transient and usually
- // succeeds on retry.
- let mut last_err = None;
- for attempt in 0..2 {
- let raw = call_ollama_and_parse(&state.http_client, &url, &body).await;
- match raw {
- Ok(raw) => {
- let filters = validate_and_convert(&raw, &state.features_response);
- let notes = raw
- .get("notes")
- .and_then(|val| val.as_str())
- .unwrap_or("")
- .to_string();
- return Ok(Json(AiFiltersResponse { filters, notes }));
- }
- Err(err) => {
- if attempt == 0 {
- warn!("LLM attempt 1 failed, retrying: {}", err.1);
- }
- last_err = Some(err);
- }
- }
- }
-
- Err(last_err.unwrap())
+/// Monotonically increasing week number derived from Unix epoch.
+/// Resets every 7 days (604800 seconds). Used for weekly rate limiting.
+fn current_week_number() -> u64 {
+ let secs = std::time::SystemTime::now()
+ .duration_since(std::time::UNIX_EPOCH)
+ .expect("system time before epoch")
+ .as_secs();
+ secs / 604_800
}
-/// Call Ollama and parse the response content as JSON.
-///
-/// Returns an error if: the HTTP call fails, the response is malformed,
-/// the content is empty after stripping think blocks, or the content is
-/// not valid JSON.
-async fn call_ollama_and_parse(
- client: &reqwest::Client,
- url: &str,
- body: &Value,
-) -> Result {
- let json_resp = ollama_chat(client, url, body).await?;
- let content = extract_ollama_content(&json_resp)?;
+/// Fetch the user's current AI token usage from PocketBase.
+/// Returns `(tokens_used, week_number)`.
+async fn fetch_ai_usage(
+ state: &AppState,
+ user_id: &str,
+) -> Result<(u64, u64), (StatusCode, String)> {
+ let pb_url = state.pocketbase_url.trim_end_matches('/');
+ let token = auth_superuser(
+ &state.http_client,
+ pb_url,
+ &state.pocketbase_admin_email,
+ &state.pocketbase_admin_password,
+ )
+ .await
+ .map_err(|err| {
+ warn!("Failed to auth superuser for AI usage check: {err}");
+ (StatusCode::BAD_GATEWAY, "Internal error".into())
+ })?;
- let content = strip_think_blocks(content);
- let content = content.trim();
+ let url = format!("{pb_url}/api/collections/users/records/{user_id}");
+ let resp = state
+ .http_client
+ .get(&url)
+ .header("Authorization", format!("Bearer {token}"))
+ .send()
+ .await
+ .map_err(|err| {
+ warn!("Failed to fetch user record for AI usage: {err}");
+ (StatusCode::BAD_GATEWAY, "Internal error".into())
+ })?;
- if content.is_empty() {
- warn!("LLM returned empty content after stripping think blocks");
+ if !resp.status().is_success() {
+ let status = resp.status();
+ warn!("PocketBase user fetch failed ({status})");
+ return Err((StatusCode::BAD_GATEWAY, "Internal error".into()));
+ }
+
+ let body: Value = resp.json().await.map_err(|err| {
+ warn!("Failed to parse user record: {err}");
+ (StatusCode::BAD_GATEWAY, "Internal error".into())
+ })?;
+
+ let tokens_used = body
+ .get("ai_tokens_used")
+ .and_then(|val| val.as_u64())
+ .unwrap_or(0);
+ let week = body
+ .get("ai_tokens_week")
+ .and_then(|val| val.as_u64())
+ .unwrap_or(0);
+
+ Ok((tokens_used, week))
+}
+
+/// Update the user's AI token usage in PocketBase.
+/// Best-effort — logs warnings on failure but does not propagate errors.
+async fn update_ai_usage(state: &AppState, user_id: &str, tokens_used: u64, week: u64) {
+ let pb_url = state.pocketbase_url.trim_end_matches('/');
+ let token = match auth_superuser(
+ &state.http_client,
+ pb_url,
+ &state.pocketbase_admin_email,
+ &state.pocketbase_admin_password,
+ )
+ .await
+ {
+ Ok(tk) => tk,
+ Err(err) => {
+ warn!("Failed to auth superuser for AI usage update: {err}");
+ return;
+ }
+ };
+
+ let url = format!("{pb_url}/api/collections/users/records/{user_id}");
+ let res = state
+ .http_client
+ .patch(&url)
+ .header("Authorization", format!("Bearer {token}"))
+ .json(&json!({
+ "ai_tokens_used": tokens_used,
+ "ai_tokens_week": week,
+ }))
+ .send()
+ .await;
+
+ match res {
+ Ok(resp) if resp.status().is_success() => {}
+ Ok(resp) => {
+ let status = resp.status();
+ warn!("Failed to update AI usage ({status})");
+ }
+ Err(err) => warn!("Failed to update AI usage: {err}"),
+ }
+}
+
+/// Maximum number of round trips (function calls + retries) before giving up.
+const MAX_TOOL_ROUNDS: usize = 5;
+
+pub async fn post_ai_filters(
+ state: Arc,
+ Extension(user): Extension,
+ Json(req): Json,
+) -> Result, (StatusCode, String)> {
+ // Auth check
+ let user = user
+ .0
+ .ok_or((StatusCode::UNAUTHORIZED, "Login required".into()))?;
+
+ // Email verification check (skipped in dev mode)
+ if !user.verified && !state.is_dev {
return Err((
- StatusCode::BAD_GATEWAY,
- "LLM returned empty content (no JSON output)".into(),
+ StatusCode::FORBIDDEN,
+ "Please verify your email to use AI filters".into(),
));
}
- serde_json::from_str(content).map_err(|err| {
- warn!(error = %err, content = %content, "Failed to parse LLM JSON output");
- (
- StatusCode::BAD_GATEWAY,
- format!("Failed to parse LLM output as JSON: {}", err),
+ // Check weekly token usage
+ let current_week = current_week_number();
+ let (stored_tokens, stored_week) = fetch_ai_usage(&state, &user.id).await?;
+ let tokens_used = if stored_week == current_week {
+ stored_tokens
+ } else {
+ 0
+ };
+
+ if tokens_used >= AI_FILTERS_WEEKLY_TOKEN_LIMIT {
+ return Err((
+ StatusCode::TOO_MANY_REQUESTS,
+ "Weekly AI usage limit reached. Resets next week.".into(),
+ ));
+ }
+
+ info!(query = %req.query, user_id = %user.id, "POST /api/ai-filters");
+
+ let tools = build_tool_declarations(&state);
+
+ // Build user message with optional context for conversational refinement
+ let user_text = if let Some(ref ctx) = req.context {
+ let mut msg = String::new();
+ msg.push_str("Currently active filters:\n");
+ msg.push_str(&serde_json::to_string(&ctx.filters).unwrap_or_default());
+ if !ctx.travel_time.is_empty() {
+ msg.push_str("\nCurrently active travel time filters:\n");
+ for tt in &ctx.travel_time {
+ let bounds = match (tt.min, tt.max) {
+ (Some(min), Some(max)) => format!("{}-{} min", min, max),
+ (Some(min), None) => format!("min {} min", min),
+ (None, Some(max)) => format!("max {} min", max),
+ (None, None) => "no range".to_string(),
+ };
+ msg.push_str(&format!("- {} to {} ({})\n", tt.mode, tt.label, bounds));
+ }
+ }
+ msg.push_str(&format!("\nUser request: {}", req.query));
+ msg
+ } else {
+ req.query.clone()
+ };
+
+ let mut contents = vec![json!({
+ "role": "user",
+ "parts": [{ "text": user_text }]
+ })];
+
+ let mut total_tokens_accumulated: u64 = 0;
+
+ // Function calling loop: model may call search_destinations, we execute and feed back
+ for round in 0..MAX_TOOL_ROUNDS {
+ let body = json!({
+ "systemInstruction": {
+ "parts": [{ "text": state.ai_filters_system_prompt }]
+ },
+ "contents": contents,
+ "tools": tools,
+ "generationConfig": {
+ "temperature": AI_FILTERS_TEMPERATURE,
+ "maxOutputTokens": AI_FILTERS_MAX_TOKENS,
+ "thinkingConfig": { "thinkingLevel": "LOW" },
+ }
+ });
+
+ let json_resp = gemini_chat(
+ &state.http_client,
+ &state.gemini_api_key,
+ &state.gemini_model,
+ &body,
)
- })
+ .await?;
+
+ // Accumulate token usage
+ total_tokens_accumulated += json_resp
+ .get("usageMetadata")
+ .and_then(|md| md.get("totalTokenCount"))
+ .and_then(|tc| tc.as_u64())
+ .unwrap_or(0);
+
+ let candidate = json_resp
+ .get("candidates")
+ .and_then(|cs| cs.get(0))
+ .and_then(|c| c.get("content"))
+ .ok_or_else(|| {
+ warn!("Malformed Gemini response: missing candidates[0].content");
+ (
+ StatusCode::BAD_GATEWAY,
+ "Malformed Gemini response".into(),
+ )
+ })?;
+
+ let parts = candidate
+ .get("parts")
+ .and_then(|p| p.as_array())
+ .ok_or_else(|| {
+ warn!("Malformed Gemini response: missing parts array");
+ (
+ StatusCode::BAD_GATEWAY,
+ "Malformed Gemini response".into(),
+ )
+ })?;
+
+ // Check if the model made a function call.
+ // Find the full part (includes thoughtSignature required by Gemini 3 models).
+ if let Some(fc_part) = parts.iter().find(|part| part.get("functionCall").is_some()) {
+ let fc = fc_part.get("functionCall").unwrap();
+ let fn_name = fc.get("name").and_then(|n| n.as_str()).unwrap_or("");
+ let fn_args = fc.get("args").cloned().unwrap_or(json!({}));
+
+ info!(
+ function = fn_name,
+ round = round,
+ "AI called tool"
+ );
+
+ let fn_result = if fn_name == "search_destinations" {
+ let query = fn_args
+ .get("query")
+ .and_then(|q| q.as_str())
+ .unwrap_or("");
+ let mode = fn_args
+ .get("mode")
+ .and_then(|m| m.as_str())
+ .unwrap_or("transit");
+ execute_destination_search(&state, query, mode)
+ } else {
+ json!({"error": "unknown function"})
+ };
+
+ // Append the model's full response (preserves thoughtSignature) + our function result
+ contents.push(candidate.clone());
+ contents.push(json!({
+ "role": "user",
+ "parts": [{
+ "functionResponse": {
+ "name": fn_name,
+ "response": { "results": fn_result }
+ }
+ }]
+ }));
+
+ // Continue the loop — model will process the results
+ continue;
+ }
+
+ // Model returned text — extract and parse as JSON
+ let text = parts
+ .iter()
+ .find_map(|part| part.get("text").and_then(|t| t.as_str()))
+ .unwrap_or("");
+ let text = strip_markdown_fences(text);
+ let text = text.trim();
+
+ if text.is_empty() {
+ warn!("Gemini returned empty text content (round {})", round);
+ // Retry by continuing the loop
+ contents.push(candidate.clone());
+ contents.push(json!({
+ "role": "user",
+ "parts": [{ "text": "Your response was empty. Please output the JSON object." }]
+ }));
+ continue;
+ }
+
+ let raw: Value = match serde_json::from_str(text) {
+ Ok(val) => val,
+ Err(err) => {
+ warn!(error = %err, round = round, "Failed to parse Gemini JSON output, retrying");
+ // Ask the model to fix its output
+ contents.push(candidate.clone());
+ contents.push(json!({
+ "role": "user",
+ "parts": [{ "text": "That was not valid JSON. Please output ONLY the JSON object with numeric_filters, enum_filters, travel_time_filters, and notes." }]
+ }));
+ continue;
+ }
+ };
+
+ let filters = validate_and_convert(&raw, &state.features_response);
+ let travel_time_filters = validate_travel_time_filters(&raw, &state);
+ let notes = raw
+ .get("notes")
+ .and_then(|val| val.as_str())
+ .unwrap_or("")
+ .to_string();
+
+ // Update usage with total accumulated tokens
+ let new_total = tokens_used + total_tokens_accumulated;
+ update_ai_usage(&state, &user.id, new_total, current_week).await;
+
+ return Ok(Json(AiFiltersResponse {
+ filters,
+ travel_time_filters,
+ notes,
+ }));
+ }
+
+ // Exhausted tool rounds without getting a final text response
+ warn!("AI exhausted {} tool-calling rounds without final response", MAX_TOOL_ROUNDS);
+ Err((
+ StatusCode::BAD_GATEWAY,
+ "AI could not complete the request".into(),
+ ))
+}
+
+/// Validate travel time filters from LLM output against available destinations.
+fn validate_travel_time_filters(raw: &Value, state: &AppState) -> Vec {
+ let arr = match raw
+ .get("travel_time_filters")
+ .and_then(|val| val.as_array())
+ {
+ Some(arr) => arr,
+ None => return Vec::new(),
+ };
+
+ let tt_store = &state.travel_time_store;
+ let mut results = Vec::new();
+
+ for item in arr {
+ let mode = match item.get("mode").and_then(|val| val.as_str()) {
+ Some(mode) => mode,
+ None => continue,
+ };
+ let slug = match item.get("slug").and_then(|val| val.as_str()) {
+ Some(slug) => slug,
+ None => continue,
+ };
+ let label = item
+ .get("label")
+ .and_then(|val| val.as_str())
+ .unwrap_or(slug);
+
+ // Verify this destination actually exists
+ if !tt_store.has_destination(mode, slug) {
+ warn!(mode = mode, slug = slug, "AI suggested non-existent destination");
+ continue;
+ }
+
+ let bound = item.get("bound").and_then(|val| val.as_str());
+ let value = item.get("value").and_then(|val| val.as_f64());
+
+ let (min, max) = match (bound, value) {
+ (Some("min"), Some(val)) => (Some(val.max(0.0).min(120.0) as f32), None),
+ (Some("max"), Some(val)) => (None, Some(val.max(0.0).min(120.0) as f32)),
+ _ => (None, None),
+ };
+
+ // Only include if at least one bound is set
+ if min.is_some() || max.is_some() {
+ results.push(TravelTimeFilter {
+ mode: mode.to_string(),
+ slug: slug.to_string(),
+ label: label.to_string(),
+ min,
+ max,
+ });
+ }
+ }
+
+ results
}
/// Validate LLM output against feature metadata and convert to FeatureFilters format.
@@ -374,3 +867,32 @@ fn validate_and_convert(raw: &Value, features: &FeaturesResponse) -> Value {
Value::Object(result)
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn strip_fences_json_tag() {
+ let input = "```json\n{\"a\": 1}\n```";
+ assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
+ }
+
+ #[test]
+ fn strip_fences_no_tag() {
+ let input = "```\n{\"a\": 1}\n```";
+ assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
+ }
+
+ #[test]
+ fn strip_fences_passthrough() {
+ let input = "{\"a\": 1}";
+ assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
+ }
+
+ #[test]
+ fn strip_fences_whitespace() {
+ let input = " ```json\n {\"a\": 1} \n``` ";
+ assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
+ }
+}
diff --git a/server-rs/src/routes/invites.rs b/server-rs/src/routes/invites.rs
index d04ed73..1d773ae 100644
--- a/server-rs/src/routes/invites.rs
+++ b/server-rs/src/routes/invites.rs
@@ -178,7 +178,7 @@ pub async fn get_invite(
}
// Dev-only: return a fake valid admin invite without hitting PocketBase
- if state.index_html.is_none() && code == DEV_INVITE_CODE {
+ if state.is_dev && code == DEV_INVITE_CODE {
return Json(InviteValidation {
valid: true,
invite_type: "admin".to_string(),
@@ -294,7 +294,7 @@ pub async fn post_redeem_invite(
}
// Dev-only: fake redeem — just return "licensed" without touching PocketBase
- if state.index_html.is_none() && req.code == DEV_INVITE_CODE {
+ if state.is_dev && req.code == DEV_INVITE_CODE {
info!(user_id = %user.id, "Dev invite redeemed (no-op)");
return Json(RedeemResponse {
result: "licensed".to_string(),
diff --git a/server-rs/src/state.rs b/server-rs/src/state.rs
index 4130a14..99ebd58 100644
--- a/server-rs/src/state.rs
+++ b/server-rs/src/state.rs
@@ -34,6 +34,8 @@ pub struct AppState {
pub screenshot_url: String,
/// Public-facing URL for absolute og:image URLs (e.g. https://perfectpostcodes.schmelczer.dev)
pub public_url: String,
+ /// True when --dist is not provided (no static serving, relaxed auth checks)
+ pub is_dev: bool,
/// Contents of index.html read at startup, used for crawler OG injection (None when --dist omitted)
pub index_html: Option,
/// Shared HTTP client for proxying to the screenshot service and PocketBase
@@ -44,16 +46,14 @@ pub struct AppState {
pub pocketbase_admin_email: String,
/// PocketBase superuser password
pub pocketbase_admin_password: String,
- /// Ollama server URL for AI area summaries (e.g. http://ollama:11434)
- pub ollama_url: String,
- /// Ollama model name for area summaries (e.g. gemma3:12b)
- pub ollama_model: String,
+ /// Gemini API key for AI filters
+ pub gemini_api_key: String,
+ /// Gemini model name (e.g. gemini-2.0-flash)
+ pub gemini_model: String,
/// Precomputed travel time data store
pub travel_time_store: Arc,
/// Token validation cache (60s TTL)
pub token_cache: Arc,
- /// JSON schema for Ollama structured output in AI filters
- pub ai_filters_schema: serde_json::Value,
/// Complete system prompt for AI filters (features + examples + instructions)
pub ai_filters_system_prompt: String,
/// Google Maps API key for Street View metadata lookups
diff --git a/server-rs/src/utils.rs b/server-rs/src/utils.rs
index 6961b84..2dcc1ec 100644
--- a/server-rs/src/utils.rs
+++ b/server-rs/src/utils.rs
@@ -6,7 +6,7 @@ mod llm;
pub use grid_index::GridIndex;
pub use hash::{generate_priorities, splitmix64_hash};
pub use interned_column::InternedColumn;
-pub use llm::{extract_ollama_content, ollama_chat, strip_think_blocks};
+pub use llm::{extract_gemini_content, gemini_chat};
/// Normalize a UK postcode: uppercase, strip spaces, insert canonical space before inward code.
/// e.g. "e142dg" → "E14 2DG", "E14 2DG" → "E14 2DG", "EC1A1BB" → "EC1A 1BB"
diff --git a/server-rs/src/utils/llm.rs b/server-rs/src/utils/llm.rs
index 8625ded..58a74ae 100644
--- a/server-rs/src/utils/llm.rs
+++ b/server-rs/src/utils/llm.rs
@@ -4,68 +4,62 @@ use tracing::warn;
pub type LlmError = (StatusCode, String);
-/// Send a chat request to Ollama and return the parsed JSON response.
+/// Send a generateContent request to the Gemini API and return the parsed JSON response.
///
/// Handles connection errors, non-success status codes, and JSON parse failures
/// uniformly as `BAD_GATEWAY` errors.
-pub async fn ollama_chat(
+pub async fn gemini_chat(
client: &reqwest::Client,
- url: &str,
+ api_key: &str,
+ model: &str,
body: &Value,
) -> Result {
- let response = client.post(url).json(body).send().await.map_err(|err| {
- warn!(error = %err, "Failed to connect to Ollama");
+ let url = format!(
+ "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
+ model, api_key
+ );
+
+ let response = client.post(&url).json(body).send().await.map_err(|err| {
+ warn!(error = %err, "Failed to connect to Gemini API");
(
StatusCode::BAD_GATEWAY,
- format!("Failed to connect to Ollama: {}", err),
+ format!("Failed to connect to Gemini API: {}", err),
)
})?;
if !response.status().is_success() {
let status = response.status();
let body_text = response.text().await.unwrap_or_default();
- warn!(status = %status, body = %body_text, "Ollama returned error");
+ warn!(status = %status, body = %body_text, "Gemini API returned error");
return Err((
StatusCode::BAD_GATEWAY,
- format!("Ollama error {}: {}", status, body_text),
+ format!("Gemini API error {}: {}", status, body_text),
));
}
response.json().await.map_err(|err| {
- warn!(error = %err, "Failed to parse Ollama response");
+ warn!(error = %err, "Failed to parse Gemini API response");
(
StatusCode::BAD_GATEWAY,
- format!("Failed to parse Ollama response: {}", err),
+ format!("Failed to parse Gemini API response: {}", err),
)
})
}
-/// Extract content from Ollama native response (`message.content`)
-pub fn extract_ollama_content(json: &Value) -> Result<&str, LlmError> {
- json.get("message")
- .and_then(|msg| msg.get("content"))
- .and_then(|ct| ct.as_str())
+/// Extract text content from Gemini response (`candidates[0].content.parts[0].text`)
+pub fn extract_gemini_content(json: &Value) -> Result<&str, LlmError> {
+ json.get("candidates")
+ .and_then(|c| c.get(0))
+ .and_then(|c| c.get("content"))
+ .and_then(|c| c.get("parts"))
+ .and_then(|p| p.get(0))
+ .and_then(|p| p.get("text"))
+ .and_then(|t| t.as_str())
.ok_or_else(|| {
- warn!("Malformed Ollama response: missing message.content");
+ warn!("Malformed Gemini response: missing candidates[0].content.parts[0].text");
(
StatusCode::BAD_GATEWAY,
- "Malformed LLM response: missing message.content".into(),
+ "Malformed Gemini response: missing content".into(),
)
})
}
-
-/// Strip `...` blocks from model output
-pub fn strip_think_blocks(text: &str) -> String {
- let mut result = String::new();
- let mut remaining = text;
- while let Some(start) = remaining.find("") {
- result.push_str(&remaining[..start]);
- if let Some(end) = remaining[start..].find("") {
- remaining = &remaining[start + end + 8..];
- } else {
- return result;
- }
- }
- result.push_str(remaining);
- result
-}