vex_llm/
provider.rs

1//! LLM Provider trait and common types
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6
7/// Errors from LLM providers
8#[derive(Debug, Error)]
9pub enum LlmError {
10    #[error("Connection failed: {0}")]
11    ConnectionFailed(String),
12    #[error("Request failed: {0}")]
13    RequestFailed(String),
14    #[error("Invalid response: {0}")]
15    InvalidResponse(String),
16    #[error("Rate limited")]
17    RateLimited,
18    #[error("Provider not available")]
19    NotAvailable,
20    #[error("Input too large: {0} bytes exceeds maximum {1} bytes")]
21    InputTooLarge(usize, usize),
22}
23
24/// Maximum allowed prompt size in bytes (100KB default - prevents DoS)
25pub const MAX_PROMPT_SIZE: usize = 100 * 1024;
26/// Maximum allowed system prompt size in bytes (10KB)
27pub const MAX_SYSTEM_SIZE: usize = 10 * 1024;
28
29/// A request to an LLM
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct LlmRequest {
32    /// Tenant ID (for cache isolation)
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub tenant_id: Option<String>,
35    /// System prompt (role/persona)
36    pub system: String,
37    /// User message
38    pub prompt: String,
39    /// Temperature (0.0 = deterministic, 1.0 = creative)
40    pub temperature: f32,
41    /// Maximum tokens to generate
42    pub max_tokens: u32,
43}
44
45impl LlmRequest {
46    /// Create a simple request with default settings
47    pub fn simple(prompt: &str) -> Self {
48        Self {
49            tenant_id: None,
50            system: "You are a helpful assistant.".to_string(),
51            prompt: prompt.to_string(),
52            temperature: 0.7,
53            max_tokens: 1024,
54        }
55    }
56
57    /// Create a request with a specific role
58    pub fn with_role(system: &str, prompt: &str) -> Self {
59        Self {
60            system: system.to_string(),
61            prompt: prompt.to_string(),
62            temperature: 0.7,
63            max_tokens: 1024,
64            tenant_id: None,
65        }
66    }
67
68    /// Validate request sizes to prevent DoS attacks
69    pub fn validate(&self) -> Result<(), LlmError> {
70        if self.prompt.len() > MAX_PROMPT_SIZE {
71            return Err(LlmError::InputTooLarge(self.prompt.len(), MAX_PROMPT_SIZE));
72        }
73        if self.system.len() > MAX_SYSTEM_SIZE {
74            return Err(LlmError::InputTooLarge(self.system.len(), MAX_SYSTEM_SIZE));
75        }
76        Ok(())
77    }
78}
79
80/// Response from an LLM
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct LlmResponse {
83    /// The generated text
84    pub content: String,
85    /// Model used
86    pub model: String,
87    /// Tokens used (if available)
88    pub tokens_used: Option<u32>,
89    /// Time taken in milliseconds
90    pub latency_ms: u64,
91    /// Merkle root of logit hashes (for cryptographic verification)
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub trace_root: Option<String>,
94}
95
96/// Trait for LLM providers
97#[async_trait]
98pub trait LlmProvider: Send + Sync + std::fmt::Debug {
99    /// Get the provider name
100    fn name(&self) -> &str;
101
102    /// Check if the provider is available
103    async fn is_available(&self) -> bool;
104
105    /// Generate a completion
106    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError>;
107
108    /// Generate with a simple prompt (convenience method)
109    async fn ask(&self, prompt: &str) -> Result<String, LlmError> {
110        let response = self.complete(LlmRequest::simple(prompt)).await?;
111        Ok(response.content)
112    }
113}
114
115/// Trait for embedding providers (text-to-vector)
116#[async_trait]
117pub trait EmbeddingProvider: Send + Sync + std::fmt::Debug {
118    /// Generate an embedding vector for the given text
119    async fn embed(&self, text: &str) -> Result<Vec<f32>, LlmError>;
120}