vex_llm/
deepseek.rs

1//! DeepSeek LLM provider (OpenAI-compatible API)
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::time::Instant;
6
7use crate::provider::{LlmError, LlmProvider, LlmRequest, LlmResponse};
8
9/// DeepSeek API request format (OpenAI-compatible)
10#[derive(Debug, Serialize)]
11struct DeepSeekRequest {
12    model: String,
13    messages: Vec<Message>,
14    temperature: f32,
15    max_tokens: u32,
16}
17
18#[derive(Debug, Serialize)]
19struct Message {
20    role: String,
21    content: String,
22}
23
24/// DeepSeek API response format
25#[derive(Debug, Deserialize)]
26struct DeepSeekResponse {
27    choices: Vec<Choice>,
28    model: String,
29    usage: Option<Usage>,
30}
31
32#[derive(Debug, Deserialize)]
33struct Choice {
34    message: MessageContent,
35}
36
37#[derive(Debug, Deserialize)]
38struct MessageContent {
39    content: String,
40}
41
42#[derive(Debug, Deserialize)]
43struct Usage {
44    total_tokens: u32,
45}
46
47/// DeepSeek provider for inference
48#[derive(Debug, Clone)]
49pub struct DeepSeekProvider {
50    /// API key
51    api_key: String,
52    /// Model to use (e.g., "deepseek-chat", "deepseek-coder")
53    model: String,
54    /// HTTP client
55    client: reqwest::Client,
56    /// Base URL
57    base_url: String,
58}
59
60impl DeepSeekProvider {
61    /// Create a new DeepSeek provider
62    pub fn new(api_key: &str, model: &str) -> Self {
63        Self {
64            api_key: api_key.to_string(),
65            model: model.to_string(),
66            client: reqwest::Client::new(),
67            base_url: "https://api.deepseek.com".to_string(),
68        }
69    }
70
71    /// Create with default chat model
72    pub fn chat(api_key: &str) -> Self {
73        Self::new(api_key, "deepseek-chat")
74    }
75
76    /// Create with coder model
77    pub fn coder(api_key: &str) -> Self {
78        Self::new(api_key, "deepseek-coder")
79    }
80}
81
82#[async_trait]
83impl LlmProvider for DeepSeekProvider {
84    fn name(&self) -> &str {
85        "deepseek"
86    }
87
88    async fn is_available(&self) -> bool {
89        // Simple check - try to reach the API
90        self.client
91            .get(format!("{}/v1/models", self.base_url))
92            .bearer_auth(&self.api_key)
93            .send()
94            .await
95            .is_ok()
96    }
97
98    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
99        let start = Instant::now();
100        let url = format!("{}/v1/chat/completions", self.base_url);
101
102        let messages = vec![
103            Message {
104                role: "system".to_string(),
105                content: request.system,
106            },
107            Message {
108                role: "user".to_string(),
109                content: request.prompt,
110            },
111        ];
112
113        let deepseek_request = DeepSeekRequest {
114            model: self.model.clone(),
115            messages,
116            temperature: request.temperature,
117            max_tokens: request.max_tokens,
118        };
119
120        let response = self
121            .client
122            .post(&url)
123            .bearer_auth(&self.api_key)
124            .json(&deepseek_request)
125            .send()
126            .await
127            .map_err(|e| LlmError::ConnectionFailed(e.to_string()))?;
128
129        if !response.status().is_success() {
130            let status = response.status();
131            let body = response.text().await.unwrap_or_default();
132            return Err(LlmError::RequestFailed(format!(
133                "Status: {}, Body: {}",
134                status, body
135            )));
136        }
137
138        let api_response: DeepSeekResponse = response
139            .json()
140            .await
141            .map_err(|e| LlmError::InvalidResponse(e.to_string()))?;
142
143        let content = api_response
144            .choices
145            .first()
146            .map(|c| c.message.content.clone())
147            .unwrap_or_default();
148
149        Ok(LlmResponse {
150            content,
151            model: api_response.model,
152            tokens_used: api_response.usage.map(|u| u.total_tokens),
153            latency_ms: start.elapsed().as_millis() as u64,
154            trace_root: None,
155        })
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[tokio::test]
164    #[ignore] // Requires valid API key
165    async fn test_deepseek() {
166        let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
167        let provider = DeepSeekProvider::chat(&api_key);
168
169        if provider.is_available().await {
170            let response = provider.ask("Say hello in one word").await.unwrap();
171            assert!(!response.is_empty());
172            println!("DeepSeek response: {}", response);
173        }
174    }
175}