1use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::time::Instant;
6
7use crate::provider::{LlmError, LlmProvider, LlmRequest, LlmResponse};
8
9#[derive(Debug, Serialize)]
11struct DeepSeekRequest {
12 model: String,
13 messages: Vec<Message>,
14 temperature: f32,
15 max_tokens: u32,
16}
17
18#[derive(Debug, Serialize)]
19struct Message {
20 role: String,
21 content: String,
22}
23
24#[derive(Debug, Deserialize)]
26struct DeepSeekResponse {
27 choices: Vec<Choice>,
28 model: String,
29 usage: Option<Usage>,
30}
31
32#[derive(Debug, Deserialize)]
33struct Choice {
34 message: MessageContent,
35}
36
37#[derive(Debug, Deserialize)]
38struct MessageContent {
39 content: String,
40}
41
42#[derive(Debug, Deserialize)]
43struct Usage {
44 total_tokens: u32,
45}
46
47#[derive(Debug, Clone)]
49pub struct DeepSeekProvider {
50 api_key: String,
52 model: String,
54 client: reqwest::Client,
56 base_url: String,
58}
59
60impl DeepSeekProvider {
61 pub fn new(api_key: &str, model: &str) -> Self {
63 Self {
64 api_key: api_key.to_string(),
65 model: model.to_string(),
66 client: reqwest::Client::new(),
67 base_url: "https://api.deepseek.com".to_string(),
68 }
69 }
70
71 pub fn chat(api_key: &str) -> Self {
73 Self::new(api_key, "deepseek-chat")
74 }
75
76 pub fn coder(api_key: &str) -> Self {
78 Self::new(api_key, "deepseek-coder")
79 }
80}
81
82#[async_trait]
83impl LlmProvider for DeepSeekProvider {
84 fn name(&self) -> &str {
85 "deepseek"
86 }
87
88 async fn is_available(&self) -> bool {
89 self.client
91 .get(format!("{}/v1/models", self.base_url))
92 .bearer_auth(&self.api_key)
93 .send()
94 .await
95 .is_ok()
96 }
97
98 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
99 let start = Instant::now();
100 let url = format!("{}/v1/chat/completions", self.base_url);
101
102 let messages = vec![
103 Message {
104 role: "system".to_string(),
105 content: request.system,
106 },
107 Message {
108 role: "user".to_string(),
109 content: request.prompt,
110 },
111 ];
112
113 let deepseek_request = DeepSeekRequest {
114 model: self.model.clone(),
115 messages,
116 temperature: request.temperature,
117 max_tokens: request.max_tokens,
118 };
119
120 let response = self
121 .client
122 .post(&url)
123 .bearer_auth(&self.api_key)
124 .json(&deepseek_request)
125 .send()
126 .await
127 .map_err(|e| LlmError::ConnectionFailed(e.to_string()))?;
128
129 if !response.status().is_success() {
130 let status = response.status();
131 let body = response.text().await.unwrap_or_default();
132 return Err(LlmError::RequestFailed(format!(
133 "Status: {}, Body: {}",
134 status, body
135 )));
136 }
137
138 let api_response: DeepSeekResponse = response
139 .json()
140 .await
141 .map_err(|e| LlmError::InvalidResponse(e.to_string()))?;
142
143 let content = api_response
144 .choices
145 .first()
146 .map(|c| c.message.content.clone())
147 .unwrap_or_default();
148
149 Ok(LlmResponse {
150 content,
151 model: api_response.model,
152 tokens_used: api_response.usage.map(|u| u.total_tokens),
153 latency_ms: start.elapsed().as_millis() as u64,
154 trace_root: None,
155 })
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[tokio::test]
164 #[ignore] async fn test_deepseek() {
166 let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
167 let provider = DeepSeekProvider::chat(&api_key);
168
169 if provider.is_available().await {
170 let response = provider.ask("Say hello in one word").await.unwrap();
171 assert!(!response.is_empty());
172 println!("DeepSeek response: {}", response);
173 }
174 }
175}