1use serde::{Deserialize, Serialize};
4#[allow(unused_imports)]
5use std::sync::Arc;
6use thiserror::Error;
7
8use crate::classifier::{QueryClassifier, QueryComplexity};
9use crate::compress::CompressionLevel;
10use crate::models::{Model, ModelPool};
11use crate::observability::Observability;
12
13pub use crate::config::RoutingStrategy;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct RoutingDecision {
19 pub model_id: String,
20 pub estimated_cost: f64,
21 pub estimated_latency_ms: u64,
22 pub estimated_savings: f64,
23 pub reason: String,
24}
25
26#[derive(Debug, Clone)]
28pub struct RouterConfig {
29 pub strategy: RoutingStrategy,
30 pub quality_threshold: f64,
31 pub max_cost_per_request: f64,
32 pub max_latency_ms: u64,
33 pub cache_enabled: bool,
34 pub guardrails_enabled: bool,
35 pub compression_level: CompressionLevel,
36}
37
38impl Default for RouterConfig {
39 fn default() -> Self {
40 Self {
41 strategy: RoutingStrategy::Auto,
42 quality_threshold: 0.85,
43 max_cost_per_request: 1.0,
44 max_latency_ms: 10000,
45 cache_enabled: true,
46 guardrails_enabled: true,
47 compression_level: CompressionLevel::Balanced,
48 }
49 }
50}
51
52#[derive(Debug, Error)]
54pub enum RouterError {
55 #[error("No models available")]
56 NoModelsAvailable,
57 #[error("Request failed: {0}")]
58 RequestFailed(String),
59 #[error("All models failed")]
60 AllModelsFailed,
61 #[error("Guardrails blocked request")]
62 GuardrailsBlocked,
63}
64
65#[derive(Debug)]
67pub struct Router {
68 pool: ModelPool,
69 classifier: QueryClassifier,
70 config: RouterConfig,
71 observability: Observability,
72}
73
74impl Router {
75 pub fn new() -> Self {
77 Self {
78 pool: ModelPool::default(),
79 classifier: QueryClassifier::new(),
80 config: RouterConfig::default(),
81 observability: Observability::default(),
82 }
83 }
84
85 pub fn with_config(config: RouterConfig) -> Self {
87 Self {
88 pool: ModelPool::default(),
89 classifier: QueryClassifier::new(),
90 config,
91 observability: Observability::default(),
92 }
93 }
94
95 pub fn builder() -> RouterBuilder {
97 RouterBuilder::new()
98 }
99
100 pub fn route(&self, prompt: &str, system: &str) -> Result<RoutingDecision, RouterError> {
102 let mut complexity = self.classifier.classify(prompt);
103
104 let system_lower = system.to_lowercase();
107 if system_lower.contains("shadow")
108 || system_lower.contains("adversarial")
109 || system_lower.contains("red agent")
110 {
111 complexity.score = (complexity.score + 0.4).min(1.0);
112 complexity.capabilities.push("adversarial".to_string());
113 }
114
115 self.route_with_complexity(&complexity)
116 }
117
118 pub fn route_with_complexity(
120 &self,
121 complexity: &QueryComplexity,
122 ) -> Result<RoutingDecision, RouterError> {
123 if self.pool.is_empty() {
124 return Err(RouterError::NoModelsAvailable);
125 }
126
127 match self.config.strategy {
128 RoutingStrategy::Auto | RoutingStrategy::Balanced => self.route_auto(complexity),
129 RoutingStrategy::CostOptimized => self.route_cost_optimized(complexity),
130 RoutingStrategy::QualityOptimized => self.route_quality_optimized(complexity),
131 RoutingStrategy::LatencyOptimized => self.route_latency_optimized(complexity),
132 RoutingStrategy::Custom => {
133 self.route_auto(complexity)
135 }
136 }
137 }
138
139 pub async fn execute(&self, prompt: &str, system: &str) -> Result<String, RouterError> {
141 let decision = self.route(prompt, system)?;
142
143 Ok(format!(
146 "[vex-router: {}] Query routed based on complexity: {:.2}, Role: {}, Estimated savings: {:.0}%",
147 decision.model_id,
148 0.5,
149 if system.to_lowercase().contains("shadow") { "Adversarial" } else { "Primary" },
150 decision.estimated_savings
151 ))
152 }
153
154 pub async fn ask(&self, prompt: &str) -> Result<String, RouterError> {
156 self.execute(prompt, "").await
157 }
158
159 fn route_auto(&self, complexity: &QueryComplexity) -> Result<RoutingDecision, RouterError> {
164 let model = if complexity.score < 0.3 {
166 self.pool.get_cheapest()
167 } else if complexity.score < 0.7 {
168 self.pool.get_medium()
169 } else {
170 self.pool.get_best()
171 };
172
173 let model = model.ok_or(RouterError::NoModelsAvailable)?;
174
175 let savings = if complexity.score < 0.3 {
176 95.0
177 } else if complexity.score < 0.7 {
178 60.0
179 } else {
180 20.0
181 };
182
183 Ok(RoutingDecision {
184 model_id: model.id.clone(),
185 estimated_cost: model.config.input_cost,
186 estimated_latency_ms: model.config.latency_ms,
187 estimated_savings: savings,
188 reason: format!(
189 "Auto-selected based on complexity score: {:.2}",
190 complexity.score
191 ),
192 })
193 }
194
195 fn route_cost_optimized(
196 &self,
197 _complexity: &QueryComplexity,
198 ) -> Result<RoutingDecision, RouterError> {
199 let mut models: Vec<&Model> = self.pool.models.iter().collect();
201 models.sort_by(|a, b| {
202 a.config
203 .input_cost
204 .partial_cmp(&b.config.input_cost)
205 .unwrap()
206 });
207
208 for model in models {
209 let meets_quality = model.config.quality_score >= self.config.quality_threshold;
210 if meets_quality {
211 return Ok(RoutingDecision {
212 model_id: model.id.clone(),
213 estimated_cost: model.config.input_cost,
214 estimated_latency_ms: model.config.latency_ms,
215 estimated_savings: 80.0,
216 reason: "Cost-optimized: cheapest model meeting quality threshold".to_string(),
217 });
218 }
219 }
220
221 Err(RouterError::NoModelsAvailable)
222 }
223
224 fn route_quality_optimized(
225 &self,
226 _complexity: &QueryComplexity,
227 ) -> Result<RoutingDecision, RouterError> {
228 let model = self.pool.get_best().ok_or(RouterError::NoModelsAvailable)?;
229
230 Ok(RoutingDecision {
231 model_id: model.id.clone(),
232 estimated_cost: model.config.input_cost,
233 estimated_latency_ms: model.config.latency_ms,
234 estimated_savings: 0.0,
235 reason: "Quality-optimized: selected best available model".to_string(),
236 })
237 }
238
239 fn route_latency_optimized(
240 &self,
241 _complexity: &QueryComplexity,
242 ) -> Result<RoutingDecision, RouterError> {
243 let mut models: Vec<&Model> = self.pool.models.iter().collect();
244 models.sort_by(|a, b| a.config.latency_ms.cmp(&b.config.latency_ms));
245
246 let model = models.first().ok_or(RouterError::NoModelsAvailable)?;
247
248 Ok(RoutingDecision {
249 model_id: model.id.clone(),
250 estimated_cost: model.config.input_cost,
251 estimated_latency_ms: model.config.latency_ms,
252 estimated_savings: 50.0,
253 reason: "Latency-optimized: fastest model".to_string(),
254 })
255 }
256
257 pub fn config(&self) -> &RouterConfig {
259 &self.config
260 }
261
262 pub fn pool(&self) -> &ModelPool {
264 &self.pool
265 }
266
267 pub fn observability(&self) -> &Observability {
269 &self.observability
270 }
271}
272
273impl Default for Router {
274 fn default() -> Self {
275 Self::new()
276 }
277}
278
279#[derive(Debug)]
281pub struct RouterBuilder {
282 config: RouterConfig,
283 custom_models: Vec<crate::config::ModelConfig>,
284}
285
286impl RouterBuilder {
287 pub fn new() -> Self {
288 Self {
289 config: RouterConfig::default(),
290 custom_models: Vec::new(),
291 }
292 }
293
294 pub fn strategy(mut self, strategy: RoutingStrategy) -> Self {
295 self.config.strategy = strategy;
296 self
297 }
298
299 pub fn quality_threshold(mut self, threshold: f64) -> Self {
300 self.config.quality_threshold = threshold;
301 self
302 }
303
304 pub fn max_cost(mut self, cost: f64) -> Self {
305 self.config.max_cost_per_request = cost;
306 self
307 }
308
309 pub fn cache_enabled(mut self, enabled: bool) -> Self {
310 self.config.cache_enabled = enabled;
311 self
312 }
313
314 pub fn guardrails_enabled(mut self, enabled: bool) -> Self {
315 self.config.guardrails_enabled = enabled;
316 self
317 }
318
319 pub fn compression_level(mut self, level: crate::compress::CompressionLevel) -> Self {
320 self.config.compression_level = level;
321 self
322 }
323
324 pub fn add_model(mut self, model: crate::config::ModelConfig) -> Self {
325 self.custom_models.push(model);
326 self
327 }
328
329 pub fn build(self) -> Router {
330 let pool = if self.custom_models.is_empty() {
331 ModelPool::default()
332 } else {
333 ModelPool::new(self.custom_models)
334 };
335
336 Router {
337 pool,
338 classifier: QueryClassifier::new(),
339 config: self.config,
340 observability: Observability::new(1000),
341 }
342 }
343}
344
345impl Default for RouterBuilder {
346 fn default() -> Self {
347 Self::new()
348 }
349}
350
351use async_trait::async_trait;
357use vex_llm::{LlmError, LlmProvider, LlmRequest, LlmResponse};
358
359#[async_trait]
360impl LlmProvider for Router {
361 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
363 let start = std::time::Instant::now();
364
365 let response = self
366 .execute(&request.prompt, &request.system)
367 .await
368 .map_err(|e| LlmError::RequestFailed(e.to_string()))?;
369
370 let response_len = response.len();
371 let latency = start.elapsed().as_millis() as u64;
372
373 let decision = self
374 .route(&request.prompt, &request.system)
375 .map_err(|e| LlmError::RequestFailed(e.to_string()))?;
376
377 Ok(LlmResponse {
378 content: response,
379 model: decision.model_id,
380 tokens_used: Some(((request.prompt.len() + response_len) as f64 / 4.0) as u32),
381 latency_ms: latency,
382 trace_root: None,
383 })
384 }
385
386 async fn is_available(&self) -> bool {
388 !self.pool.is_empty()
389 }
390
391 fn name(&self) -> &str {
393 "vex-router"
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400
401 #[tokio::test]
402 async fn test_router_auto() {
403 let router = Router::builder().strategy(RoutingStrategy::Auto).build();
404
405 let decision = router.route("What is 2+2?", "").unwrap();
406 assert!(!decision.model_id.is_empty());
407 }
408
409 #[tokio::test]
410 async fn test_router_execute() {
411 let router = Router::new();
412 let response = router.ask("Hello").await.unwrap();
413 assert!(response.contains("vex-router"));
414 }
415
416 #[test]
417 fn test_router_builder() {
418 let router = Router::builder()
419 .strategy(RoutingStrategy::CostOptimized)
420 .quality_threshold(0.9)
421 .cache_enabled(false)
422 .build();
423
424 assert_eq!(router.config().strategy, RoutingStrategy::CostOptimized);
425 assert_eq!(router.config().quality_threshold, 0.9);
426 assert!(!router.config().cache_enabled);
427 }
428
429 #[tokio::test]
430 async fn test_llm_request() {
431 let request = LlmRequest::simple("test");
432 assert_eq!(request.system, "You are a helpful assistant.");
433 assert_eq!(request.prompt, "test");
434 }
435}