1use axum::{middleware, Router};
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::signal;
8use tower::Service;
9use tower_http::compression::CompressionLayer;
10
11use crate::auth::JwtAuth;
12use crate::error::ApiError;
13use crate::middleware::{
14 auth_middleware, body_limit_layer, cors_layer, rate_limit_middleware, request_id_middleware,
15 timeout_layer, tracing_middleware,
16};
17use crate::routes::api_router;
18use vex_llm::{Metrics, RateLimitConfig};
19#[derive(Debug, Clone)]
24pub struct TlsConfig {
25 pub cert_path: String,
27 pub key_path: String,
29}
30
31impl TlsConfig {
32 pub fn new(cert_path: &str, key_path: &str) -> Self {
34 Self {
35 cert_path: cert_path.to_string(),
36 key_path: key_path.to_string(),
37 }
38 }
39
40 pub fn from_env() -> Option<Self> {
42 let cert = std::env::var("VEX_TLS_CERT").ok()?;
43 let key = std::env::var("VEX_TLS_KEY").ok()?;
44 Some(Self::new(&cert, &key))
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct ServerConfig {
51 pub addr: SocketAddr,
53 pub timeout: Duration,
55 pub max_body_size: usize,
57 pub compression: bool,
59 pub rate_limit: RateLimitConfig,
61 pub tls: Option<TlsConfig>,
63 pub enforce_https: bool,
65}
66
67impl Default for ServerConfig {
68 fn default() -> Self {
69 Self {
70 addr: "0.0.0.0:8080".parse().unwrap(),
71 timeout: Duration::from_secs(30),
72 max_body_size: 1024 * 1024, compression: true,
74 rate_limit: RateLimitConfig::default(),
75 tls: None,
76 enforce_https: false,
77 }
78 }
79}
80
81impl ServerConfig {
82 pub fn from_env() -> Self {
84 let port: u16 = std::env::var("VEX_PORT")
85 .ok()
86 .and_then(|p| p.parse().ok())
87 .unwrap_or(8080);
88
89 let timeout_secs: u64 = std::env::var("VEX_TIMEOUT_SECS")
90 .ok()
91 .and_then(|t| t.parse().ok())
92 .unwrap_or(30);
93
94 let enforce_https = std::env::var("VEX_ENFORCE_HTTPS").is_ok()
95 || std::env::var("VEX_ENV")
96 .map(|e| e == "production")
97 .unwrap_or(false);
98
99 Self {
100 addr: SocketAddr::from(([0, 0, 0, 0], port)),
101 timeout: Duration::from_secs(timeout_secs),
102 enforce_https,
103 ..Default::default()
104 }
105 }
106}
107
108use crate::state::AppState;
109
110pub struct VexServer {
112 config: ServerConfig,
113 app_state: AppState,
114}
115
116impl VexServer {
117 pub async fn new(config: ServerConfig) -> Result<Self, ApiError> {
119 use crate::jobs::agent::{AgentExecutionJob, AgentJobPayload};
120 use crate::tenant_rate_limiter::{RateLimitTier, TenantRateLimiter};
121 use vex_llm::{
122 CachedProvider, DeepSeekProvider, LlmProvider, MockProvider, ResilientProvider,
123 };
124 use vex_queue::{QueueBackend, WorkerConfig, WorkerPool};
125
126 let jwt_auth = JwtAuth::from_env()?;
127 let rate_limiter = Arc::new(TenantRateLimiter::new(RateLimitTier::Standard));
128 let metrics = Arc::new(Metrics::new());
129
130 let db_url =
132 std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite::memory:".to_string());
133 let db = vex_persist::sqlite::SqliteBackend::new(&db_url)
134 .await
135 .map_err(|e| ApiError::Internal(format!("DB Init failed: {}", e)))?;
136
137 let queue_backend = vex_persist::queue::SqliteQueueBackend::new(db.pool().clone());
139
140 let worker_pool = WorkerPool::new_with_arc(
142 Arc::new(queue_backend) as Arc<dyn QueueBackend>,
143 WorkerConfig::default(),
144 );
145
146 let _base_llm: Arc<dyn LlmProvider> = if let Ok(key) = std::env::var("DEEPSEEK_API_KEY") {
148 tracing::info!("Initializing Resilient+Cached DeepSeek Provider");
149 let base = DeepSeekProvider::chat(&key);
150 let resilient = ResilientProvider::new(base, vex_llm::LlmCircuitConfig::conservative());
152 let cached = CachedProvider::wrap(resilient);
153 Arc::new(cached)
154 } else {
155 tracing::warn!("DEEPSEEK_API_KEY not found. Using Mock Provider.");
156 Arc::new(MockProvider::smart())
157 };
158
159 let router = vex_router::Router::builder()
161 .strategy(vex_router::RoutingStrategy::Auto)
162 .build();
163 let router_arc = Arc::new(router);
164 let llm: Arc<dyn LlmProvider> = router_arc.clone();
165
166 let result_store = crate::jobs::new_result_store();
168
169 let llm_clone = llm.clone();
171 let result_store_clone = result_store.clone();
172 worker_pool.register_job_factory("agent_execution", move |payload| {
173 let job_payload: AgentJobPayload =
174 serde_json::from_value(payload).unwrap_or_else(|_| AgentJobPayload {
175 agent_id: "unknown".to_string(),
176 prompt: "payload error".to_string(),
177 context_id: None,
178 });
179 let job_id = uuid::Uuid::new_v4();
180 Box::new(AgentExecutionJob::new(
181 job_id,
182 job_payload,
183 llm_clone.clone(),
184 result_store_clone.clone(),
185 ))
186 });
187
188 let a2a_state = Arc::new(crate::a2a::handler::A2aState::default());
189
190 let app_state = AppState::new(
191 jwt_auth,
192 rate_limiter,
193 metrics,
194 Arc::new(db),
195 Arc::new(worker_pool),
196 a2a_state,
197 llm.clone(),
198 Some(router_arc),
199 );
200
201 Ok(Self { config, app_state })
202 }
203
204 pub fn router(&self) -> Router {
206 let mut app = api_router(self.app_state.clone());
207
208 app = app
210 .layer(CompressionLayer::new())
212 .layer(body_limit_layer(self.config.max_body_size))
214 .layer(timeout_layer(self.config.timeout))
216 .layer(cors_layer())
218 .layer(middleware::from_fn(request_id_middleware))
220 .layer(middleware::from_fn_with_state(
222 self.app_state.clone(),
223 tracing_middleware,
224 ))
225 .layer(middleware::from_fn_with_state(
227 self.app_state.clone(),
228 rate_limit_middleware,
229 ))
230 .layer(middleware::from_fn_with_state(
232 self.app_state.clone(),
233 auth_middleware,
234 ));
235
236 app
237 }
238
239 pub async fn run(self) -> Result<(), ApiError> {
246 let app = self.router();
247 let addr = self.config.addr;
248
249 let queue = self.app_state.queue();
251 tokio::spawn(async move {
252 queue.start().await;
253 });
254
255 if let Some(tls_config) = &self.config.tls {
257 tracing::info!("🔒 Starting VEX API server with HTTPS on {}", addr);
259
260 use rustls_pki_types::pem::PemObject;
262 use rustls_pki_types::{CertificateDer, PrivateKeyDer};
263 use std::io::Read;
264 use tokio_rustls::rustls::ServerConfig;
265
266 let mut cert_file = std::fs::File::open(&tls_config.cert_path)
267 .map_err(|e| ApiError::Internal(format!("Failed to open cert file: {}", e)))?;
268 let mut key_file = std::fs::File::open(&tls_config.key_path)
269 .map_err(|e| ApiError::Internal(format!("Failed to open key file: {}", e)))?;
270
271 let mut cert_pem = Vec::new();
272 cert_file
273 .read_to_end(&mut cert_pem)
274 .map_err(|e| ApiError::Internal(format!("Failed to read cert file: {}", e)))?;
275
276 let mut key_pem = Vec::new();
277 key_file
278 .read_to_end(&mut key_pem)
279 .map_err(|e| ApiError::Internal(format!("Failed to read key file: {}", e)))?;
280
281 let certs = CertificateDer::pem_slice_iter(&cert_pem)
282 .collect::<Result<Vec<_>, _>>()
283 .map_err(|e| ApiError::Internal(format!("Failed to parse certs: {}", e)))?;
284
285 let mut keys = PrivateKeyDer::pem_slice_iter(&key_pem)
286 .collect::<Result<Vec<_>, _>>()
287 .map_err(|e| ApiError::Internal(format!("Failed to parse key: {}", e)))?;
288
289 if keys.is_empty() {
290 return Err(ApiError::Internal("No private keys found".to_string()));
291 }
292
293 let mut server_config = ServerConfig::builder()
294 .with_no_client_auth()
295 .with_single_cert(certs, keys.remove(0))
296 .map_err(|e| ApiError::Internal(format!("Failed to build TLS config: {}", e)))?;
297
298 server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
299
300 let tls_acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(server_config));
301 let tcp_listener = tokio::net::TcpListener::bind(addr).await?;
302
303 tracing::info!("✅ VEX API listening on https://{}", addr);
304
305 loop {
306 let (tcp_stream, remote_addr) = tcp_listener
307 .accept()
308 .await
309 .map_err(|e| ApiError::Internal(format!("Accept error: {}", e)))?;
310
311 let tls_acceptor = tls_acceptor.clone();
312 let app = app.clone();
313
314 tokio::spawn(async move {
315 let tls_stream = match tls_acceptor.accept(tcp_stream).await {
316 Ok(s) => s,
317 Err(e) => {
318 tracing::error!("TLS handshake failed: {}", e);
319 return;
320 }
321 };
322
323 let tower_service = app.clone();
324 let hyper_service = hyper::service::service_fn(
325 move |request: hyper::Request<hyper::body::Incoming>| {
326 tower_service.clone().call(request)
327 },
328 );
329
330 if let Err(e) = hyper::server::conn::http1::Builder::new()
331 .serve_connection(hyper_util::rt::TokioIo::new(tls_stream), hyper_service)
332 .await
333 {
334 tracing::error!(
335 "Error serving HTTPS connection from {}: {}",
336 remote_addr,
337 e
338 );
339 }
340 });
341 }
342 } else {
343 if self.config.enforce_https {
345 tracing::error!("FATAL: HTTPS enforcement is enabled but TLS certificates are missing (VEX_TLS_CERT/VEX_TLS_KEY)");
346 return Err(ApiError::Internal("HTTPS enforcement error".to_string()));
347 }
348
349 tracing::warn!(
351 "⚠️ Starting VEX API server WITHOUT HTTPS on {} - NOT for production!",
352 addr
353 );
354
355 let listener = tokio::net::TcpListener::bind(addr).await?;
356
357 axum::serve(
358 listener,
359 app.into_make_service_with_connect_info::<std::net::SocketAddr>(),
360 )
361 .with_graceful_shutdown(shutdown_signal())
362 .await
363 .map_err(|e| ApiError::Internal(format!("Server error: {}", e)))?;
364 }
365
366 tracing::info!("Server shutdown complete");
367 Ok(())
368 }
369
370 pub fn metrics(&self) -> Arc<Metrics> {
372 self.app_state.metrics()
373 }
374}
375
376async fn shutdown_signal() {
378 let ctrl_c = async {
379 signal::ctrl_c()
380 .await
381 .expect("Failed to install Ctrl+C handler");
382 };
383
384 #[cfg(unix)]
385 let terminate = async {
386 signal::unix::signal(signal::unix::SignalKind::terminate())
387 .expect("Failed to install SIGTERM handler")
388 .recv()
389 .await;
390 };
391
392 #[cfg(not(unix))]
393 let terminate = std::future::pending::<()>();
394
395 tokio::select! {
396 _ = ctrl_c => {
397 tracing::info!("Received Ctrl+C, starting graceful shutdown");
398 }
399 _ = terminate => {
400 tracing::info!("Received SIGTERM, starting graceful shutdown");
401 }
402 }
403}
404
405pub fn init_tracing() {
407 use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
408
409 let filter = EnvFilter::try_from_default_env()
410 .unwrap_or_else(|_| EnvFilter::new("info,vex_api=debug,tower_http=debug"));
411
412 tracing_subscriber::registry()
413 .with(filter)
414 .with(tracing_subscriber::fmt::layer().with_target(true))
415 .init();
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421
422 #[test]
423 fn test_server_config_default() {
424 let config = ServerConfig::default();
425 assert_eq!(config.addr.port(), 8080);
426 assert_eq!(config.timeout, Duration::from_secs(30));
427 }
428}