vex_persist/
sqlite.rs

1//! SQLite backend implementation
2
3use async_trait::async_trait;
4use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
5use std::str::FromStr;
6use tracing::{info, warn};
7
8use crate::backend::{StorageBackend, StorageError};
9
10/// SQLite configuration options
11#[derive(Debug, Clone)]
12pub struct SqliteConfig {
13    /// Database URL (e.g., "sqlite:data.db" or "sqlite::memory:")
14    pub url: String,
15    /// Maximum number of connections in the pool
16    pub max_connections: u32,
17    /// Encryption key for SQLCipher (None = unencrypted)
18    /// Note: Requires SQLite compiled with SQLCipher extension
19    pub encryption_key: Option<String>,
20    /// Enable WAL journal mode for better concurrency
21    pub wal_mode: bool,
22    /// Enable foreign key enforcement
23    pub foreign_keys: bool,
24    /// Busy timeout in seconds
25    pub busy_timeout_secs: u32,
26}
27
28impl Default for SqliteConfig {
29    fn default() -> Self {
30        Self {
31            url: "sqlite:vex.db?mode=rwc".to_string(),
32            max_connections: 5,
33            encryption_key: None,
34            wal_mode: true,
35            foreign_keys: true,
36            busy_timeout_secs: 30,
37        }
38    }
39}
40
41impl SqliteConfig {
42    /// Create config for in-memory database (testing)
43    pub fn memory() -> Self {
44        Self {
45            url: "sqlite::memory:".to_string(),
46            max_connections: 1,
47            encryption_key: None,
48            wal_mode: false,
49            foreign_keys: true,
50            busy_timeout_secs: 5,
51        }
52    }
53
54    /// Create secure config with encryption
55    pub fn secure(url: &str, encryption_key: &str) -> Self {
56        Self {
57            url: url.to_string(),
58            max_connections: 5,
59            encryption_key: Some(encryption_key.to_string()),
60            wal_mode: true,
61            foreign_keys: true,
62            busy_timeout_secs: 30,
63        }
64    }
65}
66
67/// SQLite storage backend
68#[derive(Debug)]
69pub struct SqliteBackend {
70    pool: SqlitePool,
71    encrypted: bool,
72}
73
74impl SqliteBackend {
75    /// Create a new SQLite backend with default config
76    pub async fn new(url: &str) -> Result<Self, StorageError> {
77        let config = SqliteConfig {
78            url: url.to_string(),
79            ..Default::default()
80        };
81        Self::new_with_config(config).await
82    }
83
84    /// Create a new SQLite backend with full configuration
85    pub async fn new_with_config(config: SqliteConfig) -> Result<Self, StorageError> {
86        let mut options = SqliteConnectOptions::from_str(&config.url)
87            .map_err(|e| StorageError::Connection(e.to_string()))?;
88
89        // Set pragmas for security and performance
90        if config.foreign_keys {
91            options = options.pragma("foreign_keys", "ON");
92        }
93        options = options.pragma("busy_timeout", config.busy_timeout_secs.to_string());
94
95        if config.wal_mode {
96            options = options.pragma("journal_mode", "WAL");
97        }
98
99        // Handle encryption key (requires SQLCipher)
100        let encrypted = if let Some(ref key) = config.encryption_key {
101            // SQLCipher pragma - will fail silently if not compiled with SQLCipher
102            // Safe: escape single quotes to prevent injection
103            let escaped_key = key.replace("'", "''");
104            options = options.pragma("key", format!("'{}'", escaped_key));
105            warn!("SQLite encryption enabled - ensure SQLCipher is available");
106            true
107        } else {
108            false
109        };
110
111        let pool = SqlitePoolOptions::new()
112            .max_connections(config.max_connections)
113            .connect_with(options)
114            .await
115            .map_err(|e| StorageError::Connection(e.to_string()))?;
116
117        info!(
118            url = %config.url,
119            encrypted = encrypted,
120            wal = config.wal_mode,
121            "Connected to SQLite"
122        );
123
124        // Verify SQLCipher is actually active if encryption was requested
125        if encrypted {
126            use sqlx::Row;
127            let _result = sqlx::query("SELECT sqlite3_version()")
128                .fetch_one(&pool)
129                .await
130                .map_err(|e| {
131                    StorageError::Connection(format!("SQLCipher verification failed: {}", e))
132                })?;
133
134            // Try to verify cipher_version pragma - if it fails, SQLCipher is not available
135            let cipher_check = sqlx::query("PRAGMA cipher_version")
136                .fetch_optional(&pool)
137                .await;
138
139            match cipher_check {
140                Ok(Some(row)) => {
141                    let version: Option<String> = row.try_get(0).ok();
142                    if version.is_none() || version.as_ref().map(|v| v.is_empty()).unwrap_or(true) {
143                        return Err(StorageError::Internal(
144                            "SQLCipher encryption requested but cipher_version returned empty. \
145                             SQLite may not be compiled with SQLCipher support."
146                                .to_string(),
147                        ));
148                    }
149                    info!(cipher_version = ?version, "SQLCipher encryption verified");
150                }
151                Ok(None) | Err(_) => {
152                    return Err(StorageError::Internal(
153                        "SQLCipher encryption requested but not available. \
154                         Database will NOT be encrypted! Aborting for security."
155                            .to_string(),
156                    ));
157                }
158            }
159        }
160
161        // Run migrations
162        sqlx::migrate!("./migrations")
163            .run(&pool)
164            .await
165            .map_err(|e| StorageError::Internal(format!("Migration failed: {}", e)))?;
166
167        Ok(Self { pool, encrypted })
168    }
169
170    /// Get the connection pool
171    pub fn pool(&self) -> &SqlitePool {
172        &self.pool
173    }
174
175    /// Check if database is encrypted
176    pub fn is_encrypted(&self) -> bool {
177        self.encrypted
178    }
179}
180
181#[async_trait]
182impl StorageBackend for SqliteBackend {
183    fn name(&self) -> &str {
184        "sqlite"
185    }
186
187    async fn is_healthy(&self) -> bool {
188        !self.pool.is_closed()
189    }
190
191    async fn set_value(&self, key: &str, value: serde_json::Value) -> Result<(), StorageError> {
192        let json = serde_json::to_string(&value)
193            .map_err(|e| StorageError::Serialization(e.to_string()))?;
194
195        let now = chrono::Utc::now().timestamp();
196
197        sqlx::query(
198            "INSERT OR REPLACE INTO kv_store (key, value, created_at, updated_at) VALUES (?, ?, ?, ?)"
199        )
200        .bind(key)
201        .bind(json)
202        .bind(now)
203        .bind(now)
204        .execute(&self.pool)
205        .await
206        .map_err(|e| StorageError::Query(e.to_string()))?;
207
208        Ok(())
209    }
210
211    async fn get_value(&self, key: &str) -> Result<Option<serde_json::Value>, StorageError> {
212        use sqlx::Row;
213        let result = sqlx::query("SELECT value FROM kv_store WHERE key = ?")
214            .bind(key)
215            .fetch_optional(&self.pool)
216            .await
217            .map_err(|e| StorageError::Query(e.to_string()))?;
218
219        match result {
220            Some(row) => {
221                let value_str: String = row
222                    .try_get("value")
223                    .map_err(|e| StorageError::Query(e.to_string()))?;
224                let value = serde_json::from_str(&value_str)
225                    .map_err(|e| StorageError::Serialization(e.to_string()))?;
226                Ok(Some(value))
227            }
228            None => Ok(None),
229        }
230    }
231
232    async fn delete(&self, key: &str) -> Result<bool, StorageError> {
233        let result = sqlx::query("DELETE FROM kv_store WHERE key = ?")
234            .bind(key)
235            .execute(&self.pool)
236            .await
237            .map_err(|e| StorageError::Query(e.to_string()))?;
238
239        Ok(result.rows_affected() > 0)
240    }
241
242    async fn exists(&self, key: &str) -> Result<bool, StorageError> {
243        let result = sqlx::query("SELECT 1 FROM kv_store WHERE key = ?")
244            .bind(key)
245            .fetch_optional(&self.pool)
246            .await
247            .map_err(|e| StorageError::Query(e.to_string()))?;
248
249        Ok(result.is_some())
250    }
251
252    async fn list_keys(&self, prefix: &str) -> Result<Vec<String>, StorageError> {
253        use sqlx::Row;
254        let pattern = format!("{}%", prefix);
255        let rows = sqlx::query("SELECT key FROM kv_store WHERE key LIKE ?")
256            .bind(pattern)
257            .fetch_all(&self.pool)
258            .await
259            .map_err(|e| StorageError::Query(e.to_string()))?;
260
261        let mut keys = Vec::new();
262        for row in rows {
263            let key: String = row
264                .try_get("key")
265                .map_err(|e| StorageError::Query(e.to_string()))?;
266            keys.push(key);
267        }
268        Ok(keys)
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use crate::backend::StorageExt;
276    use serde::{Deserialize, Serialize};
277
278    #[derive(Debug, Serialize, Deserialize, PartialEq)]
279    struct TestData {
280        name: String,
281        value: i32,
282    }
283
284    #[tokio::test]
285    async fn test_sqlite_backend() {
286        let backend = SqliteBackend::new("sqlite::memory:").await.unwrap();
287
288        let data = TestData {
289            name: "test_sql".to_string(),
290            value: 99,
291        };
292
293        // Set
294        backend.set("sql:1", &data).await.unwrap();
295
296        // Exists
297        assert!(backend.exists("sql:1").await.unwrap());
298
299        // Get
300        let retrieved: Option<TestData> = backend.get("sql:1").await.unwrap();
301        assert_eq!(retrieved, Some(data));
302
303        // List
304        let keys = backend.list_keys("sql:").await.unwrap();
305        assert_eq!(keys, vec!["sql:1"]);
306
307        // Delete
308        assert!(backend.delete("sql:1").await.unwrap());
309        assert!(!backend.exists("sql:1").await.unwrap());
310    }
311}