1use async_trait::async_trait;
12use regex::Regex;
13use serde_json::Value;
14
15use crate::tool::{Capability, Tool, ToolDefinition};
16use crate::tool_error::ToolError;
17
18pub struct RegexTool {
35 definition: ToolDefinition,
36}
37
38impl RegexTool {
39 pub fn new() -> Self {
41 Self {
42 definition: ToolDefinition::new(
43 "regex",
44 "Match or extract patterns using regular expressions.",
45 r#"{
46 "type": "object",
47 "properties": {
48 "pattern": {
49 "type": "string",
50 "description": "Regular expression pattern"
51 },
52 "text": {
53 "type": "string",
54 "description": "Text to search in"
55 },
56 "operation": {
57 "type": "string",
58 "enum": ["match", "find_all", "replace"],
59 "default": "match",
60 "description": "Operation: 'match' (check if matches), 'find_all' (extract all matches), 'replace' (replace matches)"
61 },
62 "replacement": {
63 "type": "string",
64 "description": "Replacement string (for 'replace' operation)"
65 }
66 },
67 "required": ["pattern", "text"]
68 }"#,
69 ),
70 }
71 }
72}
73
74impl Default for RegexTool {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80#[async_trait]
81impl Tool for RegexTool {
82 fn definition(&self) -> &ToolDefinition {
83 &self.definition
84 }
85
86 fn capabilities(&self) -> Vec<Capability> {
87 vec![Capability::PureComputation]
88 }
89
90 fn validate(&self, args: &Value) -> Result<(), ToolError> {
91 let pattern = args
92 .get("pattern")
93 .and_then(|p| p.as_str())
94 .ok_or_else(|| ToolError::invalid_args("regex", "Missing required field 'pattern'"))?;
95
96 if pattern.len() > 500 {
98 return Err(ToolError::invalid_args(
99 "regex",
100 "Pattern too long (max 500 characters)",
101 ));
102 }
103
104 Regex::new(pattern).map_err(|e| {
106 ToolError::invalid_args("regex", format!("Invalid regex pattern: {}", e))
107 })?;
108
109 if args.get("text").and_then(|t| t.as_str()).is_none() {
111 return Err(ToolError::invalid_args(
112 "regex",
113 "Missing required field 'text'",
114 ));
115 }
116
117 if let Some(text) = args.get("text").and_then(|t| t.as_str()) {
119 if text.len() > 100_000 {
120 return Err(ToolError::invalid_args(
121 "regex",
122 "Text too long (max 100KB)",
123 ));
124 }
125 }
126
127 Ok(())
128 }
129
130 async fn execute(&self, args: Value) -> Result<Value, ToolError> {
131 let pattern = args["pattern"]
132 .as_str()
133 .ok_or_else(|| ToolError::invalid_args("regex", "Missing 'pattern' field"))?;
134
135 let text = args["text"]
136 .as_str()
137 .ok_or_else(|| ToolError::invalid_args("regex", "Missing 'text' field"))?;
138
139 let operation = args
140 .get("operation")
141 .and_then(|o| o.as_str())
142 .unwrap_or("match");
143
144 let re = Regex::new(pattern)
145 .map_err(|e| ToolError::execution_failed("regex", format!("Invalid regex: {}", e)))?;
146
147 match operation {
148 "match" => {
149 let is_match = re.is_match(text);
150 let first_match = re.find(text).map(|m| m.as_str().to_string());
151
152 Ok(serde_json::json!({
153 "matches": is_match,
154 "first_match": first_match,
155 "pattern": pattern
156 }))
157 }
158 "find_all" => {
159 let matches: Vec<String> =
160 re.find_iter(text).map(|m| m.as_str().to_string()).collect();
161
162 Ok(serde_json::json!({
163 "matches": matches,
164 "count": matches.len(),
165 "pattern": pattern
166 }))
167 }
168 "replace" => {
169 let replacement = args
170 .get("replacement")
171 .and_then(|r| r.as_str())
172 .unwrap_or("");
173
174 let result = re.replace_all(text, replacement).to_string();
175
176 Ok(serde_json::json!({
177 "result": result,
178 "pattern": pattern,
179 "replacement": replacement
180 }))
181 }
182 _ => Err(ToolError::invalid_args(
183 "regex",
184 format!(
185 "Unknown operation '{}'. Use 'match', 'find_all', or 'replace'",
186 operation
187 ),
188 )),
189 }
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196
197 #[tokio::test]
198 async fn test_match_found() {
199 let tool = RegexTool::new();
200 let result = tool
201 .execute(serde_json::json!({
202 "pattern": r"\d+",
203 "text": "Order 12345",
204 "operation": "match"
205 }))
206 .await
207 .unwrap();
208
209 assert_eq!(result["matches"], true);
210 assert_eq!(result["first_match"], "12345");
211 }
212
213 #[tokio::test]
214 async fn test_match_not_found() {
215 let tool = RegexTool::new();
216 let result = tool
217 .execute(serde_json::json!({
218 "pattern": r"\d+",
219 "text": "No numbers here",
220 "operation": "match"
221 }))
222 .await
223 .unwrap();
224
225 assert_eq!(result["matches"], false);
226 assert!(result["first_match"].is_null());
227 }
228
229 #[tokio::test]
230 async fn test_find_all() {
231 let tool = RegexTool::new();
232 let result = tool
233 .execute(serde_json::json!({
234 "pattern": r"\d+",
235 "text": "Items: 10, 20, 30",
236 "operation": "find_all"
237 }))
238 .await
239 .unwrap();
240
241 let matches = result["matches"].as_array().unwrap();
242 assert_eq!(matches.len(), 3);
243 assert_eq!(matches[0], "10");
244 assert_eq!(matches[1], "20");
245 assert_eq!(matches[2], "30");
246 }
247
248 #[tokio::test]
249 async fn test_replace() {
250 let tool = RegexTool::new();
251 let result = tool
252 .execute(serde_json::json!({
253 "pattern": r"\d+",
254 "text": "Price: $100",
255 "operation": "replace",
256 "replacement": "XXX"
257 }))
258 .await
259 .unwrap();
260
261 assert_eq!(result["result"], "Price: $XXX");
262 }
263
264 #[tokio::test]
265 async fn test_invalid_pattern() {
266 let tool = RegexTool::new();
267 let result = tool.validate(&serde_json::json!({
268 "pattern": "[invalid(",
269 "text": "test"
270 }));
271
272 assert!(matches!(result, Err(ToolError::InvalidArguments { .. })));
273 }
274
275 #[tokio::test]
276 async fn test_pattern_too_long() {
277 let tool = RegexTool::new();
278 let long_pattern = "a".repeat(600);
279 let result = tool.validate(&serde_json::json!({
280 "pattern": long_pattern,
281 "text": "test"
282 }));
283
284 assert!(matches!(result, Err(ToolError::InvalidArguments { .. })));
285 }
286}