Skip to content

Commit

Permalink
Check server capability before client sends request
Browse files Browse the repository at this point in the history
  • Loading branch information
salman1993 committed Jan 8, 2025
1 parent b0aed23 commit d0c52ba
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 39 deletions.
20 changes: 2 additions & 18 deletions crates/goose/src/agents/capabilities.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use mcp_core::protocol::ListResourcesResult;
use rust_decimal_macros::dec;
use tokio::time::timeout;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;

use super::system::{SystemConfig, SystemError, SystemInfo, SystemResult};
Expand Down Expand Up @@ -35,7 +32,7 @@ impl Capabilities {
/// Add a new MCP system based on the provided client type
// TODO IMPORTANT need to ensure this times out if the system command is broken!
pub async fn add_system(&mut self, config: SystemConfig) -> SystemResult<()> {
let client: McpClient = match config {
let mut client: McpClient = match config {
SystemConfig::Sse { ref uri } => {
let transport = SseTransport::new(uri);
McpClient::new(transport.start().await?)
Expand Down Expand Up @@ -149,23 +146,10 @@ impl Capabilities {
pub async fn get_resources(
&self,
) -> SystemResult<HashMap<String, HashMap<String, (Resource, String)>>> {
println!("In get_resources");
let mut client_resource_content = HashMap::new();
for (name, client) in &self.clients {
let client_guard = client.lock().await;

// Add timeout of 3 seconds, return empty vec if timeout occurs
let resources: ListResourcesResult = match timeout(Duration::from_secs(3), client_guard.list_resources()).await {
Ok(Ok(resources)) => resources,
Ok(Err(e)) => return Err(e.into()), // Preserve original errors
Err(_) => {
println!("Timeout occurred while fetching resources for client {}", name);
ListResourcesResult{resources: vec![]} // Skip this client and continue with others
}
};

// let resources = client_guard.list_resources().await?;
println!("In get_resources, list_resources ({}): {:?}", resources.resources.len(), resources);
let resources = client_guard.list_resources().await?;

let mut resource_content = HashMap::new();
for resource in resources.resources {
Expand Down
19 changes: 5 additions & 14 deletions crates/goose/src/agents/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,13 @@ impl DefaultAgent {
model_name: &str,
resource_content: &HashMap<String, HashMap<String, (Resource, String)>>,
) -> SystemResult<Vec<Message>> {
println!("In prepare_inference");
// Flatten all resource content into a vector of strings
let mut resources = Vec::new();
for system_resources in resource_content.values() {
for (_, content) in system_resources.values() {
resources.push(content.clone());
}
}
println!("Resources: {:?}", resources);

let approx_count = self.token_counter.count_everything(
system_prompt,
Expand All @@ -59,7 +57,6 @@ impl DefaultAgent {
&resources,
Some(model_name),
);
println!("Approx count: {:?}", approx_count);
let mut status_content: Vec<String> = Vec::new();

if approx_count > target_limit {
Expand Down Expand Up @@ -131,7 +128,6 @@ impl DefaultAgent {
}
}
} else {
println!("No trimming needed");
// Create status messages from all resources when no trimming needed
for resources in resource_content.values() {
for (resource, content) in resources.values() {
Expand All @@ -140,8 +136,6 @@ impl DefaultAgent {
}
}

println!("Status content: {:?}", status_content);

// Join remaining status content and create status message
let status_str = status_content.join("\n");

Expand Down Expand Up @@ -201,7 +195,6 @@ impl Agent for DefaultAgent {
&self,
messages: &[Message],
) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> {
println!("In goose agent reply");
let mut capabilities = self.capabilities.lock().await;
let tools = capabilities.get_prefixed_tools().await?;
let system_prompt = capabilities.get_system_prompt().await;
Expand All @@ -211,28 +204,26 @@ impl Agent for DefaultAgent {
.get_estimated_limit();

// Update conversation history for the start of the reply
println!("Preparing inference (before loop)");
let model_name = capabilities.provider().get_model_config().model_name.clone();
println!("Model name: {:?}", model_name);
let resources = capabilities.get_resources().await?;
println!("Resources: {:?}", resources);

let mut messages = self
.prepare_inference(
&system_prompt,
&tools,
messages,
&Vec::new(),
estimated_limit,
&model_name,
&capabilities
.provider()
.get_model_config()
.model_name
.clone(),
&resources,
)
.await?;

Ok(Box::pin(async_stream::try_stream! {
loop {
// Get completion from provider
println!("(in loop) Getting completion from provider with messages: {:?}", messages);
let (response, usage) = capabilities.provider().complete(
&system_prompt,
&messages,
Expand Down
2 changes: 1 addition & 1 deletion crates/mcp-client/examples/sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ async fn main() -> Result<()> {
let handle = transport.start().await?;

// Create client
let client = McpClient::new(handle);
let mut client = McpClient::new(handle);
println!("Client created\n");

// Initialize
Expand Down
6 changes: 5 additions & 1 deletion crates/mcp-client/examples/stdio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async fn main() -> Result<(), ClientError> {
let transport_handle = transport.start().await?;

// 3) Create the client
let client = McpClient::new(transport_handle);
let mut client = McpClient::new(transport_handle);

// Initialize
let server_info = client
Expand All @@ -45,5 +45,9 @@ async fn main() -> Result<(), ClientError> {
.await?;
println!("Tool result: {tool_result:?}\n");

// List resources
let resources = client.list_resources().await?;
println!("Available resources: {resources:?}\n");

Ok(())
}
2 changes: 1 addition & 1 deletion crates/mcp-client/examples/stdio_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async fn main() -> Result<(), ClientError> {
let transport_handle = transport.start().await.unwrap();

// Create client
let client = McpClient::new(transport_handle);
let mut client = McpClient::new(transport_handle);

// Initialize
let server_info = client
Expand Down
69 changes: 65 additions & 4 deletions crates/mcp-client/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use std::sync::atomic::{AtomicU64, Ordering};

use crate::transport::TransportHandle;
use mcp_core::protocol::{
CallToolResult, InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcNotification,
JsonRpcRequest, JsonRpcResponse, ListResourcesResult, ListToolsResult, ReadResourceResult,
ServerCapabilities, METHOD_NOT_FOUND,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use thiserror::Error;
use tokio::sync::Mutex;
use tower::{Service, ServiceExt};

use crate::transport::TransportHandle; // for Service::ready()
use tower::{Service, ServiceExt}; // for Service::ready()

/// Error type for MCP client operations.
#[derive(Debug, Error)]
Expand All @@ -27,6 +27,9 @@ pub enum Error {
#[error("Unexpected response from server")]
UnexpectedResponse,

#[error("Not initialized")]
NotInitialized,

#[error("Timeout or service not ready")]
NotReady,
}
Expand Down Expand Up @@ -55,6 +58,7 @@ pub struct InitializeParams {
pub struct McpClient {
service: Mutex<TransportHandle>,
next_id: AtomicU64,
server_capabilities: Option<ServerCapabilities>,
}

impl McpClient {
Expand All @@ -63,6 +67,7 @@ impl McpClient {
Self {
service: Mutex::new(transport_handle),
next_id: AtomicU64::new(1),
server_capabilities: None, // set during initialization
}
}

Expand Down Expand Up @@ -135,7 +140,7 @@ impl McpClient {
}

pub async fn initialize(
&self,
&mut self,
info: ClientInfo,
capabilities: ClientCapabilities,
) -> Result<InitializeResult, Error> {
Expand All @@ -151,24 +156,80 @@ impl McpClient {
self.send_notification("notifications/initialized", serde_json::json!({}))
.await?;

self.server_capabilities = Some(result.capabilities.clone());

Ok(result)
}

fn completed_initialization(&self) -> bool {
self.server_capabilities.is_some()
}

pub async fn list_resources(&self) -> Result<ListResourcesResult, Error> {
if !self.completed_initialization() {
return Err(Error::NotInitialized);
}
// If resources is not supported, return an empty list
if self
.server_capabilities
.as_ref()
.unwrap()
.resources
.is_none()
{
return Ok(ListResourcesResult { resources: vec![] });
}

self.send_request("resources/list", serde_json::json!({}))
.await
}

pub async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult, Error> {
if !self.completed_initialization() {
return Err(Error::NotInitialized);
}
// If resources is not supported, return an error
if self
.server_capabilities
.as_ref()
.unwrap()
.resources
.is_none()
{
return Err(Error::RpcError {
code: METHOD_NOT_FOUND,
message: "Server does not support 'resources' capability".to_string(),
});
}

let params = serde_json::json!({ "uri": uri });
self.send_request("resources/read", params).await
}

pub async fn list_tools(&self) -> Result<ListToolsResult, Error> {
if !self.completed_initialization() {
return Err(Error::NotInitialized);
}
// If tools is not supported, return an empty list
if self.server_capabilities.as_ref().unwrap().tools.is_none() {
return Ok(ListToolsResult { tools: vec![] });
}

self.send_request("tools/list", serde_json::json!({})).await
}

pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<CallToolResult, Error> {
if !self.completed_initialization() {
return Err(Error::NotInitialized);
}
// If tools is not supported, return an error
if self.server_capabilities.as_ref().unwrap().tools.is_none() {
return Err(Error::RpcError {
code: METHOD_NOT_FOUND,
message: "Server does not support 'tools' capability".to_string(),
});
}

let params = serde_json::json!({ "name": name, "arguments": arguments });
self.send_request("tools/call", params).await
}
Expand Down

0 comments on commit d0c52ba

Please sign in to comment.