1use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::time::Instant;
6
7use crate::provider::{LlmError, LlmProvider, LlmRequest, LlmResponse};
8
9#[derive(Debug, Serialize)]
11struct OpenAIRequest {
12 model: String,
13 messages: Vec<Message>,
14 temperature: f32,
15 max_tokens: u32,
16}
17
18#[derive(Debug, Serialize)]
19struct Message {
20 role: String,
21 content: String,
22}
23
24#[derive(Debug, Deserialize)]
26struct OpenAIResponse {
27 choices: Vec<Choice>,
28 model: String,
29 usage: Option<Usage>,
30}
31
32#[derive(Debug, Deserialize)]
33struct Choice {
34 message: MessageContent,
35}
36
37#[derive(Debug, Deserialize)]
38struct MessageContent {
39 content: String,
40}
41
42#[derive(Debug, Deserialize)]
43struct Usage {
44 total_tokens: u32,
45}
46
47#[derive(Debug)]
49pub struct OpenAIProvider {
50 api_key: String,
52 model: String,
54 client: reqwest::Client,
56 base_url: String,
58}
59
60impl OpenAIProvider {
61 pub fn new(api_key: &str, model: &str) -> Self {
63 Self {
64 api_key: api_key.to_string(),
65 model: model.to_string(),
66 client: reqwest::Client::new(),
67 base_url: "https://api.openai.com".to_string(),
68 }
69 }
70
71 pub fn gpt4(api_key: &str) -> Self {
73 Self::new(api_key, "gpt-4")
74 }
75
76 pub fn gpt4_turbo(api_key: &str) -> Self {
78 Self::new(api_key, "gpt-4-turbo-preview")
79 }
80
81 pub fn gpt35(api_key: &str) -> Self {
83 Self::new(api_key, "gpt-3.5-turbo")
84 }
85}
86
87#[async_trait]
88impl LlmProvider for OpenAIProvider {
89 fn name(&self) -> &str {
90 "openai"
91 }
92
93 async fn is_available(&self) -> bool {
94 self.client
95 .get(format!("{}/v1/models", self.base_url))
96 .bearer_auth(&self.api_key)
97 .send()
98 .await
99 .is_ok()
100 }
101
102 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
103 let start = Instant::now();
104 let url = format!("{}/v1/chat/completions", self.base_url);
105
106 let messages = vec![
107 Message {
108 role: "system".to_string(),
109 content: request.system,
110 },
111 Message {
112 role: "user".to_string(),
113 content: request.prompt,
114 },
115 ];
116
117 let openai_request = OpenAIRequest {
118 model: self.model.clone(),
119 messages,
120 temperature: request.temperature,
121 max_tokens: request.max_tokens,
122 };
123
124 let response = self
125 .client
126 .post(&url)
127 .bearer_auth(&self.api_key)
128 .json(&openai_request)
129 .send()
130 .await
131 .map_err(|e| LlmError::ConnectionFailed(e.to_string()))?;
132
133 if !response.status().is_success() {
134 let status = response.status();
135 let body = response.text().await.unwrap_or_default();
136 return Err(LlmError::RequestFailed(format!(
137 "Status: {}, Body: {}",
138 status, body
139 )));
140 }
141
142 let api_response: OpenAIResponse = response
143 .json()
144 .await
145 .map_err(|e| LlmError::InvalidResponse(e.to_string()))?;
146
147 let content = api_response
148 .choices
149 .first()
150 .map(|c| c.message.content.clone())
151 .unwrap_or_default();
152
153 Ok(LlmResponse {
154 content,
155 model: api_response.model,
156 tokens_used: api_response.usage.map(|u| u.total_tokens),
157 latency_ms: start.elapsed().as_millis() as u64,
158 trace_root: None,
159 })
160 }
161}