vex_router/models/
mod.rs

1//! Models module - Model pool and backend integrations
2
3use crate::config::ModelConfig;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7
8/// A model in our pool
9#[derive(Debug, Clone)]
10pub struct Model {
11    pub config: Arc<ModelConfig>,
12    pub id: String,
13}
14
15impl Model {
16    pub fn new(config: ModelConfig) -> Self {
17        let id = config.id.clone();
18        Self {
19            config: Arc::new(config),
20            id,
21        }
22    }
23}
24
25/// Model pool - manages available models
26#[derive(Debug, Clone)]
27pub struct ModelPool {
28    pub models: Vec<Model>,
29    by_id: HashMap<String, usize>,
30}
31
32impl ModelPool {
33    pub fn new(configs: Vec<ModelConfig>) -> Self {
34        let by_id: HashMap<String, usize> = configs
35            .iter()
36            .enumerate()
37            .map(|(i, c)| (c.id.clone(), i))
38            .collect();
39
40        let models = configs.into_iter().map(Model::new).collect();
41
42        Self { models, by_id }
43    }
44
45    pub fn get(&self, id: &str) -> Option<&Model> {
46        self.by_id.get(id).and_then(|&i| self.models.get(i))
47    }
48
49    pub fn get_all(&self) -> &[Model] {
50        &self.models
51    }
52
53    pub fn is_empty(&self) -> bool {
54        self.models.is_empty()
55    }
56
57    pub fn len(&self) -> usize {
58        self.models.len()
59    }
60
61    pub fn get_by_capability(&self, capability: &str) -> Vec<&Model> {
62        self.models
63            .iter()
64            .filter(|m| {
65                m.config
66                    .capabilities
67                    .iter()
68                    .any(|c| format!("{:?}", c).contains(capability))
69            })
70            .collect()
71    }
72
73    pub fn get_cheapest(&self) -> Option<&Model> {
74        self.models.iter().min_by(|a, b| {
75            a.config
76                .input_cost
77                .partial_cmp(&b.config.input_cost)
78                .unwrap()
79        })
80    }
81
82    pub fn get_medium(&self) -> Option<&Model> {
83        let mut models: Vec<_> = self.models.iter().collect();
84        models.sort_by(|a, b| {
85            a.config
86                .input_cost
87                .partial_cmp(&b.config.input_cost)
88                .unwrap()
89        });
90        models.get(models.len() / 2).copied()
91    }
92
93    pub fn get_best(&self) -> Option<&Model> {
94        self.models.iter().max_by(|a, b| {
95            a.config
96                .quality_score
97                .partial_cmp(&b.config.quality_score)
98                .unwrap()
99        })
100    }
101
102    pub fn get_fastest(&self) -> Option<&Model> {
103        self.models.iter().min_by_key(|m| m.config.latency_ms)
104    }
105
106    pub fn get_best_quality(&self) -> Option<&Model> {
107        self.get_best()
108    }
109
110    /// Get models sorted by cost (ascending)
111    pub fn get_sorted_by_cost(&self) -> Vec<&Model> {
112        let mut models: Vec<_> = self.models.iter().collect();
113        models.sort_by(|a, b| {
114            a.config
115                .input_cost
116                .partial_cmp(&b.config.input_cost)
117                .unwrap()
118        });
119        models
120    }
121
122    /// Get models sorted by quality (descending)
123    pub fn get_sorted_by_quality(&self) -> Vec<&Model> {
124        let mut models: Vec<_> = self.models.iter().collect();
125        models.sort_by(|a, b| {
126            b.config
127                .quality_score
128                .partial_cmp(&a.config.quality_score)
129                .unwrap()
130        });
131        models
132    }
133}
134
135impl Default for ModelPool {
136    fn default() -> Self {
137        use crate::config::default_models;
138        Self::new(default_models())
139    }
140}
141
142/// Request to a model
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct ChatRequest {
145    pub model: String,
146    pub messages: Vec<Message>,
147    pub temperature: Option<f64>,
148    pub max_tokens: Option<u32>,
149    pub stream: Option<bool>,
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct Message {
154    pub role: String,
155    pub content: String,
156}
157
158/// Response from a model
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct ChatResponse {
161    pub id: String,
162    pub model: String,
163    pub choices: Vec<Choice>,
164    pub usage: Usage,
165    pub created: u64,
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct Choice {
170    pub index: u32,
171    pub message: Message,
172    pub finish_reason: Option<String>,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct Usage {
177    pub prompt_tokens: u32,
178    pub completion_tokens: u32,
179    pub total_tokens: u32,
180}