vex_llm/
tool_executor.rs

1//! Tool Executor with Merkle audit integration
2//!
3//! This module provides `ToolExecutor` which wraps tool execution with:
4//! - Timeout protection (DoS prevention)
5//! - Input validation
6//! - Cryptographic result hashing
7//! - Merkle audit trail integration
8//!
9//! # VEX Innovation
10//!
11//! Every tool execution is automatically logged to the audit chain with:
12//! - Tool name and argument hash (not raw args for privacy)
13//! - Result hash for verification
14//! - Execution time metrics
15//!
16//! This enables cryptographic proof of what tools were used.
17//!
18//! # Security Considerations
19//!
20//! - All executions have configurable timeouts
21//! - Validation runs before execution
22//! - Audit logging is non-fatal (doesn't break execution)
23//! - Arguments are hashed before logging (privacy protection)
24
25use std::time::Instant;
26use tokio::time::timeout;
27use tracing::{debug, error, info, warn};
28
29use crate::tool::{Capability, ToolRegistry};
30use crate::tool_error::ToolError;
31use crate::tool_result::ToolResult;
32
33/// Tool executor with automatic audit logging and timeout protection.
34///
35/// The executor provides a safe, audited interface to execute tools:
36/// 1. Validates arguments against tool's schema
37/// 2. Executes with timeout protection
38/// 3. Hashes results for Merkle chain
39/// 4. Logs execution to audit trail (if configured)
40///
41/// # Example
42///
43/// ```ignore
44/// use vex_llm::{ToolExecutor, ToolRegistry};
45///
46/// let registry = ToolRegistry::with_builtins();
47/// let executor = ToolExecutor::new(registry);
48///
49/// let result = executor
50///     .execute("calculator", json!({"expression": "2+2"}))
51///     .await?;
52///
53/// println!("Result: {}", result.output);
54/// println!("Hash: {}", result.hash);
55/// ```
56pub struct ToolExecutor {
57    registry: ToolRegistry,
58    /// Enable/disable audit logging
59    audit_enabled: bool,
60    /// Maximum parallel executions (0 = unlimited)
61    max_parallel: usize,
62    /// Allowed capabilities for this executor (Security Sandbox)
63    allowed_capabilities: Vec<Capability>,
64}
65
66impl ToolExecutor {
67    /// Create a new executor with the given registry
68    pub fn new(registry: ToolRegistry) -> Self {
69        Self {
70            registry,
71            audit_enabled: true,
72            max_parallel: 0, // Unlimited by default
73            allowed_capabilities: vec![
74                Capability::PureComputation,
75                Capability::Network,
76                Capability::FileSystem,
77                Capability::Subprocess,
78                Capability::Environment,
79                Capability::Cryptography,
80            ],
81        }
82    }
83
84    /// Create executor with audit logging disabled
85    pub fn without_audit(registry: ToolRegistry) -> Self {
86        Self {
87            registry,
88            audit_enabled: false,
89            max_parallel: 0,
90            allowed_capabilities: vec![
91                Capability::PureComputation,
92                Capability::Network,
93                Capability::FileSystem,
94                Capability::Subprocess,
95                Capability::Environment,
96                Capability::Cryptography,
97            ],
98        }
99    }
100
101    /// Set maximum parallel executions
102    pub fn with_max_parallel(mut self, max: usize) -> Self {
103        self.max_parallel = max;
104        self
105    }
106
107    /// Set allowed capabilities for the sandbox
108    pub fn with_allowed_capabilities(mut self, caps: Vec<Capability>) -> Self {
109        self.allowed_capabilities = caps;
110        self
111    }
112
113    /// Execute a tool by name with given arguments.
114    ///
115    /// # Arguments
116    ///
117    /// * `tool_name` - Name of the tool to execute
118    /// * `args` - JSON arguments to pass to the tool
119    ///
120    /// # Returns
121    ///
122    /// * `Ok(ToolResult)` - Execution result with hash for verification
123    /// * `Err(ToolError)` - If tool not found, validation failed, execution error, or timeout
124    ///
125    /// # Security
126    ///
127    /// - Tool lookup prevents arbitrary code execution
128    /// - Timeout prevents DoS from hanging tools
129    /// - Result hash enables tamper detection
130    pub async fn execute(
131        &self,
132        tool_name: &str,
133        args: serde_json::Value,
134    ) -> Result<ToolResult, ToolError> {
135        // 1. Get tool from registry
136        let tool = self.registry.get(tool_name).ok_or_else(|| {
137            warn!(tool = tool_name, "Tool not found");
138            ToolError::not_found(tool_name)
139        })?;
140
141        // 1.5. Check capabilities (Sandbox)
142        for cap in tool.capabilities() {
143            if !self.allowed_capabilities.contains(&cap) {
144                warn!(
145                    tool = tool_name,
146                    capability = ?cap,
147                    "Tool requires missing capability"
148                );
149                return Err(ToolError::unavailable(
150                    tool_name,
151                    format!("Sandbox violation: tool requires {:?}", cap),
152                ));
153            }
154        }
155
156        // 2. Check availability
157        if !tool.is_available() {
158            warn!(tool = tool_name, "Tool is unavailable");
159            return Err(ToolError::unavailable(
160                tool_name,
161                "Tool is currently disabled",
162            ));
163        }
164
165        // 3. Validate arguments against JSON Schema
166        debug!(tool = tool_name, "Validating arguments against schema");
167        let schema_str = tool.definition().parameters;
168        if !schema_str.is_empty() && schema_str != "{}" {
169            let schema_json: serde_json::Value = serde_json::from_str(schema_str).map_err(|e| {
170                ToolError::execution_failed(tool_name, format!("Invalid tool schema: {}", e))
171            })?;
172
173            let compiled = jsonschema::JSONSchema::compile(&schema_json).map_err(|e| {
174                ToolError::execution_failed(
175                    tool_name,
176                    format!("Failed to compile tool schema: {}", e),
177                )
178            })?;
179
180            if !compiled.is_valid(&args) {
181                warn!(tool = tool_name, "Schema validation failed");
182                return Err(ToolError::invalid_args(
183                    tool_name,
184                    "Arguments do not match tool schema",
185                ));
186            }
187        }
188
189        // 4. Custom validation
190        debug!(tool = tool_name, "Running custom validation");
191        tool.validate(&args)?;
192
193        // 5. Execute with timeout
194        let tool_timeout = tool.timeout();
195        let start = Instant::now();
196
197        debug!(
198            tool = tool_name,
199            timeout_ms = tool_timeout.as_millis(),
200            "Executing tool"
201        );
202
203        let output = timeout(tool_timeout, tool.execute(args.clone()))
204            .await
205            .map_err(|_| {
206                error!(
207                    tool = tool_name,
208                    timeout_ms = tool_timeout.as_millis(),
209                    "Tool execution timed out"
210                );
211                ToolError::timeout(tool_name, tool_timeout.as_millis() as u64)
212            })??;
213
214        let elapsed = start.elapsed();
215
216        // 5. Create result with cryptographic hash
217        let result = ToolResult::new(tool_name, &args, output, elapsed);
218
219        // 6. Log execution metrics
220        info!(
221            tool = tool_name,
222            execution_ms = elapsed.as_millis(),
223            hash = %result.hash,
224            "Tool executed successfully"
225        );
226
227        // 7. Audit logging would happen here (integration point)
228        // Note: We log to tracing; actual AuditStore integration is in the runtime
229        if self.audit_enabled {
230            debug!(
231                tool = tool_name,
232                result_hash = %result.hash,
233                "Audit entry created"
234            );
235        }
236
237        Ok(result)
238    }
239
240    /// Execute multiple tools in parallel.
241    ///
242    /// # Arguments
243    ///
244    /// * `calls` - Vector of (tool_name, args) pairs
245    ///
246    /// # Returns
247    ///
248    /// Vector of results in the same order as input.
249    /// Each result is independent (one failure doesn't affect others).
250    ///
251    /// # Security
252    ///
253    /// - Respects max_parallel limit to prevent resource exhaustion
254    /// - Each tool has its own timeout
255    pub async fn execute_parallel(
256        &self,
257        calls: Vec<(String, serde_json::Value)>,
258    ) -> Vec<Result<ToolResult, ToolError>> {
259        debug!(count = calls.len(), "Executing tools in parallel");
260
261        if self.max_parallel > 0 {
262            use futures::stream::{self, StreamExt};
263
264            stream::iter(calls)
265                .map(|(name, args)| async move { self.execute(&name, args).await })
266                .buffered(self.max_parallel)
267                .collect()
268                .await
269        } else {
270            let futures: Vec<_> = calls
271                .into_iter()
272                .map(|(name, args)| {
273                    // Create an owned future that doesn't borrow the iterator
274                    async move { self.execute(&name, args).await }
275                })
276                .collect();
277
278            futures::future::join_all(futures).await
279        }
280    }
281
282    /// Get a reference to the tool registry
283    pub fn registry(&self) -> &ToolRegistry {
284        &self.registry
285    }
286
287    /// Get a mutable reference to the tool registry
288    pub fn registry_mut(&mut self) -> &mut ToolRegistry {
289        &mut self.registry
290    }
291
292    /// Check if a tool exists
293    pub fn has_tool(&self, name: &str) -> bool {
294        self.registry.contains(name)
295    }
296
297    /// List all available tool names
298    pub fn tool_names(&self) -> Vec<&str> {
299        self.registry.names()
300    }
301}
302
303impl std::fmt::Debug for ToolExecutor {
304    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305        f.debug_struct("ToolExecutor")
306            .field("tools", &self.registry.names())
307            .field("audit_enabled", &self.audit_enabled)
308            .field("max_parallel", &self.max_parallel)
309            .finish()
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use crate::tool::{Tool, ToolDefinition};
317    use async_trait::async_trait;
318    use std::sync::Arc;
319    use std::time::Duration;
320
321    // Test tool that returns arguments
322    struct EchoTool {
323        definition: ToolDefinition,
324    }
325
326    impl EchoTool {
327        fn new() -> Self {
328            Self {
329                definition: ToolDefinition::new(
330                    "echo",
331                    "Echo back the input",
332                    r#"{"type": "object"}"#,
333                ),
334            }
335        }
336    }
337
338    #[async_trait]
339    impl Tool for EchoTool {
340        fn definition(&self) -> &ToolDefinition {
341            &self.definition
342        }
343
344        async fn execute(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
345            Ok(serde_json::json!({ "echo": args }))
346        }
347    }
348
349    // Test tool that always fails
350    struct FailTool {
351        definition: ToolDefinition,
352    }
353
354    impl FailTool {
355        fn new() -> Self {
356            Self {
357                definition: ToolDefinition::new("fail", "Always fails", r#"{"type": "object"}"#),
358            }
359        }
360    }
361
362    #[async_trait]
363    impl Tool for FailTool {
364        fn definition(&self) -> &ToolDefinition {
365            &self.definition
366        }
367
368        async fn execute(&self, _args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
369            Err(ToolError::execution_failed("fail", "Intentional failure"))
370        }
371    }
372
373    // Test tool that times out
374    struct SlowTool {
375        definition: ToolDefinition,
376    }
377
378    impl SlowTool {
379        fn new() -> Self {
380            Self {
381                definition: ToolDefinition::new("slow", "Takes forever", r#"{"type": "object"}"#),
382            }
383        }
384    }
385
386    #[async_trait]
387    impl Tool for SlowTool {
388        fn definition(&self) -> &ToolDefinition {
389            &self.definition
390        }
391
392        fn timeout(&self) -> Duration {
393            Duration::from_millis(50) // Very short timeout for testing
394        }
395
396        async fn execute(&self, _args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
397            tokio::time::sleep(Duration::from_secs(10)).await;
398            Ok(serde_json::json!({"done": true}))
399        }
400    }
401
402    #[tokio::test]
403    async fn test_execute_success() {
404        let mut registry = ToolRegistry::new();
405        registry.register(Arc::new(EchoTool::new()));
406
407        let executor = ToolExecutor::new(registry);
408        let result = executor
409            .execute("echo", serde_json::json!({"message": "hello"}))
410            .await
411            .unwrap();
412
413        assert_eq!(result.tool_name, "echo");
414        assert!(result.output["echo"]["message"] == "hello");
415        assert!(!result.hash.to_string().is_empty());
416    }
417
418    #[tokio::test]
419    async fn test_execute_not_found() {
420        let registry = ToolRegistry::new();
421        let executor = ToolExecutor::new(registry);
422
423        let result = executor.execute("nonexistent", serde_json::json!({})).await;
424
425        assert!(matches!(result, Err(ToolError::NotFound { .. })));
426    }
427
428    #[tokio::test]
429    async fn test_execute_failure() {
430        let mut registry = ToolRegistry::new();
431        registry.register(Arc::new(FailTool::new()));
432
433        let executor = ToolExecutor::new(registry);
434        let result = executor.execute("fail", serde_json::json!({})).await;
435
436        assert!(matches!(result, Err(ToolError::ExecutionFailed { .. })));
437    }
438
439    #[tokio::test]
440    async fn test_execute_timeout() {
441        let mut registry = ToolRegistry::new();
442        registry.register(Arc::new(SlowTool::new()));
443
444        let executor = ToolExecutor::new(registry);
445        let result = executor.execute("slow", serde_json::json!({})).await;
446
447        assert!(matches!(result, Err(ToolError::Timeout { .. })));
448    }
449
450    #[tokio::test]
451    async fn test_execute_parallel() {
452        let mut registry = ToolRegistry::new();
453        registry.register(Arc::new(EchoTool::new()));
454
455        let executor = ToolExecutor::new(registry);
456
457        let calls = vec![
458            ("echo".to_string(), serde_json::json!({"n": 1})),
459            ("echo".to_string(), serde_json::json!({"n": 2})),
460            ("echo".to_string(), serde_json::json!({"n": 3})),
461        ];
462
463        let results = executor.execute_parallel(calls).await;
464
465        assert_eq!(results.len(), 3);
466        assert!(results.iter().all(|r| r.is_ok()));
467    }
468
469    #[tokio::test]
470    async fn test_has_tool() {
471        let mut registry = ToolRegistry::new();
472        registry.register(Arc::new(EchoTool::new()));
473
474        let executor = ToolExecutor::new(registry);
475
476        assert!(executor.has_tool("echo"));
477        assert!(!executor.has_tool("nonexistent"));
478    }
479}