vex_persist/
vector_store.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use sqlx::{Row, SqlitePool};
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6use thiserror::Error;
7
8#[derive(Error, Debug)]
9pub enum VectorError {
10    #[error("Dimension mismatch: expected {0}, got {1}")]
11    DimensionMismatch(usize, usize),
12    #[error("Serialization error: {0}")]
13    SerializationError(String),
14    #[error("Database error: {0}")]
15    DatabaseError(String),
16    #[error("Storage full: capacity exceeded")]
17    StorageFull,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct VectorEmbedding {
22    pub id: String,
23    pub vector: Vec<f32>,
24    pub metadata: HashMap<String, String>,
25}
26
27/// Generic trait for vector storage
28#[async_trait]
29pub trait VectorStoreBackend: Send + Sync + std::fmt::Debug {
30    async fn add(
31        &self,
32        id: String,
33        tenant_id: String,
34        vector: Vec<f32>,
35        metadata: HashMap<String, String>,
36    ) -> Result<(), VectorError>;
37
38    async fn search(
39        &self,
40        tenant_id: &str,
41        query: &[f32],
42        k: usize,
43    ) -> Result<Vec<(f32, VectorEmbedding)>, VectorError>;
44}
45
46/// In-memory vector store implementation (for testing and small contexts)
47#[derive(Debug, Clone)]
48pub struct MemoryVectorStore {
49    dimension: usize,
50    embeddings: Arc<RwLock<Vec<(String, String, VectorEmbedding)>>>, // (id, tenant_id, embedding)
51}
52
53impl MemoryVectorStore {
54    pub fn new(dimension: usize) -> Self {
55        Self {
56            dimension,
57            embeddings: Arc::new(RwLock::new(Vec::new())),
58        }
59    }
60}
61
62#[async_trait]
63impl VectorStoreBackend for MemoryVectorStore {
64    async fn add(
65        &self,
66        id: String,
67        tenant_id: String,
68        vector: Vec<f32>,
69        metadata: HashMap<String, String>,
70    ) -> Result<(), VectorError> {
71        if vector.len() != self.dimension {
72            return Err(VectorError::DimensionMismatch(self.dimension, vector.len()));
73        }
74
75        let mut data = self.embeddings.write().unwrap();
76
77        // Limit capacity to prevent memory DoS (Fix #12)
78        if data.len() >= 100_000 {
79            return Err(VectorError::StorageFull);
80        }
81
82        data.push((
83            id.clone(),
84            tenant_id,
85            VectorEmbedding {
86                id,
87                vector,
88                metadata,
89            },
90        ));
91
92        Ok(())
93    }
94
95    async fn search(
96        &self,
97        tenant_id: &str,
98        query: &[f32],
99        k: usize,
100    ) -> Result<Vec<(f32, VectorEmbedding)>, VectorError> {
101        if query.len() != self.dimension {
102            return Err(VectorError::DimensionMismatch(self.dimension, query.len()));
103        }
104
105        let data = self.embeddings.read().unwrap();
106        let mut scores: Vec<(f32, VectorEmbedding)> = data
107            .iter()
108            .filter(|(_, tid, _)| tid == tenant_id)
109            .map(|(_, _, emb)| {
110                let score = cosine_similarity(query, &emb.vector);
111                (score, emb.clone())
112            })
113            .collect();
114
115        scores.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
116        scores.truncate(k);
117
118        Ok(scores)
119    }
120}
121
122/// SQLite-backed persistent vector store
123#[derive(Debug, Clone)]
124pub struct SqliteVectorStore {
125    dimension: usize,
126    pool: SqlitePool,
127}
128
129impl SqliteVectorStore {
130    pub fn new(dimension: usize, pool: SqlitePool) -> Self {
131        Self { dimension, pool }
132    }
133}
134
135#[async_trait]
136impl VectorStoreBackend for SqliteVectorStore {
137    async fn add(
138        &self,
139        id: String,
140        tenant_id: String,
141        vector: Vec<f32>,
142        metadata: HashMap<String, String>,
143    ) -> Result<(), VectorError> {
144        if vector.len() != self.dimension {
145            return Err(VectorError::DimensionMismatch(self.dimension, vector.len()));
146        }
147
148        // Convert f32 vector to bytes (Little Endian)
149        let mut vector_bytes = Vec::with_capacity(vector.len() * 4);
150        for &val in &vector {
151            vector_bytes.extend_from_slice(&val.to_le_bytes());
152        }
153
154        let metadata_json = serde_json::to_string(&metadata)
155            .map_err(|e| VectorError::SerializationError(e.to_string()))?;
156
157        sqlx::query(
158            "INSERT OR REPLACE INTO vector_embeddings (id, tenant_id, vector, metadata, created_at) VALUES (?, ?, ?, ?, ?)"
159        )
160        .bind(id)
161        .bind(tenant_id)
162        .bind(vector_bytes)
163        .bind(metadata_json)
164        .bind(chrono::Utc::now().timestamp())
165        .execute(&self.pool)
166        .await
167        .map_err(|e| VectorError::DatabaseError(e.to_string()))?;
168
169        Ok(())
170    }
171
172    async fn search(
173        &self,
174        tenant_id: &str,
175        query: &[f32],
176        k: usize,
177    ) -> Result<Vec<(f32, VectorEmbedding)>, VectorError> {
178        if query.len() != self.dimension {
179            return Err(VectorError::DimensionMismatch(self.dimension, query.len()));
180        }
181
182        // In a real high-perf vector DB we'd use HNSW/IVF.
183        // For VEX P2, we perform a brute-force scan of the tenant's embeddings.
184        let rows =
185            sqlx::query("SELECT id, vector, metadata FROM vector_embeddings WHERE tenant_id = ?")
186                .bind(tenant_id)
187                .fetch_all(&self.pool)
188                .await
189                .map_err(|e| VectorError::DatabaseError(e.to_string()))?;
190
191        let mut scores = Vec::new();
192
193        for row in rows {
194            let id: String = row.get("id");
195            let vector_bytes: Vec<u8> = row.get("vector");
196            let metadata_str: String = row.get("metadata");
197
198            // Convert bytes back to f32 vector
199            if vector_bytes.len() != self.dimension * 4 {
200                continue; // Skip corrupted entry
201            }
202
203            let mut vector = Vec::with_capacity(self.dimension);
204            for chunk in vector_bytes.chunks_exact(4) {
205                let arr: [u8; 4] = chunk.try_into().unwrap();
206                vector.push(f32::from_le_bytes(arr));
207            }
208
209            let metadata: HashMap<String, String> = serde_json::from_str(&metadata_str)
210                .map_err(|e| VectorError::SerializationError(e.to_string()))?;
211
212            let score = cosine_similarity(query, &vector);
213            scores.push((
214                score,
215                VectorEmbedding {
216                    id,
217                    vector,
218                    metadata,
219                },
220            ));
221        }
222
223        scores.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
224        scores.truncate(k);
225
226        Ok(scores)
227    }
228}
229
230fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
231    let dot_product: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
232    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
233    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
234
235    if norm_a == 0.0 || norm_b == 0.0 {
236        return 0.0;
237    }
238
239    dot_product / (norm_a * norm_b)
240}