1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use tokio::sync::RwLock;
7
8#[derive(Debug, Clone)]
10pub struct RateLimitConfig {
11 pub max_requests: u32,
13 pub window: Duration,
15 pub max_tokens_per_minute: Option<u32>,
17}
18
19impl Default for RateLimitConfig {
20 fn default() -> Self {
21 Self {
22 max_requests: 60, window: Duration::from_secs(60),
24 max_tokens_per_minute: Some(100_000),
25 }
26 }
27}
28
29impl RateLimitConfig {
30 pub fn conservative() -> Self {
32 Self {
33 max_requests: 10,
34 window: Duration::from_secs(60),
35 max_tokens_per_minute: Some(10_000),
36 }
37 }
38
39 pub fn aggressive() -> Self {
41 Self {
42 max_requests: 500,
43 window: Duration::from_secs(60),
44 max_tokens_per_minute: Some(1_000_000),
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
51struct RequestWindow {
52 count: u32,
53 tokens: u32,
54 window_start: Instant,
55}
56
57impl Default for RequestWindow {
58 fn default() -> Self {
59 Self {
60 count: 0,
61 tokens: 0,
62 window_start: Instant::now(),
63 }
64 }
65}
66
67#[derive(Debug)]
69pub struct RateLimiter {
70 config: RateLimitConfig,
71 windows: RwLock<HashMap<String, RequestWindow>>,
73}
74
75impl RateLimiter {
76 pub fn new(config: RateLimitConfig) -> Self {
78 Self {
79 config,
80 windows: RwLock::new(HashMap::new()),
81 }
82 }
83
84 pub async fn check(&self, provider: &str) -> RateLimitResult {
86 let windows = self.windows.read().await;
87 let window = windows.get(provider).cloned().unwrap_or_default();
88
89 self.evaluate(&window)
90 }
91
92 pub async fn acquire(&self, provider: &str) -> Result<(), RateLimitError> {
94 loop {
95 let result = self.try_acquire(provider).await;
96 match result {
97 Ok(()) => return Ok(()),
98 Err(RateLimitError::Limited { retry_after }) => {
99 tokio::time::sleep(retry_after).await;
100 }
101 }
102 }
103 }
104
105 pub async fn try_acquire(&self, provider: &str) -> Result<(), RateLimitError> {
107 self.try_acquire_with_tokens(provider, 0).await
108 }
109
110 pub async fn try_acquire_with_tokens(
113 &self,
114 provider: &str,
115 estimated_tokens: u32,
116 ) -> Result<(), RateLimitError> {
117 let mut windows = self.windows.write().await;
118 let window = windows.entry(provider.to_string()).or_default();
119
120 let elapsed = window.window_start.elapsed();
122 if elapsed >= self.config.window {
123 *window = RequestWindow::default();
125 }
126
127 if window.count >= self.config.max_requests {
129 let retry_after = self.config.window - elapsed;
130 return Err(RateLimitError::Limited { retry_after });
131 }
132
133 if let Some(max_tokens) = self.config.max_tokens_per_minute {
135 if estimated_tokens > 0 && window.tokens + estimated_tokens > max_tokens {
136 let retry_after = self.config.window - elapsed;
137 return Err(RateLimitError::Limited { retry_after });
138 }
139 }
140
141 window.count += 1;
143 window.tokens += estimated_tokens;
144 Ok(())
145 }
146
147 pub async fn record_tokens(&self, provider: &str, additional_tokens: u32) {
150 let mut windows = self.windows.write().await;
151 if let Some(window) = windows.get_mut(provider) {
152 window.tokens += additional_tokens;
153 }
154 }
155
156 pub async fn stats(&self, provider: &str) -> RateLimitStats {
158 let windows = self.windows.read().await;
159 let window = windows.get(provider).cloned().unwrap_or_default();
160
161 RateLimitStats {
162 requests_used: window.count,
163 requests_limit: self.config.max_requests,
164 tokens_used: window.tokens,
165 tokens_limit: self.config.max_tokens_per_minute,
166 window_remaining: self
167 .config
168 .window
169 .saturating_sub(window.window_start.elapsed()),
170 }
171 }
172
173 fn evaluate(&self, window: &RequestWindow) -> RateLimitResult {
174 let elapsed = window.window_start.elapsed();
175 if elapsed >= self.config.window {
176 return RateLimitResult::Allowed;
177 }
178
179 if window.count >= self.config.max_requests {
180 let retry_after = self.config.window - elapsed;
181 return RateLimitResult::Limited { retry_after };
182 }
183
184 if let Some(max_tokens) = self.config.max_tokens_per_minute {
185 if window.tokens >= max_tokens {
186 let retry_after = self.config.window - elapsed;
187 return RateLimitResult::Limited { retry_after };
188 }
189 }
190
191 RateLimitResult::Allowed
192 }
193}
194
195#[derive(Debug, Clone)]
197pub enum RateLimitResult {
198 Allowed,
199 Limited { retry_after: Duration },
200}
201
202#[derive(Debug, thiserror::Error)]
204pub enum RateLimitError {
205 #[error("Rate limited, retry after {retry_after:?}")]
206 Limited { retry_after: Duration },
207}
208
209#[derive(Debug, Clone)]
211pub struct RateLimitStats {
212 pub requests_used: u32,
213 pub requests_limit: u32,
214 pub tokens_used: u32,
215 pub tokens_limit: Option<u32>,
216 pub window_remaining: Duration,
217}
218
219pub struct RateLimitedProvider<P> {
221 inner: P,
222 limiter: Arc<RateLimiter>,
223 provider_name: String,
224}
225
226impl<P> RateLimitedProvider<P> {
227 pub fn new(inner: P, limiter: Arc<RateLimiter>, provider_name: &str) -> Self {
228 Self {
229 inner,
230 limiter,
231 provider_name: provider_name.to_string(),
232 }
233 }
234
235 pub fn inner(&self) -> &P {
236 &self.inner
237 }
238
239 pub async fn acquire(&self) -> Result<(), RateLimitError> {
240 self.limiter.acquire(&self.provider_name).await
241 }
242
243 pub async fn stats(&self) -> RateLimitStats {
244 self.limiter.stats(&self.provider_name).await
245 }
246}
247
248#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
250pub enum UserTier {
251 #[default]
253 Free,
254 Pro,
256 Enterprise,
258}
259
260impl UserTier {
261 pub fn rate_limit_config(&self) -> RateLimitConfig {
263 match self {
264 UserTier::Free => RateLimitConfig {
265 max_requests: 10,
266 window: Duration::from_secs(60),
267 max_tokens_per_minute: Some(5_000),
268 },
269 UserTier::Pro => RateLimitConfig {
270 max_requests: 100,
271 window: Duration::from_secs(60),
272 max_tokens_per_minute: Some(50_000),
273 },
274 UserTier::Enterprise => RateLimitConfig {
275 max_requests: 1000,
276 window: Duration::from_secs(60),
277 max_tokens_per_minute: Some(500_000),
278 },
279 }
280 }
281}
282
283#[derive(Debug)]
285pub struct UserRateLimiter {
286 user_windows: RwLock<HashMap<String, UserRateLimitState>>,
288 default_tier: UserTier,
290 tier_overrides: RwLock<HashMap<String, UserTier>>,
292}
293
294#[derive(Debug, Clone)]
295struct UserRateLimitState {
296 tier: UserTier,
297 window: RequestWindow,
298}
299
300impl Default for UserRateLimitState {
301 fn default() -> Self {
302 Self {
303 tier: UserTier::Free,
304 window: RequestWindow::default(),
305 }
306 }
307}
308
309impl UserRateLimiter {
310 pub fn new(default_tier: UserTier) -> Self {
312 Self {
313 user_windows: RwLock::new(HashMap::new()),
314 default_tier,
315 tier_overrides: RwLock::new(HashMap::new()),
316 }
317 }
318
319 pub async fn set_user_tier(&self, user_id: &str, tier: UserTier) {
321 let mut overrides = self.tier_overrides.write().await;
322 overrides.insert(user_id.to_string(), tier);
323 }
324
325 pub async fn get_user_tier(&self, user_id: &str) -> UserTier {
327 let overrides = self.tier_overrides.read().await;
328 overrides.get(user_id).copied().unwrap_or(self.default_tier)
329 }
330
331 pub async fn try_acquire(&self, user_id: &str) -> Result<(), RateLimitError> {
333 self.try_acquire_with_tokens(user_id, 0).await
334 }
335
336 pub async fn try_acquire_with_tokens(
338 &self,
339 user_id: &str,
340 estimated_tokens: u32,
341 ) -> Result<(), RateLimitError> {
342 let tier = self.get_user_tier(user_id).await;
343 let config = tier.rate_limit_config();
344
345 let mut windows = self.user_windows.write().await;
346 let state = windows
347 .entry(user_id.to_string())
348 .or_insert_with(|| UserRateLimitState {
349 tier,
350 window: RequestWindow::default(),
351 });
352
353 state.tier = tier;
355
356 let elapsed = state.window.window_start.elapsed();
358 if elapsed >= config.window {
359 state.window = RequestWindow::default();
360 }
361
362 if state.window.count >= config.max_requests {
364 let retry_after = config.window - elapsed;
365 return Err(RateLimitError::Limited { retry_after });
366 }
367
368 if let Some(max_tokens) = config.max_tokens_per_minute {
370 if estimated_tokens > 0 && state.window.tokens + estimated_tokens > max_tokens {
371 let retry_after = config.window - elapsed;
372 return Err(RateLimitError::Limited { retry_after });
373 }
374 }
375
376 state.window.count += 1;
378 state.window.tokens += estimated_tokens;
379 Ok(())
380 }
381
382 pub async fn user_stats(&self, user_id: &str) -> UserRateLimitStats {
384 let tier = self.get_user_tier(user_id).await;
385 let config = tier.rate_limit_config();
386
387 let windows = self.user_windows.read().await;
388 let state = windows.get(user_id).cloned().unwrap_or_default();
389
390 let elapsed = state.window.window_start.elapsed();
391 let window_remaining = if elapsed >= config.window {
392 config.window
393 } else {
394 config.window - elapsed
395 };
396
397 UserRateLimitStats {
398 user_id: user_id.to_string(),
399 tier,
400 requests_used: state.window.count,
401 requests_limit: config.max_requests,
402 tokens_used: state.window.tokens,
403 tokens_limit: config.max_tokens_per_minute,
404 window_remaining,
405 }
406 }
407}
408
409#[derive(Debug, Clone)]
411pub struct UserRateLimitStats {
412 pub user_id: String,
413 pub tier: UserTier,
414 pub requests_used: u32,
415 pub requests_limit: u32,
416 pub tokens_used: u32,
417 pub tokens_limit: Option<u32>,
418 pub window_remaining: Duration,
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424
425 #[tokio::test]
426 async fn test_rate_limiter() {
427 let config = RateLimitConfig {
428 max_requests: 3,
429 window: Duration::from_millis(100),
430 max_tokens_per_minute: None,
431 };
432 let limiter = RateLimiter::new(config);
433
434 assert!(limiter.try_acquire("test").await.is_ok());
436 assert!(limiter.try_acquire("test").await.is_ok());
437 assert!(limiter.try_acquire("test").await.is_ok());
438
439 assert!(matches!(
441 limiter.try_acquire("test").await,
442 Err(RateLimitError::Limited { .. })
443 ));
444
445 tokio::time::sleep(Duration::from_millis(150)).await;
447
448 assert!(limiter.try_acquire("test").await.is_ok());
450 }
451
452 #[tokio::test]
453 async fn test_stats() {
454 let limiter = RateLimiter::new(RateLimitConfig::default());
455
456 limiter.try_acquire("provider1").await.unwrap();
457 limiter.try_acquire("provider1").await.unwrap();
458
459 let stats = limiter.stats("provider1").await;
460 assert_eq!(stats.requests_used, 2);
461 }
462
463 #[tokio::test]
464 async fn test_user_rate_limiter_tiers() {
465 let limiter = UserRateLimiter::new(UserTier::Free);
466
467 assert_eq!(limiter.get_user_tier("user1").await, UserTier::Free);
469
470 limiter.set_user_tier("user2", UserTier::Pro).await;
472 assert_eq!(limiter.get_user_tier("user2").await, UserTier::Pro);
473
474 for _ in 0..10 {
476 assert!(limiter.try_acquire("free_user").await.is_ok());
477 }
478 assert!(matches!(
479 limiter.try_acquire("free_user").await,
480 Err(RateLimitError::Limited { .. })
481 ));
482
483 limiter.set_user_tier("pro_user", UserTier::Pro).await;
485 for _ in 0..50 {
486 assert!(limiter.try_acquire("pro_user").await.is_ok());
487 }
488 let stats = limiter.user_stats("pro_user").await;
489 assert_eq!(stats.tier, UserTier::Pro);
490 assert_eq!(stats.requests_used, 50);
491 assert_eq!(stats.requests_limit, 100);
492 }
493}