vex_llm/
streaming_tool.rs1use std::pin::Pin;
29use std::time::Duration;
30
31use async_trait::async_trait;
32use futures::Stream;
33use serde::{Deserialize, Serialize};
34use serde_json::Value;
35
36use crate::tool::Tool;
37use crate::tool_error::ToolError;
38use crate::tool_result::ToolResult;
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
48pub enum ToolChunk {
49 Progress {
51 percent: f32,
53 message: String,
55 },
56
57 Partial {
59 data: Value,
61 index: usize,
63 },
64
65 Complete {
67 result: ToolResult,
69 },
70
71 Error {
74 tool: String,
76 message: String,
78 retryable: bool,
80 },
81}
82
83impl ToolChunk {
84 pub fn progress(percent: f32, message: impl Into<String>) -> Self {
86 Self::Progress {
87 percent: percent.clamp(0.0, 100.0),
88 message: message.into(),
89 }
90 }
91
92 pub fn partial(data: Value, index: usize) -> Self {
94 Self::Partial { data, index }
95 }
96
97 pub fn complete(result: ToolResult) -> Self {
99 Self::Complete { result }
100 }
101
102 pub fn from_error(error: &ToolError) -> Self {
104 Self::Error {
105 tool: match error {
106 ToolError::NotFound { name } => name.clone(),
107 ToolError::InvalidArguments { tool, .. } => tool.clone(),
108 ToolError::ExecutionFailed { tool, .. } => tool.clone(),
109 ToolError::Timeout { tool, .. } => tool.clone(),
110 ToolError::Unavailable { name, .. } => name.clone(),
111 ToolError::Serialization(_) => "serialization".to_string(),
112 ToolError::AuditFailed(_) => "audit".to_string(),
113 },
114 message: error.to_string(),
115 retryable: error.is_retryable(),
116 }
117 }
118
119 pub fn error(tool: impl Into<String>, message: impl Into<String>) -> Self {
121 Self::Error {
122 tool: tool.into(),
123 message: message.into(),
124 retryable: false,
125 }
126 }
127
128 pub fn is_terminal(&self) -> bool {
130 matches!(self, Self::Complete { .. } | Self::Error { .. })
131 }
132}
133
134pub type ToolStream = Pin<Box<dyn Stream<Item = ToolChunk> + Send>>;
136
137#[derive(Debug, Clone)]
145pub struct StreamConfig {
146 pub max_chunks: usize,
148 pub chunk_timeout: Duration,
150 pub max_duration: Duration,
152 pub min_progress_interval: Duration,
154}
155
156impl Default for StreamConfig {
157 fn default() -> Self {
158 Self {
159 max_chunks: 1000,
160 chunk_timeout: Duration::from_secs(30),
161 max_duration: Duration::from_secs(300), min_progress_interval: Duration::from_millis(100),
163 }
164 }
165}
166
167impl StreamConfig {
168 pub fn short() -> Self {
170 Self {
171 max_chunks: 100,
172 chunk_timeout: Duration::from_secs(5),
173 max_duration: Duration::from_secs(30),
174 min_progress_interval: Duration::from_millis(50),
175 }
176 }
177
178 pub fn long() -> Self {
180 Self {
181 max_chunks: 10000,
182 chunk_timeout: Duration::from_secs(60),
183 max_duration: Duration::from_secs(3600), min_progress_interval: Duration::from_millis(500),
185 }
186 }
187}
188
189#[async_trait]
199pub trait StreamingTool: Tool {
200 fn execute_stream(&self, args: Value, config: StreamConfig) -> ToolStream;
207
208 fn stream_config(&self) -> StreamConfig {
210 StreamConfig::default()
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use std::time::Duration;
218
219 #[test]
220 fn test_tool_chunk_progress() {
221 let chunk = ToolChunk::progress(50.0, "Halfway there");
222 match chunk {
223 ToolChunk::Progress { percent, message } => {
224 assert_eq!(percent, 50.0);
225 assert_eq!(message, "Halfway there");
226 }
227 _ => panic!("Expected Progress chunk"),
228 }
229 }
230
231 #[test]
232 fn test_tool_chunk_progress_clamped() {
233 let chunk = ToolChunk::progress(150.0, "Over 100");
234 match chunk {
235 ToolChunk::Progress { percent, .. } => {
236 assert_eq!(percent, 100.0);
237 }
238 _ => panic!("Expected Progress chunk"),
239 }
240 }
241
242 #[test]
243 fn test_tool_chunk_is_terminal() {
244 assert!(!ToolChunk::progress(50.0, "").is_terminal());
245 assert!(!ToolChunk::partial(serde_json::json!({}), 0).is_terminal());
246
247 let result = ToolResult::new(
248 "test",
249 &serde_json::json!({}),
250 serde_json::json!({}),
251 Duration::from_secs(1),
252 );
253 assert!(ToolChunk::complete(result).is_terminal());
254
255 assert!(ToolChunk::error("test", "not found").is_terminal());
256 }
257
258 #[test]
259 fn test_stream_config_default() {
260 let config = StreamConfig::default();
261 assert_eq!(config.max_chunks, 1000);
262 assert_eq!(config.max_duration, Duration::from_secs(300));
263 }
264
265 #[test]
266 fn test_stream_config_short() {
267 let config = StreamConfig::short();
268 assert_eq!(config.max_chunks, 100);
269 assert!(config.max_duration < Duration::from_secs(60));
270 }
271
272 #[test]
273 fn test_stream_config_long() {
274 let config = StreamConfig::long();
275 assert_eq!(config.max_chunks, 10000);
276 assert_eq!(config.max_duration, Duration::from_secs(3600));
277 }
278}