1use std::time::Instant;
26use tokio::time::timeout;
27use tracing::{debug, error, info, warn};
28
29use crate::tool::{Capability, ToolRegistry};
30use crate::tool_error::ToolError;
31use crate::tool_result::ToolResult;
32
33pub struct ToolExecutor {
57 registry: ToolRegistry,
58 audit_enabled: bool,
60 max_parallel: usize,
62 allowed_capabilities: Vec<Capability>,
64}
65
66impl ToolExecutor {
67 pub fn new(registry: ToolRegistry) -> Self {
69 Self {
70 registry,
71 audit_enabled: true,
72 max_parallel: 0, allowed_capabilities: vec![
74 Capability::PureComputation,
75 Capability::Network,
76 Capability::FileSystem,
77 Capability::Subprocess,
78 Capability::Environment,
79 Capability::Cryptography,
80 ],
81 }
82 }
83
84 pub fn without_audit(registry: ToolRegistry) -> Self {
86 Self {
87 registry,
88 audit_enabled: false,
89 max_parallel: 0,
90 allowed_capabilities: vec![
91 Capability::PureComputation,
92 Capability::Network,
93 Capability::FileSystem,
94 Capability::Subprocess,
95 Capability::Environment,
96 Capability::Cryptography,
97 ],
98 }
99 }
100
101 pub fn with_max_parallel(mut self, max: usize) -> Self {
103 self.max_parallel = max;
104 self
105 }
106
107 pub fn with_allowed_capabilities(mut self, caps: Vec<Capability>) -> Self {
109 self.allowed_capabilities = caps;
110 self
111 }
112
113 pub async fn execute(
131 &self,
132 tool_name: &str,
133 args: serde_json::Value,
134 ) -> Result<ToolResult, ToolError> {
135 let tool = self.registry.get(tool_name).ok_or_else(|| {
137 warn!(tool = tool_name, "Tool not found");
138 ToolError::not_found(tool_name)
139 })?;
140
141 for cap in tool.capabilities() {
143 if !self.allowed_capabilities.contains(&cap) {
144 warn!(
145 tool = tool_name,
146 capability = ?cap,
147 "Tool requires missing capability"
148 );
149 return Err(ToolError::unavailable(
150 tool_name,
151 format!("Sandbox violation: tool requires {:?}", cap),
152 ));
153 }
154 }
155
156 if !tool.is_available() {
158 warn!(tool = tool_name, "Tool is unavailable");
159 return Err(ToolError::unavailable(
160 tool_name,
161 "Tool is currently disabled",
162 ));
163 }
164
165 debug!(tool = tool_name, "Validating arguments against schema");
167 let schema_str = tool.definition().parameters;
168 if !schema_str.is_empty() && schema_str != "{}" {
169 let schema_json: serde_json::Value = serde_json::from_str(schema_str).map_err(|e| {
170 ToolError::execution_failed(tool_name, format!("Invalid tool schema: {}", e))
171 })?;
172
173 let compiled = jsonschema::JSONSchema::compile(&schema_json).map_err(|e| {
174 ToolError::execution_failed(
175 tool_name,
176 format!("Failed to compile tool schema: {}", e),
177 )
178 })?;
179
180 if !compiled.is_valid(&args) {
181 warn!(tool = tool_name, "Schema validation failed");
182 return Err(ToolError::invalid_args(
183 tool_name,
184 "Arguments do not match tool schema",
185 ));
186 }
187 }
188
189 debug!(tool = tool_name, "Running custom validation");
191 tool.validate(&args)?;
192
193 let tool_timeout = tool.timeout();
195 let start = Instant::now();
196
197 debug!(
198 tool = tool_name,
199 timeout_ms = tool_timeout.as_millis(),
200 "Executing tool"
201 );
202
203 let output = timeout(tool_timeout, tool.execute(args.clone()))
204 .await
205 .map_err(|_| {
206 error!(
207 tool = tool_name,
208 timeout_ms = tool_timeout.as_millis(),
209 "Tool execution timed out"
210 );
211 ToolError::timeout(tool_name, tool_timeout.as_millis() as u64)
212 })??;
213
214 let elapsed = start.elapsed();
215
216 let result = ToolResult::new(tool_name, &args, output, elapsed);
218
219 info!(
221 tool = tool_name,
222 execution_ms = elapsed.as_millis(),
223 hash = %result.hash,
224 "Tool executed successfully"
225 );
226
227 if self.audit_enabled {
230 debug!(
231 tool = tool_name,
232 result_hash = %result.hash,
233 "Audit entry created"
234 );
235 }
236
237 Ok(result)
238 }
239
240 pub async fn execute_parallel(
256 &self,
257 calls: Vec<(String, serde_json::Value)>,
258 ) -> Vec<Result<ToolResult, ToolError>> {
259 debug!(count = calls.len(), "Executing tools in parallel");
260
261 if self.max_parallel > 0 {
262 use futures::stream::{self, StreamExt};
263
264 stream::iter(calls)
265 .map(|(name, args)| async move { self.execute(&name, args).await })
266 .buffered(self.max_parallel)
267 .collect()
268 .await
269 } else {
270 let futures: Vec<_> = calls
271 .into_iter()
272 .map(|(name, args)| {
273 async move { self.execute(&name, args).await }
275 })
276 .collect();
277
278 futures::future::join_all(futures).await
279 }
280 }
281
282 pub fn registry(&self) -> &ToolRegistry {
284 &self.registry
285 }
286
287 pub fn registry_mut(&mut self) -> &mut ToolRegistry {
289 &mut self.registry
290 }
291
292 pub fn has_tool(&self, name: &str) -> bool {
294 self.registry.contains(name)
295 }
296
297 pub fn tool_names(&self) -> Vec<&str> {
299 self.registry.names()
300 }
301}
302
303impl std::fmt::Debug for ToolExecutor {
304 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305 f.debug_struct("ToolExecutor")
306 .field("tools", &self.registry.names())
307 .field("audit_enabled", &self.audit_enabled)
308 .field("max_parallel", &self.max_parallel)
309 .finish()
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use crate::tool::{Tool, ToolDefinition};
317 use async_trait::async_trait;
318 use std::sync::Arc;
319 use std::time::Duration;
320
321 struct EchoTool {
323 definition: ToolDefinition,
324 }
325
326 impl EchoTool {
327 fn new() -> Self {
328 Self {
329 definition: ToolDefinition::new(
330 "echo",
331 "Echo back the input",
332 r#"{"type": "object"}"#,
333 ),
334 }
335 }
336 }
337
338 #[async_trait]
339 impl Tool for EchoTool {
340 fn definition(&self) -> &ToolDefinition {
341 &self.definition
342 }
343
344 async fn execute(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
345 Ok(serde_json::json!({ "echo": args }))
346 }
347 }
348
349 struct FailTool {
351 definition: ToolDefinition,
352 }
353
354 impl FailTool {
355 fn new() -> Self {
356 Self {
357 definition: ToolDefinition::new("fail", "Always fails", r#"{"type": "object"}"#),
358 }
359 }
360 }
361
362 #[async_trait]
363 impl Tool for FailTool {
364 fn definition(&self) -> &ToolDefinition {
365 &self.definition
366 }
367
368 async fn execute(&self, _args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
369 Err(ToolError::execution_failed("fail", "Intentional failure"))
370 }
371 }
372
373 struct SlowTool {
375 definition: ToolDefinition,
376 }
377
378 impl SlowTool {
379 fn new() -> Self {
380 Self {
381 definition: ToolDefinition::new("slow", "Takes forever", r#"{"type": "object"}"#),
382 }
383 }
384 }
385
386 #[async_trait]
387 impl Tool for SlowTool {
388 fn definition(&self) -> &ToolDefinition {
389 &self.definition
390 }
391
392 fn timeout(&self) -> Duration {
393 Duration::from_millis(50) }
395
396 async fn execute(&self, _args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
397 tokio::time::sleep(Duration::from_secs(10)).await;
398 Ok(serde_json::json!({"done": true}))
399 }
400 }
401
402 #[tokio::test]
403 async fn test_execute_success() {
404 let mut registry = ToolRegistry::new();
405 registry.register(Arc::new(EchoTool::new()));
406
407 let executor = ToolExecutor::new(registry);
408 let result = executor
409 .execute("echo", serde_json::json!({"message": "hello"}))
410 .await
411 .unwrap();
412
413 assert_eq!(result.tool_name, "echo");
414 assert!(result.output["echo"]["message"] == "hello");
415 assert!(!result.hash.to_string().is_empty());
416 }
417
418 #[tokio::test]
419 async fn test_execute_not_found() {
420 let registry = ToolRegistry::new();
421 let executor = ToolExecutor::new(registry);
422
423 let result = executor.execute("nonexistent", serde_json::json!({})).await;
424
425 assert!(matches!(result, Err(ToolError::NotFound { .. })));
426 }
427
428 #[tokio::test]
429 async fn test_execute_failure() {
430 let mut registry = ToolRegistry::new();
431 registry.register(Arc::new(FailTool::new()));
432
433 let executor = ToolExecutor::new(registry);
434 let result = executor.execute("fail", serde_json::json!({})).await;
435
436 assert!(matches!(result, Err(ToolError::ExecutionFailed { .. })));
437 }
438
439 #[tokio::test]
440 async fn test_execute_timeout() {
441 let mut registry = ToolRegistry::new();
442 registry.register(Arc::new(SlowTool::new()));
443
444 let executor = ToolExecutor::new(registry);
445 let result = executor.execute("slow", serde_json::json!({})).await;
446
447 assert!(matches!(result, Err(ToolError::Timeout { .. })));
448 }
449
450 #[tokio::test]
451 async fn test_execute_parallel() {
452 let mut registry = ToolRegistry::new();
453 registry.register(Arc::new(EchoTool::new()));
454
455 let executor = ToolExecutor::new(registry);
456
457 let calls = vec![
458 ("echo".to_string(), serde_json::json!({"n": 1})),
459 ("echo".to_string(), serde_json::json!({"n": 2})),
460 ("echo".to_string(), serde_json::json!({"n": 3})),
461 ];
462
463 let results = executor.execute_parallel(calls).await;
464
465 assert_eq!(results.len(), 3);
466 assert!(results.iter().all(|r| r.is_ok()));
467 }
468
469 #[tokio::test]
470 async fn test_has_tool() {
471 let mut registry = ToolRegistry::new();
472 registry.register(Arc::new(EchoTool::new()));
473
474 let executor = ToolExecutor::new(registry);
475
476 assert!(executor.has_tool("echo"));
477 assert!(!executor.has_tool("nonexistent"));
478 }
479}