vex_core/
fitness.rs

1//! Multi-dimensional fitness evaluation
2//!
3//! Provides rich fitness scoring beyond simple confidence values.
4//! Evaluates task completion, accuracy, coherence, efficiency, and calibration.
5
6use async_trait::async_trait;
7use std::collections::HashMap;
8
9/// Result of fitness evaluation with multiple metrics
10#[derive(Debug, Clone, Default)]
11pub struct FitnessReport {
12    /// Overall fitness score (0.0-1.0)
13    pub overall: f64,
14    /// Individual metric scores
15    pub metrics: HashMap<String, f64>,
16}
17
18impl FitnessReport {
19    /// Create a simple report with just an overall score
20    pub fn simple(overall: f64) -> Self {
21        Self {
22            overall: overall.clamp(0.0, 1.0),
23            metrics: HashMap::new(),
24        }
25    }
26
27    /// Create from individual metrics with weights
28    ///
29    /// # Example
30    /// ```
31    /// use vex_core::fitness::FitnessReport;
32    /// use std::collections::HashMap;
33    ///
34    /// let mut metrics = HashMap::new();
35    /// metrics.insert("accuracy".to_string(), 0.9);
36    /// metrics.insert("coherence".to_string(), 0.8);
37    ///
38    /// let mut weights = HashMap::new();
39    /// weights.insert("accuracy".to_string(), 0.6);
40    /// weights.insert("coherence".to_string(), 0.4);
41    ///
42    /// let report = FitnessReport::from_weighted(metrics, &weights);
43    /// assert!((report.overall - 0.86).abs() < 0.01);
44    /// ```
45    pub fn from_weighted(metrics: HashMap<String, f64>, weights: &HashMap<String, f64>) -> Self {
46        let mut weighted_sum = 0.0;
47        let mut total_weight = 0.0;
48
49        for (name, score) in &metrics {
50            let weight = weights.get(name).copied().unwrap_or(1.0);
51            weighted_sum += score * weight;
52            total_weight += weight;
53        }
54
55        let overall = if total_weight > 0.0 {
56            weighted_sum / total_weight
57        } else {
58            0.5
59        };
60
61        Self {
62            overall: overall.clamp(0.0, 1.0),
63            metrics,
64        }
65    }
66
67    /// Add a metric to the report
68    pub fn add_metric(&mut self, name: &str, score: f64) {
69        self.metrics.insert(name.to_string(), score.clamp(0.0, 1.0));
70    }
71
72    /// Recalculate overall from metrics with equal weights
73    pub fn recalculate_overall(&mut self) {
74        if self.metrics.is_empty() {
75            return;
76        }
77        let sum: f64 = self.metrics.values().sum();
78        self.overall = sum / self.metrics.len() as f64;
79    }
80}
81
82/// Context for fitness evaluation
83#[derive(Debug, Clone, Default)]
84pub struct EvaluationContext {
85    /// The original task/prompt
86    pub task: String,
87    /// Expected outcome (if known)
88    pub expected_outcome: Option<String>,
89    /// Additional context
90    pub metadata: HashMap<String, String>,
91}
92
93impl EvaluationContext {
94    /// Create new context with task
95    pub fn new(task: &str) -> Self {
96        Self {
97            task: task.to_string(),
98            expected_outcome: None,
99            metadata: HashMap::new(),
100        }
101    }
102
103    /// Add expected outcome
104    pub fn with_expected(mut self, expected: &str) -> Self {
105        self.expected_outcome = Some(expected.to_string());
106        self
107    }
108}
109
110/// Trait for fitness evaluators
111///
112/// Implementations can use LLM-as-judge, heuristics, or other methods
113/// to evaluate response quality.
114#[async_trait]
115pub trait FitnessEvaluator: Send + Sync {
116    /// Evaluate agent response and return fitness report
117    async fn evaluate(&self, response: &str, context: &EvaluationContext) -> FitnessReport;
118}
119
120/// Default metric weights (sum to 1.0)
121pub fn default_weights() -> HashMap<String, f64> {
122    let mut weights = HashMap::new();
123    weights.insert("task_completion".to_string(), 0.30);
124    weights.insert("factual_accuracy".to_string(), 0.25);
125    weights.insert("coherence".to_string(), 0.15);
126    weights.insert("efficiency".to_string(), 0.15);
127    weights.insert("confidence_calibration".to_string(), 0.15);
128    weights
129}
130
131/// Simple heuristic-based evaluator (no LLM required)
132#[derive(Debug, Clone, Default)]
133pub struct HeuristicEvaluator;
134
135#[async_trait]
136impl FitnessEvaluator for HeuristicEvaluator {
137    async fn evaluate(&self, response: &str, context: &EvaluationContext) -> FitnessReport {
138        let mut metrics = HashMap::new();
139
140        // Task completion: check if response is non-empty and substantial
141        let completion = if response.len() > 50 {
142            0.8
143        } else if response.len() > 10 {
144            0.5
145        } else {
146            0.2
147        };
148        metrics.insert("task_completion".to_string(), completion);
149
150        // Coherence: check sentence structure (simple heuristic)
151        let sentences = response.matches('.').count();
152        let words = response.split_whitespace().count();
153        let avg_sentence_len = if sentences > 0 {
154            words / sentences
155        } else {
156            words
157        };
158        let coherence = if (10..40).contains(&avg_sentence_len) {
159            0.8
160        } else if avg_sentence_len < 60 {
161            0.6
162        } else {
163            0.4
164        };
165        metrics.insert("coherence".to_string(), coherence);
166
167        // Efficiency: penalize overly verbose responses
168        let task_words = context.task.split_whitespace().count();
169        let response_ratio = words as f64 / (task_words.max(10) as f64);
170        let efficiency = if response_ratio < 5.0 {
171            0.9
172        } else if response_ratio < 10.0 {
173            0.7
174        } else {
175            0.5
176        };
177        metrics.insert("efficiency".to_string(), efficiency);
178
179        // Expected match (if available)
180        if let Some(expected) = &context.expected_outcome {
181            let expected_lower = expected.to_lowercase();
182            let response_lower = response.to_lowercase();
183            let accuracy = if response_lower.contains(&expected_lower)
184                || expected_lower.contains(&response_lower)
185            {
186                0.9
187            } else {
188                // Check word overlap
189                let expected_words: std::collections::HashSet<_> =
190                    expected_lower.split_whitespace().collect();
191                let response_words: std::collections::HashSet<_> =
192                    response_lower.split_whitespace().collect();
193                let overlap = expected_words.intersection(&response_words).count();
194                let total = expected_words.len().max(1);
195                (overlap as f64 / total as f64).min(0.9)
196            };
197            metrics.insert("factual_accuracy".to_string(), accuracy);
198        }
199
200        FitnessReport::from_weighted(metrics, &default_weights())
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    #[test]
209    fn test_fitness_report_simple() {
210        let report = FitnessReport::simple(0.85);
211        assert_eq!(report.overall, 0.85);
212        assert!(report.metrics.is_empty());
213    }
214
215    #[test]
216    fn test_fitness_report_weighted() {
217        let mut metrics = HashMap::new();
218        metrics.insert("a".to_string(), 1.0);
219        metrics.insert("b".to_string(), 0.5);
220
221        let mut weights = HashMap::new();
222        weights.insert("a".to_string(), 0.5);
223        weights.insert("b".to_string(), 0.5);
224
225        let report = FitnessReport::from_weighted(metrics, &weights);
226        assert!((report.overall - 0.75).abs() < 0.01);
227    }
228
229    #[test]
230    fn test_add_metric() {
231        let mut report = FitnessReport::simple(0.5);
232        report.add_metric("test", 0.9);
233        assert_eq!(report.metrics.get("test"), Some(&0.9));
234    }
235
236    #[tokio::test]
237    async fn test_heuristic_evaluator() {
238        let evaluator = HeuristicEvaluator;
239        let context = EvaluationContext::new("Explain quantum computing");
240
241        let response = "Quantum computing uses quantum bits or qubits. \
242            Unlike classical bits that are 0 or 1, qubits can be in superposition. \
243            This allows quantum computers to process many possibilities simultaneously.";
244
245        let report = evaluator.evaluate(response, &context).await;
246
247        assert!(report.overall > 0.5);
248        assert!(report.metrics.contains_key("task_completion"));
249        assert!(report.metrics.contains_key("coherence"));
250    }
251
252    #[tokio::test]
253    async fn test_evaluator_with_expected() {
254        let evaluator = HeuristicEvaluator;
255        let context = EvaluationContext::new("What is 2+2?").with_expected("4");
256
257        let report = evaluator.evaluate("The answer is 4", &context).await;
258
259        assert!(*report.metrics.get("factual_accuracy").unwrap_or(&0.0) > 0.5);
260    }
261}