1use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::time::Instant;
6
7use crate::provider::{LlmError, LlmProvider, LlmRequest, LlmResponse};
8
9#[derive(Debug, Serialize)]
11struct MistralRequest {
12 model: String,
13 messages: Vec<Message>,
14 temperature: f32,
15 max_tokens: u32,
16}
17
18#[derive(Debug, Serialize)]
19struct Message {
20 role: String,
21 content: String,
22}
23
24#[derive(Debug, Deserialize)]
26struct MistralResponse {
27 choices: Vec<Choice>,
28 model: String,
29 usage: Option<Usage>,
30}
31
32#[derive(Debug, Deserialize)]
33struct Choice {
34 message: MessageContent,
35}
36
37#[derive(Debug, Deserialize)]
38struct MessageContent {
39 content: String,
40}
41
42#[derive(Debug, Deserialize)]
43struct Usage {
44 total_tokens: u32,
45}
46
47#[derive(Debug)]
49pub struct MistralProvider {
50 api_key: String,
52 model: String,
54 client: reqwest::Client,
56 base_url: String,
58}
59
60impl MistralProvider {
61 pub fn new(api_key: &str, model: &str) -> Self {
63 Self {
64 api_key: api_key.to_string(),
65 model: model.to_string(),
66 client: reqwest::Client::new(),
67 base_url: "https://api.mistral.ai".to_string(),
68 }
69 }
70
71 pub fn large(api_key: &str) -> Self {
73 Self::new(api_key, "mistral-large-latest")
74 }
75
76 pub fn medium(api_key: &str) -> Self {
78 Self::new(api_key, "mistral-medium-latest")
79 }
80
81 pub fn small(api_key: &str) -> Self {
83 Self::new(api_key, "mistral-small-latest")
84 }
85
86 pub fn codestral(api_key: &str) -> Self {
88 Self::new(api_key, "codestral-latest")
89 }
90
91 pub fn devstral(api_key: &str) -> Self {
93 Self::new(api_key, "devstral-small-latest")
94 }
95
96 pub fn ministral_8b(api_key: &str) -> Self {
98 Self::new(api_key, "ministral-8b-latest")
99 }
100
101 pub fn ministral_3b(api_key: &str) -> Self {
103 Self::new(api_key, "ministral-3b-latest")
104 }
105
106 pub fn pixtral(api_key: &str) -> Self {
108 Self::new(api_key, "pixtral-large-latest")
109 }
110
111 pub fn nemo(api_key: &str) -> Self {
113 Self::new(api_key, "open-mistral-nemo")
114 }
115
116 pub fn with_base_url(mut self, base_url: &str) -> Self {
118 self.base_url = base_url.to_string();
119 self
120 }
121}
122
123#[async_trait]
124impl LlmProvider for MistralProvider {
125 fn name(&self) -> &str {
126 "mistral"
127 }
128
129 async fn is_available(&self) -> bool {
130 self.client
132 .get(format!("{}/v1/models", self.base_url))
133 .bearer_auth(&self.api_key)
134 .send()
135 .await
136 .is_ok()
137 }
138
139 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
140 let start = Instant::now();
141 let url = format!("{}/v1/chat/completions", self.base_url);
142
143 let messages = vec![
144 Message {
145 role: "system".to_string(),
146 content: request.system,
147 },
148 Message {
149 role: "user".to_string(),
150 content: request.prompt,
151 },
152 ];
153
154 let mistral_request = MistralRequest {
155 model: self.model.clone(),
156 messages,
157 temperature: request.temperature,
158 max_tokens: request.max_tokens,
159 };
160
161 let response = self
162 .client
163 .post(&url)
164 .bearer_auth(&self.api_key)
165 .json(&mistral_request)
166 .send()
167 .await
168 .map_err(|e| LlmError::ConnectionFailed(e.to_string()))?;
169
170 if !response.status().is_success() {
171 let status = response.status();
172 let body = response.text().await.unwrap_or_default();
173
174 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
176 return Err(LlmError::RateLimited);
177 }
178
179 return Err(LlmError::RequestFailed(format!(
180 "Status: {}, Body: {}",
181 status, body
182 )));
183 }
184
185 let api_response: MistralResponse = response
186 .json()
187 .await
188 .map_err(|e| LlmError::InvalidResponse(e.to_string()))?;
189
190 let content = api_response
191 .choices
192 .first()
193 .map(|c| c.message.content.clone())
194 .unwrap_or_default();
195
196 Ok(LlmResponse {
197 content,
198 model: api_response.model,
199 tokens_used: api_response.usage.map(|u| u.total_tokens),
200 latency_ms: start.elapsed().as_millis() as u64,
201 trace_root: None,
202 })
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[tokio::test]
211 #[ignore] async fn test_mistral() {
213 let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
214 let provider = MistralProvider::small(&api_key);
215
216 if provider.is_available().await {
217 let response = provider.ask("Say hello in one word").await.unwrap();
218 assert!(!response.is_empty());
219 println!("Mistral response: {}", response);
220 }
221 }
222
223 #[tokio::test]
224 #[ignore] async fn test_mistral_large() {
226 let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
227 let provider = MistralProvider::large(&api_key);
228
229 if provider.is_available().await {
230 let response = provider.ask("What is 2+2?").await.unwrap();
231 assert!(!response.is_empty());
232 println!("Mistral Large response: {}", response);
233 }
234 }
235
236 #[tokio::test]
237 #[ignore] async fn test_codestral() {
239 let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
240 let provider = MistralProvider::codestral(&api_key);
241
242 if provider.is_available().await {
243 let response = provider
244 .ask("Write a simple hello world function in Rust")
245 .await
246 .unwrap();
247 assert!(!response.is_empty());
248 println!("Codestral response: {}", response);
249 }
250 }
251}