Add LLM summary
This commit is contained in:
parent
f5e6894c0f
commit
9e71ed77df
3 changed files with 275 additions and 9 deletions
166
server-rs/src/routes/area_summary.rs
Normal file
166
server-rs/src/routes/area_summary.rs
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::Json;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::consts::{
|
||||
AREA_SUMMARY_MAX_TOKENS, AREA_SUMMARY_SYSTEM_PROMPT, AREA_SUMMARY_TEMPERATURE,
|
||||
};
|
||||
use crate::state::AppState;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct NumericStat {
|
||||
name: String,
|
||||
mean: f64,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct EnumStat {
|
||||
name: String,
|
||||
counts: std::collections::HashMap<String, u64>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct AreaSummaryRequest {
|
||||
count: usize,
|
||||
location: String,
|
||||
is_postcode: bool,
|
||||
#[serde(default)]
|
||||
filters: Vec<String>,
|
||||
#[serde(default)]
|
||||
numeric_stats: Vec<NumericStat>,
|
||||
#[serde(default)]
|
||||
enum_stats: Vec<EnumStat>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct AreaSummaryResponse {
|
||||
summary: String,
|
||||
}
|
||||
|
||||
fn build_prompt(req: &AreaSummaryRequest) -> String {
|
||||
let mut parts = Vec::new();
|
||||
|
||||
let area_type = if req.is_postcode { "postcode" } else { "area" };
|
||||
parts.push(format!(
|
||||
"Summarise this {} of England ({}) which contain {} properties matching the filters.",
|
||||
area_type, req.location, req.count
|
||||
));
|
||||
|
||||
if !req.filters.is_empty() {
|
||||
parts.push(format!("Active filters: {}.", req.filters.join(", ")));
|
||||
}
|
||||
|
||||
if !req.numeric_stats.is_empty() {
|
||||
let stats: Vec<String> = req
|
||||
.numeric_stats
|
||||
.iter()
|
||||
.map(|stat| format!("{}: {:.1}", stat.name, stat.mean))
|
||||
.collect();
|
||||
parts.push(format!("Average values: {}.", stats.join(", ")));
|
||||
}
|
||||
|
||||
for es in &req.enum_stats {
|
||||
let total: u64 = es.counts.values().sum();
|
||||
if total == 0 {
|
||||
continue;
|
||||
}
|
||||
let mut sorted: Vec<_> = es.counts.iter().collect();
|
||||
sorted.sort_by(|lhs, rhs| rhs.1.cmp(lhs.1));
|
||||
let top: Vec<String> = sorted
|
||||
.iter()
|
||||
.take(3)
|
||||
.map(|(val, count)| {
|
||||
let pct = **count as f64 / total as f64 * 100.0;
|
||||
format!("{} ({:.0}%)", val, pct)
|
||||
})
|
||||
.collect();
|
||||
parts.push(format!("{}: {}.", es.name, top.join(", ")));
|
||||
}
|
||||
|
||||
let result = parts.join(" ");
|
||||
info!(prompt = %result, "Built prompt for area summary");
|
||||
result
|
||||
}
|
||||
|
||||
/// Strip `<think>...</think>` blocks from model output
|
||||
fn strip_think_blocks(text: &str) -> String {
|
||||
let mut result = String::new();
|
||||
let mut remaining = text;
|
||||
while let Some(start) = remaining.find("<think>") {
|
||||
result.push_str(&remaining[..start]);
|
||||
if let Some(end) = remaining[start..].find("</think>") {
|
||||
remaining = &remaining[start + end + 8..];
|
||||
} else {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
result.push_str(remaining);
|
||||
result
|
||||
}
|
||||
|
||||
pub async fn post_area_summary(
|
||||
state: Arc<AppState>,
|
||||
Json(req): Json<AreaSummaryRequest>,
|
||||
) -> Result<Json<AreaSummaryResponse>, (StatusCode, String)> {
|
||||
let prompt = build_prompt(&req);
|
||||
info!(location = %req.location, count = req.count, "POST /api/area-summary");
|
||||
|
||||
let url = format!("{}/v1/chat/completions", state.ollama_url);
|
||||
let body = serde_json::json!({
|
||||
"model": state.ollama_model,
|
||||
"messages": [
|
||||
{ "role": "system", "content": AREA_SUMMARY_SYSTEM_PROMPT },
|
||||
{ "role": "user", "content": prompt }
|
||||
],
|
||||
"stream": false,
|
||||
"temperature": AREA_SUMMARY_TEMPERATURE,
|
||||
"max_tokens": AREA_SUMMARY_MAX_TOKENS,
|
||||
});
|
||||
|
||||
let response = state
|
||||
.http_client
|
||||
.post(&url)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| {
|
||||
warn!(error = %err, "Failed to connect to Ollama");
|
||||
(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Failed to connect to Ollama: {}", err),
|
||||
)
|
||||
})?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let body_text = response.text().await.unwrap_or_default();
|
||||
warn!(status = %status, body = %body_text, "Ollama returned error");
|
||||
return Err((
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Ollama error {}: {}", status, body_text),
|
||||
));
|
||||
}
|
||||
|
||||
let json: serde_json::Value = response.json().await.map_err(|err| {
|
||||
warn!(error = %err, "Failed to parse Ollama response");
|
||||
(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Failed to parse Ollama response: {}", err),
|
||||
)
|
||||
})?;
|
||||
|
||||
let content = json
|
||||
.get("choices")
|
||||
.and_then(|ch| ch.get(0))
|
||||
.and_then(|ch| ch.get("message"))
|
||||
.and_then(|msg| msg.get("content"))
|
||||
.and_then(|ct| ct.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let summary = strip_think_blocks(content).trim().to_string();
|
||||
|
||||
Ok(Json(AreaSummaryResponse { summary }))
|
||||
}
|
||||
|
|
@ -9,15 +9,7 @@ use tracing::warn;
|
|||
use crate::state::AppState;
|
||||
|
||||
pub async fn proxy_to_pocketbase(state: Arc<AppState>, req: Request) -> impl IntoResponse {
|
||||
let pb_url = match &state.pocketbase_url {
|
||||
Some(url) => url.trim_end_matches('/'),
|
||||
None => {
|
||||
return Response::builder()
|
||||
.status(StatusCode::SERVICE_UNAVAILABLE)
|
||||
.body(Body::from("PocketBase not configured"))
|
||||
.unwrap();
|
||||
}
|
||||
};
|
||||
let pb_url = state.pocketbase_url.trim_end_matches('/');
|
||||
|
||||
let path = req.uri().path();
|
||||
let target_path = path.strip_prefix("/pb").unwrap_or(path);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue