Same
This commit is contained in:
parent
9cd2b8849c
commit
bbc2fcb86c
3 changed files with 173 additions and 18 deletions
|
|
@ -166,6 +166,76 @@ fn execute_destination_search(state: &AppState, query: &str, mode: &str) -> Valu
|
|||
matches.truncate(10);
|
||||
|
||||
if matches.is_empty() {
|
||||
// Check if the query matched a city that lacks its own travel data.
|
||||
// If so, return nearby stations within that city as suggestions.
|
||||
let matched_city_name: Option<&str> =
|
||||
pd.name_lower
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find_map(|(idx, name_lower)| {
|
||||
let words_match = query_words.iter().all(|word| name_lower.contains(word));
|
||||
let slug = slugify(&pd.name[idx]);
|
||||
let slug_match =
|
||||
slug.contains(&query_slug) || query_slug.contains(&slug);
|
||||
if (words_match || slug_match) && pd.type_rank[idx] == 0 {
|
||||
Some(pd.name[idx].as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(city_name) = matched_city_name {
|
||||
let city_lower = city_name.to_lowercase();
|
||||
let mut city_matches: Vec<(usize, String, u8, u32)> = pd
|
||||
.city
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, city_opt)| {
|
||||
let city = city_opt.as_deref()?;
|
||||
if city.to_lowercase() != city_lower {
|
||||
return None;
|
||||
}
|
||||
let slug = slugify(&pd.name[idx]);
|
||||
if slug_set.contains(&slug) {
|
||||
Some((idx, slug, pd.type_rank[idx], pd.population[idx]))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
city_matches.sort_unstable_by(|a, b| a.2.cmp(&b.2).then(b.3.cmp(&a.3)));
|
||||
city_matches.truncate(10);
|
||||
|
||||
if !city_matches.is_empty() {
|
||||
let results: Vec<Value> = city_matches
|
||||
.into_iter()
|
||||
.map(|(idx, slug, ..)| {
|
||||
json!({
|
||||
"name": pd.name[idx],
|
||||
"slug": slug,
|
||||
"place_type": pd.place_type.get(idx).to_string(),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
info!(
|
||||
query = query,
|
||||
city = city_name,
|
||||
results = results.len(),
|
||||
"Destination search fell back to city stations"
|
||||
);
|
||||
|
||||
return json!({
|
||||
"results": results,
|
||||
"message": format!(
|
||||
"No travel data for '{}' directly. Pick one of these nearby stations:",
|
||||
city_name
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
query = query,
|
||||
mode = mode,
|
||||
|
|
@ -381,8 +451,8 @@ pub fn build_system_prompt(
|
|||
{\"name\": \"Serious crime (avg/yr)\", \"bound\": \"max\", \"value\": 20}, \
|
||||
{\"name\": \"Minor crime (avg/yr)\", \"bound\": \"max\", \"value\": 50}, \
|
||||
{\"name\": \"Noise (dB)\", \"bound\": \"max\", \"value\": 55}, \
|
||||
{\"name\": \"Good+ primary schools within 5km\", \"bound\": \"min\", \"value\": 5}, \
|
||||
{\"name\": \"Good+ secondary schools within 5km\", \"bound\": \"min\", \"value\": 2}, \
|
||||
{\"name\": \"Good+ primary schools within 2km\", \"bound\": \"min\", \"value\": 2}, \
|
||||
{\"name\": \"Good+ secondary schools within 2km\", \"bound\": \"min\", \"value\": 1}, \
|
||||
{\"name\": \"Number of parks within 2km\", \"bound\": \"min\", \"value\": 3}], \
|
||||
\"enum_filters\": [], \"travel_time_filters\": [], \"notes\": \"\"}"
|
||||
.to_string(),
|
||||
|
|
@ -420,8 +490,8 @@ pub fn build_system_prompt(
|
|||
Output: {\"numeric_filters\": [\
|
||||
{\"name\": \"Total floor area (sqm)\", \"bound\": \"min\", \"value\": 100}, \
|
||||
{\"name\": \"Number of bedrooms & living rooms\", \"bound\": \"min\", \"value\": 5}, \
|
||||
{\"name\": \"Good+ primary schools within 5km\", \"bound\": \"min\", \"value\": 5}, \
|
||||
{\"name\": \"Good+ secondary schools within 5km\", \"bound\": \"min\", \"value\": 2}], \
|
||||
{\"name\": \"Good+ primary schools within 2km\", \"bound\": \"min\", \"value\": 2}, \
|
||||
{\"name\": \"Good+ secondary schools within 2km\", \"bound\": \"min\", \"value\": 1}], \
|
||||
\"enum_filters\": [{\"name\": \"Property type\", \
|
||||
\"values\": [\"Detached\", \"Semi-Detached\"]}], \
|
||||
\"travel_time_filters\": [{\"mode\": \"car\", \"slug\": \"manchester\", \
|
||||
|
|
@ -445,7 +515,7 @@ pub fn build_system_prompt(
|
|||
"\nUser: \"3 bed house to buy under 500k with good schools\"\n\
|
||||
Output: {\"listing_type\": \"buy\", \
|
||||
\"numeric_filters\": [{\"name\": \"Asking price\", \"bound\": \"max\", \"value\": 500000}, \
|
||||
{\"name\": \"Good+ primary schools within 5km\", \"bound\": \"min\", \"value\": 5}], \
|
||||
{\"name\": \"Good+ primary schools within 2km\", \"bound\": \"min\", \"value\": 2}], \
|
||||
\"enum_filters\": [{\"name\": \"Property type\", \
|
||||
\"values\": [\"Detached\", \"Semi-Detached\", \"Terraced\"]}], \
|
||||
\"travel_time_filters\": [], \
|
||||
|
|
@ -671,8 +741,13 @@ fn count_matching_rows(
|
|||
count
|
||||
}
|
||||
|
||||
/// Maximum number of round trips (function calls + retries) before giving up.
|
||||
const MAX_TOOL_ROUNDS: usize = 5;
|
||||
/// Budget limits for the Gemini conversation loop. Separate counters prevent
|
||||
/// tool calls (destination searches) from starving JSON retries or zero-match
|
||||
/// refinements.
|
||||
const MAX_TOOL_CALLS: usize = 4;
|
||||
const MAX_RETRIES: usize = 3;
|
||||
const MAX_REFINEMENTS: u32 = 3;
|
||||
const MAX_TOTAL_ROUNDS: usize = 10;
|
||||
|
||||
pub async fn post_ai_filters(
|
||||
State(shared): State<Arc<SharedState>>,
|
||||
|
|
@ -746,10 +821,12 @@ pub async fn post_ai_filters(
|
|||
})];
|
||||
|
||||
let mut total_tokens_accumulated: u64 = 0;
|
||||
let mut tool_call_count = 0usize;
|
||||
let mut retry_count = 0usize;
|
||||
let mut refinement_attempts = 0u32;
|
||||
|
||||
// Function calling loop: model may call search_destinations, we execute and feed back
|
||||
for round in 0..MAX_TOOL_ROUNDS {
|
||||
for round in 0..MAX_TOTAL_ROUNDS {
|
||||
let body = json!({
|
||||
"systemInstruction": {
|
||||
"parts": [{ "text": state.ai_filters_system_prompt }]
|
||||
|
|
@ -802,7 +879,18 @@ pub async fn post_ai_filters(
|
|||
let fn_name = fc.get("name").and_then(|n| n.as_str()).unwrap_or("");
|
||||
let fn_args = fc.get("args").cloned().unwrap_or(json!({}));
|
||||
|
||||
info!(function = fn_name, round = round, "AI called tool");
|
||||
tool_call_count += 1;
|
||||
info!(function = fn_name, round = round, tool_call = tool_call_count, "AI called tool");
|
||||
|
||||
if tool_call_count > MAX_TOOL_CALLS {
|
||||
warn!("Tool call budget exhausted, forcing text output");
|
||||
contents.push(candidate.clone());
|
||||
contents.push(json!({
|
||||
"role": "user",
|
||||
"parts": [{ "text": "Tool call limit reached. Output your best JSON now using the destinations you already found. Do not call any more tools." }]
|
||||
}));
|
||||
continue;
|
||||
}
|
||||
|
||||
let fn_result = if fn_name == "search_destinations" {
|
||||
let query = fn_args.get("query").and_then(|q| q.as_str()).unwrap_or("");
|
||||
|
|
@ -840,8 +928,11 @@ pub async fn post_ai_filters(
|
|||
let text = text.trim();
|
||||
|
||||
if text.is_empty() {
|
||||
warn!("Gemini returned empty text content (round {})", round);
|
||||
// Retry by continuing the loop
|
||||
retry_count += 1;
|
||||
warn!("Gemini returned empty text content (round {}, retry {})", round, retry_count);
|
||||
if retry_count > MAX_RETRIES {
|
||||
return Err((StatusCode::BAD_GATEWAY, "AI returned empty responses".into()));
|
||||
}
|
||||
contents.push(candidate.clone());
|
||||
contents.push(json!({
|
||||
"role": "user",
|
||||
|
|
@ -853,8 +944,11 @@ pub async fn post_ai_filters(
|
|||
let raw: Value = match serde_json::from_str(text) {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
warn!(error = %err, round = round, "Failed to parse Gemini JSON output, retrying");
|
||||
// Ask the model to fix its output
|
||||
retry_count += 1;
|
||||
warn!(error = %err, round = round, retry = retry_count, "Failed to parse Gemini JSON output");
|
||||
if retry_count > MAX_RETRIES {
|
||||
return Err((StatusCode::BAD_GATEWAY, "AI returned invalid JSON".into()));
|
||||
}
|
||||
contents.push(candidate.clone());
|
||||
contents.push(json!({
|
||||
"role": "user",
|
||||
|
|
@ -903,6 +997,29 @@ pub async fn post_ai_filters(
|
|||
attempt = refinement_attempts,
|
||||
"0 matches out of {total_rows} — asking AI to relax filters"
|
||||
);
|
||||
|
||||
if refinement_attempts > MAX_REFINEMENTS {
|
||||
warn!("Refinement budget exhausted, returning filters with 0 matches");
|
||||
let new_total = tokens_used + total_tokens_accumulated;
|
||||
update_ai_usage(&state, &user.id, new_total, current_week).await;
|
||||
counter!("ai_tokens_total").increment(total_tokens_accumulated);
|
||||
counter!("ai_requests_total", "status" => "zero_matches").increment(1);
|
||||
|
||||
let notes = if notes.is_empty() {
|
||||
"No properties match these filters. Try relaxing some constraints.".to_string()
|
||||
} else {
|
||||
format!("{}. No properties match — try relaxing some constraints.", notes)
|
||||
};
|
||||
|
||||
return Ok(Json(AiFiltersResponse {
|
||||
filters,
|
||||
travel_time_filters,
|
||||
notes,
|
||||
listing_type: listing_type.to_string(),
|
||||
match_count: 0,
|
||||
}));
|
||||
}
|
||||
|
||||
let feedback = match refinement_attempts {
|
||||
1 => format!(
|
||||
"Your proposed filters matched 0 properties out of {total_rows} total. \
|
||||
|
|
@ -966,10 +1083,10 @@ pub async fn post_ai_filters(
|
|||
}));
|
||||
}
|
||||
|
||||
// Exhausted tool rounds without getting a final text response
|
||||
// Exhausted total round budget without getting a valid response
|
||||
warn!(
|
||||
"AI exhausted {} tool-calling rounds without final response",
|
||||
MAX_TOOL_ROUNDS
|
||||
"AI exhausted {} total rounds without final response (tools={}, retries={}, refinements={})",
|
||||
MAX_TOTAL_ROUNDS, tool_call_count, retry_count, refinement_attempts
|
||||
);
|
||||
Err((
|
||||
StatusCode::BAD_GATEWAY,
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ pub struct DestinationResult {
|
|||
place_type: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
city: Option<String>,
|
||||
lat: f32,
|
||||
lon: f32,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
|
|
@ -76,6 +78,8 @@ pub async fn get_travel_destinations(
|
|||
slug,
|
||||
place_type: pd.place_type.get(idx).to_string(),
|
||||
city: pd.city[idx].clone(),
|
||||
lat: pd.lat[idx],
|
||||
lon: pd.lon[idx],
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue