1use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::time::Instant;
6
7use crate::provider::{LlmError, LlmProvider, LlmRequest, LlmResponse};
8
9#[derive(Debug, Serialize)]
11struct OllamaRequest {
12 model: String,
13 prompt: String,
14 system: Option<String>,
15 stream: bool,
16 options: OllamaOptions,
17}
18
19#[derive(Debug, Serialize)]
20struct OllamaOptions {
21 temperature: f32,
22 num_predict: u32,
23}
24
25#[derive(Debug, Deserialize)]
27struct OllamaApiResponse {
28 response: String,
29 model: String,
30 #[serde(default)]
31 eval_count: Option<u32>,
32}
33
34#[derive(Debug)]
36pub struct OllamaProvider {
37 base_url: String,
39 model: String,
41 client: reqwest::Client,
43}
44
45impl OllamaProvider {
46 pub fn new(model: &str) -> Self {
48 Self {
49 base_url: "http://localhost:11434".to_string(),
50 model: model.to_string(),
51 client: reqwest::Client::new(),
52 }
53 }
54
55 pub fn with_url(base_url: &str, model: &str) -> Self {
57 let url = base_url.to_lowercase();
59 if url.contains("localhost") || url.contains("127.0.0.1") || url.contains("::1") {
60 tracing::warn!(url = %base_url, "Potentially unsafe URL in OllamaProvider");
61 }
62
63 Self {
64 base_url: base_url.to_string(),
65 model: model.to_string(),
66 client: reqwest::Client::new(),
67 }
68 }
69}
70
71#[async_trait]
72impl LlmProvider for OllamaProvider {
73 fn name(&self) -> &str {
74 "ollama"
75 }
76
77 async fn is_available(&self) -> bool {
78 let url = format!("{}/api/tags", self.base_url);
79 self.client.get(&url).send().await.is_ok()
80 }
81
82 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
83 let start = Instant::now();
84 let url = format!("{}/api/generate", self.base_url);
85
86 let ollama_request = OllamaRequest {
87 model: self.model.clone(),
88 prompt: request.prompt,
89 system: Some(request.system),
90 stream: false,
91 options: OllamaOptions {
92 temperature: request.temperature,
93 num_predict: request.max_tokens,
94 },
95 };
96
97 let response = self
98 .client
99 .post(&url)
100 .json(&ollama_request)
101 .send()
102 .await
103 .map_err(|e| LlmError::ConnectionFailed(e.to_string()))?;
104
105 if !response.status().is_success() {
106 return Err(LlmError::RequestFailed(format!(
107 "Status: {}",
108 response.status()
109 )));
110 }
111
112 let api_response: OllamaApiResponse = response
113 .json()
114 .await
115 .map_err(|e| LlmError::InvalidResponse(e.to_string()))?;
116
117 Ok(LlmResponse {
118 content: api_response.response,
119 model: api_response.model,
120 tokens_used: api_response.eval_count,
121 latency_ms: start.elapsed().as_millis() as u64,
122 trace_root: None,
123 })
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130
131 #[tokio::test]
132 #[ignore] async fn test_ollama_available() {
134 let provider = OllamaProvider::new("llama2");
135 if provider.is_available().await {
137 let response = provider.ask("Say hello in one word").await.unwrap();
138 assert!(!response.is_empty());
139 }
140 }
141}