good changes

This commit is contained in:
Andras Schmelczer 2026-03-25 08:04:48 +00:00
parent 160283f1a1
commit c997ea46a5
26 changed files with 991 additions and 288 deletions

View file

@ -12,6 +12,8 @@ use tracing::{info, warn};
use crate::auth::OptionalUser;
use crate::consts::{AI_FILTERS_MAX_TOKENS, AI_FILTERS_TEMPERATURE, AI_FILTERS_WEEKLY_TOKEN_LIMIT};
use crate::data::slugify;
use crate::data::travel_time::TravelData;
use crate::parsing::{parse_filters, row_passes_filters};
use crate::pocketbase::{get_superuser_token, log_ai_query};
use crate::routes::{FeatureInfo, FeaturesResponse};
use crate::state::{AppState, SharedState};
@ -62,6 +64,8 @@ pub struct AiFiltersResponse {
notes: String,
/// The listing mode used for this response (historical/buy/rent)
listing_type: String,
/// Number of properties matching the proposed filters (excludes travel time)
match_count: usize,
}
/// Strip markdown code fences (```json ... ``` or ``` ... ```) from LLM output.
@ -556,6 +560,117 @@ async fn update_ai_usage(state: &AppState, user_id: &str, tokens_used: u64, week
}
}
/// Convert validated filter JSON back to the `;;`-separated filter string format
/// that `parse_filters` expects.
///
/// Numeric: `{"name": [min, max]}` → `name:min:max`
/// Enum: `{"name": ["val1", "val2"]}` → `name:val1|val2`
fn filters_to_filter_string(filters: &Value) -> String {
let obj = match filters.as_object() {
Some(obj) => obj,
None => return String::new(),
};
let mut parts = Vec::new();
for (name, value) in obj {
if let Some(arr) = value.as_array() {
if arr.len() == 2 && arr[0].is_number() && arr[1].is_number() {
let min = arr[0].as_f64().unwrap_or(0.0);
let max = arr[1].as_f64().unwrap_or(0.0);
parts.push(format!("{name}:{min}:{max}"));
} else if !arr.is_empty() && arr[0].is_string() {
let values: Vec<&str> = arr.iter().filter_map(|v| v.as_str()).collect();
if !values.is_empty() {
parts.push(format!("{name}:{}", values.join("|")));
}
}
}
}
parts.join(";;")
}
/// Count how many rows in the property dataset pass the given property filters
/// AND travel time filters. Travel time data is loaded from the TravelTimeStore
/// and checked per-postcode (same logic as hexagons.rs).
fn count_matching_rows(
state: &AppState,
filters: &Value,
travel_time_filters: &[TravelTimeFilter],
) -> usize {
let filter_str = filters_to_filter_string(filters);
let quant = state.data.quant_ref();
let (parsed_filters, parsed_enum_filters) = if filter_str.is_empty() {
(Vec::new(), Vec::new())
} else {
match parse_filters(
Some(&filter_str),
&state.feature_name_to_index,
&state.data.enum_values,
&quant,
) {
Ok(f) => f,
Err(err) => {
warn!("Failed to parse filters for match count: {err}");
return 0;
}
}
};
// Load travel time data for each filter entry
let travel_data: Vec<(TravelData, Option<f32>, Option<f32>)> = travel_time_filters
.iter()
.filter_map(|ttf| {
let data = state.travel_time_store.get(&ttf.mode, &ttf.slug).ok()?;
Some((data, ttf.min, ttf.max))
})
.collect();
let has_travel = !travel_data.is_empty();
let feature_data = &state.data.feature_data;
let num_features = state.data.num_features;
let num_rows = state.data.lat.len();
let (pc_interner, pc_keys) = state.data.postcode_parts();
let mut count = 0usize;
for row in 0..num_rows {
if !row_passes_filters(
row,
&parsed_filters,
&parsed_enum_filters,
feature_data,
num_features,
) {
continue;
}
if has_travel {
let postcode = pc_interner.resolve(&pc_keys[row]);
let mut passes_travel = true;
for (data, fmin, fmax) in &travel_data {
let pass = if let Some(mins) = data.get(postcode).map(|r| r.minutes as f32) {
fmin.map_or(true, |min| mins >= min)
&& fmax.map_or(true, |max| mins <= max)
} else {
false // no travel data → postcode not reachable
};
if !pass {
passes_travel = false;
break;
}
}
if !passes_travel {
continue;
}
}
count += 1;
}
count
}
/// Maximum number of round trips (function calls + retries) before giving up.
const MAX_TOOL_ROUNDS: usize = 5;
@ -631,6 +746,7 @@ pub async fn post_ai_filters(
})];
let mut total_tokens_accumulated: u64 = 0;
let mut refinement_attempts = 0u32;
// Function calling loop: model may call search_destinations, we execute and feed back
for round in 0..MAX_TOOL_ROUNDS {
@ -776,6 +892,42 @@ pub async fn post_ai_filters(
map.insert("Listing status".to_string(), json!([listing_value]));
}
// Count matching properties and refine if too restrictive
let match_count = count_matching_rows(&state, &filters, &travel_time_filters);
info!(match_count = match_count, round = round, "AI filter match count");
if match_count == 0 {
refinement_attempts += 1;
let total_rows = state.data.lat.len();
info!(
attempt = refinement_attempts,
"0 matches out of {total_rows} — asking AI to relax filters"
);
let feedback = match refinement_attempts {
1 => format!(
"Your proposed filters matched 0 properties out of {total_rows} total. \
The combination is too restrictive. Please widen some numeric ranges \
or add more enum values while keeping the user's intent. \
Output the adjusted JSON."
),
2 => format!(
"Still 0 matches out of {total_rows}. Please widen ranges further. \
Output the adjusted JSON."
),
_ => format!(
"Still 0 matches out of {total_rows}. Please remove additional filters \
until some properties match, keeping the user's core priority. \
Output the adjusted JSON."
),
};
contents.push(candidate.clone());
contents.push(json!({
"role": "user",
"parts": [{ "text": feedback }]
}));
continue;
}
// Update usage with total accumulated tokens
let new_total = tokens_used + total_tokens_accumulated;
update_ai_usage(&state, &user.id, new_total, current_week).await;
@ -810,6 +962,7 @@ pub async fn post_ai_filters(
travel_time_filters,
notes,
listing_type: listing_type.to_string(),
match_count,
}));
}
@ -902,8 +1055,10 @@ fn validate_travel_time_filters(raw: &Value, state: &AppState) -> Vec<TravelTime
fn validate_and_convert(raw: &Value, features: &FeaturesResponse, listing_type: &str) -> Value {
let mut result = serde_json::Map::new();
// Build lookup maps from feature metadata, filtering by listing mode
let mut numeric_features: rustc_hash::FxHashMap<&str, (f32, f32)> =
// Build lookup maps from feature metadata, filtering by listing mode.
// Store both slider bounds (min/max from percentiles) and true data bounds
// (histogram.min/max) so one-sided AI filters use the full data range.
let mut numeric_features: rustc_hash::FxHashMap<&str, (f32, f32, f32, f32)> =
rustc_hash::FxHashMap::default();
let mut enum_features: rustc_hash::FxHashMap<&str, &[String]> =
rustc_hash::FxHashMap::default();
@ -915,12 +1070,14 @@ fn validate_and_convert(raw: &Value, features: &FeaturesResponse, listing_type:
name,
min,
max,
histogram,
modes,
..
} => {
// Only include features valid for the chosen listing mode
if modes.is_empty() || modes.contains(&listing_type) {
numeric_features.insert(name, (*min, *max));
numeric_features
.insert(name, (*min, *max, histogram.min, histogram.max));
}
}
FeatureInfo::Enum { name, values, .. } => {
@ -933,32 +1090,37 @@ fn validate_and_convert(raw: &Value, features: &FeaturesResponse, listing_type:
}
}
// Process numeric filters — each sets one bound (min or max)
// Process numeric filters — each sets one bound (min or max).
// The unset side uses the true data min/max (from histogram), not
// the slider bounds (percentile-based), so a "max" filter for crime
// produces [0, value] rather than [2nd-percentile, value].
if let Some(arr) = raw.get("numeric_filters").and_then(|val| val.as_array()) {
for item in arr {
let name = match item.get("name").and_then(|val| val.as_str()) {
Some(name) => name,
None => continue,
};
let (feat_min, feat_max) = match numeric_features.get(name) {
Some(range) => *range,
None => continue,
};
let (slider_min, slider_max, data_min, data_max) =
match numeric_features.get(name) {
Some(range) => *range,
None => continue,
};
let bound = match item.get("bound").and_then(|val| val.as_str()) {
Some(b) => b,
None => continue,
};
// Clamp value to true data range (not slider range)
let value = match item.get("value").and_then(|val| val.as_f64()) {
Some(v) => v.max(feat_min as f64).min(feat_max as f64) as f32,
Some(v) => v.max(data_min as f64).min(data_max as f64) as f32,
None => continue,
};
let (filter_min, filter_max) = match bound {
"min" => (value, feat_max),
"max" => (feat_min, value),
"min" => (value, data_max),
"max" => (data_min, value),
_ => continue,
};
// Only include if range is narrower than full range
if filter_min > feat_min || filter_max < feat_max {
// Only include if range is narrower than full slider range
if filter_min > slider_min || filter_max < slider_max {
result.insert(name.to_string(), json!([filter_min, filter_max]));
}
}