vex_llm/tools/
calculator.rs

1//! Calculator tool for evaluating mathematical expressions
2//!
3//! Uses the `meval` crate for safe expression evaluation.
4//!
5//! # Security
6//!
7//! - Only arithmetic operations are allowed (no arbitrary code execution)
8//! - Does not access filesystem, network, or environment
9//! - Pure computation: safe for any sandbox
10//!
11//! # Supported Operations
12//!
13//! - Basic arithmetic: `+`, `-`, `*`, `/`, `^` (power)
14//! - Parentheses: `(2 + 3) * 4`
15//! - Functions: `sqrt()`, `sin()`, `cos()`, `tan()`, `log()`, `exp()`, `abs()`
16//! - Constants: `pi`, `e`
17
18use async_trait::async_trait;
19use serde_json::Value;
20
21use crate::tool::{Capability, Tool, ToolDefinition};
22use crate::tool_error::ToolError;
23
24/// Calculator tool for evaluating mathematical expressions.
25///
26/// # Example
27///
28/// ```ignore
29/// use vex_llm::CalculatorTool;
30/// use vex_llm::Tool;
31///
32/// let calc = CalculatorTool::new();
33/// let result = calc.execute(json!({"expression": "2 + 3 * 4"})).await?;
34/// assert_eq!(result["result"], 14.0);
35/// ```
36pub struct CalculatorTool {
37    definition: ToolDefinition,
38}
39
40impl CalculatorTool {
41    /// Create a new calculator tool
42    pub fn new() -> Self {
43        Self {
44            definition: ToolDefinition::new(
45                "calculator",
46                "Evaluate mathematical expressions. Supports arithmetic operators (+, -, *, /, ^), \
47                 functions (sqrt, sin, cos, tan, log, exp, abs), and constants (pi, e).",
48                r#"{
49                    "type": "object",
50                    "properties": {
51                        "expression": {
52                            "type": "string",
53                            "description": "Mathematical expression to evaluate, e.g. '2 + 3 * 4' or 'sqrt(16)'"
54                        }
55                    },
56                    "required": ["expression"]
57                }"#,
58            ),
59        }
60    }
61}
62
63impl Default for CalculatorTool {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69#[async_trait]
70impl Tool for CalculatorTool {
71    fn definition(&self) -> &ToolDefinition {
72        &self.definition
73    }
74
75    fn capabilities(&self) -> Vec<Capability> {
76        vec![Capability::PureComputation] // Safe: no I/O
77    }
78
79    fn validate(&self, args: &Value) -> Result<(), ToolError> {
80        // Check required field exists
81        let expr = args
82            .get("expression")
83            .and_then(|e| e.as_str())
84            .ok_or_else(|| {
85                ToolError::invalid_args("calculator", "Missing required field 'expression'")
86            })?;
87
88        // Basic length check to prevent DoS
89        if expr.len() > 1000 {
90            return Err(ToolError::invalid_args(
91                "calculator",
92                "Expression too long (max 1000 characters)",
93            ));
94        }
95
96        Ok(())
97    }
98
99    async fn execute(&self, args: Value) -> Result<Value, ToolError> {
100        let expr = args["expression"]
101            .as_str()
102            .ok_or_else(|| ToolError::invalid_args("calculator", "Missing 'expression' field"))?;
103
104        // Evaluate the expression using meval
105        // meval is safe: only arithmetic, no arbitrary code execution
106        let result = meval::eval_str(expr).map_err(|e| {
107            ToolError::execution_failed(
108                "calculator",
109                format!("Failed to evaluate expression: {}", e),
110            )
111        })?;
112
113        // Check for NaN or Infinity
114        if result.is_nan() {
115            return Err(ToolError::execution_failed(
116                "calculator",
117                "Result is not a number (NaN)",
118            ));
119        }
120        if result.is_infinite() {
121            return Err(ToolError::execution_failed(
122                "calculator",
123                "Result is infinite (division by zero or overflow)",
124            ));
125        }
126
127        Ok(serde_json::json!({
128            "expression": expr,
129            "result": result
130        }))
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    #[tokio::test]
139    async fn test_basic_arithmetic() {
140        let calc = CalculatorTool::new();
141        let result = calc
142            .execute(serde_json::json!({"expression": "2 + 3 * 4"}))
143            .await
144            .unwrap();
145
146        assert_eq!(result["result"], 14.0);
147    }
148
149    #[tokio::test]
150    async fn test_with_parentheses() {
151        let calc = CalculatorTool::new();
152        let result = calc
153            .execute(serde_json::json!({"expression": "(2 + 3) * 4"}))
154            .await
155            .unwrap();
156
157        assert_eq!(result["result"], 20.0);
158    }
159
160    #[tokio::test]
161    async fn test_functions() {
162        let calc = CalculatorTool::new();
163        let result = calc
164            .execute(serde_json::json!({"expression": "sqrt(16)"}))
165            .await
166            .unwrap();
167
168        assert_eq!(result["result"], 4.0);
169    }
170
171    #[tokio::test]
172    async fn test_constants() {
173        let calc = CalculatorTool::new();
174        let result = calc
175            .execute(serde_json::json!({"expression": "pi"}))
176            .await
177            .unwrap();
178
179        let pi = result["result"].as_f64().unwrap();
180        assert!((pi - std::f64::consts::PI).abs() < 0.0001);
181    }
182
183    #[tokio::test]
184    async fn test_invalid_expression() {
185        let calc = CalculatorTool::new();
186        let result = calc
187            .execute(serde_json::json!({"expression": "invalid ++ syntax"}))
188            .await;
189
190        assert!(matches!(result, Err(ToolError::ExecutionFailed { .. })));
191    }
192
193    #[tokio::test]
194    async fn test_missing_expression() {
195        let calc = CalculatorTool::new();
196        let result = calc.validate(&serde_json::json!({}));
197
198        assert!(matches!(result, Err(ToolError::InvalidArguments { .. })));
199    }
200
201    #[tokio::test]
202    async fn test_division_by_zero() {
203        let calc = CalculatorTool::new();
204        let result = calc.execute(serde_json::json!({"expression": "1/0"})).await;
205
206        assert!(matches!(result, Err(ToolError::ExecutionFailed { .. })));
207    }
208
209    #[tokio::test]
210    async fn test_expression_too_long() {
211        let calc = CalculatorTool::new();
212        let long_expr = "1+".repeat(600);
213        let result = calc.validate(&serde_json::json!({"expression": long_expr}));
214
215        assert!(matches!(result, Err(ToolError::InvalidArguments { .. })));
216    }
217}