1use async_trait::async_trait;
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, Default)]
11pub struct FitnessReport {
12 pub overall: f64,
14 pub metrics: HashMap<String, f64>,
16}
17
18impl FitnessReport {
19 pub fn simple(overall: f64) -> Self {
21 Self {
22 overall: overall.clamp(0.0, 1.0),
23 metrics: HashMap::new(),
24 }
25 }
26
27 pub fn from_weighted(metrics: HashMap<String, f64>, weights: &HashMap<String, f64>) -> Self {
46 let mut weighted_sum = 0.0;
47 let mut total_weight = 0.0;
48
49 for (name, score) in &metrics {
50 let weight = weights.get(name).copied().unwrap_or(1.0);
51 weighted_sum += score * weight;
52 total_weight += weight;
53 }
54
55 let overall = if total_weight > 0.0 {
56 weighted_sum / total_weight
57 } else {
58 0.5
59 };
60
61 Self {
62 overall: overall.clamp(0.0, 1.0),
63 metrics,
64 }
65 }
66
67 pub fn add_metric(&mut self, name: &str, score: f64) {
69 self.metrics.insert(name.to_string(), score.clamp(0.0, 1.0));
70 }
71
72 pub fn recalculate_overall(&mut self) {
74 if self.metrics.is_empty() {
75 return;
76 }
77 let sum: f64 = self.metrics.values().sum();
78 self.overall = sum / self.metrics.len() as f64;
79 }
80}
81
82#[derive(Debug, Clone, Default)]
84pub struct EvaluationContext {
85 pub task: String,
87 pub expected_outcome: Option<String>,
89 pub metadata: HashMap<String, String>,
91}
92
93impl EvaluationContext {
94 pub fn new(task: &str) -> Self {
96 Self {
97 task: task.to_string(),
98 expected_outcome: None,
99 metadata: HashMap::new(),
100 }
101 }
102
103 pub fn with_expected(mut self, expected: &str) -> Self {
105 self.expected_outcome = Some(expected.to_string());
106 self
107 }
108}
109
110#[async_trait]
115pub trait FitnessEvaluator: Send + Sync {
116 async fn evaluate(&self, response: &str, context: &EvaluationContext) -> FitnessReport;
118}
119
120pub fn default_weights() -> HashMap<String, f64> {
122 let mut weights = HashMap::new();
123 weights.insert("task_completion".to_string(), 0.30);
124 weights.insert("factual_accuracy".to_string(), 0.25);
125 weights.insert("coherence".to_string(), 0.15);
126 weights.insert("efficiency".to_string(), 0.15);
127 weights.insert("confidence_calibration".to_string(), 0.15);
128 weights
129}
130
131#[derive(Debug, Clone, Default)]
133pub struct HeuristicEvaluator;
134
135#[async_trait]
136impl FitnessEvaluator for HeuristicEvaluator {
137 async fn evaluate(&self, response: &str, context: &EvaluationContext) -> FitnessReport {
138 let mut metrics = HashMap::new();
139
140 let completion = if response.len() > 50 {
142 0.8
143 } else if response.len() > 10 {
144 0.5
145 } else {
146 0.2
147 };
148 metrics.insert("task_completion".to_string(), completion);
149
150 let sentences = response.matches('.').count();
152 let words = response.split_whitespace().count();
153 let avg_sentence_len = if sentences > 0 {
154 words / sentences
155 } else {
156 words
157 };
158 let coherence = if (10..40).contains(&avg_sentence_len) {
159 0.8
160 } else if avg_sentence_len < 60 {
161 0.6
162 } else {
163 0.4
164 };
165 metrics.insert("coherence".to_string(), coherence);
166
167 let task_words = context.task.split_whitespace().count();
169 let response_ratio = words as f64 / (task_words.max(10) as f64);
170 let efficiency = if response_ratio < 5.0 {
171 0.9
172 } else if response_ratio < 10.0 {
173 0.7
174 } else {
175 0.5
176 };
177 metrics.insert("efficiency".to_string(), efficiency);
178
179 if let Some(expected) = &context.expected_outcome {
181 let expected_lower = expected.to_lowercase();
182 let response_lower = response.to_lowercase();
183 let accuracy = if response_lower.contains(&expected_lower)
184 || expected_lower.contains(&response_lower)
185 {
186 0.9
187 } else {
188 let expected_words: std::collections::HashSet<_> =
190 expected_lower.split_whitespace().collect();
191 let response_words: std::collections::HashSet<_> =
192 response_lower.split_whitespace().collect();
193 let overlap = expected_words.intersection(&response_words).count();
194 let total = expected_words.len().max(1);
195 (overlap as f64 / total as f64).min(0.9)
196 };
197 metrics.insert("factual_accuracy".to_string(), accuracy);
198 }
199
200 FitnessReport::from_weighted(metrics, &default_weights())
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 #[test]
209 fn test_fitness_report_simple() {
210 let report = FitnessReport::simple(0.85);
211 assert_eq!(report.overall, 0.85);
212 assert!(report.metrics.is_empty());
213 }
214
215 #[test]
216 fn test_fitness_report_weighted() {
217 let mut metrics = HashMap::new();
218 metrics.insert("a".to_string(), 1.0);
219 metrics.insert("b".to_string(), 0.5);
220
221 let mut weights = HashMap::new();
222 weights.insert("a".to_string(), 0.5);
223 weights.insert("b".to_string(), 0.5);
224
225 let report = FitnessReport::from_weighted(metrics, &weights);
226 assert!((report.overall - 0.75).abs() < 0.01);
227 }
228
229 #[test]
230 fn test_add_metric() {
231 let mut report = FitnessReport::simple(0.5);
232 report.add_metric("test", 0.9);
233 assert_eq!(report.metrics.get("test"), Some(&0.9));
234 }
235
236 #[tokio::test]
237 async fn test_heuristic_evaluator() {
238 let evaluator = HeuristicEvaluator;
239 let context = EvaluationContext::new("Explain quantum computing");
240
241 let response = "Quantum computing uses quantum bits or qubits. \
242 Unlike classical bits that are 0 or 1, qubits can be in superposition. \
243 This allows quantum computers to process many possibilities simultaneously.";
244
245 let report = evaluator.evaluate(response, &context).await;
246
247 assert!(report.overall > 0.5);
248 assert!(report.metrics.contains_key("task_completion"));
249 assert!(report.metrics.contains_key("coherence"));
250 }
251
252 #[tokio::test]
253 async fn test_evaluator_with_expected() {
254 let evaluator = HeuristicEvaluator;
255 let context = EvaluationContext::new("What is 2+2?").with_expected("4");
256
257 let report = evaluator.evaluate("The answer is 4", &context).await;
258
259 assert!(*report.metrics.get("factual_accuracy").unwrap_or(&0.0) > 0.5);
260 }
261}