Skip to content

Commit

Permalink
checkout files from origin/v1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
salman1993 committed Jan 13, 2025
1 parent e8ba87f commit d75e1d2
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 71 deletions.
35 changes: 22 additions & 13 deletions crates/goose/src/agents/capabilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,16 @@ impl Capabilities {
// TODO IMPORTANT need to ensure this times out if the system command is broken!
pub async fn add_system(&mut self, config: SystemConfig) -> SystemResult<()> {
let mut client: McpClient = match config {
SystemConfig::Sse { ref uri } => {
let transport = SseTransport::new(uri);
SystemConfig::Sse { ref uri, ref envs } => {
let transport = SseTransport::new(uri, envs.get_env());
McpClient::new(transport.start().await?)
}
SystemConfig::Stdio {
ref cmd,
ref args,
ref env,
ref envs,
} => {
let transport = StdioTransport::new(cmd, args.to_vec()).with_env(env.clone());
let transport = StdioTransport::new(cmd, args.to_vec(), envs.get_env());
McpClient::new(transport.start().await?)
}
};
Expand Down Expand Up @@ -187,14 +187,23 @@ impl Capabilities {
let mut tools = Vec::new();
for (name, client) in &self.clients {
let client_guard = client.lock().await;
let client_tools = client_guard.list_tools().await?;

for tool in client_tools.tools {
tools.push(Tool::new(
format!("{}__{}", name, tool.name),
&tool.description,
tool.input_schema,
));
let mut client_tools = client_guard.list_tools(None).await?;

loop {
for tool in client_tools.tools {
tools.push(Tool::new(
format!("{}__{}", name, tool.name),
&tool.description,
tool.input_schema,
));
}

// exit loop when there are no more pages
if client_tools.next_cursor.is_none() {
break;
}

client_tools = client_guard.list_tools(client_tools.next_cursor).await?;
}
}
Ok(tools)
Expand All @@ -206,7 +215,7 @@ impl Capabilities {

for (name, client) in &self.clients {
let client_guard = client.lock().await;
let resources = client_guard.list_resources().await?;
let resources = client_guard.list_resources(None).await?;

for resource in resources.resources {
// Skip reading the resource if it's not marked active
Expand Down
84 changes: 43 additions & 41 deletions crates/goose/src/agents/system.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::HashMap;

use mcp_client::client::Error as ClientError;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use thiserror::Error;

/// Errors from System operation
Expand All @@ -16,30 +17,63 @@ pub enum SystemError {

pub type SystemResult<T> = Result<T, SystemError>;

#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct Envs {
/// A map of environment variables to set, e.g. API_KEY -> some_secret, HOST -> host
#[serde(default)]
#[serde(flatten)]
map: HashMap<String, String>,
}

impl Envs {
pub fn new(map: HashMap<String, String>) -> Self {
Self { map }
}

pub fn default() -> Self {
Self::new(HashMap::new())
}

pub fn get_env(&self) -> HashMap<String, String> {
self.map
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect()
}
}

/// Represents the different types of MCP systems that can be added to the manager
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
pub enum SystemConfig {
/// Server-sent events client with a URI endpoint
Sse { uri: String },
Sse {
uri: String,
#[serde(default)]
envs: Envs,
},
/// Standard I/O client with command and arguments
Stdio {
cmd: String,
args: Vec<String>,
env: Option<HashMap<String, String>>,
#[serde(default)]
envs: Envs,
},
}

impl SystemConfig {
pub fn sse<S: Into<String>>(uri: S) -> Self {
Self::Sse { uri: uri.into() }
Self::Sse {
uri: uri.into(),
envs: Envs::default(),
}
}

pub fn stdio<S: Into<String>>(cmd: S) -> Self {
Self::Stdio {
cmd: cmd.into(),
args: vec![],
env: None,
envs: Envs::default(),
}
}

Expand All @@ -49,31 +83,10 @@ impl SystemConfig {
S: Into<String>,
{
match self {
Self::Stdio { cmd, env, .. } => Self::Stdio {
Self::Stdio { cmd, envs, .. } => Self::Stdio {
cmd,
envs,
args: args.into_iter().map(Into::into).collect(),
env,
},
other => other,
}
}

pub fn with_env<I, K, V>(self, env_vars: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
K: Into<String>,
V: Into<String>,
{
match self {
Self::Stdio { cmd, args, .. } => Self::Stdio {
cmd,
args,
env: Some(
env_vars
.into_iter()
.map(|(k, v)| (k.into(), v.into()))
.collect(),
),
},
other => other,
}
Expand All @@ -83,19 +96,8 @@ impl SystemConfig {
impl std::fmt::Display for SystemConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SystemConfig::Sse { uri } => write!(f, "SSE({})", uri),
SystemConfig::Stdio { cmd, args, env } => {
let env_str = env.as_ref().map_or(String::new(), |e| {
format!(
" with env: {}",
e.iter()
.map(|(k, v)| format!("{}={}", k, v))
.collect::<Vec<_>>()
.join(",")
)
});
write!(f, "Stdio({} {}{})", cmd, args.join(" "), env_str)
}
SystemConfig::Sse { uri, .. } => write!(f, "SSE({})", uri),
SystemConfig::Stdio { cmd, args, .. } => write!(f, "Stdio({} {})", cmd, args.join(" ")),
}
}
}
Expand Down
29 changes: 12 additions & 17 deletions crates/mcp-client/src/transport/stdio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,39 +104,34 @@ impl StdioActor {
pub struct StdioTransport {
command: String,
args: Vec<String>,
env: Option<HashMap<String, String>>,
env: HashMap<String, String>,
}

impl StdioTransport {
pub fn new<S: Into<String>>(command: S, args: Vec<String>) -> Self {
pub fn new<S: Into<String>>(
command: S,
args: Vec<String>,
env: HashMap<String, String>,
) -> Self {
Self {
command: command.into(),
args,
env: None,
env: env,
}
}

pub fn with_env(mut self, env: Option<HashMap<String, String>>) -> Self {
self.env = env;
self
}

async fn spawn_process(&self) -> Result<(Child, ChildStdin, ChildStdout), Error> {
let mut command = Command::new(&self.command);
command
let mut process = Command::new(&self.command)
.envs(&self.env)
.args(&self.args)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::inherit())
.kill_on_drop(true)
// 0 sets the process group ID equal to the process ID
.process_group(0); // don't inherit signal handling from parent process

if let Some(env) = &self.env {
command.envs(env);
}

let mut process = command.spawn().map_err(|e| Error::Other(e.to_string()))?;
.process_group(0) // don't inherit signal handling from parent process
.spawn()
.map_err(|e| Error::Other(e.to_string()))?;

let stdin = process
.stdin
Expand Down

0 comments on commit d75e1d2

Please sign in to comment.