vex_llm/
cached_provider.rs1use async_trait::async_trait;
14use moka::future::Cache;
15use sha2::{Digest, Sha256};
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::time::Duration;
18
19use crate::{LlmError, LlmProvider, LlmRequest, LlmResponse};
20
21#[derive(Debug, Clone)]
23pub struct LlmCacheConfig {
24 pub max_entries: u64,
26 pub ttl: Duration,
28 pub tti: Option<Duration>,
30}
31
32impl Default for LlmCacheConfig {
33 fn default() -> Self {
34 Self {
35 max_entries: 1000,
36 ttl: Duration::from_secs(3600), tti: Some(Duration::from_secs(1800)), }
39 }
40}
41
42impl LlmCacheConfig {
43 pub fn aggressive() -> Self {
45 Self {
46 max_entries: 10000,
47 ttl: Duration::from_secs(86400), tti: None,
49 }
50 }
51
52 pub fn conservative() -> Self {
54 Self {
55 max_entries: 100,
56 ttl: Duration::from_secs(300), tti: Some(Duration::from_secs(60)),
58 }
59 }
60}
61
62fn cache_key(request: &LlmRequest) -> String {
64 let mut hasher = Sha256::new();
65 if let Some(tenant_id) = &request.tenant_id {
66 hasher.update(tenant_id.as_bytes());
67 hasher.update(b"|");
68 }
69 hasher.update(request.system.as_bytes());
70 hasher.update(b"|");
71 hasher.update(request.prompt.as_bytes());
72 hasher.update(b"|");
73 hasher.update(request.temperature.to_be_bytes());
74 hasher.update(b"|");
75 hasher.update(request.max_tokens.to_be_bytes());
76 hex::encode(hasher.finalize())
77}
78
79#[derive(Debug)]
81pub struct CachedProvider<P: LlmProvider> {
82 inner: P,
83 cache: Cache<String, LlmResponse>,
84 hits: AtomicU64,
85 misses: AtomicU64,
86}
87
88impl<P: LlmProvider> CachedProvider<P> {
89 pub fn new(provider: P, config: LlmCacheConfig) -> Self {
91 let mut builder = Cache::builder()
92 .max_capacity(config.max_entries)
93 .time_to_live(config.ttl);
94
95 if let Some(tti) = config.tti {
96 builder = builder.time_to_idle(tti);
97 }
98
99 Self {
100 inner: provider,
101 cache: builder.build(),
102 hits: AtomicU64::new(0),
103 misses: AtomicU64::new(0),
104 }
105 }
106
107 pub fn wrap(provider: P) -> Self {
109 Self::new(provider, LlmCacheConfig::default())
110 }
111
112 pub fn stats(&self) -> (u64, u64, f64) {
114 let hits = self.hits.load(Ordering::Relaxed);
115 let misses = self.misses.load(Ordering::Relaxed);
116 let total = hits + misses;
117 let hit_rate = if total > 0 {
118 hits as f64 / total as f64
119 } else {
120 0.0
121 };
122 (hits, misses, hit_rate)
123 }
124
125 pub fn clear(&self) {
127 self.cache.invalidate_all();
128 }
129
130 pub fn size(&self) -> u64 {
132 self.cache.entry_count()
133 }
134}
135
136#[async_trait]
137impl<P: LlmProvider + 'static> LlmProvider for CachedProvider<P> {
138 fn name(&self) -> &str {
139 "cached"
140 }
141
142 async fn is_available(&self) -> bool {
143 self.inner.is_available().await
144 }
145
146 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
147 let key = cache_key(&request);
148
149 if let Some(cached) = self.cache.get(&key).await {
151 self.hits.fetch_add(1, Ordering::Relaxed);
152 tracing::debug!(cache_key = %key, "LLM cache hit");
153 return Ok(cached);
154 }
155
156 self.misses.fetch_add(1, Ordering::Relaxed);
158 let response = self.inner.complete(request).await?;
159
160 self.cache.insert(key.clone(), response.clone()).await;
162 tracing::debug!(cache_key = %key, "LLM cache miss - stored");
163
164 Ok(response)
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use crate::MockProvider;
172
173 #[tokio::test]
174 async fn test_cached_provider_caches_responses() {
175 let mock = MockProvider::constant("cached response");
176 let cached = CachedProvider::wrap(mock);
177
178 let req = LlmRequest::simple("test prompt");
180 let resp1 = cached.complete(req.clone()).await.unwrap();
181
182 let resp2 = cached.complete(req).await.unwrap();
184
185 assert_eq!(resp1.content, resp2.content);
186
187 let (hits, misses, _) = cached.stats();
188 assert_eq!(hits, 1);
189 assert_eq!(misses, 1);
190 }
191
192 #[tokio::test]
193 async fn test_different_prompts_different_cache_keys() {
194 let mock = MockProvider::smart();
195 let cached = CachedProvider::wrap(mock);
196
197 let req1 = LlmRequest::simple("prompt 1");
198 let req2 = LlmRequest::simple("prompt 2");
199
200 let _ = cached.complete(req1).await.unwrap();
201 let _ = cached.complete(req2).await.unwrap();
202
203 let (hits, misses, _) = cached.stats();
204 assert_eq!(hits, 0);
205 assert_eq!(misses, 2);
206 }
207}