vex_llm/
resilient_provider.rs1use async_trait::async_trait;
12use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15use tokio::sync::RwLock;
16
17use crate::{LlmError, LlmProvider, LlmRequest, LlmResponse};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum CircuitState {
22 Closed,
24 Open,
26 HalfOpen,
28}
29
30#[derive(Debug, Clone)]
32pub struct LlmCircuitConfig {
33 pub failure_threshold: u32,
35 pub success_threshold: u32,
37 pub reset_timeout: Duration,
39}
40
41impl Default for LlmCircuitConfig {
42 fn default() -> Self {
43 Self {
44 failure_threshold: 5,
45 success_threshold: 2,
46 reset_timeout: Duration::from_secs(30),
47 }
48 }
49}
50
51impl LlmCircuitConfig {
52 pub fn conservative() -> Self {
54 Self {
55 failure_threshold: 3,
56 success_threshold: 3,
57 reset_timeout: Duration::from_secs(60),
58 }
59 }
60}
61
62#[derive(Debug)]
64struct CircuitBreakerState {
65 state: CircuitState,
66 failure_count: u32,
67 success_count: u32,
68 last_failure: Option<Instant>,
69}
70
71#[derive(Debug)]
73pub struct ResilientProvider<P: LlmProvider> {
74 inner: Arc<P>,
75 config: LlmCircuitConfig,
76 cb_state: RwLock<CircuitBreakerState>,
77 total_requests: AtomicU64,
78 total_failures: AtomicU64,
79 circuit_opens: AtomicU32,
80}
81
82impl<P: LlmProvider> ResilientProvider<P> {
83 pub fn new(provider: P, config: LlmCircuitConfig) -> Self {
85 Self {
86 inner: Arc::new(provider),
87 config,
88 cb_state: RwLock::new(CircuitBreakerState {
89 state: CircuitState::Closed,
90 failure_count: 0,
91 success_count: 0,
92 last_failure: None,
93 }),
94 total_requests: AtomicU64::new(0),
95 total_failures: AtomicU64::new(0),
96 circuit_opens: AtomicU32::new(0),
97 }
98 }
99
100 pub fn wrap(provider: P) -> Self {
102 Self::new(provider, LlmCircuitConfig::conservative())
103 }
104
105 pub async fn circuit_state(&self) -> CircuitState {
107 self.cb_state.read().await.state
108 }
109
110 pub fn stats(&self) -> (u64, u64, u32) {
112 (
113 self.total_requests.load(Ordering::Relaxed),
114 self.total_failures.load(Ordering::Relaxed),
115 self.circuit_opens.load(Ordering::Relaxed),
116 )
117 }
118
119 async fn record_success(&self) {
120 let mut state = self.cb_state.write().await;
121 state.failure_count = 0;
122
123 if state.state == CircuitState::HalfOpen {
124 state.success_count += 1;
125 if state.success_count >= self.config.success_threshold {
126 state.state = CircuitState::Closed;
127 state.success_count = 0;
128 tracing::info!(provider = %self.inner.name(), "Circuit closed - provider recovered");
129 }
130 }
131 }
132
133 async fn record_failure(&self) {
134 self.total_failures.fetch_add(1, Ordering::Relaxed);
135 let mut state = self.cb_state.write().await;
136 state.failure_count += 1;
137 state.last_failure = Some(Instant::now());
138
139 if state.state == CircuitState::HalfOpen {
140 state.state = CircuitState::Open;
142 self.circuit_opens.fetch_add(1, Ordering::Relaxed);
143 tracing::warn!(provider = %self.inner.name(), "Circuit re-opened - recovery test failed");
144 } else if state.failure_count >= self.config.failure_threshold {
145 state.state = CircuitState::Open;
146 self.circuit_opens.fetch_add(1, Ordering::Relaxed);
147 tracing::warn!(
148 provider = %self.inner.name(),
149 failures = state.failure_count,
150 "Circuit opened - failure threshold exceeded"
151 );
152 }
153 }
154
155 async fn check_circuit(&self) -> Result<(), LlmError> {
156 let mut state = self.cb_state.write().await;
157
158 match state.state {
159 CircuitState::Closed => Ok(()),
160 CircuitState::Open => {
161 if let Some(last_failure) = state.last_failure {
163 if last_failure.elapsed() >= self.config.reset_timeout {
164 state.state = CircuitState::HalfOpen;
165 state.success_count = 0;
166 tracing::info!(provider = %self.inner.name(), "Circuit half-open - testing recovery");
167 return Ok(());
168 }
169 }
170 Err(LlmError::NotAvailable)
171 }
172 CircuitState::HalfOpen => Ok(()),
173 }
174 }
175}
176
177#[async_trait]
178impl<P: LlmProvider + 'static> LlmProvider for ResilientProvider<P> {
179 fn name(&self) -> &str {
180 "resilient"
182 }
183
184 async fn is_available(&self) -> bool {
185 self.check_circuit().await.is_ok() && self.inner.is_available().await
186 }
187
188 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
189 self.total_requests.fetch_add(1, Ordering::Relaxed);
190
191 self.check_circuit().await?;
193
194 match self.inner.complete(request).await {
196 Ok(response) => {
197 self.record_success().await;
198 Ok(response)
199 }
200 Err(e) => {
201 match &e {
203 LlmError::ConnectionFailed(_)
204 | LlmError::NotAvailable
205 | LlmError::RateLimited => {
206 self.record_failure().await;
207 }
208 _ => {}
209 }
210 Err(e)
211 }
212 }
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use crate::MockProvider;
220
221 #[tokio::test]
222 async fn test_resilient_provider_passes_through() {
223 let mock = MockProvider::smart();
224 let resilient = ResilientProvider::wrap(mock);
225
226 let result = resilient.ask("test").await;
227 assert!(result.is_ok());
228 assert_eq!(resilient.circuit_state().await, CircuitState::Closed);
229 }
230}