vex_api/
middleware.rs

1//! Tower middleware for VEX API
2
3use axum::{
4    extract::{Request, State},
5    http::header,
6    middleware::Next,
7    response::Response,
8};
9use std::time::Instant;
10
11use crate::auth::{Claims, JwtAuth};
12use crate::error::ApiError;
13use crate::state::AppState;
14// use vex_llm::{RateLimiter, Metrics}; // No longer needed directly here? No, rate_limiter is used.
15
16/// Authentication middleware
17pub async fn auth_middleware(
18    State(state): State<AppState>,
19    mut request: Request,
20    next: Next,
21) -> Result<Response, ApiError> {
22    // Skip auth for health check and public endpoints
23    let path = request.uri().path();
24    if path == "/health" || path.starts_with("/public/") {
25        return Ok(next.run(request).await);
26    }
27
28    // Extract token from header
29    let auth_header = request
30        .headers()
31        .get(header::AUTHORIZATION)
32        .and_then(|v| v.to_str().ok())
33        .ok_or_else(|| ApiError::Unauthorized("Missing Authorization header".to_string()))?;
34
35    let token = JwtAuth::extract_from_header(auth_header)?;
36    let claims = state.jwt_auth().decode(token)?;
37
38    // Insert claims into request extensions for handlers
39    request.extensions_mut().insert(claims);
40
41    Ok(next.run(request).await)
42}
43
44/// Rate limiting middleware
45pub async fn rate_limit_middleware(
46    State(state): State<AppState>,
47    request: Request,
48    next: Next,
49) -> Result<Response, ApiError> {
50    // Extract tenant identifier (prioritize authenticated sub from JWT)
51    let tenant_id = request
52        .extensions()
53        .get::<Claims>()
54        .map(|c| c.sub.clone())
55        .or_else(|| {
56            request
57                .headers()
58                .get("x-client-id")
59                .and_then(|h| h.to_str().ok())
60                .map(|s| s.to_string())
61        })
62        .unwrap_or_else(|| "anonymous".to_string());
63
64    // Check rate limit
65    state
66        .rate_limiter()
67        .check(&tenant_id)
68        .await
69        .map_err(|_| ApiError::RateLimited)?;
70
71    Ok(next.run(request).await)
72}
73
74/// Request tracing middleware
75pub async fn tracing_middleware(
76    State(state): State<AppState>,
77    request: Request,
78    next: Next,
79) -> Response {
80    let start = Instant::now();
81    let method = request.method().clone();
82    let uri = request.uri().clone();
83    let path = uri.path().to_string();
84    // Extract IDs for tracing
85    let request_id = request
86        .extensions()
87        .get::<RequestId>()
88        .map(|id| id.0.clone())
89        .unwrap_or_else(|| "unknown".to_string());
90    let tenant_id = request
91        .extensions()
92        .get::<Claims>()
93        .map(|c| c.sub.clone())
94        .unwrap_or_else(|| "anonymous".to_string());
95
96    // Create span for this request
97    let span = tracing::info_span!(
98        "http_request",
99        method = %method,
100        path = %path,
101        request_id = %request_id,
102        tenant_id = %tenant_id,
103        status = tracing::field::Empty,
104        latency_ms = tracing::field::Empty,
105    );
106
107    let response = {
108        let _enter = span.enter();
109        next.run(request).await
110    };
111
112    let latency = start.elapsed();
113    let status = response.status();
114
115    // Record metrics
116    state.metrics().record_llm_call(0, !status.is_success());
117
118    // Log request
119    tracing::info!(
120        method = %method,
121        path = %path,
122        status = %status.as_u16(),
123        latency_ms = %latency.as_millis(),
124        "Request completed"
125    );
126
127    response
128}
129
130/// Request ID middleware
131pub async fn request_id_middleware(mut request: Request, next: Next) -> Response {
132    let request_id = uuid::Uuid::new_v4().to_string();
133
134    // Add to request extensions
135    request
136        .extensions_mut()
137        .insert(RequestId(request_id.clone()));
138
139    let mut response = next.run(request).await;
140
141    // Add to response headers
142    response
143        .headers_mut()
144        .insert("X-Request-ID", request_id.parse().unwrap());
145
146    response
147}
148
149/// Request ID wrapper
150#[derive(Clone, Debug)]
151pub struct RequestId(pub String);
152
153/// CORS configuration helper
154/// Reads allowed origins from VEX_CORS_ORIGINS env var (comma-separated)
155/// Falls back to restrictive default if not set
156pub fn cors_layer() -> tower_http::cors::CorsLayer {
157    use tower_http::cors::{AllowOrigin, CorsLayer};
158
159    let origins = std::env::var("VEX_CORS_ORIGINS").ok();
160
161    let allow_origin = match origins {
162        Some(origins_str) if !origins_str.is_empty() => {
163            let origins: Vec<axum::http::HeaderValue> = origins_str
164                .split(',')
165                .filter_map(|s| s.trim().parse().ok())
166                .collect();
167            if origins.is_empty() {
168                tracing::warn!("VEX_CORS_ORIGINS is set but contains no valid origins, using restrictive default");
169                AllowOrigin::exact("https://localhost".parse().unwrap())
170            } else {
171                tracing::info!("CORS configured for {} origin(s)", origins.len());
172                AllowOrigin::list(origins)
173            }
174        }
175        _ => {
176            // No CORS_ORIGINS set - use restrictive default for security
177            tracing::warn!("VEX_CORS_ORIGINS not set, using restrictive CORS (localhost only)");
178            AllowOrigin::exact("https://localhost".parse().unwrap())
179        }
180    };
181
182    CorsLayer::new()
183        .allow_origin(allow_origin)
184        .allow_methods([
185            axum::http::Method::GET,
186            axum::http::Method::POST,
187            axum::http::Method::PUT,
188            axum::http::Method::DELETE,
189            axum::http::Method::OPTIONS,
190        ])
191        .allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE, header::ACCEPT])
192        .max_age(std::time::Duration::from_secs(3600))
193}
194
195/// Timeout layer helper
196#[allow(deprecated)]
197pub fn timeout_layer(duration: std::time::Duration) -> tower_http::timeout::TimeoutLayer {
198    tower_http::timeout::TimeoutLayer::new(duration)
199}
200
201/// Request body size limit
202pub fn body_limit_layer(limit: usize) -> tower_http::limit::RequestBodyLimitLayer {
203    tower_http::limit::RequestBodyLimitLayer::new(limit)
204}
205
206/// Security headers middleware
207/// Adds standard security headers to all responses
208pub async fn security_headers_middleware(request: Request, next: Next) -> Response {
209    let mut response = next.run(request).await;
210
211    let headers = response.headers_mut();
212
213    // Prevent MIME type sniffing
214    headers.insert("X-Content-Type-Options", "nosniff".parse().unwrap());
215
216    // Prevent clickjacking
217    headers.insert("X-Frame-Options", "DENY".parse().unwrap());
218
219    // XSS protection (legacy, but still useful)
220    headers.insert("X-XSS-Protection", "1; mode=block".parse().unwrap());
221
222    // Content Security Policy
223    headers.insert(
224        "Content-Security-Policy",
225        "default-src 'self'; frame-ancestors 'none'"
226            .parse()
227            .unwrap(),
228    );
229
230    // HSTS - Enable in production by setting VEX_ENABLE_HSTS=1
231    if std::env::var("VEX_ENABLE_HSTS").is_ok() {
232        headers.insert(
233            "Strict-Transport-Security",
234            "max-age=31536000; includeSubDomains".parse().unwrap(),
235        );
236    }
237
238    // Referrer policy
239    headers.insert(
240        "Referrer-Policy",
241        "strict-origin-when-cross-origin".parse().unwrap(),
242    );
243
244    // Permissions policy
245    headers.insert(
246        "Permissions-Policy",
247        "geolocation=(), microphone=(), camera=()".parse().unwrap(),
248    );
249
250    response
251}
252
253#[cfg(test)]
254mod tests {
255    #[test]
256    fn test_request_id() {
257        let id1 = uuid::Uuid::new_v4().to_string();
258        let id2 = uuid::Uuid::new_v4().to_string();
259        assert_ne!(id1, id2);
260    }
261}