1use parking_lot::RwLock;
4use serde::{Deserialize, Serialize};
5use sha2::{Digest, Sha256};
6use std::collections::HashMap;
7use std::sync::Arc;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct CachedResponse {
11 pub response: String,
12 pub similarity: f32,
13 pub cached_at: i64,
14 pub token_count: u32,
15}
16
17#[derive(Debug, Clone)]
18pub struct CacheEntry {
19 pub response: CachedResponse,
20 pub embedding: Vec<f32>,
21}
22
23pub struct SemanticCache {
24 entries: Arc<RwLock<HashMap<String, CacheEntry>>>,
25 similarity_threshold: f32,
26 max_cache_size: usize,
27 ttl_seconds: i64,
28}
29
30impl SemanticCache {
31 pub fn new(similarity_threshold: f32, max_cache_size: usize, ttl_seconds: i64) -> Self {
32 Self {
33 entries: Arc::new(RwLock::new(HashMap::new())),
34 similarity_threshold,
35 max_cache_size,
36 ttl_seconds,
37 }
38 }
39
40 pub fn get(&self, query: &str) -> Option<CachedResponse> {
41 let query_embedding = self.compute_embedding(query);
42 let entries = self.entries.read();
43
44 let mut best_match: Option<(f32, &CacheEntry)> = None;
45
46 for (_key, entry) in entries.iter() {
47 let similarity = cosine_similarity(&query_embedding, &entry.embedding);
48
49 if similarity >= self.similarity_threshold
50 && (best_match.is_none() || similarity > best_match.as_ref().unwrap().0)
51 {
52 best_match = Some((similarity, entry));
53 }
54 }
55
56 if let Some((similarity, entry)) = best_match {
57 let now = chrono::Utc::now().timestamp();
58 if now - entry.response.cached_at < self.ttl_seconds {
59 let mut response = entry.response.clone();
60 response.similarity = similarity;
61 return Some(response);
62 }
63 }
64
65 None
66 }
67
68 pub fn store(&self, query: &str, response: String, token_count: u32) {
69 let key = self.compute_key(query);
70 let embedding = self.compute_embedding(query);
71
72 let mut entries = self.entries.write();
73
74 if entries.len() >= self.max_cache_size {
75 if let Some(oldest_key) = entries
76 .iter()
77 .min_by_key(|(_, e)| e.response.cached_at)
78 .map(|(k, _)| k.clone())
79 {
80 entries.remove(&oldest_key);
81 }
82 }
83
84 entries.insert(
85 key,
86 CacheEntry {
87 response: CachedResponse {
88 response,
89 similarity: 1.0,
90 cached_at: chrono::Utc::now().timestamp(),
91 token_count,
92 },
93 embedding,
94 },
95 );
96 }
97
98 fn compute_key(&self, query: &str) -> String {
99 let mut hasher = Sha256::new();
100 hasher.update(query.as_bytes());
101 hex::encode(hasher.finalize())
102 }
103
104 fn compute_embedding(&self, query: &str) -> Vec<f32> {
105 simple_embedding(query)
106 }
107
108 pub fn stats(&self) -> CacheStats {
109 let entries = self.entries.read();
110 let now = chrono::Utc::now().timestamp();
111
112 let valid_entries = entries
113 .values()
114 .filter(|e| now - e.response.cached_at < self.ttl_seconds)
115 .count();
116
117 CacheStats {
118 total_entries: entries.len(),
119 valid_entries,
120 cache_size_bytes: entries
121 .values()
122 .map(|e| e.response.response.len() + e.embedding.len() * 4)
123 .sum(),
124 }
125 }
126
127 pub fn clear(&self) {
128 let mut entries = self.entries.write();
129 entries.clear();
130 }
131}
132
133fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
134 if a.len() != b.len() || a.is_empty() {
135 return 0.0;
136 }
137
138 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
139 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
140 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
141
142 if norm_a == 0.0 || norm_b == 0.0 {
143 return 0.0;
144 }
145
146 dot / (norm_a * norm_b)
147}
148
149fn simple_embedding(text: &str) -> Vec<f32> {
150 let text_lower = text.to_lowercase();
151 let words: Vec<&str> = text_lower.split_whitespace().collect();
152
153 let mut embedding = vec![0.0f32; 64];
154
155 for (i, word) in words.iter().take(64).enumerate() {
156 let hash = simple_hash(word);
157 embedding[i % 64] += (hash as f32) / (words.len() as f32).sqrt();
158 }
159
160 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
161 if norm > 0.0 {
162 for x in &mut embedding {
163 *x /= norm;
164 }
165 }
166
167 embedding
168}
169
170fn simple_hash(s: &str) -> u32 {
171 let mut hash: u32 = 5381;
172 for c in s.bytes() {
173 hash = hash.wrapping_mul(33).wrapping_add(c as u32);
174 }
175 hash
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct CacheStats {
180 pub total_entries: usize,
181 pub valid_entries: usize,
182 pub cache_size_bytes: usize,
183}
184
185impl Default for SemanticCache {
186 fn default() -> Self {
187 Self::new(0.85, 10000, 86400)
188 }
189}