vex_api/a2a/
handler.rs

1//! A2A HTTP handlers
2//!
3//! Axum handlers for A2A protocol endpoints.
4
5use axum::{
6    extract::{Path, State},
7    http::StatusCode,
8    Json, Router,
9};
10use chrono::{DateTime, Duration, Utc};
11use moka::future::Cache;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use uuid::Uuid;
15
16use super::agent_card::AgentCard;
17use super::task::{TaskRequest, TaskResponse};
18use crate::sanitize::sanitize_name;
19
20/// Nonce cache for replay protection (2025 best practice)
21pub struct NonceCache {
22    cache: Cache<String, ()>,
23}
24
25impl Default for NonceCache {
26    fn default() -> Self {
27        Self {
28            cache: Cache::builder()
29                .max_capacity(100_000)
30                .time_to_live(std::time::Duration::from_secs(300)) // 5 minutes matches max_age
31                .build(),
32        }
33    }
34}
35
36impl NonceCache {
37    pub async fn validate(
38        &self,
39        nonce: Option<&str>,
40        timestamp: DateTime<Utc>,
41    ) -> Result<(), String> {
42        let now = Utc::now();
43        let age = now.signed_duration_since(timestamp);
44
45        // TTL enforcement
46        if age > Duration::minutes(5) {
47            return Err(format!("Timestamp too old: {}s", age.num_seconds()));
48        }
49        if age < Duration::seconds(-30) {
50            // Allow for some clock skew
51            return Err("Timestamp is in the future".to_string());
52        }
53
54        if let Some(nonce) = nonce {
55            if self.cache.contains_key(nonce) {
56                return Err("Replay detected".to_string());
57            }
58            self.cache.insert(nonce.to_string(), ()).await;
59        }
60        Ok(())
61    }
62}
63
64/// Shared state for A2A handlers
65pub struct A2aState {
66    pub agent_card: RwLock<AgentCard>,
67    pub nonce_mgr: NonceCache,
68}
69
70impl Default for A2aState {
71    fn default() -> Self {
72        Self {
73            agent_card: RwLock::new(AgentCard::vex_default()),
74            nonce_mgr: NonceCache::default(),
75        }
76    }
77}
78
79/// Health check for A2A protocol (Agent Card)
80#[utoipa::path(
81    get,
82    path = "/.well-known/agent.json",
83    responses(
84        (status = 200, description = "A2A Agent Card", body = AgentCard)
85    )
86)]
87pub async fn agent_card_handler(State(a2a_state): State<Arc<A2aState>>) -> Json<AgentCard> {
88    Json(a2a_state.agent_card.read().await.clone())
89}
90
91/// Create a new task (A2A)
92#[utoipa::path(
93    post,
94    path = "/a2a/tasks",
95    request_body = TaskRequest,
96    responses(
97        (status = 202, description = "Task accepted", body = TaskResponse),
98        (status = 400, description = "Invalid request or replay detected")
99    )
100)]
101pub async fn create_task_handler(
102    State(a2a_state): State<Arc<A2aState>>,
103    Json(request): Json<TaskRequest>,
104) -> Result<(StatusCode, Json<TaskResponse>), (StatusCode, String)> {
105    if let Err(e) = a2a_state
106        .nonce_mgr
107        .validate(request.nonce.as_deref(), request.timestamp)
108        .await
109    {
110        tracing::warn!(task_id = %request.id, error = %e, "A2A replay protection failed");
111        return Err((
112            StatusCode::BAD_REQUEST,
113            format!("Replay protection failed: {}", e),
114        ));
115    }
116
117    // Sanitize skill field (Fix for #25)
118    if let Err(e) = sanitize_name(&request.skill) {
119        tracing::warn!(task_id = %request.id, error = %e, "Invalid skill in A2A task");
120        return Err((StatusCode::BAD_REQUEST, format!("Invalid skill: {}", e)));
121    }
122
123    let response = TaskResponse::pending(request.id);
124    Ok((StatusCode::ACCEPTED, Json(response)))
125}
126
127/// Get task status (A2A)
128#[utoipa::path(
129    get,
130    path = "/a2a/tasks/{id}",
131    params(
132        ("id" = Uuid, Path, description = "Task ID")
133    ),
134    responses(
135        (status = 200, description = "Current task status", body = TaskResponse),
136        (status = 404, description = "Task not found")
137    )
138)]
139pub async fn get_task_handler(
140    State(_a2a_state): State<Arc<A2aState>>,
141    Path(task_id): Path<Uuid>,
142) -> Result<Json<TaskResponse>, StatusCode> {
143    let response = TaskResponse::pending(task_id);
144    Ok(Json(response))
145}
146
147/// Build A2A routes decoupled from main AppState
148pub fn a2a_routes() -> Router<Arc<A2aState>> {
149    use axum::routing::{get, post};
150
151    Router::new()
152        .route("/.well-known/agent.json", get(agent_card_handler))
153        .route("/a2a/tasks", post(create_task_handler))
154        .route("/a2a/tasks/{id}", get(get_task_handler))
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use axum::body::Body;
161    use axum::http::Request;
162    use tower::ServiceExt;
163
164    fn create_test_state() -> Arc<A2aState> {
165        Arc::new(A2aState::default())
166    }
167
168    #[tokio::test]
169    async fn test_agent_card_endpoint() {
170        let state = create_test_state();
171        let app = a2a_routes().with_state(state);
172
173        let response = app
174            .oneshot(
175                Request::builder()
176                    .uri("/.well-known/agent.json")
177                    .body(Body::empty())
178                    .unwrap(),
179            )
180            .await
181            .unwrap();
182
183        assert_eq!(response.status(), StatusCode::OK);
184    }
185
186    #[tokio::test]
187    async fn test_create_task_endpoint() {
188        let state = create_test_state();
189        let app = a2a_routes().with_state(state);
190
191        let task_req = TaskRequest::new("verify", serde_json::json!({"claim": "test"}));
192        let body = serde_json::to_string(&task_req).unwrap();
193
194        let response = app
195            .oneshot(
196                Request::builder()
197                    .method("POST")
198                    .uri("/a2a/tasks")
199                    .header("content-type", "application/json")
200                    .body(Body::from(body))
201                    .unwrap(),
202            )
203            .await
204            .unwrap();
205
206        assert_eq!(response.status(), StatusCode::ACCEPTED);
207    }
208
209    #[tokio::test]
210    async fn test_get_task_endpoint() {
211        let state = create_test_state();
212        let app = a2a_routes().with_state(state);
213        let task_id = Uuid::new_v4();
214
215        let response = app
216            .oneshot(
217                Request::builder()
218                    .uri(format!("/a2a/tasks/{}", task_id))
219                    .body(Body::empty())
220                    .unwrap(),
221            )
222            .await
223            .unwrap();
224
225        assert_eq!(response.status(), StatusCode::OK);
226    }
227}