vex_llm/
mistral.rs

1//! Mistral AI LLM provider (OpenAI-compatible API)
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::time::Instant;
6
7use crate::provider::{LlmError, LlmProvider, LlmRequest, LlmResponse};
8
9/// Mistral API request format (OpenAI-compatible)
10#[derive(Debug, Serialize)]
11struct MistralRequest {
12    model: String,
13    messages: Vec<Message>,
14    temperature: f32,
15    max_tokens: u32,
16}
17
18#[derive(Debug, Serialize)]
19struct Message {
20    role: String,
21    content: String,
22}
23
24/// Mistral API response format
25#[derive(Debug, Deserialize)]
26struct MistralResponse {
27    choices: Vec<Choice>,
28    model: String,
29    usage: Option<Usage>,
30}
31
32#[derive(Debug, Deserialize)]
33struct Choice {
34    message: MessageContent,
35}
36
37#[derive(Debug, Deserialize)]
38struct MessageContent {
39    content: String,
40}
41
42#[derive(Debug, Deserialize)]
43struct Usage {
44    total_tokens: u32,
45}
46
47/// Mistral AI provider for inference
48#[derive(Debug)]
49pub struct MistralProvider {
50    /// API key
51    api_key: String,
52    /// Model to use (e.g., "mistral-large-latest", "mistral-small-latest", "codestral-latest")
53    model: String,
54    /// HTTP client
55    client: reqwest::Client,
56    /// Base URL
57    base_url: String,
58}
59
60impl MistralProvider {
61    /// Create a new Mistral provider
62    pub fn new(api_key: &str, model: &str) -> Self {
63        Self {
64            api_key: api_key.to_string(),
65            model: model.to_string(),
66            client: reqwest::Client::new(),
67            base_url: "https://api.mistral.ai".to_string(),
68        }
69    }
70
71    /// Create with Mistral Large 3 (state-of-the-art, open-weight, general-purpose multimodal model)
72    pub fn large(api_key: &str) -> Self {
73        Self::new(api_key, "mistral-large-latest")
74    }
75
76    /// Create with Mistral Medium 3.1 (frontier-class multimodal model)
77    pub fn medium(api_key: &str) -> Self {
78        Self::new(api_key, "mistral-medium-latest")
79    }
80
81    /// Create with Mistral Small 3.2 (fast and efficient)
82    pub fn small(api_key: &str) -> Self {
83        Self::new(api_key, "mistral-small-latest")
84    }
85
86    /// Create with Codestral (cutting-edge code generation model)
87    pub fn codestral(api_key: &str) -> Self {
88        Self::new(api_key, "codestral-latest")
89    }
90
91    /// Create with Devstral (excels at software engineering use cases)
92    pub fn devstral(api_key: &str) -> Self {
93        Self::new(api_key, "devstral-small-latest")
94    }
95
96    /// Create with Ministral 8B (lightweight model with best-in-class text and vision)
97    pub fn ministral_8b(api_key: &str) -> Self {
98        Self::new(api_key, "ministral-8b-latest")
99    }
100
101    /// Create with Ministral 3B (tiny and efficient model)
102    pub fn ministral_3b(api_key: &str) -> Self {
103        Self::new(api_key, "ministral-3b-latest")
104    }
105
106    /// Create with Pixtral Large (frontier-class multimodal vision model)
107    pub fn pixtral(api_key: &str) -> Self {
108        Self::new(api_key, "pixtral-large-latest")
109    }
110
111    /// Create with Mistral Nemo 12B (multilingual open source model)
112    pub fn nemo(api_key: &str) -> Self {
113        Self::new(api_key, "open-mistral-nemo")
114    }
115
116    /// Set a custom base URL (useful for self-hosted or proxy setups)
117    pub fn with_base_url(mut self, base_url: &str) -> Self {
118        self.base_url = base_url.to_string();
119        self
120    }
121}
122
123#[async_trait]
124impl LlmProvider for MistralProvider {
125    fn name(&self) -> &str {
126        "mistral"
127    }
128
129    async fn is_available(&self) -> bool {
130        // Simple check - try to reach the API
131        self.client
132            .get(format!("{}/v1/models", self.base_url))
133            .bearer_auth(&self.api_key)
134            .send()
135            .await
136            .is_ok()
137    }
138
139    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
140        let start = Instant::now();
141        let url = format!("{}/v1/chat/completions", self.base_url);
142
143        let messages = vec![
144            Message {
145                role: "system".to_string(),
146                content: request.system,
147            },
148            Message {
149                role: "user".to_string(),
150                content: request.prompt,
151            },
152        ];
153
154        let mistral_request = MistralRequest {
155            model: self.model.clone(),
156            messages,
157            temperature: request.temperature,
158            max_tokens: request.max_tokens,
159        };
160
161        let response = self
162            .client
163            .post(&url)
164            .bearer_auth(&self.api_key)
165            .json(&mistral_request)
166            .send()
167            .await
168            .map_err(|e| LlmError::ConnectionFailed(e.to_string()))?;
169
170        if !response.status().is_success() {
171            let status = response.status();
172            let body = response.text().await.unwrap_or_default();
173
174            // Handle rate limiting specifically
175            if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
176                return Err(LlmError::RateLimited);
177            }
178
179            return Err(LlmError::RequestFailed(format!(
180                "Status: {}, Body: {}",
181                status, body
182            )));
183        }
184
185        let api_response: MistralResponse = response
186            .json()
187            .await
188            .map_err(|e| LlmError::InvalidResponse(e.to_string()))?;
189
190        let content = api_response
191            .choices
192            .first()
193            .map(|c| c.message.content.clone())
194            .unwrap_or_default();
195
196        Ok(LlmResponse {
197            content,
198            model: api_response.model,
199            tokens_used: api_response.usage.map(|u| u.total_tokens),
200            latency_ms: start.elapsed().as_millis() as u64,
201            trace_root: None,
202        })
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[tokio::test]
211    #[ignore] // Requires valid API key
212    async fn test_mistral() {
213        let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
214        let provider = MistralProvider::small(&api_key);
215
216        if provider.is_available().await {
217            let response = provider.ask("Say hello in one word").await.unwrap();
218            assert!(!response.is_empty());
219            println!("Mistral response: {}", response);
220        }
221    }
222
223    #[tokio::test]
224    #[ignore] // Requires valid API key
225    async fn test_mistral_large() {
226        let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
227        let provider = MistralProvider::large(&api_key);
228
229        if provider.is_available().await {
230            let response = provider.ask("What is 2+2?").await.unwrap();
231            assert!(!response.is_empty());
232            println!("Mistral Large response: {}", response);
233        }
234    }
235
236    #[tokio::test]
237    #[ignore] // Requires valid API key
238    async fn test_codestral() {
239        let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
240        let provider = MistralProvider::codestral(&api_key);
241
242        if provider.is_available().await {
243            let response = provider
244                .ask("Write a simple hello world function in Rust")
245                .await
246                .unwrap();
247            assert!(!response.is_empty());
248            println!("Codestral response: {}", response);
249        }
250    }
251}