1use axum::{
4 extract::{Request, State},
5 http::header,
6 middleware::Next,
7 response::Response,
8};
9use std::time::Instant;
10
11use crate::auth::{Claims, JwtAuth};
12use crate::error::ApiError;
13use crate::state::AppState;
14pub async fn auth_middleware(
18 State(state): State<AppState>,
19 mut request: Request,
20 next: Next,
21) -> Result<Response, ApiError> {
22 let path = request.uri().path();
24 if path == "/health" || path.starts_with("/public/") {
25 return Ok(next.run(request).await);
26 }
27
28 let auth_header = request
30 .headers()
31 .get(header::AUTHORIZATION)
32 .and_then(|v| v.to_str().ok())
33 .ok_or_else(|| ApiError::Unauthorized("Missing Authorization header".to_string()))?;
34
35 let token = JwtAuth::extract_from_header(auth_header)?;
36 let claims = state.jwt_auth().decode(token)?;
37
38 request.extensions_mut().insert(claims);
40
41 Ok(next.run(request).await)
42}
43
44pub async fn rate_limit_middleware(
46 State(state): State<AppState>,
47 request: Request,
48 next: Next,
49) -> Result<Response, ApiError> {
50 let tenant_id = request
52 .extensions()
53 .get::<Claims>()
54 .map(|c| c.sub.clone())
55 .or_else(|| {
56 request
57 .headers()
58 .get("x-client-id")
59 .and_then(|h| h.to_str().ok())
60 .map(|s| s.to_string())
61 })
62 .unwrap_or_else(|| "anonymous".to_string());
63
64 state
66 .rate_limiter()
67 .check(&tenant_id)
68 .await
69 .map_err(|_| ApiError::RateLimited)?;
70
71 Ok(next.run(request).await)
72}
73
74pub async fn tracing_middleware(
76 State(state): State<AppState>,
77 request: Request,
78 next: Next,
79) -> Response {
80 let start = Instant::now();
81 let method = request.method().clone();
82 let uri = request.uri().clone();
83 let path = uri.path().to_string();
84 let request_id = request
86 .extensions()
87 .get::<RequestId>()
88 .map(|id| id.0.clone())
89 .unwrap_or_else(|| "unknown".to_string());
90 let tenant_id = request
91 .extensions()
92 .get::<Claims>()
93 .map(|c| c.sub.clone())
94 .unwrap_or_else(|| "anonymous".to_string());
95
96 let span = tracing::info_span!(
98 "http_request",
99 method = %method,
100 path = %path,
101 request_id = %request_id,
102 tenant_id = %tenant_id,
103 status = tracing::field::Empty,
104 latency_ms = tracing::field::Empty,
105 );
106
107 let response = {
108 let _enter = span.enter();
109 next.run(request).await
110 };
111
112 let latency = start.elapsed();
113 let status = response.status();
114
115 state.metrics().record_llm_call(0, !status.is_success());
117
118 tracing::info!(
120 method = %method,
121 path = %path,
122 status = %status.as_u16(),
123 latency_ms = %latency.as_millis(),
124 "Request completed"
125 );
126
127 response
128}
129
130pub async fn request_id_middleware(mut request: Request, next: Next) -> Response {
132 let request_id = uuid::Uuid::new_v4().to_string();
133
134 request
136 .extensions_mut()
137 .insert(RequestId(request_id.clone()));
138
139 let mut response = next.run(request).await;
140
141 response
143 .headers_mut()
144 .insert("X-Request-ID", request_id.parse().unwrap());
145
146 response
147}
148
149#[derive(Clone, Debug)]
151pub struct RequestId(pub String);
152
153pub fn cors_layer() -> tower_http::cors::CorsLayer {
157 use tower_http::cors::{AllowOrigin, CorsLayer};
158
159 let origins = std::env::var("VEX_CORS_ORIGINS").ok();
160
161 let allow_origin = match origins {
162 Some(origins_str) if !origins_str.is_empty() => {
163 let origins: Vec<axum::http::HeaderValue> = origins_str
164 .split(',')
165 .filter_map(|s| s.trim().parse().ok())
166 .collect();
167 if origins.is_empty() {
168 tracing::warn!("VEX_CORS_ORIGINS is set but contains no valid origins, using restrictive default");
169 AllowOrigin::exact("https://localhost".parse().unwrap())
170 } else {
171 tracing::info!("CORS configured for {} origin(s)", origins.len());
172 AllowOrigin::list(origins)
173 }
174 }
175 _ => {
176 tracing::warn!("VEX_CORS_ORIGINS not set, using restrictive CORS (localhost only)");
178 AllowOrigin::exact("https://localhost".parse().unwrap())
179 }
180 };
181
182 CorsLayer::new()
183 .allow_origin(allow_origin)
184 .allow_methods([
185 axum::http::Method::GET,
186 axum::http::Method::POST,
187 axum::http::Method::PUT,
188 axum::http::Method::DELETE,
189 axum::http::Method::OPTIONS,
190 ])
191 .allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE, header::ACCEPT])
192 .max_age(std::time::Duration::from_secs(3600))
193}
194
195#[allow(deprecated)]
197pub fn timeout_layer(duration: std::time::Duration) -> tower_http::timeout::TimeoutLayer {
198 tower_http::timeout::TimeoutLayer::new(duration)
199}
200
201pub fn body_limit_layer(limit: usize) -> tower_http::limit::RequestBodyLimitLayer {
203 tower_http::limit::RequestBodyLimitLayer::new(limit)
204}
205
206pub async fn security_headers_middleware(request: Request, next: Next) -> Response {
209 let mut response = next.run(request).await;
210
211 let headers = response.headers_mut();
212
213 headers.insert("X-Content-Type-Options", "nosniff".parse().unwrap());
215
216 headers.insert("X-Frame-Options", "DENY".parse().unwrap());
218
219 headers.insert("X-XSS-Protection", "1; mode=block".parse().unwrap());
221
222 headers.insert(
224 "Content-Security-Policy",
225 "default-src 'self'; frame-ancestors 'none'"
226 .parse()
227 .unwrap(),
228 );
229
230 if std::env::var("VEX_ENABLE_HSTS").is_ok() {
232 headers.insert(
233 "Strict-Transport-Security",
234 "max-age=31536000; includeSubDomains".parse().unwrap(),
235 );
236 }
237
238 headers.insert(
240 "Referrer-Policy",
241 "strict-origin-when-cross-origin".parse().unwrap(),
242 );
243
244 headers.insert(
246 "Permissions-Policy",
247 "geolocation=(), microphone=(), camera=()".parse().unwrap(),
248 );
249
250 response
251}
252
253#[cfg(test)]
254mod tests {
255 #[test]
256 fn test_request_id() {
257 let id1 = uuid::Uuid::new_v4().to_string();
258 let id2 = uuid::Uuid::new_v4().to_string();
259 assert_ne!(id1, id2);
260 }
261}