1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use sqlx::{Row, SqlitePool};
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6use thiserror::Error;
7
8#[derive(Error, Debug)]
9pub enum VectorError {
10 #[error("Dimension mismatch: expected {0}, got {1}")]
11 DimensionMismatch(usize, usize),
12 #[error("Serialization error: {0}")]
13 SerializationError(String),
14 #[error("Database error: {0}")]
15 DatabaseError(String),
16 #[error("Storage full: capacity exceeded")]
17 StorageFull,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct VectorEmbedding {
22 pub id: String,
23 pub vector: Vec<f32>,
24 pub metadata: HashMap<String, String>,
25}
26
27#[async_trait]
29pub trait VectorStoreBackend: Send + Sync + std::fmt::Debug {
30 async fn add(
31 &self,
32 id: String,
33 tenant_id: String,
34 vector: Vec<f32>,
35 metadata: HashMap<String, String>,
36 ) -> Result<(), VectorError>;
37
38 async fn search(
39 &self,
40 tenant_id: &str,
41 query: &[f32],
42 k: usize,
43 ) -> Result<Vec<(f32, VectorEmbedding)>, VectorError>;
44}
45
46#[derive(Debug, Clone)]
48pub struct MemoryVectorStore {
49 dimension: usize,
50 embeddings: Arc<RwLock<Vec<(String, String, VectorEmbedding)>>>, }
52
53impl MemoryVectorStore {
54 pub fn new(dimension: usize) -> Self {
55 Self {
56 dimension,
57 embeddings: Arc::new(RwLock::new(Vec::new())),
58 }
59 }
60}
61
62#[async_trait]
63impl VectorStoreBackend for MemoryVectorStore {
64 async fn add(
65 &self,
66 id: String,
67 tenant_id: String,
68 vector: Vec<f32>,
69 metadata: HashMap<String, String>,
70 ) -> Result<(), VectorError> {
71 if vector.len() != self.dimension {
72 return Err(VectorError::DimensionMismatch(self.dimension, vector.len()));
73 }
74
75 let mut data = self.embeddings.write().unwrap();
76
77 if data.len() >= 100_000 {
79 return Err(VectorError::StorageFull);
80 }
81
82 data.push((
83 id.clone(),
84 tenant_id,
85 VectorEmbedding {
86 id,
87 vector,
88 metadata,
89 },
90 ));
91
92 Ok(())
93 }
94
95 async fn search(
96 &self,
97 tenant_id: &str,
98 query: &[f32],
99 k: usize,
100 ) -> Result<Vec<(f32, VectorEmbedding)>, VectorError> {
101 if query.len() != self.dimension {
102 return Err(VectorError::DimensionMismatch(self.dimension, query.len()));
103 }
104
105 let data = self.embeddings.read().unwrap();
106 let mut scores: Vec<(f32, VectorEmbedding)> = data
107 .iter()
108 .filter(|(_, tid, _)| tid == tenant_id)
109 .map(|(_, _, emb)| {
110 let score = cosine_similarity(query, &emb.vector);
111 (score, emb.clone())
112 })
113 .collect();
114
115 scores.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
116 scores.truncate(k);
117
118 Ok(scores)
119 }
120}
121
122#[derive(Debug, Clone)]
124pub struct SqliteVectorStore {
125 dimension: usize,
126 pool: SqlitePool,
127}
128
129impl SqliteVectorStore {
130 pub fn new(dimension: usize, pool: SqlitePool) -> Self {
131 Self { dimension, pool }
132 }
133}
134
135#[async_trait]
136impl VectorStoreBackend for SqliteVectorStore {
137 async fn add(
138 &self,
139 id: String,
140 tenant_id: String,
141 vector: Vec<f32>,
142 metadata: HashMap<String, String>,
143 ) -> Result<(), VectorError> {
144 if vector.len() != self.dimension {
145 return Err(VectorError::DimensionMismatch(self.dimension, vector.len()));
146 }
147
148 let mut vector_bytes = Vec::with_capacity(vector.len() * 4);
150 for &val in &vector {
151 vector_bytes.extend_from_slice(&val.to_le_bytes());
152 }
153
154 let metadata_json = serde_json::to_string(&metadata)
155 .map_err(|e| VectorError::SerializationError(e.to_string()))?;
156
157 sqlx::query(
158 "INSERT OR REPLACE INTO vector_embeddings (id, tenant_id, vector, metadata, created_at) VALUES (?, ?, ?, ?, ?)"
159 )
160 .bind(id)
161 .bind(tenant_id)
162 .bind(vector_bytes)
163 .bind(metadata_json)
164 .bind(chrono::Utc::now().timestamp())
165 .execute(&self.pool)
166 .await
167 .map_err(|e| VectorError::DatabaseError(e.to_string()))?;
168
169 Ok(())
170 }
171
172 async fn search(
173 &self,
174 tenant_id: &str,
175 query: &[f32],
176 k: usize,
177 ) -> Result<Vec<(f32, VectorEmbedding)>, VectorError> {
178 if query.len() != self.dimension {
179 return Err(VectorError::DimensionMismatch(self.dimension, query.len()));
180 }
181
182 let rows =
185 sqlx::query("SELECT id, vector, metadata FROM vector_embeddings WHERE tenant_id = ?")
186 .bind(tenant_id)
187 .fetch_all(&self.pool)
188 .await
189 .map_err(|e| VectorError::DatabaseError(e.to_string()))?;
190
191 let mut scores = Vec::new();
192
193 for row in rows {
194 let id: String = row.get("id");
195 let vector_bytes: Vec<u8> = row.get("vector");
196 let metadata_str: String = row.get("metadata");
197
198 if vector_bytes.len() != self.dimension * 4 {
200 continue; }
202
203 let mut vector = Vec::with_capacity(self.dimension);
204 for chunk in vector_bytes.chunks_exact(4) {
205 let arr: [u8; 4] = chunk.try_into().unwrap();
206 vector.push(f32::from_le_bytes(arr));
207 }
208
209 let metadata: HashMap<String, String> = serde_json::from_str(&metadata_str)
210 .map_err(|e| VectorError::SerializationError(e.to_string()))?;
211
212 let score = cosine_similarity(query, &vector);
213 scores.push((
214 score,
215 VectorEmbedding {
216 id,
217 vector,
218 metadata,
219 },
220 ));
221 }
222
223 scores.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
224 scores.truncate(k);
225
226 Ok(scores)
227 }
228}
229
230fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
231 let dot_product: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
232 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
233 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
234
235 if norm_a == 0.0 || norm_b == 0.0 {
236 return 0.0;
237 }
238
239 dot_product / (norm_a * norm_b)
240}