vex_llm/
cached_provider.rs

1//! Cached LLM provider wrapper using Moka
2//!
3//! Provides in-memory caching of LLM responses to reduce latency and API costs.
4//! Uses the Moka cache library for high-performance concurrent caching with
5//! TTL-based expiration.
6//!
7//! # 2025 Best Practices
8//! - Uses SHA-256 hash of request as cache key
9//! - Configurable TTL and max entries
10//! - Thread-safe concurrent access
11//! - Does NOT cache streaming responses
12
13use async_trait::async_trait;
14use moka::future::Cache;
15use sha2::{Digest, Sha256};
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::time::Duration;
18
19use crate::{LlmError, LlmProvider, LlmRequest, LlmResponse};
20
21/// Configuration for the LLM cache
22#[derive(Debug, Clone)]
23pub struct LlmCacheConfig {
24    /// Maximum number of cached responses
25    pub max_entries: u64,
26    /// Time-to-live for cached entries
27    pub ttl: Duration,
28    /// Time-to-idle (evict if not accessed)
29    pub tti: Option<Duration>,
30}
31
32impl Default for LlmCacheConfig {
33    fn default() -> Self {
34        Self {
35            max_entries: 1000,
36            ttl: Duration::from_secs(3600),       // 1 hour
37            tti: Some(Duration::from_secs(1800)), // 30 minutes idle
38        }
39    }
40}
41
42impl LlmCacheConfig {
43    /// Aggressive caching for cost savings
44    pub fn aggressive() -> Self {
45        Self {
46            max_entries: 10000,
47            ttl: Duration::from_secs(86400), // 24 hours
48            tti: None,
49        }
50    }
51
52    /// Conservative caching for freshness
53    pub fn conservative() -> Self {
54        Self {
55            max_entries: 100,
56            ttl: Duration::from_secs(300), // 5 minutes
57            tti: Some(Duration::from_secs(60)),
58        }
59    }
60}
61
62/// Calculate cache key from request
63fn cache_key(request: &LlmRequest) -> String {
64    let mut hasher = Sha256::new();
65    if let Some(tenant_id) = &request.tenant_id {
66        hasher.update(tenant_id.as_bytes());
67        hasher.update(b"|");
68    }
69    hasher.update(request.system.as_bytes());
70    hasher.update(b"|");
71    hasher.update(request.prompt.as_bytes());
72    hasher.update(b"|");
73    hasher.update(request.temperature.to_be_bytes());
74    hasher.update(b"|");
75    hasher.update(request.max_tokens.to_be_bytes());
76    hex::encode(hasher.finalize())
77}
78
79/// Cached LLM provider wrapper
80#[derive(Debug)]
81pub struct CachedProvider<P: LlmProvider> {
82    inner: P,
83    cache: Cache<String, LlmResponse>,
84    hits: AtomicU64,
85    misses: AtomicU64,
86}
87
88impl<P: LlmProvider> CachedProvider<P> {
89    /// Create a cached wrapper around an LLM provider
90    pub fn new(provider: P, config: LlmCacheConfig) -> Self {
91        let mut builder = Cache::builder()
92            .max_capacity(config.max_entries)
93            .time_to_live(config.ttl);
94
95        if let Some(tti) = config.tti {
96            builder = builder.time_to_idle(tti);
97        }
98
99        Self {
100            inner: provider,
101            cache: builder.build(),
102            hits: AtomicU64::new(0),
103            misses: AtomicU64::new(0),
104        }
105    }
106
107    /// Create with default configuration
108    pub fn wrap(provider: P) -> Self {
109        Self::new(provider, LlmCacheConfig::default())
110    }
111
112    /// Get cache statistics
113    pub fn stats(&self) -> (u64, u64, f64) {
114        let hits = self.hits.load(Ordering::Relaxed);
115        let misses = self.misses.load(Ordering::Relaxed);
116        let total = hits + misses;
117        let hit_rate = if total > 0 {
118            hits as f64 / total as f64
119        } else {
120            0.0
121        };
122        (hits, misses, hit_rate)
123    }
124
125    /// Clear the cache
126    pub fn clear(&self) {
127        self.cache.invalidate_all();
128    }
129
130    /// Get current cache size
131    pub fn size(&self) -> u64 {
132        self.cache.entry_count()
133    }
134}
135
136#[async_trait]
137impl<P: LlmProvider + 'static> LlmProvider for CachedProvider<P> {
138    fn name(&self) -> &str {
139        "cached"
140    }
141
142    async fn is_available(&self) -> bool {
143        self.inner.is_available().await
144    }
145
146    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
147        let key = cache_key(&request);
148
149        // Check cache first
150        if let Some(cached) = self.cache.get(&key).await {
151            self.hits.fetch_add(1, Ordering::Relaxed);
152            tracing::debug!(cache_key = %key, "LLM cache hit");
153            return Ok(cached);
154        }
155
156        // Cache miss - call provider
157        self.misses.fetch_add(1, Ordering::Relaxed);
158        let response = self.inner.complete(request).await?;
159
160        // Cache the response
161        self.cache.insert(key.clone(), response.clone()).await;
162        tracing::debug!(cache_key = %key, "LLM cache miss - stored");
163
164        Ok(response)
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use crate::MockProvider;
172
173    #[tokio::test]
174    async fn test_cached_provider_caches_responses() {
175        let mock = MockProvider::constant("cached response");
176        let cached = CachedProvider::wrap(mock);
177
178        // First call - cache miss
179        let req = LlmRequest::simple("test prompt");
180        let resp1 = cached.complete(req.clone()).await.unwrap();
181
182        // Second call - cache hit
183        let resp2 = cached.complete(req).await.unwrap();
184
185        assert_eq!(resp1.content, resp2.content);
186
187        let (hits, misses, _) = cached.stats();
188        assert_eq!(hits, 1);
189        assert_eq!(misses, 1);
190    }
191
192    #[tokio::test]
193    async fn test_different_prompts_different_cache_keys() {
194        let mock = MockProvider::smart();
195        let cached = CachedProvider::wrap(mock);
196
197        let req1 = LlmRequest::simple("prompt 1");
198        let req2 = LlmRequest::simple("prompt 2");
199
200        let _ = cached.complete(req1).await.unwrap();
201        let _ = cached.complete(req2).await.unwrap();
202
203        let (hits, misses, _) = cached.stats();
204        assert_eq!(hits, 0);
205        assert_eq!(misses, 2);
206    }
207}