vex_llm/
rate_limit.rs

1//! Rate limiting for LLM API calls
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use tokio::sync::RwLock;
7
8/// Rate limiter configuration
9#[derive(Debug, Clone)]
10pub struct RateLimitConfig {
11    /// Maximum requests per window
12    pub max_requests: u32,
13    /// Time window
14    pub window: Duration,
15    /// Maximum tokens per minute (if applicable)
16    pub max_tokens_per_minute: Option<u32>,
17}
18
19impl Default for RateLimitConfig {
20    fn default() -> Self {
21        Self {
22            max_requests: 60, // 60 requests per minute
23            window: Duration::from_secs(60),
24            max_tokens_per_minute: Some(100_000),
25        }
26    }
27}
28
29impl RateLimitConfig {
30    /// Conservative rate limit for free tiers
31    pub fn conservative() -> Self {
32        Self {
33            max_requests: 10,
34            window: Duration::from_secs(60),
35            max_tokens_per_minute: Some(10_000),
36        }
37    }
38
39    /// Aggressive rate limit for paid tiers
40    pub fn aggressive() -> Self {
41        Self {
42            max_requests: 500,
43            window: Duration::from_secs(60),
44            max_tokens_per_minute: Some(1_000_000),
45        }
46    }
47}
48
49/// Request tracking
50#[derive(Debug, Clone)]
51struct RequestWindow {
52    count: u32,
53    tokens: u32,
54    window_start: Instant,
55}
56
57impl Default for RequestWindow {
58    fn default() -> Self {
59        Self {
60            count: 0,
61            tokens: 0,
62            window_start: Instant::now(),
63        }
64    }
65}
66
67/// Rate limiter for API calls
68#[derive(Debug)]
69pub struct RateLimiter {
70    config: RateLimitConfig,
71    /// Per-provider rate tracking
72    windows: RwLock<HashMap<String, RequestWindow>>,
73}
74
75impl RateLimiter {
76    /// Create a new rate limiter
77    pub fn new(config: RateLimitConfig) -> Self {
78        Self {
79            config,
80            windows: RwLock::new(HashMap::new()),
81        }
82    }
83
84    /// Check if a request is allowed (doesn't consume)
85    pub async fn check(&self, provider: &str) -> RateLimitResult {
86        let windows = self.windows.read().await;
87        let window = windows.get(provider).cloned().unwrap_or_default();
88
89        self.evaluate(&window)
90    }
91
92    /// Acquire a permit (blocks if rate limited)
93    pub async fn acquire(&self, provider: &str) -> Result<(), RateLimitError> {
94        loop {
95            let result = self.try_acquire(provider).await;
96            match result {
97                Ok(()) => return Ok(()),
98                Err(RateLimitError::Limited { retry_after }) => {
99                    tokio::time::sleep(retry_after).await;
100                }
101            }
102        }
103    }
104
105    /// Try to acquire a permit (non-blocking)
106    pub async fn try_acquire(&self, provider: &str) -> Result<(), RateLimitError> {
107        self.try_acquire_with_tokens(provider, 0).await
108    }
109
110    /// Try to acquire a permit with estimated token usage (atomic check)
111    /// This prevents race conditions between request counting and token tracking
112    pub async fn try_acquire_with_tokens(
113        &self,
114        provider: &str,
115        estimated_tokens: u32,
116    ) -> Result<(), RateLimitError> {
117        let mut windows = self.windows.write().await;
118        let window = windows.entry(provider.to_string()).or_default();
119
120        // Check if window expired
121        let elapsed = window.window_start.elapsed();
122        if elapsed >= self.config.window {
123            // Reset window
124            *window = RequestWindow::default();
125        }
126
127        // Check request limit
128        if window.count >= self.config.max_requests {
129            let retry_after = self.config.window - elapsed;
130            return Err(RateLimitError::Limited { retry_after });
131        }
132
133        // Check token limit (if configured and tokens provided)
134        if let Some(max_tokens) = self.config.max_tokens_per_minute {
135            if estimated_tokens > 0 && window.tokens + estimated_tokens > max_tokens {
136                let retry_after = self.config.window - elapsed;
137                return Err(RateLimitError::Limited { retry_after });
138            }
139        }
140
141        // Acquire atomically - update both counters under same lock
142        window.count += 1;
143        window.tokens += estimated_tokens;
144        Ok(())
145    }
146
147    /// Record additional token usage after completion
148    /// Use this to adjust for actual tokens used vs estimated
149    pub async fn record_tokens(&self, provider: &str, additional_tokens: u32) {
150        let mut windows = self.windows.write().await;
151        if let Some(window) = windows.get_mut(provider) {
152            window.tokens += additional_tokens;
153        }
154    }
155
156    /// Get current usage stats
157    pub async fn stats(&self, provider: &str) -> RateLimitStats {
158        let windows = self.windows.read().await;
159        let window = windows.get(provider).cloned().unwrap_or_default();
160
161        RateLimitStats {
162            requests_used: window.count,
163            requests_limit: self.config.max_requests,
164            tokens_used: window.tokens,
165            tokens_limit: self.config.max_tokens_per_minute,
166            window_remaining: self
167                .config
168                .window
169                .saturating_sub(window.window_start.elapsed()),
170        }
171    }
172
173    fn evaluate(&self, window: &RequestWindow) -> RateLimitResult {
174        let elapsed = window.window_start.elapsed();
175        if elapsed >= self.config.window {
176            return RateLimitResult::Allowed;
177        }
178
179        if window.count >= self.config.max_requests {
180            let retry_after = self.config.window - elapsed;
181            return RateLimitResult::Limited { retry_after };
182        }
183
184        if let Some(max_tokens) = self.config.max_tokens_per_minute {
185            if window.tokens >= max_tokens {
186                let retry_after = self.config.window - elapsed;
187                return RateLimitResult::Limited { retry_after };
188            }
189        }
190
191        RateLimitResult::Allowed
192    }
193}
194
195/// Rate limit check result
196#[derive(Debug, Clone)]
197pub enum RateLimitResult {
198    Allowed,
199    Limited { retry_after: Duration },
200}
201
202/// Rate limit error
203#[derive(Debug, thiserror::Error)]
204pub enum RateLimitError {
205    #[error("Rate limited, retry after {retry_after:?}")]
206    Limited { retry_after: Duration },
207}
208
209/// Rate limit statistics
210#[derive(Debug, Clone)]
211pub struct RateLimitStats {
212    pub requests_used: u32,
213    pub requests_limit: u32,
214    pub tokens_used: u32,
215    pub tokens_limit: Option<u32>,
216    pub window_remaining: Duration,
217}
218
219/// Rate-limited LLM provider wrapper
220pub struct RateLimitedProvider<P> {
221    inner: P,
222    limiter: Arc<RateLimiter>,
223    provider_name: String,
224}
225
226impl<P> RateLimitedProvider<P> {
227    pub fn new(inner: P, limiter: Arc<RateLimiter>, provider_name: &str) -> Self {
228        Self {
229            inner,
230            limiter,
231            provider_name: provider_name.to_string(),
232        }
233    }
234
235    pub fn inner(&self) -> &P {
236        &self.inner
237    }
238
239    pub async fn acquire(&self) -> Result<(), RateLimitError> {
240        self.limiter.acquire(&self.provider_name).await
241    }
242
243    pub async fn stats(&self) -> RateLimitStats {
244        self.limiter.stats(&self.provider_name).await
245    }
246}
247
248/// User tier for rate limiting
249#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
250pub enum UserTier {
251    /// Free tier - limited access
252    #[default]
253    Free,
254    /// Pro tier - increased limits
255    Pro,
256    /// Enterprise tier - highest limits
257    Enterprise,
258}
259
260impl UserTier {
261    /// Get rate limit config for this tier
262    pub fn rate_limit_config(&self) -> RateLimitConfig {
263        match self {
264            UserTier::Free => RateLimitConfig {
265                max_requests: 10,
266                window: Duration::from_secs(60),
267                max_tokens_per_minute: Some(5_000),
268            },
269            UserTier::Pro => RateLimitConfig {
270                max_requests: 100,
271                window: Duration::from_secs(60),
272                max_tokens_per_minute: Some(50_000),
273            },
274            UserTier::Enterprise => RateLimitConfig {
275                max_requests: 1000,
276                window: Duration::from_secs(60),
277                max_tokens_per_minute: Some(500_000),
278            },
279        }
280    }
281}
282
283/// Per-user rate limiter with tier-based quotas
284#[derive(Debug)]
285pub struct UserRateLimiter {
286    /// Per-user windows (keyed by user_id)
287    user_windows: RwLock<HashMap<String, UserRateLimitState>>,
288    /// Default tier for unknown users
289    default_tier: UserTier,
290    /// User tier overrides
291    tier_overrides: RwLock<HashMap<String, UserTier>>,
292}
293
294#[derive(Debug, Clone)]
295struct UserRateLimitState {
296    tier: UserTier,
297    window: RequestWindow,
298}
299
300impl Default for UserRateLimitState {
301    fn default() -> Self {
302        Self {
303            tier: UserTier::Free,
304            window: RequestWindow::default(),
305        }
306    }
307}
308
309impl UserRateLimiter {
310    /// Create a new per-user rate limiter
311    pub fn new(default_tier: UserTier) -> Self {
312        Self {
313            user_windows: RwLock::new(HashMap::new()),
314            default_tier,
315            tier_overrides: RwLock::new(HashMap::new()),
316        }
317    }
318
319    /// Set a user's tier
320    pub async fn set_user_tier(&self, user_id: &str, tier: UserTier) {
321        let mut overrides = self.tier_overrides.write().await;
322        overrides.insert(user_id.to_string(), tier);
323    }
324
325    /// Get a user's current tier
326    pub async fn get_user_tier(&self, user_id: &str) -> UserTier {
327        let overrides = self.tier_overrides.read().await;
328        overrides.get(user_id).copied().unwrap_or(self.default_tier)
329    }
330
331    /// Try to acquire a permit for a user
332    pub async fn try_acquire(&self, user_id: &str) -> Result<(), RateLimitError> {
333        self.try_acquire_with_tokens(user_id, 0).await
334    }
335
336    /// Try to acquire a permit with estimated tokens for a user
337    pub async fn try_acquire_with_tokens(
338        &self,
339        user_id: &str,
340        estimated_tokens: u32,
341    ) -> Result<(), RateLimitError> {
342        let tier = self.get_user_tier(user_id).await;
343        let config = tier.rate_limit_config();
344
345        let mut windows = self.user_windows.write().await;
346        let state = windows
347            .entry(user_id.to_string())
348            .or_insert_with(|| UserRateLimitState {
349                tier,
350                window: RequestWindow::default(),
351            });
352
353        // Update tier if it changed
354        state.tier = tier;
355
356        // Check if window expired
357        let elapsed = state.window.window_start.elapsed();
358        if elapsed >= config.window {
359            state.window = RequestWindow::default();
360        }
361
362        // Check request limit
363        if state.window.count >= config.max_requests {
364            let retry_after = config.window - elapsed;
365            return Err(RateLimitError::Limited { retry_after });
366        }
367
368        // Check token limit
369        if let Some(max_tokens) = config.max_tokens_per_minute {
370            if estimated_tokens > 0 && state.window.tokens + estimated_tokens > max_tokens {
371                let retry_after = config.window - elapsed;
372                return Err(RateLimitError::Limited { retry_after });
373            }
374        }
375
376        // Acquire
377        state.window.count += 1;
378        state.window.tokens += estimated_tokens;
379        Ok(())
380    }
381
382    /// Get usage stats for a user
383    pub async fn user_stats(&self, user_id: &str) -> UserRateLimitStats {
384        let tier = self.get_user_tier(user_id).await;
385        let config = tier.rate_limit_config();
386
387        let windows = self.user_windows.read().await;
388        let state = windows.get(user_id).cloned().unwrap_or_default();
389
390        let elapsed = state.window.window_start.elapsed();
391        let window_remaining = if elapsed >= config.window {
392            config.window
393        } else {
394            config.window - elapsed
395        };
396
397        UserRateLimitStats {
398            user_id: user_id.to_string(),
399            tier,
400            requests_used: state.window.count,
401            requests_limit: config.max_requests,
402            tokens_used: state.window.tokens,
403            tokens_limit: config.max_tokens_per_minute,
404            window_remaining,
405        }
406    }
407}
408
409/// Per-user rate limit statistics
410#[derive(Debug, Clone)]
411pub struct UserRateLimitStats {
412    pub user_id: String,
413    pub tier: UserTier,
414    pub requests_used: u32,
415    pub requests_limit: u32,
416    pub tokens_used: u32,
417    pub tokens_limit: Option<u32>,
418    pub window_remaining: Duration,
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424
425    #[tokio::test]
426    async fn test_rate_limiter() {
427        let config = RateLimitConfig {
428            max_requests: 3,
429            window: Duration::from_millis(100),
430            max_tokens_per_minute: None,
431        };
432        let limiter = RateLimiter::new(config);
433
434        // Should allow first 3 requests
435        assert!(limiter.try_acquire("test").await.is_ok());
436        assert!(limiter.try_acquire("test").await.is_ok());
437        assert!(limiter.try_acquire("test").await.is_ok());
438
439        // 4th should be limited
440        assert!(matches!(
441            limiter.try_acquire("test").await,
442            Err(RateLimitError::Limited { .. })
443        ));
444
445        // Wait for window to reset
446        tokio::time::sleep(Duration::from_millis(150)).await;
447
448        // Should allow again
449        assert!(limiter.try_acquire("test").await.is_ok());
450    }
451
452    #[tokio::test]
453    async fn test_stats() {
454        let limiter = RateLimiter::new(RateLimitConfig::default());
455
456        limiter.try_acquire("provider1").await.unwrap();
457        limiter.try_acquire("provider1").await.unwrap();
458
459        let stats = limiter.stats("provider1").await;
460        assert_eq!(stats.requests_used, 2);
461    }
462
463    #[tokio::test]
464    async fn test_user_rate_limiter_tiers() {
465        let limiter = UserRateLimiter::new(UserTier::Free);
466
467        // Default tier should be Free (10 requests/min)
468        assert_eq!(limiter.get_user_tier("user1").await, UserTier::Free);
469
470        // Set user to Pro tier
471        limiter.set_user_tier("user2", UserTier::Pro).await;
472        assert_eq!(limiter.get_user_tier("user2").await, UserTier::Pro);
473
474        // Free user should be limited after 10 requests
475        for _ in 0..10 {
476            assert!(limiter.try_acquire("free_user").await.is_ok());
477        }
478        assert!(matches!(
479            limiter.try_acquire("free_user").await,
480            Err(RateLimitError::Limited { .. })
481        ));
482
483        // Pro user should have 100 request limit
484        limiter.set_user_tier("pro_user", UserTier::Pro).await;
485        for _ in 0..50 {
486            assert!(limiter.try_acquire("pro_user").await.is_ok());
487        }
488        let stats = limiter.user_stats("pro_user").await;
489        assert_eq!(stats.tier, UserTier::Pro);
490        assert_eq!(stats.requests_used, 50);
491        assert_eq!(stats.requests_limit, 100);
492    }
493}