1use super::types::{McpConfig, McpError, McpToolInfo};
7use crate::tool::{Capability, Tool, ToolDefinition};
8use crate::tool_error::ToolError;
9use async_trait::async_trait;
10use futures::{SinkExt, StreamExt};
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use std::collections::HashMap;
14use std::sync::Arc;
15use tokio::sync::{mpsc, oneshot, RwLock};
16use tokio_tungstenite::{connect_async, tungstenite::protocol::Message};
17use tracing::{error, info};
18
19#[derive(Debug, Serialize)]
21struct JsonRpcRequest {
22 jsonrpc: String,
23 method: String,
24 params: Value,
25 id: u64,
26}
27
28#[derive(Debug, Deserialize)]
30struct JsonRpcResponse {
31 _jsonrpc: String,
32 result: Option<Value>,
33 error: Option<JsonRpcError>,
34 id: Option<u64>,
35}
36
37#[derive(Debug, Deserialize)]
39struct JsonRpcError {
40 _code: i32,
41 message: String,
42 #[allow(dead_code)]
43 data: Option<Value>,
44}
45
46enum McpCommand {
47 Call {
48 method: String,
49 params: Value,
50 resp_tx: oneshot::Sender<Result<Value, McpError>>,
51 },
52 Shutdown,
53}
54
55pub struct McpClient {
57 server_url: String,
58 command_tx: mpsc::Sender<McpCommand>,
59 connected: Arc<RwLock<bool>>,
60 tools_cache: Arc<RwLock<Option<Vec<McpToolInfo>>>>,
61}
62
63impl McpClient {
64 pub async fn connect(url: &str, config: McpConfig) -> Result<Self, McpError> {
66 let is_localhost = url.contains("localhost")
67 || url.contains("127.0.0.1")
68 || url.contains("[::1]")
69 || url.contains("0.0.0.0");
70
71 if config.require_tls
72 && !is_localhost
73 && !url.starts_with("wss://")
74 && !url.starts_with("https://")
75 {
76 return Err(McpError::TlsRequired);
77 }
78
79 let (ws_stream, _) = connect_async(url)
80 .await
81 .map_err(|e| McpError::ConnectionFailed(e.to_string()))?;
82
83 info!(url = url, "Connected to MCP server");
84
85 let (command_tx, mut command_rx) = mpsc::channel::<McpCommand>(32);
86 let connected = Arc::new(RwLock::new(true));
87 let connected_clone = connected.clone();
88
89 tokio::spawn(async move {
91 let (mut ws_tx, mut ws_rx) = ws_stream.split();
92 let mut pending_requests: HashMap<u64, oneshot::Sender<Result<Value, McpError>>> =
93 HashMap::new();
94 let mut next_id = 1u64;
95
96 loop {
97 tokio::select! {
98 Some(cmd) = command_rx.recv() => {
100 match cmd {
101 McpCommand::Call { method, params, resp_tx } => {
102 let id = next_id;
103 next_id += 1;
104
105 let req = JsonRpcRequest {
106 jsonrpc: "2.0".to_string(),
107 method,
108 params,
109 id,
110 };
111
112 let json = serde_json::to_string(&req).unwrap();
113 if let Err(e) = ws_tx.send(Message::Text(json)).await {
114 error!("WS send failed: {}", e);
115 let _ = resp_tx.send(Err(McpError::ConnectionFailed(e.to_string())));
116 break;
117 }
118 pending_requests.insert(id, resp_tx);
119 }
120 McpCommand::Shutdown => break,
121 }
122 }
123
124 Some(msg) = ws_rx.next() => {
126 match msg {
127 Ok(Message::Text(text)) => {
128 if let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(&text) {
129 if let Some(id) = resp.id {
130 if let Some(tx) = pending_requests.remove(&id) {
131 if let Some(err) = resp.error {
132 let _ = tx.send(Err(McpError::ExecutionFailed(err.message)));
133 } else {
134 let _ = tx.send(Ok(resp.result.unwrap_or(Value::Null)));
135 }
136 }
137 }
138 }
139 }
140 Ok(Message::Close(_)) => {
141 info!("MCP server closed connection");
142 break;
143 }
144 Err(e) => {
145 error!("WS read error: {}", e);
146 break;
147 }
148 _ => {}
149 }
150 }
151 }
152 }
153
154 *connected_clone.write().await = false;
155 });
156
157 let client = Self {
158 server_url: url.to_string(),
159 command_tx,
160 connected,
161 tools_cache: Arc::new(RwLock::new(None)),
162 };
163
164 client.initialize().await?;
166
167 Ok(client)
168 }
169
170 async fn initialize(&self) -> Result<(), McpError> {
171 let params = serde_json::json!({
172 "protocolVersion": "2024-11-05",
173 "capabilities": {},
174 "clientInfo": {
175 "name": "vex-client",
176 "version": "0.1.5"
177 }
178 });
179
180 self.call_raw("initialize", params).await?;
181 Ok(())
183 }
184
185 async fn call_raw(&self, method: &str, params: Value) -> Result<Value, McpError> {
186 let (resp_tx, resp_rx) = oneshot::channel();
187 self.command_tx
188 .send(McpCommand::Call {
189 method: method.to_string(),
190 params,
191 resp_tx,
192 })
193 .await
194 .map_err(|_| McpError::ConnectionFailed("Channel closed".into()))?;
195
196 resp_rx
197 .await
198 .map_err(|_| McpError::ConnectionFailed("Response channel closed".into()))?
199 }
200
201 pub async fn list_tools(&self) -> Result<Vec<McpToolInfo>, McpError> {
203 if let Some(ref tools) = *self.tools_cache.read().await {
204 return Ok(tools.clone());
205 }
206
207 let resp = self.call_raw("tools/list", Value::Null).await?;
208 let tools: Vec<McpToolInfo> = serde_json::from_value(resp["tools"].clone())
209 .map_err(|e| McpError::Serialization(e.to_string()))?;
210
211 *self.tools_cache.write().await = Some(tools.clone());
212 Ok(tools)
213 }
214
215 pub async fn call_tool(&self, name: &str, args: Value) -> Result<Value, McpError> {
217 let params = serde_json::json!({
218 "name": name,
219 "arguments": args
220 });
221
222 self.call_raw("tools/call", params).await
223 }
224
225 pub fn server_url(&self) -> &str {
226 &self.server_url
227 }
228
229 pub async fn is_connected(&self) -> bool {
230 *self.connected.read().await
231 }
232
233 pub async fn disconnect(&self) {
234 let _ = self.command_tx.send(McpCommand::Shutdown).await;
235 }
236}
237
238pub struct McpToolAdapter {
240 client: Arc<McpClient>,
241 info: McpToolInfo,
242 definition: ToolDefinition,
243}
244
245impl McpToolAdapter {
246 pub fn new(client: Arc<McpClient>, info: McpToolInfo) -> Self {
247 let name: &'static str = Box::leak(info.name.clone().into_boxed_str());
248 let description: &'static str = Box::leak(info.description.clone().into_boxed_str());
249 let parameters: &'static str = Box::leak(
250 serde_json::to_string(&info.input_schema)
251 .unwrap_or_default()
252 .into_boxed_str(),
253 );
254
255 let definition = ToolDefinition::new(name, description, parameters);
256
257 Self {
258 client,
259 info,
260 definition,
261 }
262 }
263}
264
265#[async_trait]
266impl Tool for McpToolAdapter {
267 fn definition(&self) -> &ToolDefinition {
268 &self.definition
269 }
270
271 fn capabilities(&self) -> Vec<Capability> {
272 vec![Capability::Network]
273 }
274
275 fn timeout(&self) -> std::time::Duration {
276 std::time::Duration::from_secs(30)
277 }
278
279 async fn execute(&self, args: Value) -> Result<Value, ToolError> {
280 self.client
281 .call_tool(&self.info.name, args)
282 .await
283 .map_err(|e| ToolError::execution_failed(&self.info.name, e.to_string()))
284 }
285}