good changes
This commit is contained in:
parent
160283f1a1
commit
c997ea46a5
26 changed files with 991 additions and 288 deletions
|
|
@ -12,6 +12,8 @@ use tracing::{info, warn};
|
|||
use crate::auth::OptionalUser;
|
||||
use crate::consts::{AI_FILTERS_MAX_TOKENS, AI_FILTERS_TEMPERATURE, AI_FILTERS_WEEKLY_TOKEN_LIMIT};
|
||||
use crate::data::slugify;
|
||||
use crate::data::travel_time::TravelData;
|
||||
use crate::parsing::{parse_filters, row_passes_filters};
|
||||
use crate::pocketbase::{get_superuser_token, log_ai_query};
|
||||
use crate::routes::{FeatureInfo, FeaturesResponse};
|
||||
use crate::state::{AppState, SharedState};
|
||||
|
|
@ -62,6 +64,8 @@ pub struct AiFiltersResponse {
|
|||
notes: String,
|
||||
/// The listing mode used for this response (historical/buy/rent)
|
||||
listing_type: String,
|
||||
/// Number of properties matching the proposed filters (excludes travel time)
|
||||
match_count: usize,
|
||||
}
|
||||
|
||||
/// Strip markdown code fences (```json ... ``` or ``` ... ```) from LLM output.
|
||||
|
|
@ -556,6 +560,117 @@ async fn update_ai_usage(state: &AppState, user_id: &str, tokens_used: u64, week
|
|||
}
|
||||
}
|
||||
|
||||
/// Convert validated filter JSON back to the `;;`-separated filter string format
|
||||
/// that `parse_filters` expects.
|
||||
///
|
||||
/// Numeric: `{"name": [min, max]}` → `name:min:max`
|
||||
/// Enum: `{"name": ["val1", "val2"]}` → `name:val1|val2`
|
||||
fn filters_to_filter_string(filters: &Value) -> String {
|
||||
let obj = match filters.as_object() {
|
||||
Some(obj) => obj,
|
||||
None => return String::new(),
|
||||
};
|
||||
|
||||
let mut parts = Vec::new();
|
||||
for (name, value) in obj {
|
||||
if let Some(arr) = value.as_array() {
|
||||
if arr.len() == 2 && arr[0].is_number() && arr[1].is_number() {
|
||||
let min = arr[0].as_f64().unwrap_or(0.0);
|
||||
let max = arr[1].as_f64().unwrap_or(0.0);
|
||||
parts.push(format!("{name}:{min}:{max}"));
|
||||
} else if !arr.is_empty() && arr[0].is_string() {
|
||||
let values: Vec<&str> = arr.iter().filter_map(|v| v.as_str()).collect();
|
||||
if !values.is_empty() {
|
||||
parts.push(format!("{name}:{}", values.join("|")));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
parts.join(";;")
|
||||
}
|
||||
|
||||
/// Count how many rows in the property dataset pass the given property filters
|
||||
/// AND travel time filters. Travel time data is loaded from the TravelTimeStore
|
||||
/// and checked per-postcode (same logic as hexagons.rs).
|
||||
fn count_matching_rows(
|
||||
state: &AppState,
|
||||
filters: &Value,
|
||||
travel_time_filters: &[TravelTimeFilter],
|
||||
) -> usize {
|
||||
let filter_str = filters_to_filter_string(filters);
|
||||
|
||||
let quant = state.data.quant_ref();
|
||||
let (parsed_filters, parsed_enum_filters) = if filter_str.is_empty() {
|
||||
(Vec::new(), Vec::new())
|
||||
} else {
|
||||
match parse_filters(
|
||||
Some(&filter_str),
|
||||
&state.feature_name_to_index,
|
||||
&state.data.enum_values,
|
||||
&quant,
|
||||
) {
|
||||
Ok(f) => f,
|
||||
Err(err) => {
|
||||
warn!("Failed to parse filters for match count: {err}");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Load travel time data for each filter entry
|
||||
let travel_data: Vec<(TravelData, Option<f32>, Option<f32>)> = travel_time_filters
|
||||
.iter()
|
||||
.filter_map(|ttf| {
|
||||
let data = state.travel_time_store.get(&ttf.mode, &ttf.slug).ok()?;
|
||||
Some((data, ttf.min, ttf.max))
|
||||
})
|
||||
.collect();
|
||||
let has_travel = !travel_data.is_empty();
|
||||
|
||||
let feature_data = &state.data.feature_data;
|
||||
let num_features = state.data.num_features;
|
||||
let num_rows = state.data.lat.len();
|
||||
let (pc_interner, pc_keys) = state.data.postcode_parts();
|
||||
|
||||
let mut count = 0usize;
|
||||
for row in 0..num_rows {
|
||||
if !row_passes_filters(
|
||||
row,
|
||||
&parsed_filters,
|
||||
&parsed_enum_filters,
|
||||
feature_data,
|
||||
num_features,
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if has_travel {
|
||||
let postcode = pc_interner.resolve(&pc_keys[row]);
|
||||
let mut passes_travel = true;
|
||||
for (data, fmin, fmax) in &travel_data {
|
||||
let pass = if let Some(mins) = data.get(postcode).map(|r| r.minutes as f32) {
|
||||
fmin.map_or(true, |min| mins >= min)
|
||||
&& fmax.map_or(true, |max| mins <= max)
|
||||
} else {
|
||||
false // no travel data → postcode not reachable
|
||||
};
|
||||
if !pass {
|
||||
passes_travel = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !passes_travel {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
count += 1;
|
||||
}
|
||||
|
||||
count
|
||||
}
|
||||
|
||||
/// Maximum number of round trips (function calls + retries) before giving up.
|
||||
const MAX_TOOL_ROUNDS: usize = 5;
|
||||
|
||||
|
|
@ -631,6 +746,7 @@ pub async fn post_ai_filters(
|
|||
})];
|
||||
|
||||
let mut total_tokens_accumulated: u64 = 0;
|
||||
let mut refinement_attempts = 0u32;
|
||||
|
||||
// Function calling loop: model may call search_destinations, we execute and feed back
|
||||
for round in 0..MAX_TOOL_ROUNDS {
|
||||
|
|
@ -776,6 +892,42 @@ pub async fn post_ai_filters(
|
|||
map.insert("Listing status".to_string(), json!([listing_value]));
|
||||
}
|
||||
|
||||
// Count matching properties and refine if too restrictive
|
||||
let match_count = count_matching_rows(&state, &filters, &travel_time_filters);
|
||||
info!(match_count = match_count, round = round, "AI filter match count");
|
||||
|
||||
if match_count == 0 {
|
||||
refinement_attempts += 1;
|
||||
let total_rows = state.data.lat.len();
|
||||
info!(
|
||||
attempt = refinement_attempts,
|
||||
"0 matches out of {total_rows} — asking AI to relax filters"
|
||||
);
|
||||
let feedback = match refinement_attempts {
|
||||
1 => format!(
|
||||
"Your proposed filters matched 0 properties out of {total_rows} total. \
|
||||
The combination is too restrictive. Please widen some numeric ranges \
|
||||
or add more enum values while keeping the user's intent. \
|
||||
Output the adjusted JSON."
|
||||
),
|
||||
2 => format!(
|
||||
"Still 0 matches out of {total_rows}. Please widen ranges further. \
|
||||
Output the adjusted JSON."
|
||||
),
|
||||
_ => format!(
|
||||
"Still 0 matches out of {total_rows}. Please remove additional filters \
|
||||
until some properties match, keeping the user's core priority. \
|
||||
Output the adjusted JSON."
|
||||
),
|
||||
};
|
||||
contents.push(candidate.clone());
|
||||
contents.push(json!({
|
||||
"role": "user",
|
||||
"parts": [{ "text": feedback }]
|
||||
}));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Update usage with total accumulated tokens
|
||||
let new_total = tokens_used + total_tokens_accumulated;
|
||||
update_ai_usage(&state, &user.id, new_total, current_week).await;
|
||||
|
|
@ -810,6 +962,7 @@ pub async fn post_ai_filters(
|
|||
travel_time_filters,
|
||||
notes,
|
||||
listing_type: listing_type.to_string(),
|
||||
match_count,
|
||||
}));
|
||||
}
|
||||
|
||||
|
|
@ -902,8 +1055,10 @@ fn validate_travel_time_filters(raw: &Value, state: &AppState) -> Vec<TravelTime
|
|||
fn validate_and_convert(raw: &Value, features: &FeaturesResponse, listing_type: &str) -> Value {
|
||||
let mut result = serde_json::Map::new();
|
||||
|
||||
// Build lookup maps from feature metadata, filtering by listing mode
|
||||
let mut numeric_features: rustc_hash::FxHashMap<&str, (f32, f32)> =
|
||||
// Build lookup maps from feature metadata, filtering by listing mode.
|
||||
// Store both slider bounds (min/max from percentiles) and true data bounds
|
||||
// (histogram.min/max) so one-sided AI filters use the full data range.
|
||||
let mut numeric_features: rustc_hash::FxHashMap<&str, (f32, f32, f32, f32)> =
|
||||
rustc_hash::FxHashMap::default();
|
||||
let mut enum_features: rustc_hash::FxHashMap<&str, &[String]> =
|
||||
rustc_hash::FxHashMap::default();
|
||||
|
|
@ -915,12 +1070,14 @@ fn validate_and_convert(raw: &Value, features: &FeaturesResponse, listing_type:
|
|||
name,
|
||||
min,
|
||||
max,
|
||||
histogram,
|
||||
modes,
|
||||
..
|
||||
} => {
|
||||
// Only include features valid for the chosen listing mode
|
||||
if modes.is_empty() || modes.contains(&listing_type) {
|
||||
numeric_features.insert(name, (*min, *max));
|
||||
numeric_features
|
||||
.insert(name, (*min, *max, histogram.min, histogram.max));
|
||||
}
|
||||
}
|
||||
FeatureInfo::Enum { name, values, .. } => {
|
||||
|
|
@ -933,32 +1090,37 @@ fn validate_and_convert(raw: &Value, features: &FeaturesResponse, listing_type:
|
|||
}
|
||||
}
|
||||
|
||||
// Process numeric filters — each sets one bound (min or max)
|
||||
// Process numeric filters — each sets one bound (min or max).
|
||||
// The unset side uses the true data min/max (from histogram), not
|
||||
// the slider bounds (percentile-based), so a "max" filter for crime
|
||||
// produces [0, value] rather than [2nd-percentile, value].
|
||||
if let Some(arr) = raw.get("numeric_filters").and_then(|val| val.as_array()) {
|
||||
for item in arr {
|
||||
let name = match item.get("name").and_then(|val| val.as_str()) {
|
||||
Some(name) => name,
|
||||
None => continue,
|
||||
};
|
||||
let (feat_min, feat_max) = match numeric_features.get(name) {
|
||||
Some(range) => *range,
|
||||
None => continue,
|
||||
};
|
||||
let (slider_min, slider_max, data_min, data_max) =
|
||||
match numeric_features.get(name) {
|
||||
Some(range) => *range,
|
||||
None => continue,
|
||||
};
|
||||
let bound = match item.get("bound").and_then(|val| val.as_str()) {
|
||||
Some(b) => b,
|
||||
None => continue,
|
||||
};
|
||||
// Clamp value to true data range (not slider range)
|
||||
let value = match item.get("value").and_then(|val| val.as_f64()) {
|
||||
Some(v) => v.max(feat_min as f64).min(feat_max as f64) as f32,
|
||||
Some(v) => v.max(data_min as f64).min(data_max as f64) as f32,
|
||||
None => continue,
|
||||
};
|
||||
let (filter_min, filter_max) = match bound {
|
||||
"min" => (value, feat_max),
|
||||
"max" => (feat_min, value),
|
||||
"min" => (value, data_max),
|
||||
"max" => (data_min, value),
|
||||
_ => continue,
|
||||
};
|
||||
// Only include if range is narrower than full range
|
||||
if filter_min > feat_min || filter_max < feat_max {
|
||||
// Only include if range is narrower than full slider range
|
||||
if filter_min > slider_min || filter_max < slider_max {
|
||||
result.insert(name.to_string(), json!([filter_min, filter_max]));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue