vex_api/
server.rs

1//! VEX API Server with graceful shutdown
2
3use axum::{middleware, Router};
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::signal;
8use tower::Service;
9use tower_http::compression::CompressionLayer;
10
11use crate::auth::JwtAuth;
12use crate::error::ApiError;
13use crate::middleware::{
14    auth_middleware, body_limit_layer, cors_layer, rate_limit_middleware, request_id_middleware,
15    timeout_layer, tracing_middleware,
16};
17use crate::routes::api_router;
18use vex_llm::{Metrics, RateLimitConfig};
19// use vex_persist::StorageBackend; // Not dealing with trait directly here
20// use vex_queue::WorkerPool;
21
22/// TLS configuration for HTTPS
23#[derive(Debug, Clone)]
24pub struct TlsConfig {
25    /// Path to certificate file (PEM format)
26    pub cert_path: String,
27    /// Path to private key file (PEM format)
28    pub key_path: String,
29}
30
31impl TlsConfig {
32    /// Create TLS config from paths
33    pub fn new(cert_path: &str, key_path: &str) -> Self {
34        Self {
35            cert_path: cert_path.to_string(),
36            key_path: key_path.to_string(),
37        }
38    }
39
40    /// Create from environment variables VEX_TLS_CERT and VEX_TLS_KEY
41    pub fn from_env() -> Option<Self> {
42        let cert = std::env::var("VEX_TLS_CERT").ok()?;
43        let key = std::env::var("VEX_TLS_KEY").ok()?;
44        Some(Self::new(&cert, &key))
45    }
46}
47
48/// Server configuration
49#[derive(Debug, Clone)]
50pub struct ServerConfig {
51    /// Server address
52    pub addr: SocketAddr,
53    /// Request timeout
54    pub timeout: Duration,
55    /// Max request body size (bytes)
56    pub max_body_size: usize,
57    /// Enable compression
58    pub compression: bool,
59    /// Rate limit config
60    pub rate_limit: RateLimitConfig,
61    /// Optional TLS configuration for HTTPS
62    pub tls: Option<TlsConfig>,
63    /// Whether to strictly enforce HTTPS (fail if not configured)
64    pub enforce_https: bool,
65}
66
67impl Default for ServerConfig {
68    fn default() -> Self {
69        Self {
70            addr: "0.0.0.0:8080".parse().unwrap(),
71            timeout: Duration::from_secs(30),
72            max_body_size: 1024 * 1024, // 1MB
73            compression: true,
74            rate_limit: RateLimitConfig::default(),
75            tls: None,
76            enforce_https: false,
77        }
78    }
79}
80
81impl ServerConfig {
82    /// Create from environment variables
83    pub fn from_env() -> Self {
84        let port: u16 = std::env::var("VEX_PORT")
85            .ok()
86            .and_then(|p| p.parse().ok())
87            .unwrap_or(8080);
88
89        let timeout_secs: u64 = std::env::var("VEX_TIMEOUT_SECS")
90            .ok()
91            .and_then(|t| t.parse().ok())
92            .unwrap_or(30);
93
94        let enforce_https = std::env::var("VEX_ENFORCE_HTTPS").is_ok()
95            || std::env::var("VEX_ENV")
96                .map(|e| e == "production")
97                .unwrap_or(false);
98
99        Self {
100            addr: SocketAddr::from(([0, 0, 0, 0], port)),
101            timeout: Duration::from_secs(timeout_secs),
102            enforce_https,
103            ..Default::default()
104        }
105    }
106}
107
108use crate::state::AppState;
109
110/// VEX API Server
111pub struct VexServer {
112    config: ServerConfig,
113    app_state: AppState,
114}
115
116impl VexServer {
117    /// Create a new server
118    pub async fn new(config: ServerConfig) -> Result<Self, ApiError> {
119        use crate::jobs::agent::{AgentExecutionJob, AgentJobPayload};
120        use crate::tenant_rate_limiter::{RateLimitTier, TenantRateLimiter};
121        use vex_llm::{
122            CachedProvider, DeepSeekProvider, LlmProvider, MockProvider, ResilientProvider,
123        };
124        use vex_queue::{QueueBackend, WorkerConfig, WorkerPool};
125
126        let jwt_auth = JwtAuth::from_env()?;
127        let rate_limiter = Arc::new(TenantRateLimiter::new(RateLimitTier::Standard));
128        let metrics = Arc::new(Metrics::new());
129
130        // Initialize Persistence (SQLite)
131        let db_url =
132            std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite::memory:".to_string());
133        let db = vex_persist::sqlite::SqliteBackend::new(&db_url)
134            .await
135            .map_err(|e| ApiError::Internal(format!("DB Init failed: {}", e)))?;
136
137        // Initialize Queue (Persistent SQLite)
138        let queue_backend = vex_persist::queue::SqliteQueueBackend::new(db.pool().clone());
139
140        // Use dynamic dispatch for the worker pool backend
141        let worker_pool = WorkerPool::new_with_arc(
142            Arc::new(queue_backend) as Arc<dyn QueueBackend>,
143            WorkerConfig::default(),
144        );
145
146        // Initialize Intelligence (LLM) with resilience and caching
147        let _base_llm: Arc<dyn LlmProvider> = if let Ok(key) = std::env::var("DEEPSEEK_API_KEY") {
148            tracing::info!("Initializing Resilient+Cached DeepSeek Provider");
149            let base = DeepSeekProvider::chat(&key);
150            // Wrap with resilience first, then caching
151            let resilient = ResilientProvider::new(base, vex_llm::LlmCircuitConfig::conservative());
152            let cached = CachedProvider::wrap(resilient);
153            Arc::new(cached)
154        } else {
155            tracing::warn!("DEEPSEEK_API_KEY not found. Using Mock Provider.");
156            Arc::new(MockProvider::smart())
157        };
158
159        // Initialize Router (Smart Routing Layer)
160        let router = vex_router::Router::builder()
161            .strategy(vex_router::RoutingStrategy::Auto)
162            .build();
163        let router_arc = Arc::new(router);
164        let llm: Arc<dyn LlmProvider> = router_arc.clone();
165
166        // Create shared result store for job results
167        let result_store = crate::jobs::new_result_store();
168
169        // Register Agent Job
170        let llm_clone = llm.clone();
171        let result_store_clone = result_store.clone();
172        worker_pool.register_job_factory("agent_execution", move |payload| {
173            let job_payload: AgentJobPayload =
174                serde_json::from_value(payload).unwrap_or_else(|_| AgentJobPayload {
175                    agent_id: "unknown".to_string(),
176                    prompt: "payload error".to_string(),
177                    context_id: None,
178                });
179            let job_id = uuid::Uuid::new_v4();
180            Box::new(AgentExecutionJob::new(
181                job_id,
182                job_payload,
183                llm_clone.clone(),
184                result_store_clone.clone(),
185            ))
186        });
187
188        let a2a_state = Arc::new(crate::a2a::handler::A2aState::default());
189
190        let app_state = AppState::new(
191            jwt_auth,
192            rate_limiter,
193            metrics,
194            Arc::new(db),
195            Arc::new(worker_pool),
196            a2a_state,
197            llm.clone(),
198            Some(router_arc),
199        );
200
201        Ok(Self { config, app_state })
202    }
203
204    /// Build the complete    /// Get the configured router
205    pub fn router(&self) -> Router {
206        let mut app = api_router(self.app_state.clone());
207
208        // Apply middleware layers (order matters - bottom to top execution)
209        app = app
210            // Compression (outermost - compresses response)
211            .layer(CompressionLayer::new())
212            // Body size limit
213            .layer(body_limit_layer(self.config.max_body_size))
214            // Timeout
215            .layer(timeout_layer(self.config.timeout))
216            // CORS
217            .layer(cors_layer())
218            // Request ID
219            .layer(middleware::from_fn(request_id_middleware))
220            // Tracing
221            .layer(middleware::from_fn_with_state(
222                self.app_state.clone(),
223                tracing_middleware,
224            ))
225            // Rate limiting
226            .layer(middleware::from_fn_with_state(
227                self.app_state.clone(),
228                rate_limit_middleware,
229            ))
230            // Authentication (innermost - runs first)
231            .layer(middleware::from_fn_with_state(
232                self.app_state.clone(),
233                auth_middleware,
234            ));
235
236        app
237    }
238
239    /// Run the server with graceful shutdown
240    ///
241    /// # HTTPS Support
242    /// When `config.tls` is set, the server starts with TLS using RustlsConfig.
243    /// Without TLS, the server requires `allow_insecure` to prevent accidental
244    /// plaintext deployment in production.
245    pub async fn run(self) -> Result<(), ApiError> {
246        let app = self.router();
247        let addr = self.config.addr;
248
249        // Start Worker Pool in background
250        let queue = self.app_state.queue();
251        tokio::spawn(async move {
252            queue.start().await;
253        });
254
255        // HTTPS with TLS
256        if let Some(tls_config) = &self.config.tls {
257            // HTTPS with TLS
258            tracing::info!("🔒 Starting VEX API server with HTTPS on {}", addr);
259
260            // Load TLS certificates
261            use rustls_pki_types::pem::PemObject;
262            use rustls_pki_types::{CertificateDer, PrivateKeyDer};
263            use std::io::Read;
264            use tokio_rustls::rustls::ServerConfig;
265
266            let mut cert_file = std::fs::File::open(&tls_config.cert_path)
267                .map_err(|e| ApiError::Internal(format!("Failed to open cert file: {}", e)))?;
268            let mut key_file = std::fs::File::open(&tls_config.key_path)
269                .map_err(|e| ApiError::Internal(format!("Failed to open key file: {}", e)))?;
270
271            let mut cert_pem = Vec::new();
272            cert_file
273                .read_to_end(&mut cert_pem)
274                .map_err(|e| ApiError::Internal(format!("Failed to read cert file: {}", e)))?;
275
276            let mut key_pem = Vec::new();
277            key_file
278                .read_to_end(&mut key_pem)
279                .map_err(|e| ApiError::Internal(format!("Failed to read key file: {}", e)))?;
280
281            let certs = CertificateDer::pem_slice_iter(&cert_pem)
282                .collect::<Result<Vec<_>, _>>()
283                .map_err(|e| ApiError::Internal(format!("Failed to parse certs: {}", e)))?;
284
285            let mut keys = PrivateKeyDer::pem_slice_iter(&key_pem)
286                .collect::<Result<Vec<_>, _>>()
287                .map_err(|e| ApiError::Internal(format!("Failed to parse key: {}", e)))?;
288
289            if keys.is_empty() {
290                return Err(ApiError::Internal("No private keys found".to_string()));
291            }
292
293            let mut server_config = ServerConfig::builder()
294                .with_no_client_auth()
295                .with_single_cert(certs, keys.remove(0))
296                .map_err(|e| ApiError::Internal(format!("Failed to build TLS config: {}", e)))?;
297
298            server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
299
300            let tls_acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(server_config));
301            let tcp_listener = tokio::net::TcpListener::bind(addr).await?;
302
303            tracing::info!("✅ VEX API listening on https://{}", addr);
304
305            loop {
306                let (tcp_stream, remote_addr) = tcp_listener
307                    .accept()
308                    .await
309                    .map_err(|e| ApiError::Internal(format!("Accept error: {}", e)))?;
310
311                let tls_acceptor = tls_acceptor.clone();
312                let app = app.clone();
313
314                tokio::spawn(async move {
315                    let tls_stream = match tls_acceptor.accept(tcp_stream).await {
316                        Ok(s) => s,
317                        Err(e) => {
318                            tracing::error!("TLS handshake failed: {}", e);
319                            return;
320                        }
321                    };
322
323                    let tower_service = app.clone();
324                    let hyper_service = hyper::service::service_fn(
325                        move |request: hyper::Request<hyper::body::Incoming>| {
326                            tower_service.clone().call(request)
327                        },
328                    );
329
330                    if let Err(e) = hyper::server::conn::http1::Builder::new()
331                        .serve_connection(hyper_util::rt::TokioIo::new(tls_stream), hyper_service)
332                        .await
333                    {
334                        tracing::error!(
335                            "Error serving HTTPS connection from {}: {}",
336                            remote_addr,
337                            e
338                        );
339                    }
340                });
341            }
342        } else {
343            // Check enforcement
344            if self.config.enforce_https {
345                tracing::error!("FATAL: HTTPS enforcement is enabled but TLS certificates are missing (VEX_TLS_CERT/VEX_TLS_KEY)");
346                return Err(ApiError::Internal("HTTPS enforcement error".to_string()));
347            }
348
349            // HTTP (development only)
350            tracing::warn!(
351                "⚠️  Starting VEX API server WITHOUT HTTPS on {} - NOT for production!",
352                addr
353            );
354
355            let listener = tokio::net::TcpListener::bind(addr).await?;
356
357            axum::serve(
358                listener,
359                app.into_make_service_with_connect_info::<std::net::SocketAddr>(),
360            )
361            .with_graceful_shutdown(shutdown_signal())
362            .await
363            .map_err(|e| ApiError::Internal(format!("Server error: {}", e)))?;
364        }
365
366        tracing::info!("Server shutdown complete");
367        Ok(())
368    }
369
370    /// Get server metrics
371    pub fn metrics(&self) -> Arc<Metrics> {
372        self.app_state.metrics()
373    }
374}
375
376/// Graceful shutdown signal handler
377async fn shutdown_signal() {
378    let ctrl_c = async {
379        signal::ctrl_c()
380            .await
381            .expect("Failed to install Ctrl+C handler");
382    };
383
384    #[cfg(unix)]
385    let terminate = async {
386        signal::unix::signal(signal::unix::SignalKind::terminate())
387            .expect("Failed to install SIGTERM handler")
388            .recv()
389            .await;
390    };
391
392    #[cfg(not(unix))]
393    let terminate = std::future::pending::<()>();
394
395    tokio::select! {
396        _ = ctrl_c => {
397            tracing::info!("Received Ctrl+C, starting graceful shutdown");
398        }
399        _ = terminate => {
400            tracing::info!("Received SIGTERM, starting graceful shutdown");
401        }
402    }
403}
404
405/// Initialize tracing subscriber
406pub fn init_tracing() {
407    use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
408
409    let filter = EnvFilter::try_from_default_env()
410        .unwrap_or_else(|_| EnvFilter::new("info,vex_api=debug,tower_http=debug"));
411
412    tracing_subscriber::registry()
413        .with(filter)
414        .with(tracing_subscriber::fmt::layer().with_target(true))
415        .init();
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421
422    #[test]
423    fn test_server_config_default() {
424        let config = ServerConfig::default();
425        assert_eq!(config.addr.port(), 8080);
426        assert_eq!(config.timeout, Duration::from_secs(30));
427    }
428}