1use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use std::collections::VecDeque;
6
7use crate::compression::TemporalCompressor;
8use crate::horizon::HorizonConfig;
9use vex_persist::VectorStoreBackend;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Episode {
14 pub id: u64,
16 pub content: String,
18 pub created_at: DateTime<Utc>,
20 pub base_importance: f64,
22 pub pinned: bool,
24 pub tags: Vec<String>,
26}
27
28impl Episode {
29 pub fn new(content: &str, importance: f64) -> Self {
31 static COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
32 Self {
33 id: COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
34 content: content.to_string(),
35 created_at: Utc::now(),
36 base_importance: importance.clamp(0.0, 1.0),
37 pinned: false,
38 tags: Vec::new(),
39 }
40 }
41
42 pub fn pinned(content: &str) -> Self {
44 let mut ep = Self::new(content, 1.0);
45 ep.pinned = true;
46 ep
47 }
48
49 pub fn with_tag(mut self, tag: &str) -> Self {
51 self.tags.push(tag.to_string());
52 self
53 }
54}
55
56#[derive(Debug, Clone)]
58pub struct EpisodicMemory {
59 pub config: HorizonConfig,
61 pub compressor: TemporalCompressor,
63 episodes: VecDeque<Episode>,
65}
66
67impl EpisodicMemory {
68 pub fn new(config: HorizonConfig) -> Self {
70 let max_age = config
71 .horizon
72 .duration()
73 .unwrap_or(chrono::Duration::weeks(52));
74 Self {
75 config,
76 compressor: TemporalCompressor::new(
77 crate::compression::DecayStrategy::Exponential,
78 max_age,
79 ),
80 episodes: VecDeque::new(),
81 }
82 }
83
84 pub fn add(&mut self, episode: Episode) {
86 self.episodes.push_front(episode);
87 self.maybe_evict();
88 }
89
90 pub fn remember(&mut self, content: &str, importance: f64) {
92 self.add(Episode::new(content, importance));
93 }
94
95 pub fn episodes(&self) -> impl Iterator<Item = &Episode> {
97 self.episodes.iter()
98 }
99
100 pub fn by_tag(&self, tag: &str) -> Vec<&Episode> {
102 self.episodes
103 .iter()
104 .filter(|e| e.tags.contains(&tag.to_string()))
105 .collect()
106 }
107
108 pub fn recent(&self) -> Vec<&Episode> {
110 self.episodes
111 .iter()
112 .filter(|e| self.config.horizon.contains(e.created_at))
113 .collect()
114 }
115
116 pub fn by_importance(&self) -> Vec<(&Episode, f64)> {
118 let mut episodes: Vec<_> = self
119 .episodes
120 .iter()
121 .map(|e| {
122 let importance = if e.pinned {
123 1.0
124 } else {
125 self.compressor.importance(e.created_at, e.base_importance)
126 };
127 (e, importance)
128 })
129 .collect();
130 episodes.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
131 episodes
132 }
133
134 pub fn len(&self) -> usize {
136 self.episodes.len()
137 }
138
139 pub fn is_empty(&self) -> bool {
141 self.episodes.is_empty()
142 }
143
144 pub fn clear(&mut self) {
146 self.episodes.retain(|e| e.pinned);
147 }
148
149 fn maybe_evict(&mut self) {
152 if !self.config.auto_evict {
153 return;
154 }
155
156 let max_age_ids: std::collections::HashSet<u64> = self
164 .episodes
165 .iter()
166 .filter(|e| !e.pinned && self.compressor.should_evict(e.created_at))
167 .map(|e| e.id)
168 .collect();
169
170 if !max_age_ids.is_empty() {
171 self.episodes.retain(|e| !max_age_ids.contains(&e.id));
172 }
173
174 let current_len = self.episodes.len();
176 if current_len <= self.config.max_entries {
177 return;
178 }
179
180 let pinned_count = self.episodes.iter().filter(|e| e.pinned).count();
183 if pinned_count >= self.config.max_entries {
184 self.episodes.retain(|e| e.pinned);
185 return;
186 }
187
188 let slots_for_non_pinned = self.config.max_entries - pinned_count;
189
190 let mut candidates: Vec<(f64, DateTime<Utc>, u64)> = self
193 .episodes
194 .iter()
195 .filter(|e| !e.pinned)
196 .map(|e| {
197 (
198 self.compressor.importance(e.created_at, e.base_importance),
199 e.created_at,
200 e.id,
201 )
202 })
203 .collect();
204
205 if candidates.len() > slots_for_non_pinned {
208 let target_idx = candidates.len() - slots_for_non_pinned;
211
212 candidates.select_nth_unstable_by(target_idx, |a, b| {
214 a.0.partial_cmp(&b.0)
215 .unwrap_or(std::cmp::Ordering::Equal)
216 .then_with(|| a.1.cmp(&b.1))
217 });
218
219 let keep_ids: std::collections::HashSet<u64> =
222 candidates[target_idx..].iter().map(|c| c.2).collect();
223
224 self.episodes
226 .retain(|e| e.pinned || keep_ids.contains(&e.id));
227 }
228 }
229
230 pub fn compress_old(&mut self) -> usize {
232 if !self.config.auto_compress {
233 return 0;
234 }
235
236 let mut count = 0;
237 for episode in &mut self.episodes {
238 if episode.pinned {
239 continue;
240 }
241
242 let ratio = self.compressor.compression_ratio(episode.created_at);
243 if ratio > 0.1 {
244 episode.content = self.compressor.compress(&episode.content, ratio);
245 count += 1;
246 }
247 }
248 count
249 }
250
251 pub async fn compress_old_with_llm<L: vex_llm::LlmProvider + vex_llm::EmbeddingProvider>(
254 &mut self,
255 llm: &L,
256 vector_store: Option<&dyn VectorStoreBackend>,
257 tenant_id: Option<&str>,
258 ) -> usize {
259 if !self.config.auto_compress {
260 return 0;
261 }
262
263 let mut count = 0;
264 for episode in &mut self.episodes {
265 if episode.pinned {
266 continue;
267 }
268
269 let ratio = self.compressor.compression_ratio(episode.created_at);
270 if ratio > 0.1 {
271 match self
272 .compressor
273 .compress_with_llm(&episode.content, ratio, llm, vector_store, tenant_id)
274 .await
275 {
276 Ok(compressed) => {
277 tracing::debug!(
278 episode_id = %episode.id,
279 original_len = episode.content.len(),
280 compressed_len = compressed.len(),
281 ratio = ratio,
282 "Compressed episode with LLM"
283 );
284 episode.content = compressed;
285 count += 1;
286 }
287 Err(e) => {
288 tracing::warn!("LLM compression failed for episode {}: {}", episode.id, e);
289 episode.content = self.compressor.compress(&episode.content, ratio);
291 count += 1;
292 }
293 }
294 }
295 }
296 count
297 }
298
299 pub async fn summarize_all_with_llm<L: vex_llm::LlmProvider>(
302 &self,
303 llm: &L,
304 ) -> Result<String, vex_llm::LlmError> {
305 if self.episodes.is_empty() {
306 return Ok(String::from("No memories recorded."));
307 }
308
309 let all_content: String = self
311 .episodes
312 .iter()
313 .map(|e| {
314 format!(
315 "[{}] (importance: {:.1}): {}",
316 e.created_at.format("%Y-%m-%d %H:%M"),
317 e.base_importance,
318 e.content
319 )
320 })
321 .collect::<Vec<_>>()
322 .join("\n\n");
323
324 let prompt = format!(
325 "You are a memory consolidation system. Summarize the following episodic memories \
326 into a coherent narrative that preserves the most important information, decisions, \
327 and context. Focus on factual content and key events.\n\n\
328 MEMORIES:\n{}\n\n\
329 CONSOLIDATED SUMMARY:",
330 all_content
331 );
332
333 llm.ask(&prompt).await.map(|s| s.trim().to_string())
334 }
335
336 pub fn summarize(&self) -> String {
338 let total = self.len();
339 let pinned = self.episodes.iter().filter(|e| e.pinned).count();
340 let recent = self.recent().len();
341
342 format!(
343 "Memory: {} total ({} pinned, {} recent within {:?})",
344 total, pinned, recent, self.config.horizon
345 )
346 }
347}
348
349impl Default for EpisodicMemory {
350 fn default() -> Self {
351 Self::new(HorizonConfig::default())
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 #[test]
360 fn test_episodic_memory() {
361 let mut memory = EpisodicMemory::default();
362
363 memory.remember("First event", 0.8);
364 memory.remember("Second event", 0.5);
365 memory.add(Episode::pinned("Important system info"));
366
367 assert_eq!(memory.len(), 3);
368 assert_eq!(memory.recent().len(), 3);
369 }
370
371 #[test]
372 fn test_by_importance() {
373 let mut memory = EpisodicMemory::default();
374
375 memory.remember("Low importance", 0.2);
376 memory.remember("High importance", 0.9);
377
378 let sorted = memory.by_importance();
379 assert!(sorted[0].1 > sorted[1].1);
380 }
381
382 #[test]
383 fn test_pinned_not_evicted() {
384 let config = HorizonConfig {
385 max_entries: 2,
386 ..Default::default()
387 };
388
389 let mut memory = EpisodicMemory::new(config);
390 memory.add(Episode::pinned("System"));
391 memory.remember("Event 1", 0.5);
392 memory.remember("Event 2", 0.5);
393 memory.remember("Event 3", 0.5);
394
395 assert!(memory.episodes().any(|e| e.content == "System"));
397 }
398}