Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/v1.0' into baxen/system-configs
Browse files Browse the repository at this point in the history
* origin/v1.0:
  feat: env and secrets configuration for mcp server (#565)
  Add Databricks moderation (#540)
  feat: add pagination support for tools/list and resources/list (#566)
  Add resource capabilties to MCP servers that use it (#576)
  Add goose versions to the UI (#526)
  • Loading branch information
salman1993 committed Jan 13, 2025
2 parents d75e1d2 + a258b76 commit 14764c5
Show file tree
Hide file tree
Showing 25 changed files with 702 additions and 161 deletions.
5 changes: 3 additions & 2 deletions crates/goose-mcp/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
### Test with MCP Inspector

Update examples/mcp.rs to use the appropriate the MCP server (eg. DeveloperRouter)

```bash
npx @modelcontextprotocol/inspector cargo run -p developer
npx @modelcontextprotocol/inspector cargo run -p jetbrains
npx @modelcontextprotocol/inspector cargo run -p goose-mcp --example mcp
```

Then visit the Inspector in the browser window and test the different endpoints.
36 changes: 36 additions & 0 deletions crates/goose-mcp/examples/mcp.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// An example script to run an MCP server
use anyhow::Result;
use goose_mcp::DeveloperRouter;
use mcp_server::router::RouterService;
use mcp_server::{ByteTransport, Server};
use tokio::io::{stdin, stdout};
use tracing_appender::rolling::{RollingFileAppender, Rotation};
use tracing_subscriber::{self, EnvFilter};

#[tokio::main]
async fn main() -> Result<()> {
// Set up file appender for logging
let file_appender = RollingFileAppender::new(Rotation::DAILY, "logs", "mcp-server.log");

// Initialize the tracing subscriber with file and stdout logging
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::INFO.into()))
.with_writer(file_appender)
.with_target(false)
.with_thread_ids(true)
.with_file(true)
.with_line_number(true)
.init();

tracing::info!("Starting MCP server");

// Create an instance of our counter router
let router = RouterService(DeveloperRouter::new());

// Create and run the server
let server = Server::new(router);
let transport = ByteTransport::new(stdin(), stdout());

tracing::info!("Server initialized and ready to handle requests");
Ok(server.run(transport).await?)
}
5 changes: 4 additions & 1 deletion crates/goose-mcp/src/developer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,10 @@ impl Router for DeveloperRouter {
}

fn capabilities(&self) -> ServerCapabilities {
CapabilitiesBuilder::new().with_tools(true).build()
CapabilitiesBuilder::new()
.with_tools(false)
.with_resources(false, false)
.build()
}

fn list_tools(&self) -> Vec<Tool> {
Expand Down
5 changes: 4 additions & 1 deletion crates/goose-mcp/src/developer2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,10 @@ impl Router for Developer2Router {
}

fn capabilities(&self) -> ServerCapabilities {
CapabilitiesBuilder::new().with_tools(true).build()
CapabilitiesBuilder::new()
.with_tools(false)
.with_resources(false, false)
.build()
}

fn list_tools(&self) -> Vec<Tool> {
Expand Down
5 changes: 4 additions & 1 deletion crates/goose-mcp/src/nondeveloper/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,10 @@ impl Router for NonDeveloperRouter {
}

fn capabilities(&self) -> ServerCapabilities {
CapabilitiesBuilder::new().with_tools(true).build()
CapabilitiesBuilder::new()
.with_tools(false)
.with_resources(false, false)
.build()
}

fn list_tools(&self) -> Vec<Tool> {
Expand Down
98 changes: 81 additions & 17 deletions crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use axum::{
use bytes::Bytes;
use futures::{stream::StreamExt, Stream};
use goose::message::{Message, MessageContent};
use goose::providers::base::ModerationError;
use mcp_core::{content::Content, role::Role};
use serde::Deserialize;
use serde_json::{json, Value};
Expand Down Expand Up @@ -159,6 +160,23 @@ impl ProtocolFormatter {
format!("a:{}\n", response)
}

fn format_error(error: &str) -> String {
// Error messages start with "3:" in the new protocol.
format!("3:{}\n", error)
}

fn format_moderation_error(error: &ModerationError) -> String {
let error_part = match error {
ModerationError::ContentFlagged { categories, .. } => {
format!(
"Content was flagged by moderation in the following categories: {}",
categories
)
}
};
format!("3:\"{}\"\n", error_part)
}

fn format_finish(reason: &str) -> String {
// Finish messages start with "d:"
let finish = json!({
Expand Down Expand Up @@ -193,8 +211,12 @@ async fn stream_message(
.await?;
}
Err(err) => {
// Send an error message first
tx.send(ProtocolFormatter::format_error(&err.to_string()))
.await?;
// Then send an empty tool response to maintain the protocol
let result =
vec![Content::text(format!("Error {}", err)).with_priority(0.0)];
vec![Content::text(format!("Error: {}", err)).with_priority(0.0)];
tx.send(ProtocolFormatter::format_tool_response(
&response.id,
&result,
Expand All @@ -209,22 +231,24 @@ async fn stream_message(
for content in message.content {
match content {
MessageContent::ToolRequest(request) => {
if let Ok(tool_call) = request.tool_call {
tx.send(ProtocolFormatter::format_tool_call(
&request.id,
&tool_call.name,
&tool_call.arguments,
))
.await?;
} else {
// if the llm generates an invalid object tool call, we still have
// to include it in the history. It always comes with a response indicating the error
tx.send(ProtocolFormatter::format_tool_call(
&request.id,
"invalid name",
&json!({}),
))
.await?;
match request.tool_call {
Ok(tool_call) => {
tx.send(ProtocolFormatter::format_tool_call(
&request.id,
&tool_call.name,
&tool_call.arguments,
))
.await?;
}
Err(err) => {
// Send a placeholder tool call to maintain protocol
tx.send(ProtocolFormatter::format_tool_call(
&request.id,
"invalid_tool",
&json!({"error": err.to_string()}),
))
.await?;
}
}
}
MessageContent::Text(text) => {
Expand Down Expand Up @@ -278,6 +302,18 @@ async fn handler(
Ok(stream) => stream,
Err(e) => {
tracing::error!("Failed to start reply stream: {}", e);
// Check if it's a moderation error
if let Some(moderation_error) = e.downcast_ref::<ModerationError>() {
let _ = tx
.send(ProtocolFormatter::format_moderation_error(moderation_error))
.await;
// Kill the stream since we encountered a moderation error
} else {
// Send a generic error message
let _ = tx
.send(ProtocolFormatter::format_error(&e.to_string()))
.await;
}
// Send a finish message with error as the reason
let _ = tx.send(ProtocolFormatter::format_finish("error")).await;
return;
Expand All @@ -291,11 +327,18 @@ async fn handler(
Ok(Some(Ok(message))) => {
if let Err(e) = stream_message(message, &tx).await {
tracing::error!("Error sending message through channel: {}", e);
let _ = tx.send(ProtocolFormatter::format_error(&e.to_string())).await;
break;
}
}
Ok(Some(Err(e))) => {
tracing::error!("Error processing message: {}", e);
// Check if it's a moderation error
if let Some(moderation_error) = e.downcast_ref::<ModerationError>() {
let _ = tx.send(ProtocolFormatter::format_moderation_error(moderation_error)).await;
} else {
let _ = tx.send(ProtocolFormatter::format_error(&e.to_string())).await;
}
break;
}
Ok(None) => {
Expand Down Expand Up @@ -503,6 +546,27 @@ mod tests {
assert!(formatted.starts_with("a:"));
assert!(formatted.contains("\"toolCallId\":\"123\""));

// Test error formatting
let formatted = ProtocolFormatter::format_error("Test error");
println!("Formatted error: {}", formatted);
assert!(formatted.starts_with("3:"));
assert!(formatted.contains("Test error"));

// Test moderation error formatting
let moderation_error = ModerationError::ContentFlagged {
categories: "hate, violence".to_string(),
category_scores: Some(json!({
"hate": 0.9,
"violence": 0.8
})),
};
let formatted = ProtocolFormatter::format_moderation_error(&moderation_error);
println!("{}", formatted);
assert!(formatted.starts_with("3:"));
assert!(
formatted.contains("Content was flagged by moderation in the following categories:")
);

// Test finish formatting
let formatted = ProtocolFormatter::format_finish("stop");
assert!(formatted.starts_with("d:"));
Expand Down
38 changes: 22 additions & 16 deletions crates/goose/src/providers/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use thiserror::Error;
use tokio::select;
use tokio::sync::RwLock;

Expand All @@ -12,6 +13,15 @@ use crate::message::{Message, MessageContent};
use mcp_core::role::Role;
use mcp_core::tool::Tool;

#[derive(Error, Debug)]
pub enum ModerationError {
#[error("Content was flagged for moderation in categories: {categories}")]
ContentFlagged {
categories: String,
category_scores: Option<serde_json::Value>,
},
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderUsage {
pub model: String,
Expand Down Expand Up @@ -197,10 +207,10 @@ pub trait Provider: Send + Sync + Moderation {
let categories = result.categories
.unwrap_or_else(|| vec!["unknown".to_string()])
.join(", ");
return Err(anyhow::anyhow!(
"Content was flagged for moderation in categories: {}",
categories
));
return Err(ModerationError::ContentFlagged {
categories,
category_scores: result.category_scores,
}.into());
}

// Moderation passed, wait for completion
Expand All @@ -215,10 +225,10 @@ pub trait Provider: Send + Sync + Moderation {
let categories = moderation_result.categories
.unwrap_or_else(|| vec!["unknown".to_string()])
.join(", ");
return Err(anyhow::anyhow!(
"Content was flagged for moderation in categories: {}",
categories
));
return Err(ModerationError::ContentFlagged {
categories,
category_scores: moderation_result.category_scores,
}.into());
}

Ok(completion_result)
Expand Down Expand Up @@ -338,10 +348,8 @@ mod tests {
let result = provider.complete("system", &[test_message], &[]).await;

assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Content was flagged"));
let err = result.unwrap_err();
assert!(err.downcast_ref::<ModerationError>().is_some());
}

#[tokio::test]
Expand Down Expand Up @@ -407,10 +415,8 @@ mod tests {
let result = provider.complete("system", &[test_message], &[]).await;

assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Content was flagged"));
let err = result.unwrap_err();
assert!(err.downcast_ref::<ModerationError>().is_some());
}

#[tokio::test]
Expand Down
Loading

0 comments on commit 14764c5

Please sign in to comment.