vex_temporal/
memory.rs

1//! Episodic memory management
2
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use std::collections::VecDeque;
6
7use crate::compression::TemporalCompressor;
8use crate::horizon::HorizonConfig;
9use vex_persist::VectorStoreBackend;
10
11/// An episode in memory
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Episode {
14    /// Unique identifier
15    pub id: u64,
16    /// Content of the episode
17    pub content: String,
18    /// When it was created
19    pub created_at: DateTime<Utc>,
20    /// Base importance (0.0 - 1.0)
21    pub base_importance: f64,
22    /// Whether this episode is pinned (never evicted)
23    pub pinned: bool,
24    /// Tags for categorization
25    pub tags: Vec<String>,
26}
27
28impl Episode {
29    /// Create a new episode
30    pub fn new(content: &str, importance: f64) -> Self {
31        static COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
32        Self {
33            id: COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
34            content: content.to_string(),
35            created_at: Utc::now(),
36            base_importance: importance.clamp(0.0, 1.0),
37            pinned: false,
38            tags: Vec::new(),
39        }
40    }
41
42    /// Create a pinned episode (never evicted)
43    pub fn pinned(content: &str) -> Self {
44        let mut ep = Self::new(content, 1.0);
45        ep.pinned = true;
46        ep
47    }
48
49    /// Add a tag
50    pub fn with_tag(mut self, tag: &str) -> Self {
51        self.tags.push(tag.to_string());
52        self
53    }
54}
55
56/// Episodic memory store
57#[derive(Debug, Clone)]
58pub struct EpisodicMemory {
59    /// Configuration
60    pub config: HorizonConfig,
61    /// Compressor
62    pub compressor: TemporalCompressor,
63    /// Episodes (most recent first)
64    episodes: VecDeque<Episode>,
65}
66
67impl EpisodicMemory {
68    /// Create new episodic memory with config
69    pub fn new(config: HorizonConfig) -> Self {
70        let max_age = config
71            .horizon
72            .duration()
73            .unwrap_or(chrono::Duration::weeks(52));
74        Self {
75            config,
76            compressor: TemporalCompressor::new(
77                crate::compression::DecayStrategy::Exponential,
78                max_age,
79            ),
80            episodes: VecDeque::new(),
81        }
82    }
83
84    /// Add a new episode
85    pub fn add(&mut self, episode: Episode) {
86        self.episodes.push_front(episode);
87        self.maybe_evict();
88    }
89
90    /// Add simple content
91    pub fn remember(&mut self, content: &str, importance: f64) {
92        self.add(Episode::new(content, importance));
93    }
94
95    /// Get all episodes
96    pub fn episodes(&self) -> impl Iterator<Item = &Episode> {
97        self.episodes.iter()
98    }
99
100    /// Get episodes by tag
101    pub fn by_tag(&self, tag: &str) -> Vec<&Episode> {
102        self.episodes
103            .iter()
104            .filter(|e| e.tags.contains(&tag.to_string()))
105            .collect()
106    }
107
108    /// Get recent episodes (within horizon)
109    pub fn recent(&self) -> Vec<&Episode> {
110        self.episodes
111            .iter()
112            .filter(|e| self.config.horizon.contains(e.created_at))
113            .collect()
114    }
115
116    /// Get episodes sorted by current importance
117    pub fn by_importance(&self) -> Vec<(&Episode, f64)> {
118        let mut episodes: Vec<_> = self
119            .episodes
120            .iter()
121            .map(|e| {
122                let importance = if e.pinned {
123                    1.0
124                } else {
125                    self.compressor.importance(e.created_at, e.base_importance)
126                };
127                (e, importance)
128            })
129            .collect();
130        episodes.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
131        episodes
132    }
133
134    /// Get total episode count
135    pub fn len(&self) -> usize {
136        self.episodes.len()
137    }
138
139    /// Check if empty
140    pub fn is_empty(&self) -> bool {
141        self.episodes.is_empty()
142    }
143
144    /// Clear all non-pinned episodes
145    pub fn clear(&mut self) {
146        self.episodes.retain(|e| e.pinned);
147    }
148
149    /// Evict old episodes if over capacity
150    /// Evict old episodes if over capacity (Optimized O(N) bulk eviction)
151    fn maybe_evict(&mut self) {
152        if !self.config.auto_evict {
153            return;
154        }
155
156        // 1. Evict by age (O(N)) - using a collected list of IDs to avoid borrow conflicts
157        // if we needed to access self.compressor inside retain.
158        // Actually, for age validation we need compressor.
159        // We can't access self.compressor inside self.episodes.retain().
160        // So we must use the ID collection strategy for BOTH checks or perform age check separately.
161
162        // Age check typically simple, let's just do it first with ID collection
163        let max_age_ids: std::collections::HashSet<u64> = self
164            .episodes
165            .iter()
166            .filter(|e| !e.pinned && self.compressor.should_evict(e.created_at))
167            .map(|e| e.id)
168            .collect();
169
170        if !max_age_ids.is_empty() {
171            self.episodes.retain(|e| !max_age_ids.contains(&e.id));
172        }
173
174        // 2. Check overlap for Count eviction
175        let current_len = self.episodes.len();
176        if current_len <= self.config.max_entries {
177            return;
178        }
179
180        // We need to reduce to max_entries.
181        // Pinned items are protected.
182        let pinned_count = self.episodes.iter().filter(|e| e.pinned).count();
183        if pinned_count >= self.config.max_entries {
184            self.episodes.retain(|e| e.pinned);
185            return;
186        }
187
188        let slots_for_non_pinned = self.config.max_entries - pinned_count;
189
190        // 3. Collect scores for all non-pinned items: (Importance, Time, ID)
191        // We calculate importance ONCE per pass.
192        let mut candidates: Vec<(f64, DateTime<Utc>, u64)> = self
193            .episodes
194            .iter()
195            .filter(|e| !e.pinned)
196            .map(|e| {
197                (
198                    self.compressor.importance(e.created_at, e.base_importance),
199                    e.created_at,
200                    e.id,
201                )
202            })
203            .collect();
204
205        // 4. Find threshold to KEEP top N items
206        // We want to keep the `slots_for_non_pinned` items with HIGHEST scores.
207        if candidates.len() > slots_for_non_pinned {
208            // We want the pivot at index (len - slots).
209            // Items AFTER pivot will be the largest (to keep).
210            let target_idx = candidates.len() - slots_for_non_pinned;
211
212            // Sort such that smallest are at beginning, largest at end
213            candidates.select_nth_unstable_by(target_idx, |a, b| {
214                a.0.partial_cmp(&b.0)
215                    .unwrap_or(std::cmp::Ordering::Equal)
216                    .then_with(|| a.1.cmp(&b.1))
217            });
218
219            // Collect IDs of items to KEEP (those >= threshold)
220            // The items from target_idx onwards are the ones to keep.
221            let keep_ids: std::collections::HashSet<u64> =
222                candidates[target_idx..].iter().map(|c| c.2).collect();
223
224            // 5. Bulk retain
225            self.episodes
226                .retain(|e| e.pinned || keep_ids.contains(&e.id));
227        }
228    }
229
230    /// Compress old episodes (returns number compressed) - sync fallback using truncation
231    pub fn compress_old(&mut self) -> usize {
232        if !self.config.auto_compress {
233            return 0;
234        }
235
236        let mut count = 0;
237        for episode in &mut self.episodes {
238            if episode.pinned {
239                continue;
240            }
241
242            let ratio = self.compressor.compression_ratio(episode.created_at);
243            if ratio > 0.1 {
244                episode.content = self.compressor.compress(&episode.content, ratio);
245                count += 1;
246            }
247        }
248        count
249    }
250
251    /// Compress old episodes using LLM for intelligent summarization
252    /// Returns the number of episodes that were compressed
253    pub async fn compress_old_with_llm<L: vex_llm::LlmProvider + vex_llm::EmbeddingProvider>(
254        &mut self,
255        llm: &L,
256        vector_store: Option<&dyn VectorStoreBackend>,
257        tenant_id: Option<&str>,
258    ) -> usize {
259        if !self.config.auto_compress {
260            return 0;
261        }
262
263        let mut count = 0;
264        for episode in &mut self.episodes {
265            if episode.pinned {
266                continue;
267            }
268
269            let ratio = self.compressor.compression_ratio(episode.created_at);
270            if ratio > 0.1 {
271                match self
272                    .compressor
273                    .compress_with_llm(&episode.content, ratio, llm, vector_store, tenant_id)
274                    .await
275                {
276                    Ok(compressed) => {
277                        tracing::debug!(
278                            episode_id = %episode.id,
279                            original_len = episode.content.len(),
280                            compressed_len = compressed.len(),
281                            ratio = ratio,
282                            "Compressed episode with LLM"
283                        );
284                        episode.content = compressed;
285                        count += 1;
286                    }
287                    Err(e) => {
288                        tracing::warn!("LLM compression failed for episode {}: {}", episode.id, e);
289                        // Fallback to truncation
290                        episode.content = self.compressor.compress(&episode.content, ratio);
291                        count += 1;
292                    }
293                }
294            }
295        }
296        count
297    }
298
299    /// Summarize all episodes into a single context string using LLM
300    /// Useful for providing memory context to agents
301    pub async fn summarize_all_with_llm<L: vex_llm::LlmProvider>(
302        &self,
303        llm: &L,
304    ) -> Result<String, vex_llm::LlmError> {
305        if self.episodes.is_empty() {
306            return Ok(String::from("No memories recorded."));
307        }
308
309        // Combine all episodes into a single text
310        let all_content: String = self
311            .episodes
312            .iter()
313            .map(|e| {
314                format!(
315                    "[{}] (importance: {:.1}): {}",
316                    e.created_at.format("%Y-%m-%d %H:%M"),
317                    e.base_importance,
318                    e.content
319                )
320            })
321            .collect::<Vec<_>>()
322            .join("\n\n");
323
324        let prompt = format!(
325            "You are a memory consolidation system. Summarize the following episodic memories \
326             into a coherent narrative that preserves the most important information, decisions, \
327             and context. Focus on factual content and key events.\n\n\
328             MEMORIES:\n{}\n\n\
329             CONSOLIDATED SUMMARY:",
330            all_content
331        );
332
333        llm.ask(&prompt).await.map(|s| s.trim().to_string())
334    }
335
336    /// Get a summary of memory contents
337    pub fn summarize(&self) -> String {
338        let total = self.len();
339        let pinned = self.episodes.iter().filter(|e| e.pinned).count();
340        let recent = self.recent().len();
341
342        format!(
343            "Memory: {} total ({} pinned, {} recent within {:?})",
344            total, pinned, recent, self.config.horizon
345        )
346    }
347}
348
349impl Default for EpisodicMemory {
350    fn default() -> Self {
351        Self::new(HorizonConfig::default())
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    #[test]
360    fn test_episodic_memory() {
361        let mut memory = EpisodicMemory::default();
362
363        memory.remember("First event", 0.8);
364        memory.remember("Second event", 0.5);
365        memory.add(Episode::pinned("Important system info"));
366
367        assert_eq!(memory.len(), 3);
368        assert_eq!(memory.recent().len(), 3);
369    }
370
371    #[test]
372    fn test_by_importance() {
373        let mut memory = EpisodicMemory::default();
374
375        memory.remember("Low importance", 0.2);
376        memory.remember("High importance", 0.9);
377
378        let sorted = memory.by_importance();
379        assert!(sorted[0].1 > sorted[1].1);
380    }
381
382    #[test]
383    fn test_pinned_not_evicted() {
384        let config = HorizonConfig {
385            max_entries: 2,
386            ..Default::default()
387        };
388
389        let mut memory = EpisodicMemory::new(config);
390        memory.add(Episode::pinned("System"));
391        memory.remember("Event 1", 0.5);
392        memory.remember("Event 2", 0.5);
393        memory.remember("Event 3", 0.5);
394
395        // Pinned should still be there
396        assert!(memory.episodes().any(|e| e.content == "System"));
397    }
398}