vex_llm/
openai.rs

1//! OpenAI LLM provider
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::time::Instant;
6
7use crate::provider::{LlmError, LlmProvider, LlmRequest, LlmResponse};
8
9/// OpenAI API request format
10#[derive(Debug, Serialize)]
11struct OpenAIRequest {
12    model: String,
13    messages: Vec<Message>,
14    temperature: f32,
15    max_tokens: u32,
16}
17
18#[derive(Debug, Serialize)]
19struct Message {
20    role: String,
21    content: String,
22}
23
24/// OpenAI API response format
25#[derive(Debug, Deserialize)]
26struct OpenAIResponse {
27    choices: Vec<Choice>,
28    model: String,
29    usage: Option<Usage>,
30}
31
32#[derive(Debug, Deserialize)]
33struct Choice {
34    message: MessageContent,
35}
36
37#[derive(Debug, Deserialize)]
38struct MessageContent {
39    content: String,
40}
41
42#[derive(Debug, Deserialize)]
43struct Usage {
44    total_tokens: u32,
45}
46
47/// OpenAI provider
48#[derive(Debug)]
49pub struct OpenAIProvider {
50    /// API key
51    api_key: String,
52    /// Model to use (e.g., "gpt-4", "gpt-3.5-turbo")
53    model: String,
54    /// HTTP client
55    client: reqwest::Client,
56    /// Base URL
57    base_url: String,
58}
59
60impl OpenAIProvider {
61    /// Create a new OpenAI provider
62    pub fn new(api_key: &str, model: &str) -> Self {
63        Self {
64            api_key: api_key.to_string(),
65            model: model.to_string(),
66            client: reqwest::Client::new(),
67            base_url: "https://api.openai.com".to_string(),
68        }
69    }
70
71    /// Create with GPT-4
72    pub fn gpt4(api_key: &str) -> Self {
73        Self::new(api_key, "gpt-4")
74    }
75
76    /// Create with GPT-4 Turbo
77    pub fn gpt4_turbo(api_key: &str) -> Self {
78        Self::new(api_key, "gpt-4-turbo-preview")
79    }
80
81    /// Create with GPT-3.5 Turbo
82    pub fn gpt35(api_key: &str) -> Self {
83        Self::new(api_key, "gpt-3.5-turbo")
84    }
85}
86
87#[async_trait]
88impl LlmProvider for OpenAIProvider {
89    fn name(&self) -> &str {
90        "openai"
91    }
92
93    async fn is_available(&self) -> bool {
94        self.client
95            .get(format!("{}/v1/models", self.base_url))
96            .bearer_auth(&self.api_key)
97            .send()
98            .await
99            .is_ok()
100    }
101
102    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
103        let start = Instant::now();
104        let url = format!("{}/v1/chat/completions", self.base_url);
105
106        let messages = vec![
107            Message {
108                role: "system".to_string(),
109                content: request.system,
110            },
111            Message {
112                role: "user".to_string(),
113                content: request.prompt,
114            },
115        ];
116
117        let openai_request = OpenAIRequest {
118            model: self.model.clone(),
119            messages,
120            temperature: request.temperature,
121            max_tokens: request.max_tokens,
122        };
123
124        let response = self
125            .client
126            .post(&url)
127            .bearer_auth(&self.api_key)
128            .json(&openai_request)
129            .send()
130            .await
131            .map_err(|e| LlmError::ConnectionFailed(e.to_string()))?;
132
133        if !response.status().is_success() {
134            let status = response.status();
135            let body = response.text().await.unwrap_or_default();
136            return Err(LlmError::RequestFailed(format!(
137                "Status: {}, Body: {}",
138                status, body
139            )));
140        }
141
142        let api_response: OpenAIResponse = response
143            .json()
144            .await
145            .map_err(|e| LlmError::InvalidResponse(e.to_string()))?;
146
147        let content = api_response
148            .choices
149            .first()
150            .map(|c| c.message.content.clone())
151            .unwrap_or_default();
152
153        Ok(LlmResponse {
154            content,
155            model: api_response.model,
156            tokens_used: api_response.usage.map(|u| u.total_tokens),
157            latency_ms: start.elapsed().as_millis() as u64,
158            trace_root: None,
159        })
160    }
161}