vex_llm/
resilient_provider.rs

1//! Resilient LLM provider wrapper with circuit breaker pattern
2//!
3//! Provides fault tolerance for LLM providers by implementing the circuit breaker
4//! pattern, preventing cascading failures when external providers are unavailable.
5//!
6//! # 2025 Best Practices
7//! - Three states: Closed (normal), Open (failing fast), Half-Open (testing recovery)
8//! - Configurable thresholds and timeouts
9//! - Automatic recovery testing after cooldown period
10
11use async_trait::async_trait;
12use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15use tokio::sync::RwLock;
16
17use crate::{LlmError, LlmProvider, LlmRequest, LlmResponse};
18
19/// Circuit breaker state
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum CircuitState {
22    /// Normal operation - requests pass through
23    Closed,
24    /// Circuit tripped - requests fail immediately  
25    Open,
26    /// Testing recovery - limited requests allowed
27    HalfOpen,
28}
29
30/// Configuration for the LLM circuit breaker
31#[derive(Debug, Clone)]
32pub struct LlmCircuitConfig {
33    /// Number of failures before opening circuit
34    pub failure_threshold: u32,
35    /// Number of successes in half-open to close circuit
36    pub success_threshold: u32,
37    /// Time to wait before testing recovery
38    pub reset_timeout: Duration,
39}
40
41impl Default for LlmCircuitConfig {
42    fn default() -> Self {
43        Self {
44            failure_threshold: 5,
45            success_threshold: 2,
46            reset_timeout: Duration::from_secs(30),
47        }
48    }
49}
50
51impl LlmCircuitConfig {
52    /// Conservative settings for production LLM providers
53    pub fn conservative() -> Self {
54        Self {
55            failure_threshold: 3,
56            success_threshold: 3,
57            reset_timeout: Duration::from_secs(60),
58        }
59    }
60}
61
62/// Internal circuit breaker state
63#[derive(Debug)]
64struct CircuitBreakerState {
65    state: CircuitState,
66    failure_count: u32,
67    success_count: u32,
68    last_failure: Option<Instant>,
69}
70
71/// Resilient LLM provider that wraps any provider with circuit breaker resilience
72#[derive(Debug)]
73pub struct ResilientProvider<P: LlmProvider> {
74    inner: Arc<P>,
75    config: LlmCircuitConfig,
76    cb_state: RwLock<CircuitBreakerState>,
77    total_requests: AtomicU64,
78    total_failures: AtomicU64,
79    circuit_opens: AtomicU32,
80}
81
82impl<P: LlmProvider> ResilientProvider<P> {
83    /// Create a resilient wrapper around an LLM provider
84    pub fn new(provider: P, config: LlmCircuitConfig) -> Self {
85        Self {
86            inner: Arc::new(provider),
87            config,
88            cb_state: RwLock::new(CircuitBreakerState {
89                state: CircuitState::Closed,
90                failure_count: 0,
91                success_count: 0,
92                last_failure: None,
93            }),
94            total_requests: AtomicU64::new(0),
95            total_failures: AtomicU64::new(0),
96            circuit_opens: AtomicU32::new(0),
97        }
98    }
99
100    /// Create with default (conservative) config
101    pub fn wrap(provider: P) -> Self {
102        Self::new(provider, LlmCircuitConfig::conservative())
103    }
104
105    /// Get current circuit state
106    pub async fn circuit_state(&self) -> CircuitState {
107        self.cb_state.read().await.state
108    }
109
110    /// Get circuit statistics
111    pub fn stats(&self) -> (u64, u64, u32) {
112        (
113            self.total_requests.load(Ordering::Relaxed),
114            self.total_failures.load(Ordering::Relaxed),
115            self.circuit_opens.load(Ordering::Relaxed),
116        )
117    }
118
119    async fn record_success(&self) {
120        let mut state = self.cb_state.write().await;
121        state.failure_count = 0;
122
123        if state.state == CircuitState::HalfOpen {
124            state.success_count += 1;
125            if state.success_count >= self.config.success_threshold {
126                state.state = CircuitState::Closed;
127                state.success_count = 0;
128                tracing::info!(provider = %self.inner.name(), "Circuit closed - provider recovered");
129            }
130        }
131    }
132
133    async fn record_failure(&self) {
134        self.total_failures.fetch_add(1, Ordering::Relaxed);
135        let mut state = self.cb_state.write().await;
136        state.failure_count += 1;
137        state.last_failure = Some(Instant::now());
138
139        if state.state == CircuitState::HalfOpen {
140            // Any failure in half-open goes back to open
141            state.state = CircuitState::Open;
142            self.circuit_opens.fetch_add(1, Ordering::Relaxed);
143            tracing::warn!(provider = %self.inner.name(), "Circuit re-opened - recovery test failed");
144        } else if state.failure_count >= self.config.failure_threshold {
145            state.state = CircuitState::Open;
146            self.circuit_opens.fetch_add(1, Ordering::Relaxed);
147            tracing::warn!(
148                provider = %self.inner.name(),
149                failures = state.failure_count,
150                "Circuit opened - failure threshold exceeded"
151            );
152        }
153    }
154
155    async fn check_circuit(&self) -> Result<(), LlmError> {
156        let mut state = self.cb_state.write().await;
157
158        match state.state {
159            CircuitState::Closed => Ok(()),
160            CircuitState::Open => {
161                // Check if reset timeout has passed
162                if let Some(last_failure) = state.last_failure {
163                    if last_failure.elapsed() >= self.config.reset_timeout {
164                        state.state = CircuitState::HalfOpen;
165                        state.success_count = 0;
166                        tracing::info!(provider = %self.inner.name(), "Circuit half-open - testing recovery");
167                        return Ok(());
168                    }
169                }
170                Err(LlmError::NotAvailable)
171            }
172            CircuitState::HalfOpen => Ok(()),
173        }
174    }
175}
176
177#[async_trait]
178impl<P: LlmProvider + 'static> LlmProvider for ResilientProvider<P> {
179    fn name(&self) -> &str {
180        // Return a static descriptor since we can't easily compose names
181        "resilient"
182    }
183
184    async fn is_available(&self) -> bool {
185        self.check_circuit().await.is_ok() && self.inner.is_available().await
186    }
187
188    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
189        self.total_requests.fetch_add(1, Ordering::Relaxed);
190
191        // Check circuit state
192        self.check_circuit().await?;
193
194        // Execute request
195        match self.inner.complete(request).await {
196            Ok(response) => {
197                self.record_success().await;
198                Ok(response)
199            }
200            Err(e) => {
201                // Only count as failure for connection/availability issues, not validation
202                match &e {
203                    LlmError::ConnectionFailed(_)
204                    | LlmError::NotAvailable
205                    | LlmError::RateLimited => {
206                        self.record_failure().await;
207                    }
208                    _ => {}
209                }
210                Err(e)
211            }
212        }
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use crate::MockProvider;
220
221    #[tokio::test]
222    async fn test_resilient_provider_passes_through() {
223        let mock = MockProvider::smart();
224        let resilient = ResilientProvider::wrap(mock);
225
226        let result = resilient.ask("test").await;
227        assert!(result.is_ok());
228        assert_eq!(resilient.circuit_state().await, CircuitState::Closed);
229    }
230}