1use axum::{
6 extract::{Path, State},
7 http::StatusCode,
8 Json, Router,
9};
10use chrono::{DateTime, Duration, Utc};
11use moka::future::Cache;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use uuid::Uuid;
15
16use super::agent_card::AgentCard;
17use super::task::{TaskRequest, TaskResponse};
18use crate::sanitize::sanitize_name;
19
20pub struct NonceCache {
22 cache: Cache<String, ()>,
23}
24
25impl Default for NonceCache {
26 fn default() -> Self {
27 Self {
28 cache: Cache::builder()
29 .max_capacity(100_000)
30 .time_to_live(std::time::Duration::from_secs(300)) .build(),
32 }
33 }
34}
35
36impl NonceCache {
37 pub async fn validate(
38 &self,
39 nonce: Option<&str>,
40 timestamp: DateTime<Utc>,
41 ) -> Result<(), String> {
42 let now = Utc::now();
43 let age = now.signed_duration_since(timestamp);
44
45 if age > Duration::minutes(5) {
47 return Err(format!("Timestamp too old: {}s", age.num_seconds()));
48 }
49 if age < Duration::seconds(-30) {
50 return Err("Timestamp is in the future".to_string());
52 }
53
54 if let Some(nonce) = nonce {
55 if self.cache.contains_key(nonce) {
56 return Err("Replay detected".to_string());
57 }
58 self.cache.insert(nonce.to_string(), ()).await;
59 }
60 Ok(())
61 }
62}
63
64pub struct A2aState {
66 pub agent_card: RwLock<AgentCard>,
67 pub nonce_mgr: NonceCache,
68}
69
70impl Default for A2aState {
71 fn default() -> Self {
72 Self {
73 agent_card: RwLock::new(AgentCard::vex_default()),
74 nonce_mgr: NonceCache::default(),
75 }
76 }
77}
78
79#[utoipa::path(
81 get,
82 path = "/.well-known/agent.json",
83 responses(
84 (status = 200, description = "A2A Agent Card", body = AgentCard)
85 )
86)]
87pub async fn agent_card_handler(State(a2a_state): State<Arc<A2aState>>) -> Json<AgentCard> {
88 Json(a2a_state.agent_card.read().await.clone())
89}
90
91#[utoipa::path(
93 post,
94 path = "/a2a/tasks",
95 request_body = TaskRequest,
96 responses(
97 (status = 202, description = "Task accepted", body = TaskResponse),
98 (status = 400, description = "Invalid request or replay detected")
99 )
100)]
101pub async fn create_task_handler(
102 State(a2a_state): State<Arc<A2aState>>,
103 Json(request): Json<TaskRequest>,
104) -> Result<(StatusCode, Json<TaskResponse>), (StatusCode, String)> {
105 if let Err(e) = a2a_state
106 .nonce_mgr
107 .validate(request.nonce.as_deref(), request.timestamp)
108 .await
109 {
110 tracing::warn!(task_id = %request.id, error = %e, "A2A replay protection failed");
111 return Err((
112 StatusCode::BAD_REQUEST,
113 format!("Replay protection failed: {}", e),
114 ));
115 }
116
117 if let Err(e) = sanitize_name(&request.skill) {
119 tracing::warn!(task_id = %request.id, error = %e, "Invalid skill in A2A task");
120 return Err((StatusCode::BAD_REQUEST, format!("Invalid skill: {}", e)));
121 }
122
123 let response = TaskResponse::pending(request.id);
124 Ok((StatusCode::ACCEPTED, Json(response)))
125}
126
127#[utoipa::path(
129 get,
130 path = "/a2a/tasks/{id}",
131 params(
132 ("id" = Uuid, Path, description = "Task ID")
133 ),
134 responses(
135 (status = 200, description = "Current task status", body = TaskResponse),
136 (status = 404, description = "Task not found")
137 )
138)]
139pub async fn get_task_handler(
140 State(_a2a_state): State<Arc<A2aState>>,
141 Path(task_id): Path<Uuid>,
142) -> Result<Json<TaskResponse>, StatusCode> {
143 let response = TaskResponse::pending(task_id);
144 Ok(Json(response))
145}
146
147pub fn a2a_routes() -> Router<Arc<A2aState>> {
149 use axum::routing::{get, post};
150
151 Router::new()
152 .route("/.well-known/agent.json", get(agent_card_handler))
153 .route("/a2a/tasks", post(create_task_handler))
154 .route("/a2a/tasks/{id}", get(get_task_handler))
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use axum::body::Body;
161 use axum::http::Request;
162 use tower::ServiceExt;
163
164 fn create_test_state() -> Arc<A2aState> {
165 Arc::new(A2aState::default())
166 }
167
168 #[tokio::test]
169 async fn test_agent_card_endpoint() {
170 let state = create_test_state();
171 let app = a2a_routes().with_state(state);
172
173 let response = app
174 .oneshot(
175 Request::builder()
176 .uri("/.well-known/agent.json")
177 .body(Body::empty())
178 .unwrap(),
179 )
180 .await
181 .unwrap();
182
183 assert_eq!(response.status(), StatusCode::OK);
184 }
185
186 #[tokio::test]
187 async fn test_create_task_endpoint() {
188 let state = create_test_state();
189 let app = a2a_routes().with_state(state);
190
191 let task_req = TaskRequest::new("verify", serde_json::json!({"claim": "test"}));
192 let body = serde_json::to_string(&task_req).unwrap();
193
194 let response = app
195 .oneshot(
196 Request::builder()
197 .method("POST")
198 .uri("/a2a/tasks")
199 .header("content-type", "application/json")
200 .body(Body::from(body))
201 .unwrap(),
202 )
203 .await
204 .unwrap();
205
206 assert_eq!(response.status(), StatusCode::ACCEPTED);
207 }
208
209 #[tokio::test]
210 async fn test_get_task_endpoint() {
211 let state = create_test_state();
212 let app = a2a_routes().with_state(state);
213 let task_id = Uuid::new_v4();
214
215 let response = app
216 .oneshot(
217 Request::builder()
218 .uri(format!("/a2a/tasks/{}", task_id))
219 .body(Body::empty())
220 .unwrap(),
221 )
222 .await
223 .unwrap();
224
225 assert_eq!(response.status(), StatusCode::OK);
226 }
227}