vex_api/
tenant_rate_limiter.rs1use governor::{
13 clock::{Clock, DefaultClock},
14 state::{InMemoryState, NotKeyed},
15 Quota, RateLimiter as Governor,
16};
17use std::collections::HashMap;
18use std::num::NonZeroU32;
19use std::sync::Arc;
20use std::time::Duration;
21use tokio::sync::RwLock;
22
23#[derive(
25 Debug,
26 Clone,
27 Copy,
28 PartialEq,
29 Eq,
30 Hash,
31 serde::Serialize,
32 serde::Deserialize,
33 utoipa::ToSchema,
34 Default,
35)]
36pub enum RateLimitTier {
37 #[default]
39 Free,
40 Standard,
42 Pro,
44 Unlimited,
46}
47
48impl RateLimitTier {
49 pub fn quota(&self) -> Option<Quota> {
51 match self {
52 Self::Free => Some(Quota::per_minute(NonZeroU32::new(10).unwrap())),
53 Self::Standard => Some(Quota::per_minute(NonZeroU32::new(100).unwrap())),
54 Self::Pro => Some(Quota::per_minute(NonZeroU32::new(1000).unwrap())),
55 Self::Unlimited => None, }
57 }
58}
59
60type TenantLimiter = Governor<NotKeyed, InMemoryState, DefaultClock>;
62
63#[derive(Debug)]
65pub struct TenantRateLimiter {
66 limiters: RwLock<HashMap<String, Arc<TenantLimiter>>>,
68 default_tier: RateLimitTier,
70 tier_assignments: RwLock<HashMap<String, RateLimitTier>>,
72}
73
74impl Default for TenantRateLimiter {
75 fn default() -> Self {
76 Self::new(RateLimitTier::Free)
77 }
78}
79
80impl TenantRateLimiter {
81 pub fn new(default_tier: RateLimitTier) -> Self {
83 Self {
84 limiters: RwLock::new(HashMap::new()),
85 default_tier,
86 tier_assignments: RwLock::new(HashMap::new()),
87 }
88 }
89
90 pub async fn set_tier(&self, tenant_id: &str, tier: RateLimitTier) {
92 let mut assignments = self.tier_assignments.write().await;
93 assignments.insert(tenant_id.to_string(), tier);
94
95 let mut limiters = self.limiters.write().await;
97 limiters.remove(tenant_id);
98 }
99
100 pub async fn get_tier(&self, tenant_id: &str) -> RateLimitTier {
102 let assignments = self.tier_assignments.read().await;
103 assignments
104 .get(tenant_id)
105 .copied()
106 .unwrap_or(self.default_tier)
107 }
108
109 pub async fn check(&self, tenant_id: &str) -> Result<(), Duration> {
111 if tenant_id.trim().is_empty() {
113 return Err(Duration::from_secs(3600)); }
115
116 let tier = self.get_tier(tenant_id).await;
117
118 let quota = match tier.quota() {
120 Some(q) => q,
121 None => return Ok(()),
122 };
123
124 let limiter = self.get_or_create_limiter(tenant_id, quota).await;
126
127 match limiter.check() {
128 Ok(_) => Ok(()),
129 Err(not_until) => {
130 let wait = not_until.wait_time_from(DefaultClock::default().now());
131 Err(wait)
132 }
133 }
134 }
135
136 async fn get_or_create_limiter(&self, tenant_id: &str, quota: Quota) -> Arc<TenantLimiter> {
138 {
140 let limiters = self.limiters.read().await;
141 if let Some(limiter) = limiters.get(tenant_id) {
142 return limiter.clone();
143 }
144 }
145
146 let mut limiters = self.limiters.write().await;
148
149 if let Some(limiter) = limiters.get(tenant_id) {
151 return limiter.clone();
152 }
153
154 let limiter = Arc::new(Governor::direct(quota));
155 limiters.insert(tenant_id.to_string(), limiter.clone());
156 limiter
157 }
158
159 pub async fn cleanup(&self) {
161 let limiters = self.limiters.write().await;
162 tracing::debug!(limiter_count = limiters.len(), "Tenant limiter cleanup");
165
166 let _ = limiters; }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[tokio::test]
177 async fn test_rate_limiter_allows_within_quota() {
178 let limiter = TenantRateLimiter::new(RateLimitTier::Standard);
179
180 for _ in 0..10 {
182 assert!(limiter.check("tenant1").await.is_ok());
183 }
184 }
185
186 #[tokio::test]
187 async fn test_rate_limiter_blocks_over_quota() {
188 let limiter = TenantRateLimiter::new(RateLimitTier::Free);
189
190 for _ in 0..10 {
192 let _ = limiter.check("tenant1").await;
193 }
194
195 let result = limiter.check("tenant1").await;
197 assert!(result.is_err());
198 }
199
200 #[tokio::test]
201 async fn test_different_tenants_independent() {
202 let limiter = TenantRateLimiter::new(RateLimitTier::Free);
203
204 for _ in 0..15 {
206 let _ = limiter.check("tenant1").await;
207 }
208
209 assert!(limiter.check("tenant2").await.is_ok());
211 }
212
213 #[tokio::test]
214 async fn test_unlimited_tier() {
215 let limiter = TenantRateLimiter::new(RateLimitTier::Free);
216 limiter.set_tier("vip", RateLimitTier::Unlimited).await;
217
218 for _ in 0..1000 {
220 assert!(limiter.check("vip").await.is_ok());
221 }
222 }
223}