1use crate::config::ModelConfig;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7
8#[derive(Debug, Clone)]
10pub struct Model {
11 pub config: Arc<ModelConfig>,
12 pub id: String,
13}
14
15impl Model {
16 pub fn new(config: ModelConfig) -> Self {
17 let id = config.id.clone();
18 Self {
19 config: Arc::new(config),
20 id,
21 }
22 }
23}
24
25#[derive(Debug, Clone)]
27pub struct ModelPool {
28 pub models: Vec<Model>,
29 by_id: HashMap<String, usize>,
30}
31
32impl ModelPool {
33 pub fn new(configs: Vec<ModelConfig>) -> Self {
34 let by_id: HashMap<String, usize> = configs
35 .iter()
36 .enumerate()
37 .map(|(i, c)| (c.id.clone(), i))
38 .collect();
39
40 let models = configs.into_iter().map(Model::new).collect();
41
42 Self { models, by_id }
43 }
44
45 pub fn get(&self, id: &str) -> Option<&Model> {
46 self.by_id.get(id).and_then(|&i| self.models.get(i))
47 }
48
49 pub fn get_all(&self) -> &[Model] {
50 &self.models
51 }
52
53 pub fn is_empty(&self) -> bool {
54 self.models.is_empty()
55 }
56
57 pub fn len(&self) -> usize {
58 self.models.len()
59 }
60
61 pub fn get_by_capability(&self, capability: &str) -> Vec<&Model> {
62 self.models
63 .iter()
64 .filter(|m| {
65 m.config
66 .capabilities
67 .iter()
68 .any(|c| format!("{:?}", c).contains(capability))
69 })
70 .collect()
71 }
72
73 pub fn get_cheapest(&self) -> Option<&Model> {
74 self.models.iter().min_by(|a, b| {
75 a.config
76 .input_cost
77 .partial_cmp(&b.config.input_cost)
78 .unwrap()
79 })
80 }
81
82 pub fn get_medium(&self) -> Option<&Model> {
83 let mut models: Vec<_> = self.models.iter().collect();
84 models.sort_by(|a, b| {
85 a.config
86 .input_cost
87 .partial_cmp(&b.config.input_cost)
88 .unwrap()
89 });
90 models.get(models.len() / 2).copied()
91 }
92
93 pub fn get_best(&self) -> Option<&Model> {
94 self.models.iter().max_by(|a, b| {
95 a.config
96 .quality_score
97 .partial_cmp(&b.config.quality_score)
98 .unwrap()
99 })
100 }
101
102 pub fn get_fastest(&self) -> Option<&Model> {
103 self.models.iter().min_by_key(|m| m.config.latency_ms)
104 }
105
106 pub fn get_best_quality(&self) -> Option<&Model> {
107 self.get_best()
108 }
109
110 pub fn get_sorted_by_cost(&self) -> Vec<&Model> {
112 let mut models: Vec<_> = self.models.iter().collect();
113 models.sort_by(|a, b| {
114 a.config
115 .input_cost
116 .partial_cmp(&b.config.input_cost)
117 .unwrap()
118 });
119 models
120 }
121
122 pub fn get_sorted_by_quality(&self) -> Vec<&Model> {
124 let mut models: Vec<_> = self.models.iter().collect();
125 models.sort_by(|a, b| {
126 b.config
127 .quality_score
128 .partial_cmp(&a.config.quality_score)
129 .unwrap()
130 });
131 models
132 }
133}
134
135impl Default for ModelPool {
136 fn default() -> Self {
137 use crate::config::default_models;
138 Self::new(default_models())
139 }
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct ChatRequest {
145 pub model: String,
146 pub messages: Vec<Message>,
147 pub temperature: Option<f64>,
148 pub max_tokens: Option<u32>,
149 pub stream: Option<bool>,
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct Message {
154 pub role: String,
155 pub content: String,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct ChatResponse {
161 pub id: String,
162 pub model: String,
163 pub choices: Vec<Choice>,
164 pub usage: Usage,
165 pub created: u64,
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct Choice {
170 pub index: u32,
171 pub message: Message,
172 pub finish_reason: Option<String>,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct Usage {
177 pub prompt_tokens: u32,
178 pub completion_tokens: u32,
179 pub total_tokens: u32,
180}