1use async_trait::async_trait;
4use serde::{de::DeserializeOwned, Serialize};
5use std::fmt::Debug;
6
7#[derive(Debug, thiserror::Error)]
9pub enum StorageError {
10 #[error("Not found: {0}")]
11 NotFound(String),
12
13 #[error("Already exists: {0}")]
14 AlreadyExists(String),
15
16 #[error("Serialization error: {0}")]
17 Serialization(String),
18
19 #[error("Connection error: {0}")]
20 Connection(String),
21
22 #[error("Query error: {0}")]
23 Query(String),
24
25 #[error("Internal error: {0}")]
26 Internal(String),
27}
28
29#[async_trait]
31pub trait StorageBackend: Send + Sync + Debug {
32 fn name(&self) -> &str;
34
35 async fn is_healthy(&self) -> bool;
37
38 async fn set_value(&self, key: &str, value: serde_json::Value) -> Result<(), StorageError>;
40
41 async fn get_value(&self, key: &str) -> Result<Option<serde_json::Value>, StorageError>;
43
44 async fn delete(&self, key: &str) -> Result<bool, StorageError>;
46
47 async fn exists(&self, key: &str) -> Result<bool, StorageError>;
49
50 async fn list_keys(&self, prefix: &str) -> Result<Vec<String>, StorageError>;
52}
53
54#[async_trait]
56pub trait StorageExt {
57 async fn set<T: Serialize + Send + Sync>(
58 &self,
59 key: &str,
60 value: &T,
61 ) -> Result<(), StorageError>;
62 async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, StorageError>;
63}
64
65#[async_trait]
66impl<S: StorageBackend + ?Sized> StorageExt for S {
67 async fn set<T: Serialize + Send + Sync>(
68 &self,
69 key: &str,
70 value: &T,
71 ) -> Result<(), StorageError> {
72 let json =
73 serde_json::to_value(value).map_err(|e| StorageError::Serialization(e.to_string()))?;
74 self.set_value(key, json).await
75 }
76
77 async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, StorageError> {
78 match self.get_value(key).await? {
79 Some(json) => {
80 let value = serde_json::from_value(json)
81 .map_err(|e| StorageError::Serialization(e.to_string()))?;
82 Ok(Some(value))
83 }
84 None => Ok(None),
85 }
86 }
87}
88
89#[derive(Debug, Default)]
91pub struct MemoryBackend {
92 data: tokio::sync::RwLock<std::collections::HashMap<String, serde_json::Value>>,
93}
94
95impl MemoryBackend {
96 pub fn new() -> Self {
97 Self::default()
98 }
99}
100
101#[async_trait]
102impl StorageBackend for MemoryBackend {
103 fn name(&self) -> &str {
104 "memory"
105 }
106
107 async fn is_healthy(&self) -> bool {
108 true
109 }
110
111 async fn set_value(&self, key: &str, value: serde_json::Value) -> Result<(), StorageError> {
112 self.data.write().await.insert(key.to_string(), value);
113 Ok(())
114 }
115
116 async fn get_value(&self, key: &str) -> Result<Option<serde_json::Value>, StorageError> {
117 let data = self.data.read().await;
118 Ok(data.get(key).cloned())
119 }
120
121 async fn delete(&self, key: &str) -> Result<bool, StorageError> {
122 Ok(self.data.write().await.remove(key).is_some())
123 }
124
125 async fn exists(&self, key: &str) -> Result<bool, StorageError> {
126 Ok(self.data.read().await.contains_key(key))
127 }
128
129 async fn list_keys(&self, prefix: &str) -> Result<Vec<String>, StorageError> {
130 let data = self.data.read().await;
131 let keys: Vec<String> = data
132 .keys()
133 .filter(|k| k.starts_with(prefix))
134 .cloned()
135 .collect();
136 Ok(keys)
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use serde::{Deserialize, Serialize};
144 #[derive(Debug, Serialize, Deserialize, PartialEq)]
148 struct TestData {
149 name: String,
150 value: i32,
151 }
152
153 #[tokio::test]
154 async fn test_memory_backend() {
155 let backend = MemoryBackend::new();
156
157 let data = TestData {
158 name: "test".to_string(),
159 value: 42,
160 };
161
162 backend.set("test:1", &data).await.unwrap();
164
165 let retrieved: Option<TestData> = backend.get("test:1").await.unwrap();
167 assert_eq!(retrieved, Some(data));
168
169 assert!(backend.exists("test:1").await.unwrap());
171 assert!(!backend.exists("test:2").await.unwrap());
172
173 let keys = backend.list_keys("test:").await.unwrap();
175 assert_eq!(keys, vec!["test:1"]);
176
177 assert!(backend.delete("test:1").await.unwrap());
179 assert!(!backend.exists("test:1").await.unwrap());
180 }
181}