vex_llm/mcp/
client.rs

1//! MCP client implementation
2//!
3//! Provides a client for connecting to MCP servers and calling tools.
4//! This implementation uses tokio-tungstenite for WebSocket communication.
5
6use super::types::{McpConfig, McpError, McpToolInfo};
7use crate::tool::{Capability, Tool, ToolDefinition};
8use crate::tool_error::ToolError;
9use async_trait::async_trait;
10use futures::{SinkExt, StreamExt};
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use std::collections::HashMap;
14use std::sync::Arc;
15use tokio::sync::{mpsc, oneshot, RwLock};
16use tokio_tungstenite::{connect_async, tungstenite::protocol::Message};
17use tracing::{error, info};
18
19/// JSON-RPC Request
20#[derive(Debug, Serialize)]
21struct JsonRpcRequest {
22    jsonrpc: String,
23    method: String,
24    params: Value,
25    id: u64,
26}
27
28/// JSON-RPC Response
29#[derive(Debug, Deserialize)]
30struct JsonRpcResponse {
31    _jsonrpc: String,
32    result: Option<Value>,
33    error: Option<JsonRpcError>,
34    id: Option<u64>,
35}
36
37/// JSON-RPC Error
38#[derive(Debug, Deserialize)]
39struct JsonRpcError {
40    _code: i32,
41    message: String,
42    #[allow(dead_code)]
43    data: Option<Value>,
44}
45
46enum McpCommand {
47    Call {
48        method: String,
49        params: Value,
50        resp_tx: oneshot::Sender<Result<Value, McpError>>,
51    },
52    Shutdown,
53}
54
55/// MCP Client for connecting to and interacting with MCP servers.
56pub struct McpClient {
57    server_url: String,
58    command_tx: mpsc::Sender<McpCommand>,
59    connected: Arc<RwLock<bool>>,
60    tools_cache: Arc<RwLock<Option<Vec<McpToolInfo>>>>,
61}
62
63impl McpClient {
64    /// Connect to an MCP server.
65    pub async fn connect(url: &str, config: McpConfig) -> Result<Self, McpError> {
66        let is_localhost = url.contains("localhost")
67            || url.contains("127.0.0.1")
68            || url.contains("[::1]")
69            || url.contains("0.0.0.0");
70
71        if config.require_tls
72            && !is_localhost
73            && !url.starts_with("wss://")
74            && !url.starts_with("https://")
75        {
76            return Err(McpError::TlsRequired);
77        }
78
79        let (ws_stream, _) = connect_async(url)
80            .await
81            .map_err(|e| McpError::ConnectionFailed(e.to_string()))?;
82
83        info!(url = url, "Connected to MCP server");
84
85        let (command_tx, mut command_rx) = mpsc::channel::<McpCommand>(32);
86        let connected = Arc::new(RwLock::new(true));
87        let connected_clone = connected.clone();
88
89        // Background task for WebSocket handling
90        tokio::spawn(async move {
91            let (mut ws_tx, mut ws_rx) = ws_stream.split();
92            let mut pending_requests: HashMap<u64, oneshot::Sender<Result<Value, McpError>>> =
93                HashMap::new();
94            let mut next_id = 1u64;
95
96            loop {
97                tokio::select! {
98                    // Handle commands from the client
99                    Some(cmd) = command_rx.recv() => {
100                        match cmd {
101                            McpCommand::Call { method, params, resp_tx } => {
102                                let id = next_id;
103                                next_id += 1;
104
105                                let req = JsonRpcRequest {
106                                    jsonrpc: "2.0".to_string(),
107                                    method,
108                                    params,
109                                    id,
110                                };
111
112                                let json = serde_json::to_string(&req).unwrap();
113                                if let Err(e) = ws_tx.send(Message::Text(json)).await {
114                                    error!("WS send failed: {}", e);
115                                    let _ = resp_tx.send(Err(McpError::ConnectionFailed(e.to_string())));
116                                    break;
117                                }
118                                pending_requests.insert(id, resp_tx);
119                            }
120                            McpCommand::Shutdown => break,
121                        }
122                    }
123
124                    // Handle messages from the server
125                    Some(msg) = ws_rx.next() => {
126                        match msg {
127                            Ok(Message::Text(text)) => {
128                                if let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(&text) {
129                                    if let Some(id) = resp.id {
130                                        if let Some(tx) = pending_requests.remove(&id) {
131                                            if let Some(err) = resp.error {
132                                                let _ = tx.send(Err(McpError::ExecutionFailed(err.message)));
133                                            } else {
134                                                let _ = tx.send(Ok(resp.result.unwrap_or(Value::Null)));
135                                            }
136                                        }
137                                    }
138                                }
139                            }
140                            Ok(Message::Close(_)) => {
141                                info!("MCP server closed connection");
142                                break;
143                            }
144                            Err(e) => {
145                                error!("WS read error: {}", e);
146                                break;
147                            }
148                            _ => {}
149                        }
150                    }
151                }
152            }
153
154            *connected_clone.write().await = false;
155        });
156
157        let client = Self {
158            server_url: url.to_string(),
159            command_tx,
160            connected,
161            tools_cache: Arc::new(RwLock::new(None)),
162        };
163
164        // Initialize MCP protocol
165        client.initialize().await?;
166
167        Ok(client)
168    }
169
170    async fn initialize(&self) -> Result<(), McpError> {
171        let params = serde_json::json!({
172            "protocolVersion": "2024-11-05",
173            "capabilities": {},
174            "clientInfo": {
175                "name": "vex-client",
176                "version": "0.1.5"
177            }
178        });
179
180        self.call_raw("initialize", params).await?;
181        // Note: notifications/initialized is often skipped in simple clients but can be added
182        Ok(())
183    }
184
185    async fn call_raw(&self, method: &str, params: Value) -> Result<Value, McpError> {
186        let (resp_tx, resp_rx) = oneshot::channel();
187        self.command_tx
188            .send(McpCommand::Call {
189                method: method.to_string(),
190                params,
191                resp_tx,
192            })
193            .await
194            .map_err(|_| McpError::ConnectionFailed("Channel closed".into()))?;
195
196        resp_rx
197            .await
198            .map_err(|_| McpError::ConnectionFailed("Response channel closed".into()))?
199    }
200
201    /// List available tools from the MCP server.
202    pub async fn list_tools(&self) -> Result<Vec<McpToolInfo>, McpError> {
203        if let Some(ref tools) = *self.tools_cache.read().await {
204            return Ok(tools.clone());
205        }
206
207        let resp = self.call_raw("tools/list", Value::Null).await?;
208        let tools: Vec<McpToolInfo> = serde_json::from_value(resp["tools"].clone())
209            .map_err(|e| McpError::Serialization(e.to_string()))?;
210
211        *self.tools_cache.write().await = Some(tools.clone());
212        Ok(tools)
213    }
214
215    /// Call a tool on the MCP server.
216    pub async fn call_tool(&self, name: &str, args: Value) -> Result<Value, McpError> {
217        let params = serde_json::json!({
218            "name": name,
219            "arguments": args
220        });
221
222        self.call_raw("tools/call", params).await
223    }
224
225    pub fn server_url(&self) -> &str {
226        &self.server_url
227    }
228
229    pub async fn is_connected(&self) -> bool {
230        *self.connected.read().await
231    }
232
233    pub async fn disconnect(&self) {
234        let _ = self.command_tx.send(McpCommand::Shutdown).await;
235    }
236}
237
238/// Adapter that wraps an MCP tool to be used as a VEX Tool.
239pub struct McpToolAdapter {
240    client: Arc<McpClient>,
241    info: McpToolInfo,
242    definition: ToolDefinition,
243}
244
245impl McpToolAdapter {
246    pub fn new(client: Arc<McpClient>, info: McpToolInfo) -> Self {
247        let name: &'static str = Box::leak(info.name.clone().into_boxed_str());
248        let description: &'static str = Box::leak(info.description.clone().into_boxed_str());
249        let parameters: &'static str = Box::leak(
250            serde_json::to_string(&info.input_schema)
251                .unwrap_or_default()
252                .into_boxed_str(),
253        );
254
255        let definition = ToolDefinition::new(name, description, parameters);
256
257        Self {
258            client,
259            info,
260            definition,
261        }
262    }
263}
264
265#[async_trait]
266impl Tool for McpToolAdapter {
267    fn definition(&self) -> &ToolDefinition {
268        &self.definition
269    }
270
271    fn capabilities(&self) -> Vec<Capability> {
272        vec![Capability::Network]
273    }
274
275    fn timeout(&self) -> std::time::Duration {
276        std::time::Duration::from_secs(30)
277    }
278
279    async fn execute(&self, args: Value) -> Result<Value, ToolError> {
280        self.client
281            .call_tool(&self.info.name, args)
282            .await
283            .map_err(|e| ToolError::execution_failed(&self.info.name, e.to_string()))
284    }
285}