vex_api/
tenant_rate_limiter.rs

1//! Tenant-scoped rate limiting using governor (GCRA algorithm)
2//!
3//! Provides per-tenant rate limiting for API endpoints using the GCRA
4//! (Generic Cell Rate Algorithm) which is efficient and doesn't require
5//! background maintenance threads.
6//!
7//! # 2025 Best Practices
8//! - Uses governor crate for efficient GCRA implementation
9//! - Per-tenant keyed rate limiting (by header or API key)
10//! - Configurable quotas per tier
11
12use governor::{
13    clock::{Clock, DefaultClock},
14    state::{InMemoryState, NotKeyed},
15    Quota, RateLimiter as Governor,
16};
17use std::collections::HashMap;
18use std::num::NonZeroU32;
19use std::sync::Arc;
20use std::time::Duration;
21use tokio::sync::RwLock;
22
23/// Rate limit tier for different tenant types
24#[derive(
25    Debug,
26    Clone,
27    Copy,
28    PartialEq,
29    Eq,
30    Hash,
31    serde::Serialize,
32    serde::Deserialize,
33    utoipa::ToSchema,
34    Default,
35)]
36pub enum RateLimitTier {
37    /// Free tier: 10 requests/minute
38    #[default]
39    Free,
40    /// Standard tier: 100 requests/minute
41    Standard,
42    /// Pro tier: 1000 requests/minute
43    Pro,
44    /// Unlimited (for internal services)
45    Unlimited,
46}
47
48impl RateLimitTier {
49    /// Get the quota for this tier
50    pub fn quota(&self) -> Option<Quota> {
51        match self {
52            Self::Free => Some(Quota::per_minute(NonZeroU32::new(10).unwrap())),
53            Self::Standard => Some(Quota::per_minute(NonZeroU32::new(100).unwrap())),
54            Self::Pro => Some(Quota::per_minute(NonZeroU32::new(1000).unwrap())),
55            Self::Unlimited => None, // No limiting
56        }
57    }
58}
59
60/// Per-tenant rate limiter state
61type TenantLimiter = Governor<NotKeyed, InMemoryState, DefaultClock>;
62
63/// Tenant-scoped rate limiter
64#[derive(Debug)]
65pub struct TenantRateLimiter {
66    /// Per-tenant limiters
67    limiters: RwLock<HashMap<String, Arc<TenantLimiter>>>,
68    /// Default tier for new tenants
69    default_tier: RateLimitTier,
70    /// Tier assignments per tenant
71    tier_assignments: RwLock<HashMap<String, RateLimitTier>>,
72}
73
74impl Default for TenantRateLimiter {
75    fn default() -> Self {
76        Self::new(RateLimitTier::Free)
77    }
78}
79
80impl TenantRateLimiter {
81    /// Create a new tenant rate limiter with a default tier
82    pub fn new(default_tier: RateLimitTier) -> Self {
83        Self {
84            limiters: RwLock::new(HashMap::new()),
85            default_tier,
86            tier_assignments: RwLock::new(HashMap::new()),
87        }
88    }
89
90    /// Assign a tier to a tenant
91    pub async fn set_tier(&self, tenant_id: &str, tier: RateLimitTier) {
92        let mut assignments = self.tier_assignments.write().await;
93        assignments.insert(tenant_id.to_string(), tier);
94
95        // Remove cached limiter so it gets recreated with new tier
96        let mut limiters = self.limiters.write().await;
97        limiters.remove(tenant_id);
98    }
99
100    /// Get a tenant's tier
101    pub async fn get_tier(&self, tenant_id: &str) -> RateLimitTier {
102        let assignments = self.tier_assignments.read().await;
103        assignments
104            .get(tenant_id)
105            .copied()
106            .unwrap_or(self.default_tier)
107    }
108
109    /// Check if a request is allowed for a tenant
110    pub async fn check(&self, tenant_id: &str) -> Result<(), Duration> {
111        // Prevent bypass via empty tenant_id (Fix #11)
112        if tenant_id.trim().is_empty() {
113            return Err(Duration::from_secs(3600)); // Block for 1 hour
114        }
115
116        let tier = self.get_tier(tenant_id).await;
117
118        // Unlimited tier always passes
119        let quota = match tier.quota() {
120            Some(q) => q,
121            None => return Ok(()),
122        };
123
124        // Get or create limiter for this tenant
125        let limiter = self.get_or_create_limiter(tenant_id, quota).await;
126
127        match limiter.check() {
128            Ok(_) => Ok(()),
129            Err(not_until) => {
130                let wait = not_until.wait_time_from(DefaultClock::default().now());
131                Err(wait)
132            }
133        }
134    }
135
136    /// Get or create a limiter for a tenant
137    async fn get_or_create_limiter(&self, tenant_id: &str, quota: Quota) -> Arc<TenantLimiter> {
138        // Fast path: check if exists
139        {
140            let limiters = self.limiters.read().await;
141            if let Some(limiter) = limiters.get(tenant_id) {
142                return limiter.clone();
143            }
144        }
145
146        // Slow path: create new limiter
147        let mut limiters = self.limiters.write().await;
148
149        // Double-check after acquiring write lock
150        if let Some(limiter) = limiters.get(tenant_id) {
151            return limiter.clone();
152        }
153
154        let limiter = Arc::new(Governor::direct(quota));
155        limiters.insert(tenant_id.to_string(), limiter.clone());
156        limiter
157    }
158
159    /// Cleanup stale limiters (call periodically)
160    pub async fn cleanup(&self) {
161        let limiters = self.limiters.write().await;
162        // In a production system, you'd track last activity and remove inactive ones
163        // For now, just log the count
164        tracing::debug!(limiter_count = limiters.len(), "Tenant limiter cleanup");
165
166        // Could add logic here to remove limiters unused for > X minutes
167        // but governor's memory footprint is tiny, so not critical
168        let _ = limiters; // Suppress unused warning
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[tokio::test]
177    async fn test_rate_limiter_allows_within_quota() {
178        let limiter = TenantRateLimiter::new(RateLimitTier::Standard);
179
180        // Should allow requests within quota
181        for _ in 0..10 {
182            assert!(limiter.check("tenant1").await.is_ok());
183        }
184    }
185
186    #[tokio::test]
187    async fn test_rate_limiter_blocks_over_quota() {
188        let limiter = TenantRateLimiter::new(RateLimitTier::Free);
189
190        // Exhaust the quota (10 requests for Free tier)
191        for _ in 0..10 {
192            let _ = limiter.check("tenant1").await;
193        }
194
195        // Next request should be rate limited
196        let result = limiter.check("tenant1").await;
197        assert!(result.is_err());
198    }
199
200    #[tokio::test]
201    async fn test_different_tenants_independent() {
202        let limiter = TenantRateLimiter::new(RateLimitTier::Free);
203
204        // Exhaust tenant1's quota
205        for _ in 0..15 {
206            let _ = limiter.check("tenant1").await;
207        }
208
209        // tenant2 should still work
210        assert!(limiter.check("tenant2").await.is_ok());
211    }
212
213    #[tokio::test]
214    async fn test_unlimited_tier() {
215        let limiter = TenantRateLimiter::new(RateLimitTier::Free);
216        limiter.set_tier("vip", RateLimitTier::Unlimited).await;
217
218        // Should never be rate limited
219        for _ in 0..1000 {
220            assert!(limiter.check("vip").await.is_ok());
221        }
222    }
223}