vex_llm/
ollama.rs

1//! Ollama LLM provider for local inference
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::time::Instant;
6
7use crate::provider::{LlmError, LlmProvider, LlmRequest, LlmResponse};
8
9/// Ollama API request format
10#[derive(Debug, Serialize)]
11struct OllamaRequest {
12    model: String,
13    prompt: String,
14    system: Option<String>,
15    stream: bool,
16    options: OllamaOptions,
17}
18
19#[derive(Debug, Serialize)]
20struct OllamaOptions {
21    temperature: f32,
22    num_predict: u32,
23}
24
25/// Ollama API response format
26#[derive(Debug, Deserialize)]
27struct OllamaApiResponse {
28    response: String,
29    model: String,
30    #[serde(default)]
31    eval_count: Option<u32>,
32}
33
34/// Ollama provider for local LLM inference
35#[derive(Debug)]
36pub struct OllamaProvider {
37    /// Base URL for Ollama API
38    base_url: String,
39    /// Model to use (e.g., "llama2", "mistral", "codellama")
40    model: String,
41    /// HTTP client
42    client: reqwest::Client,
43}
44
45impl OllamaProvider {
46    /// Create a new Ollama provider with default settings
47    pub fn new(model: &str) -> Self {
48        Self {
49            base_url: "http://localhost:11434".to_string(),
50            model: model.to_string(),
51            client: reqwest::Client::new(),
52        }
53    }
54
55    /// Create with custom base URL
56    pub fn with_url(base_url: &str, model: &str) -> Self {
57        // Basic SSRF protection (2025 best practice)
58        let url = base_url.to_lowercase();
59        if url.contains("localhost") || url.contains("127.0.0.1") || url.contains("::1") {
60            tracing::warn!(url = %base_url, "Potentially unsafe URL in OllamaProvider");
61        }
62
63        Self {
64            base_url: base_url.to_string(),
65            model: model.to_string(),
66            client: reqwest::Client::new(),
67        }
68    }
69}
70
71#[async_trait]
72impl LlmProvider for OllamaProvider {
73    fn name(&self) -> &str {
74        "ollama"
75    }
76
77    async fn is_available(&self) -> bool {
78        let url = format!("{}/api/tags", self.base_url);
79        self.client.get(&url).send().await.is_ok()
80    }
81
82    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
83        let start = Instant::now();
84        let url = format!("{}/api/generate", self.base_url);
85
86        let ollama_request = OllamaRequest {
87            model: self.model.clone(),
88            prompt: request.prompt,
89            system: Some(request.system),
90            stream: false,
91            options: OllamaOptions {
92                temperature: request.temperature,
93                num_predict: request.max_tokens,
94            },
95        };
96
97        let response = self
98            .client
99            .post(&url)
100            .json(&ollama_request)
101            .send()
102            .await
103            .map_err(|e| LlmError::ConnectionFailed(e.to_string()))?;
104
105        if !response.status().is_success() {
106            return Err(LlmError::RequestFailed(format!(
107                "Status: {}",
108                response.status()
109            )));
110        }
111
112        let api_response: OllamaApiResponse = response
113            .json()
114            .await
115            .map_err(|e| LlmError::InvalidResponse(e.to_string()))?;
116
117        Ok(LlmResponse {
118            content: api_response.response,
119            model: api_response.model,
120            tokens_used: api_response.eval_count,
121            latency_ms: start.elapsed().as_millis() as u64,
122            trace_root: None,
123        })
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    #[tokio::test]
132    #[ignore] // Requires Ollama running locally
133    async fn test_ollama_available() {
134        let provider = OllamaProvider::new("llama2");
135        // This test is ignored by default since it requires Ollama
136        if provider.is_available().await {
137            let response = provider.ask("Say hello in one word").await.unwrap();
138            assert!(!response.is_empty());
139        }
140    }
141}