vex_llm/tools/
calculator.rs1use async_trait::async_trait;
19use serde_json::Value;
20
21use crate::tool::{Capability, Tool, ToolDefinition};
22use crate::tool_error::ToolError;
23
24pub struct CalculatorTool {
37 definition: ToolDefinition,
38}
39
40impl CalculatorTool {
41 pub fn new() -> Self {
43 Self {
44 definition: ToolDefinition::new(
45 "calculator",
46 "Evaluate mathematical expressions. Supports arithmetic operators (+, -, *, /, ^), \
47 functions (sqrt, sin, cos, tan, log, exp, abs), and constants (pi, e).",
48 r#"{
49 "type": "object",
50 "properties": {
51 "expression": {
52 "type": "string",
53 "description": "Mathematical expression to evaluate, e.g. '2 + 3 * 4' or 'sqrt(16)'"
54 }
55 },
56 "required": ["expression"]
57 }"#,
58 ),
59 }
60 }
61}
62
63impl Default for CalculatorTool {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69#[async_trait]
70impl Tool for CalculatorTool {
71 fn definition(&self) -> &ToolDefinition {
72 &self.definition
73 }
74
75 fn capabilities(&self) -> Vec<Capability> {
76 vec![Capability::PureComputation] }
78
79 fn validate(&self, args: &Value) -> Result<(), ToolError> {
80 let expr = args
82 .get("expression")
83 .and_then(|e| e.as_str())
84 .ok_or_else(|| {
85 ToolError::invalid_args("calculator", "Missing required field 'expression'")
86 })?;
87
88 if expr.len() > 1000 {
90 return Err(ToolError::invalid_args(
91 "calculator",
92 "Expression too long (max 1000 characters)",
93 ));
94 }
95
96 Ok(())
97 }
98
99 async fn execute(&self, args: Value) -> Result<Value, ToolError> {
100 let expr = args["expression"]
101 .as_str()
102 .ok_or_else(|| ToolError::invalid_args("calculator", "Missing 'expression' field"))?;
103
104 let result = meval::eval_str(expr).map_err(|e| {
107 ToolError::execution_failed(
108 "calculator",
109 format!("Failed to evaluate expression: {}", e),
110 )
111 })?;
112
113 if result.is_nan() {
115 return Err(ToolError::execution_failed(
116 "calculator",
117 "Result is not a number (NaN)",
118 ));
119 }
120 if result.is_infinite() {
121 return Err(ToolError::execution_failed(
122 "calculator",
123 "Result is infinite (division by zero or overflow)",
124 ));
125 }
126
127 Ok(serde_json::json!({
128 "expression": expr,
129 "result": result
130 }))
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137
138 #[tokio::test]
139 async fn test_basic_arithmetic() {
140 let calc = CalculatorTool::new();
141 let result = calc
142 .execute(serde_json::json!({"expression": "2 + 3 * 4"}))
143 .await
144 .unwrap();
145
146 assert_eq!(result["result"], 14.0);
147 }
148
149 #[tokio::test]
150 async fn test_with_parentheses() {
151 let calc = CalculatorTool::new();
152 let result = calc
153 .execute(serde_json::json!({"expression": "(2 + 3) * 4"}))
154 .await
155 .unwrap();
156
157 assert_eq!(result["result"], 20.0);
158 }
159
160 #[tokio::test]
161 async fn test_functions() {
162 let calc = CalculatorTool::new();
163 let result = calc
164 .execute(serde_json::json!({"expression": "sqrt(16)"}))
165 .await
166 .unwrap();
167
168 assert_eq!(result["result"], 4.0);
169 }
170
171 #[tokio::test]
172 async fn test_constants() {
173 let calc = CalculatorTool::new();
174 let result = calc
175 .execute(serde_json::json!({"expression": "pi"}))
176 .await
177 .unwrap();
178
179 let pi = result["result"].as_f64().unwrap();
180 assert!((pi - std::f64::consts::PI).abs() < 0.0001);
181 }
182
183 #[tokio::test]
184 async fn test_invalid_expression() {
185 let calc = CalculatorTool::new();
186 let result = calc
187 .execute(serde_json::json!({"expression": "invalid ++ syntax"}))
188 .await;
189
190 assert!(matches!(result, Err(ToolError::ExecutionFailed { .. })));
191 }
192
193 #[tokio::test]
194 async fn test_missing_expression() {
195 let calc = CalculatorTool::new();
196 let result = calc.validate(&serde_json::json!({}));
197
198 assert!(matches!(result, Err(ToolError::InvalidArguments { .. })));
199 }
200
201 #[tokio::test]
202 async fn test_division_by_zero() {
203 let calc = CalculatorTool::new();
204 let result = calc.execute(serde_json::json!({"expression": "1/0"})).await;
205
206 assert!(matches!(result, Err(ToolError::ExecutionFailed { .. })));
207 }
208
209 #[tokio::test]
210 async fn test_expression_too_long() {
211 let calc = CalculatorTool::new();
212 let long_expr = "1+".repeat(600);
213 let result = calc.validate(&serde_json::json!({"expression": long_expr}));
214
215 assert!(matches!(result, Err(ToolError::InvalidArguments { .. })));
216 }
217}