vex_llm/
streaming_tool.rs

1//! Streaming tool support for long-running operations
2//!
3//! This module provides traits and types for tools that produce streaming output.
4//!
5//! # Security Considerations
6//!
7//! - **Backpressure**: Streams MUST respect consumer pace (DoS prevention)
8//! - **Timeouts**: Per-chunk timeouts in addition to total timeout
9//! - **Resource Limits**: Maximum chunks per stream to prevent memory exhaustion
10//! - **Cancellation**: Streams support graceful cancellation via Drop
11//!
12//! # Example
13//!
14//! ```ignore
15//! use vex_llm::streaming_tool::{StreamingTool, ToolChunk};
16//!
17//! let stream = tool.execute_stream(args);
18//! pin_mut!(stream);
19//! while let Some(chunk) = stream.next().await {
20//!     match chunk {
21//!         ToolChunk::Progress { percent, message } => println!("{}% - {}", percent, message),
22//!         ToolChunk::Complete { result } => println!("Done: {:?}", result.hash),
23//!         _ => {}
24//!     }
25//! }
26//! ```
27
28use std::pin::Pin;
29use std::time::Duration;
30
31use async_trait::async_trait;
32use futures::Stream;
33use serde::{Deserialize, Serialize};
34use serde_json::Value;
35
36use crate::tool::Tool;
37use crate::tool_error::ToolError;
38use crate::tool_result::ToolResult;
39
40/// A chunk of streaming output from a tool.
41///
42/// # Security
43///
44/// - Progress updates are rate-limited by design (max 1/100ms recommended)
45/// - Partial data is NOT hashed until Complete (prevents hash oracle attacks)
46/// - Errors stop the stream immediately
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub enum ToolChunk {
49    /// Progress update (percentage and message)
50    Progress {
51        /// Progress percentage (0.0 to 100.0)
52        percent: f32,
53        /// Human-readable status message
54        message: String,
55    },
56
57    /// Partial data chunk (intermediate result)
58    Partial {
59        /// Partial data (NOT hashed for security)
60        data: Value,
61        /// Chunk index (0-based)
62        index: usize,
63    },
64
65    /// Final complete result with cryptographic hash
66    Complete {
67        /// Final result with Merkle-compatible hash
68        result: ToolResult,
69    },
70
71    /// Error during streaming (terminates stream)
72    /// Note: Uses String to allow Clone/Serialize without ToolError constraints
73    Error {
74        /// Tool name that failed
75        tool: String,
76        /// Error message (sanitized)
77        message: String,
78        /// Whether the error is retryable
79        retryable: bool,
80    },
81}
82
83impl ToolChunk {
84    /// Create a progress chunk
85    pub fn progress(percent: f32, message: impl Into<String>) -> Self {
86        Self::Progress {
87            percent: percent.clamp(0.0, 100.0),
88            message: message.into(),
89        }
90    }
91
92    /// Create a partial data chunk
93    pub fn partial(data: Value, index: usize) -> Self {
94        Self::Partial { data, index }
95    }
96
97    /// Create a complete chunk from a tool result
98    pub fn complete(result: ToolResult) -> Self {
99        Self::Complete { result }
100    }
101
102    /// Create an error chunk from a ToolError
103    pub fn from_error(error: &ToolError) -> Self {
104        Self::Error {
105            tool: match error {
106                ToolError::NotFound { name } => name.clone(),
107                ToolError::InvalidArguments { tool, .. } => tool.clone(),
108                ToolError::ExecutionFailed { tool, .. } => tool.clone(),
109                ToolError::Timeout { tool, .. } => tool.clone(),
110                ToolError::Unavailable { name, .. } => name.clone(),
111                ToolError::Serialization(_) => "serialization".to_string(),
112                ToolError::AuditFailed(_) => "audit".to_string(),
113            },
114            message: error.to_string(),
115            retryable: error.is_retryable(),
116        }
117    }
118
119    /// Create a simple error chunk
120    pub fn error(tool: impl Into<String>, message: impl Into<String>) -> Self {
121        Self::Error {
122            tool: tool.into(),
123            message: message.into(),
124            retryable: false,
125        }
126    }
127
128    /// Check if this chunk terminates the stream
129    pub fn is_terminal(&self) -> bool {
130        matches!(self, Self::Complete { .. } | Self::Error { .. })
131    }
132}
133
134/// Type alias for a boxed async stream of tool chunks
135pub type ToolStream = Pin<Box<dyn Stream<Item = ToolChunk> + Send>>;
136
137/// Configuration for streaming tool execution
138///
139/// # Security
140///
141/// - `max_chunks`: Prevents unbounded memory growth (DoS)
142/// - `chunk_timeout`: Prevents hanging streams (DoS)
143/// - `max_duration`: Total execution limit
144#[derive(Debug, Clone)]
145pub struct StreamConfig {
146    /// Maximum number of chunks before forced termination
147    pub max_chunks: usize,
148    /// Timeout for each individual chunk
149    pub chunk_timeout: Duration,
150    /// Maximum total duration for the stream
151    pub max_duration: Duration,
152    /// Minimum interval between progress updates (rate limiting)
153    pub min_progress_interval: Duration,
154}
155
156impl Default for StreamConfig {
157    fn default() -> Self {
158        Self {
159            max_chunks: 1000,
160            chunk_timeout: Duration::from_secs(30),
161            max_duration: Duration::from_secs(300), // 5 minutes
162            min_progress_interval: Duration::from_millis(100),
163        }
164    }
165}
166
167impl StreamConfig {
168    /// Create config for short operations
169    pub fn short() -> Self {
170        Self {
171            max_chunks: 100,
172            chunk_timeout: Duration::from_secs(5),
173            max_duration: Duration::from_secs(30),
174            min_progress_interval: Duration::from_millis(50),
175        }
176    }
177
178    /// Create config for long operations
179    pub fn long() -> Self {
180        Self {
181            max_chunks: 10000,
182            chunk_timeout: Duration::from_secs(60),
183            max_duration: Duration::from_secs(3600), // 1 hour
184            min_progress_interval: Duration::from_millis(500),
185        }
186    }
187}
188
189/// Trait for tools that produce streaming output.
190///
191/// # Security
192///
193/// Implementors MUST:
194/// - Respect cancellation (check for stream drop)
195/// - Limit output size (respect StreamConfig)
196/// - Hash only the final result (not intermediate chunks)
197/// - Sanitize all output data
198#[async_trait]
199pub trait StreamingTool: Tool {
200    /// Execute with streaming output
201    ///
202    /// Returns a stream of `ToolChunk` values. The stream MUST:
203    /// - Emit at least one `Complete` or `Error` chunk before ending
204    /// - Respect the provided configuration limits
205    /// - Be cancellable (stop when dropped)
206    fn execute_stream(&self, args: Value, config: StreamConfig) -> ToolStream;
207
208    /// Get the default stream configuration for this tool
209    fn stream_config(&self) -> StreamConfig {
210        StreamConfig::default()
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use std::time::Duration;
218
219    #[test]
220    fn test_tool_chunk_progress() {
221        let chunk = ToolChunk::progress(50.0, "Halfway there");
222        match chunk {
223            ToolChunk::Progress { percent, message } => {
224                assert_eq!(percent, 50.0);
225                assert_eq!(message, "Halfway there");
226            }
227            _ => panic!("Expected Progress chunk"),
228        }
229    }
230
231    #[test]
232    fn test_tool_chunk_progress_clamped() {
233        let chunk = ToolChunk::progress(150.0, "Over 100");
234        match chunk {
235            ToolChunk::Progress { percent, .. } => {
236                assert_eq!(percent, 100.0);
237            }
238            _ => panic!("Expected Progress chunk"),
239        }
240    }
241
242    #[test]
243    fn test_tool_chunk_is_terminal() {
244        assert!(!ToolChunk::progress(50.0, "").is_terminal());
245        assert!(!ToolChunk::partial(serde_json::json!({}), 0).is_terminal());
246
247        let result = ToolResult::new(
248            "test",
249            &serde_json::json!({}),
250            serde_json::json!({}),
251            Duration::from_secs(1),
252        );
253        assert!(ToolChunk::complete(result).is_terminal());
254
255        assert!(ToolChunk::error("test", "not found").is_terminal());
256    }
257
258    #[test]
259    fn test_stream_config_default() {
260        let config = StreamConfig::default();
261        assert_eq!(config.max_chunks, 1000);
262        assert_eq!(config.max_duration, Duration::from_secs(300));
263    }
264
265    #[test]
266    fn test_stream_config_short() {
267        let config = StreamConfig::short();
268        assert_eq!(config.max_chunks, 100);
269        assert!(config.max_duration < Duration::from_secs(60));
270    }
271
272    #[test]
273    fn test_stream_config_long() {
274        let config = StreamConfig::long();
275        assert_eq!(config.max_chunks, 10000);
276        assert_eq!(config.max_duration, Duration::from_secs(3600));
277    }
278}