vex_router/cache/
mod.rs

1//! Semantic Caching - Cache responses using vector embeddings
2
3use parking_lot::RwLock;
4use serde::{Deserialize, Serialize};
5use sha2::{Digest, Sha256};
6use std::collections::HashMap;
7use std::sync::Arc;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct CachedResponse {
11    pub response: String,
12    pub similarity: f32,
13    pub cached_at: i64,
14    pub token_count: u32,
15}
16
17#[derive(Debug, Clone)]
18pub struct CacheEntry {
19    pub response: CachedResponse,
20    pub embedding: Vec<f32>,
21}
22
23pub struct SemanticCache {
24    entries: Arc<RwLock<HashMap<String, CacheEntry>>>,
25    similarity_threshold: f32,
26    max_cache_size: usize,
27    ttl_seconds: i64,
28}
29
30impl SemanticCache {
31    pub fn new(similarity_threshold: f32, max_cache_size: usize, ttl_seconds: i64) -> Self {
32        Self {
33            entries: Arc::new(RwLock::new(HashMap::new())),
34            similarity_threshold,
35            max_cache_size,
36            ttl_seconds,
37        }
38    }
39
40    pub fn get(&self, query: &str) -> Option<CachedResponse> {
41        let query_embedding = self.compute_embedding(query);
42        let entries = self.entries.read();
43
44        let mut best_match: Option<(f32, &CacheEntry)> = None;
45
46        for (_key, entry) in entries.iter() {
47            let similarity = cosine_similarity(&query_embedding, &entry.embedding);
48
49            if similarity >= self.similarity_threshold
50                && (best_match.is_none() || similarity > best_match.as_ref().unwrap().0)
51            {
52                best_match = Some((similarity, entry));
53            }
54        }
55
56        if let Some((similarity, entry)) = best_match {
57            let now = chrono::Utc::now().timestamp();
58            if now - entry.response.cached_at < self.ttl_seconds {
59                let mut response = entry.response.clone();
60                response.similarity = similarity;
61                return Some(response);
62            }
63        }
64
65        None
66    }
67
68    pub fn store(&self, query: &str, response: String, token_count: u32) {
69        let key = self.compute_key(query);
70        let embedding = self.compute_embedding(query);
71
72        let mut entries = self.entries.write();
73
74        if entries.len() >= self.max_cache_size {
75            if let Some(oldest_key) = entries
76                .iter()
77                .min_by_key(|(_, e)| e.response.cached_at)
78                .map(|(k, _)| k.clone())
79            {
80                entries.remove(&oldest_key);
81            }
82        }
83
84        entries.insert(
85            key,
86            CacheEntry {
87                response: CachedResponse {
88                    response,
89                    similarity: 1.0,
90                    cached_at: chrono::Utc::now().timestamp(),
91                    token_count,
92                },
93                embedding,
94            },
95        );
96    }
97
98    fn compute_key(&self, query: &str) -> String {
99        let mut hasher = Sha256::new();
100        hasher.update(query.as_bytes());
101        hex::encode(hasher.finalize())
102    }
103
104    fn compute_embedding(&self, query: &str) -> Vec<f32> {
105        simple_embedding(query)
106    }
107
108    pub fn stats(&self) -> CacheStats {
109        let entries = self.entries.read();
110        let now = chrono::Utc::now().timestamp();
111
112        let valid_entries = entries
113            .values()
114            .filter(|e| now - e.response.cached_at < self.ttl_seconds)
115            .count();
116
117        CacheStats {
118            total_entries: entries.len(),
119            valid_entries,
120            cache_size_bytes: entries
121                .values()
122                .map(|e| e.response.response.len() + e.embedding.len() * 4)
123                .sum(),
124        }
125    }
126
127    pub fn clear(&self) {
128        let mut entries = self.entries.write();
129        entries.clear();
130    }
131}
132
133fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
134    if a.len() != b.len() || a.is_empty() {
135        return 0.0;
136    }
137
138    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
139    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
140    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
141
142    if norm_a == 0.0 || norm_b == 0.0 {
143        return 0.0;
144    }
145
146    dot / (norm_a * norm_b)
147}
148
149fn simple_embedding(text: &str) -> Vec<f32> {
150    let text_lower = text.to_lowercase();
151    let words: Vec<&str> = text_lower.split_whitespace().collect();
152
153    let mut embedding = vec![0.0f32; 64];
154
155    for (i, word) in words.iter().take(64).enumerate() {
156        let hash = simple_hash(word);
157        embedding[i % 64] += (hash as f32) / (words.len() as f32).sqrt();
158    }
159
160    let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
161    if norm > 0.0 {
162        for x in &mut embedding {
163            *x /= norm;
164        }
165    }
166
167    embedding
168}
169
170fn simple_hash(s: &str) -> u32 {
171    let mut hash: u32 = 5381;
172    for c in s.bytes() {
173        hash = hash.wrapping_mul(33).wrapping_add(c as u32);
174    }
175    hash
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct CacheStats {
180    pub total_entries: usize,
181    pub valid_entries: usize,
182    pub cache_size_bytes: usize,
183}
184
185impl Default for SemanticCache {
186    fn default() -> Self {
187        Self::new(0.85, 10000, 86400)
188    }
189}