vex_router/router/
mod.rs

1//! Router - Core routing logic for VEX
2
3use serde::{Deserialize, Serialize};
4#[allow(unused_imports)]
5use std::sync::Arc;
6use thiserror::Error;
7
8use crate::classifier::{QueryClassifier, QueryComplexity};
9use crate::compress::CompressionLevel;
10use crate::models::{Model, ModelPool};
11use crate::observability::Observability;
12
13/// Routing strategy (re-exported from config)
14pub use crate::config::RoutingStrategy;
15
16/// A routing decision
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct RoutingDecision {
19    pub model_id: String,
20    pub estimated_cost: f64,
21    pub estimated_latency_ms: u64,
22    pub estimated_savings: f64,
23    pub reason: String,
24}
25
26/// Router configuration
27#[derive(Debug, Clone)]
28pub struct RouterConfig {
29    pub strategy: RoutingStrategy,
30    pub quality_threshold: f64,
31    pub max_cost_per_request: f64,
32    pub max_latency_ms: u64,
33    pub cache_enabled: bool,
34    pub guardrails_enabled: bool,
35    pub compression_level: CompressionLevel,
36}
37
38impl Default for RouterConfig {
39    fn default() -> Self {
40        Self {
41            strategy: RoutingStrategy::Auto,
42            quality_threshold: 0.85,
43            max_cost_per_request: 1.0,
44            max_latency_ms: 10000,
45            cache_enabled: true,
46            guardrails_enabled: true,
47            compression_level: CompressionLevel::Balanced,
48        }
49    }
50}
51
52/// Router errors
53#[derive(Debug, Error)]
54pub enum RouterError {
55    #[error("No models available")]
56    NoModelsAvailable,
57    #[error("Request failed: {0}")]
58    RequestFailed(String),
59    #[error("All models failed")]
60    AllModelsFailed,
61    #[error("Guardrails blocked request")]
62    GuardrailsBlocked,
63}
64
65/// The main Router - implements LlmProvider trait for VEX
66#[derive(Debug)]
67pub struct Router {
68    pool: ModelPool,
69    classifier: QueryClassifier,
70    config: RouterConfig,
71    observability: Observability,
72}
73
74impl Router {
75    /// Create a new router with default settings
76    pub fn new() -> Self {
77        Self {
78            pool: ModelPool::default(),
79            classifier: QueryClassifier::new(),
80            config: RouterConfig::default(),
81            observability: Observability::default(),
82        }
83    }
84
85    /// Create a router with a custom configuration
86    pub fn with_config(config: RouterConfig) -> Self {
87        Self {
88            pool: ModelPool::default(),
89            classifier: QueryClassifier::new(),
90            config,
91            observability: Observability::default(),
92        }
93    }
94
95    /// Get a builder for configuration
96    pub fn builder() -> RouterBuilder {
97        RouterBuilder::new()
98    }
99
100    /// Route a query and return a decision (without executing)
101    pub fn route(&self, prompt: &str, system: &str) -> Result<RoutingDecision, RouterError> {
102        let mut complexity = self.classifier.classify(prompt);
103
104        // ADVERSARIAL ROUTING: If system prompt implies an attacker/shadow role,
105        // bump the complexity/quality requirements to ensure a strong adversary.
106        let system_lower = system.to_lowercase();
107        if system_lower.contains("shadow")
108            || system_lower.contains("adversarial")
109            || system_lower.contains("red agent")
110        {
111            complexity.score = (complexity.score + 0.4).min(1.0);
112            complexity.capabilities.push("adversarial".to_string());
113        }
114
115        self.route_with_complexity(&complexity)
116    }
117
118    /// Route with pre-computed complexity
119    pub fn route_with_complexity(
120        &self,
121        complexity: &QueryComplexity,
122    ) -> Result<RoutingDecision, RouterError> {
123        if self.pool.is_empty() {
124            return Err(RouterError::NoModelsAvailable);
125        }
126
127        match self.config.strategy {
128            RoutingStrategy::Auto | RoutingStrategy::Balanced => self.route_auto(complexity),
129            RoutingStrategy::CostOptimized => self.route_cost_optimized(complexity),
130            RoutingStrategy::QualityOptimized => self.route_quality_optimized(complexity),
131            RoutingStrategy::LatencyOptimized => self.route_latency_optimized(complexity),
132            RoutingStrategy::Custom => {
133                // Fall back to auto for custom
134                self.route_auto(complexity)
135            }
136        }
137    }
138
139    /// Execute a query through the router
140    pub async fn execute(&self, prompt: &str, system: &str) -> Result<String, RouterError> {
141        let decision = self.route(prompt, system)?;
142
143        // For now, return a mock response
144        // In VEX integration, this would call the actual LLM
145        Ok(format!(
146            "[vex-router: {}] Query routed based on complexity: {:.2}, Role: {}, Estimated savings: {:.0}%",
147            decision.model_id,
148            0.5,
149            if system.to_lowercase().contains("shadow") { "Adversarial" } else { "Primary" },
150            decision.estimated_savings
151        ))
152    }
153
154    /// Convenience method - ask a question
155    pub async fn ask(&self, prompt: &str) -> Result<String, RouterError> {
156        self.execute(prompt, "").await
157    }
158
159    // =========================================================================
160    // Routing Strategies
161    // =========================================================================
162
163    fn route_auto(&self, complexity: &QueryComplexity) -> Result<RoutingDecision, RouterError> {
164        // Simple heuristic: low complexity = cheap model, high complexity = premium
165        let model = if complexity.score < 0.3 {
166            self.pool.get_cheapest()
167        } else if complexity.score < 0.7 {
168            self.pool.get_medium()
169        } else {
170            self.pool.get_best()
171        };
172
173        let model = model.ok_or(RouterError::NoModelsAvailable)?;
174
175        let savings = if complexity.score < 0.3 {
176            95.0
177        } else if complexity.score < 0.7 {
178            60.0
179        } else {
180            20.0
181        };
182
183        Ok(RoutingDecision {
184            model_id: model.id.clone(),
185            estimated_cost: model.config.input_cost,
186            estimated_latency_ms: model.config.latency_ms,
187            estimated_savings: savings,
188            reason: format!(
189                "Auto-selected based on complexity score: {:.2}",
190                complexity.score
191            ),
192        })
193    }
194
195    fn route_cost_optimized(
196        &self,
197        _complexity: &QueryComplexity,
198    ) -> Result<RoutingDecision, RouterError> {
199        // Find cheapest model that meets quality threshold
200        let mut models: Vec<&Model> = self.pool.models.iter().collect();
201        models.sort_by(|a, b| {
202            a.config
203                .input_cost
204                .partial_cmp(&b.config.input_cost)
205                .unwrap()
206        });
207
208        for model in models {
209            let meets_quality = model.config.quality_score >= self.config.quality_threshold;
210            if meets_quality {
211                return Ok(RoutingDecision {
212                    model_id: model.id.clone(),
213                    estimated_cost: model.config.input_cost,
214                    estimated_latency_ms: model.config.latency_ms,
215                    estimated_savings: 80.0,
216                    reason: "Cost-optimized: cheapest model meeting quality threshold".to_string(),
217                });
218            }
219        }
220
221        Err(RouterError::NoModelsAvailable)
222    }
223
224    fn route_quality_optimized(
225        &self,
226        _complexity: &QueryComplexity,
227    ) -> Result<RoutingDecision, RouterError> {
228        let model = self.pool.get_best().ok_or(RouterError::NoModelsAvailable)?;
229
230        Ok(RoutingDecision {
231            model_id: model.id.clone(),
232            estimated_cost: model.config.input_cost,
233            estimated_latency_ms: model.config.latency_ms,
234            estimated_savings: 0.0,
235            reason: "Quality-optimized: selected best available model".to_string(),
236        })
237    }
238
239    fn route_latency_optimized(
240        &self,
241        _complexity: &QueryComplexity,
242    ) -> Result<RoutingDecision, RouterError> {
243        let mut models: Vec<&Model> = self.pool.models.iter().collect();
244        models.sort_by(|a, b| a.config.latency_ms.cmp(&b.config.latency_ms));
245
246        let model = models.first().ok_or(RouterError::NoModelsAvailable)?;
247
248        Ok(RoutingDecision {
249            model_id: model.id.clone(),
250            estimated_cost: model.config.input_cost,
251            estimated_latency_ms: model.config.latency_ms,
252            estimated_savings: 50.0,
253            reason: "Latency-optimized: fastest model".to_string(),
254        })
255    }
256
257    /// Get the current configuration
258    pub fn config(&self) -> &RouterConfig {
259        &self.config
260    }
261
262    /// Get the model pool
263    pub fn pool(&self) -> &ModelPool {
264        &self.pool
265    }
266
267    /// Get the observability metrics
268    pub fn observability(&self) -> &Observability {
269        &self.observability
270    }
271}
272
273impl Default for Router {
274    fn default() -> Self {
275        Self::new()
276    }
277}
278
279/// Builder for Router
280#[derive(Debug)]
281pub struct RouterBuilder {
282    config: RouterConfig,
283    custom_models: Vec<crate::config::ModelConfig>,
284}
285
286impl RouterBuilder {
287    pub fn new() -> Self {
288        Self {
289            config: RouterConfig::default(),
290            custom_models: Vec::new(),
291        }
292    }
293
294    pub fn strategy(mut self, strategy: RoutingStrategy) -> Self {
295        self.config.strategy = strategy;
296        self
297    }
298
299    pub fn quality_threshold(mut self, threshold: f64) -> Self {
300        self.config.quality_threshold = threshold;
301        self
302    }
303
304    pub fn max_cost(mut self, cost: f64) -> Self {
305        self.config.max_cost_per_request = cost;
306        self
307    }
308
309    pub fn cache_enabled(mut self, enabled: bool) -> Self {
310        self.config.cache_enabled = enabled;
311        self
312    }
313
314    pub fn guardrails_enabled(mut self, enabled: bool) -> Self {
315        self.config.guardrails_enabled = enabled;
316        self
317    }
318
319    pub fn compression_level(mut self, level: crate::compress::CompressionLevel) -> Self {
320        self.config.compression_level = level;
321        self
322    }
323
324    pub fn add_model(mut self, model: crate::config::ModelConfig) -> Self {
325        self.custom_models.push(model);
326        self
327    }
328
329    pub fn build(self) -> Router {
330        let pool = if self.custom_models.is_empty() {
331            ModelPool::default()
332        } else {
333            ModelPool::new(self.custom_models)
334        };
335
336        Router {
337            pool,
338            classifier: QueryClassifier::new(),
339            config: self.config,
340            observability: Observability::new(1000),
341        }
342    }
343}
344
345impl Default for RouterBuilder {
346    fn default() -> Self {
347        Self::new()
348    }
349}
350
351// =============================================================================
352// VEX LlmProvider Trait Implementation (for VEX integration)
353// =============================================================================
354
355// Re-using official VEX LLM types
356use async_trait::async_trait;
357use vex_llm::{LlmError, LlmProvider, LlmRequest, LlmResponse};
358
359#[async_trait]
360impl LlmProvider for Router {
361    /// Complete a request (implements vex_llm::LlmProvider::complete)
362    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
363        let start = std::time::Instant::now();
364
365        let response = self
366            .execute(&request.prompt, &request.system)
367            .await
368            .map_err(|e| LlmError::RequestFailed(e.to_string()))?;
369
370        let response_len = response.len();
371        let latency = start.elapsed().as_millis() as u64;
372
373        let decision = self
374            .route(&request.prompt, &request.system)
375            .map_err(|e| LlmError::RequestFailed(e.to_string()))?;
376
377        Ok(LlmResponse {
378            content: response,
379            model: decision.model_id,
380            tokens_used: Some(((request.prompt.len() + response_len) as f64 / 4.0) as u32),
381            latency_ms: latency,
382            trace_root: None,
383        })
384    }
385
386    /// Check if router is available
387    async fn is_available(&self) -> bool {
388        !self.pool.is_empty()
389    }
390
391    /// Get provider name
392    fn name(&self) -> &str {
393        "vex-router"
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400
401    #[tokio::test]
402    async fn test_router_auto() {
403        let router = Router::builder().strategy(RoutingStrategy::Auto).build();
404
405        let decision = router.route("What is 2+2?", "").unwrap();
406        assert!(!decision.model_id.is_empty());
407    }
408
409    #[tokio::test]
410    async fn test_router_execute() {
411        let router = Router::new();
412        let response = router.ask("Hello").await.unwrap();
413        assert!(response.contains("vex-router"));
414    }
415
416    #[test]
417    fn test_router_builder() {
418        let router = Router::builder()
419            .strategy(RoutingStrategy::CostOptimized)
420            .quality_threshold(0.9)
421            .cache_enabled(false)
422            .build();
423
424        assert_eq!(router.config().strategy, RoutingStrategy::CostOptimized);
425        assert_eq!(router.config().quality_threshold, 0.9);
426        assert!(!router.config().cache_enabled);
427    }
428
429    #[tokio::test]
430    async fn test_llm_request() {
431        let request = LlmRequest::simple("test");
432        assert_eq!(request.system, "You are a helpful assistant.");
433        assert_eq!(request.prompt, "test");
434    }
435}