vex_llm/tools/
regex.rs

1//! Regex tool for pattern matching
2//!
3//! Provides regex matching and extraction capabilities.
4//!
5//! # Security
6//!
7//! - Pattern length limited (prevents ReDoS)
8//! - Execution timeout on complex patterns
9//! - Pure computation, no I/O
10
11use async_trait::async_trait;
12use regex::Regex;
13use serde_json::Value;
14
15use crate::tool::{Capability, Tool, ToolDefinition};
16use crate::tool_error::ToolError;
17
18/// Regex tool for pattern matching and extraction.
19///
20/// # Example
21///
22/// ```ignore
23/// use vex_llm::RegexTool;
24/// use vex_llm::Tool;
25///
26/// let re = RegexTool::new();
27/// let result = re.execute(json!({
28///     "pattern": r"\d+",
29///     "text": "Order 12345",
30///     "operation": "find_all"
31/// })).await?;
32/// println!("{:?}", result["matches"]);
33/// ```
34pub struct RegexTool {
35    definition: ToolDefinition,
36}
37
38impl RegexTool {
39    /// Create a new regex tool
40    pub fn new() -> Self {
41        Self {
42            definition: ToolDefinition::new(
43                "regex",
44                "Match or extract patterns using regular expressions.",
45                r#"{
46                    "type": "object",
47                    "properties": {
48                        "pattern": {
49                            "type": "string",
50                            "description": "Regular expression pattern"
51                        },
52                        "text": {
53                            "type": "string",
54                            "description": "Text to search in"
55                        },
56                        "operation": {
57                            "type": "string",
58                            "enum": ["match", "find_all", "replace"],
59                            "default": "match",
60                            "description": "Operation: 'match' (check if matches), 'find_all' (extract all matches), 'replace' (replace matches)"
61                        },
62                        "replacement": {
63                            "type": "string",
64                            "description": "Replacement string (for 'replace' operation)"
65                        }
66                    },
67                    "required": ["pattern", "text"]
68                }"#,
69            ),
70        }
71    }
72}
73
74impl Default for RegexTool {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80#[async_trait]
81impl Tool for RegexTool {
82    fn definition(&self) -> &ToolDefinition {
83        &self.definition
84    }
85
86    fn capabilities(&self) -> Vec<Capability> {
87        vec![Capability::PureComputation]
88    }
89
90    fn validate(&self, args: &Value) -> Result<(), ToolError> {
91        let pattern = args
92            .get("pattern")
93            .and_then(|p| p.as_str())
94            .ok_or_else(|| ToolError::invalid_args("regex", "Missing required field 'pattern'"))?;
95
96        // Limit pattern length to prevent ReDoS
97        if pattern.len() > 500 {
98            return Err(ToolError::invalid_args(
99                "regex",
100                "Pattern too long (max 500 characters)",
101            ));
102        }
103
104        // Validate the regex compiles
105        Regex::new(pattern).map_err(|e| {
106            ToolError::invalid_args("regex", format!("Invalid regex pattern: {}", e))
107        })?;
108
109        // Check text is provided
110        if args.get("text").and_then(|t| t.as_str()).is_none() {
111            return Err(ToolError::invalid_args(
112                "regex",
113                "Missing required field 'text'",
114            ));
115        }
116
117        // Limit text length
118        if let Some(text) = args.get("text").and_then(|t| t.as_str()) {
119            if text.len() > 100_000 {
120                return Err(ToolError::invalid_args(
121                    "regex",
122                    "Text too long (max 100KB)",
123                ));
124            }
125        }
126
127        Ok(())
128    }
129
130    async fn execute(&self, args: Value) -> Result<Value, ToolError> {
131        let pattern = args["pattern"]
132            .as_str()
133            .ok_or_else(|| ToolError::invalid_args("regex", "Missing 'pattern' field"))?;
134
135        let text = args["text"]
136            .as_str()
137            .ok_or_else(|| ToolError::invalid_args("regex", "Missing 'text' field"))?;
138
139        let operation = args
140            .get("operation")
141            .and_then(|o| o.as_str())
142            .unwrap_or("match");
143
144        let re = Regex::new(pattern)
145            .map_err(|e| ToolError::execution_failed("regex", format!("Invalid regex: {}", e)))?;
146
147        match operation {
148            "match" => {
149                let is_match = re.is_match(text);
150                let first_match = re.find(text).map(|m| m.as_str().to_string());
151
152                Ok(serde_json::json!({
153                    "matches": is_match,
154                    "first_match": first_match,
155                    "pattern": pattern
156                }))
157            }
158            "find_all" => {
159                let matches: Vec<String> =
160                    re.find_iter(text).map(|m| m.as_str().to_string()).collect();
161
162                Ok(serde_json::json!({
163                    "matches": matches,
164                    "count": matches.len(),
165                    "pattern": pattern
166                }))
167            }
168            "replace" => {
169                let replacement = args
170                    .get("replacement")
171                    .and_then(|r| r.as_str())
172                    .unwrap_or("");
173
174                let result = re.replace_all(text, replacement).to_string();
175
176                Ok(serde_json::json!({
177                    "result": result,
178                    "pattern": pattern,
179                    "replacement": replacement
180                }))
181            }
182            _ => Err(ToolError::invalid_args(
183                "regex",
184                format!(
185                    "Unknown operation '{}'. Use 'match', 'find_all', or 'replace'",
186                    operation
187                ),
188            )),
189        }
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    #[tokio::test]
198    async fn test_match_found() {
199        let tool = RegexTool::new();
200        let result = tool
201            .execute(serde_json::json!({
202                "pattern": r"\d+",
203                "text": "Order 12345",
204                "operation": "match"
205            }))
206            .await
207            .unwrap();
208
209        assert_eq!(result["matches"], true);
210        assert_eq!(result["first_match"], "12345");
211    }
212
213    #[tokio::test]
214    async fn test_match_not_found() {
215        let tool = RegexTool::new();
216        let result = tool
217            .execute(serde_json::json!({
218                "pattern": r"\d+",
219                "text": "No numbers here",
220                "operation": "match"
221            }))
222            .await
223            .unwrap();
224
225        assert_eq!(result["matches"], false);
226        assert!(result["first_match"].is_null());
227    }
228
229    #[tokio::test]
230    async fn test_find_all() {
231        let tool = RegexTool::new();
232        let result = tool
233            .execute(serde_json::json!({
234                "pattern": r"\d+",
235                "text": "Items: 10, 20, 30",
236                "operation": "find_all"
237            }))
238            .await
239            .unwrap();
240
241        let matches = result["matches"].as_array().unwrap();
242        assert_eq!(matches.len(), 3);
243        assert_eq!(matches[0], "10");
244        assert_eq!(matches[1], "20");
245        assert_eq!(matches[2], "30");
246    }
247
248    #[tokio::test]
249    async fn test_replace() {
250        let tool = RegexTool::new();
251        let result = tool
252            .execute(serde_json::json!({
253                "pattern": r"\d+",
254                "text": "Price: $100",
255                "operation": "replace",
256                "replacement": "XXX"
257            }))
258            .await
259            .unwrap();
260
261        assert_eq!(result["result"], "Price: $XXX");
262    }
263
264    #[tokio::test]
265    async fn test_invalid_pattern() {
266        let tool = RegexTool::new();
267        let result = tool.validate(&serde_json::json!({
268            "pattern": "[invalid(",
269            "text": "test"
270        }));
271
272        assert!(matches!(result, Err(ToolError::InvalidArguments { .. })));
273    }
274
275    #[tokio::test]
276    async fn test_pattern_too_long() {
277        let tool = RegexTool::new();
278        let long_pattern = "a".repeat(600);
279        let result = tool.validate(&serde_json::json!({
280            "pattern": long_pattern,
281            "text": "test"
282        }));
283
284        assert!(matches!(result, Err(ToolError::InvalidArguments { .. })));
285    }
286}